mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-24 08:12:30 +00:00
502 lines
12 KiB
Go
502 lines
12 KiB
Go
package controller
|
|
|
|
import (
|
|
"fmt"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/go-querystring/query"
|
|
"github.com/steveiliop56/tinyauth/internal/config"
|
|
"github.com/steveiliop56/tinyauth/internal/repository"
|
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
|
)
|
|
|
|
var (
|
|
SupportedResponseTypes = []string{"code"}
|
|
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
|
SupportedGrantTypes = []string{"authorization_code"}
|
|
)
|
|
|
|
type OIDCControllerConfig struct {
|
|
Clients []config.OIDCClientConfig
|
|
AppURL string
|
|
}
|
|
|
|
type OIDCController struct {
|
|
config OIDCControllerConfig
|
|
router *gin.RouterGroup
|
|
queries *repository.Queries
|
|
}
|
|
|
|
type AuthorizeRequest struct {
|
|
Scope string `json:"scope" binding:"required"`
|
|
ResponseType string `json:"response_type" binding:"required"`
|
|
ClientID string `json:"client_id" binding:"required"`
|
|
RedirectURI string `json:"redirect_uri" binding:"required"`
|
|
State string `json:"state" binding:"required"`
|
|
}
|
|
|
|
type TokenRequest struct {
|
|
GrantType string `form:"grant_type" binding:"required"`
|
|
Code string `form:"code" binding:"required"`
|
|
RedirectURI string `form:"redirect_uri" binding:"required"`
|
|
}
|
|
|
|
type CallbackError struct {
|
|
Error string `url:"error"`
|
|
ErrorDescription string `url:"error_description"`
|
|
State string `url:"state"`
|
|
}
|
|
|
|
func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, queries *repository.Queries) *OIDCController {
|
|
return &OIDCController{
|
|
config: config,
|
|
router: router,
|
|
queries: queries,
|
|
}
|
|
}
|
|
|
|
func (controller *OIDCController) SetupRoutes() {
|
|
oidcGroup := controller.router.Group("/oidc")
|
|
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
|
oidcGroup.POST("/authorize", controller.Authorize)
|
|
oidcGroup.POST("/token", controller.Token)
|
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
|
}
|
|
|
|
type ClientRequest struct {
|
|
ClientID string `uri:"id" binding:"required"`
|
|
}
|
|
|
|
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|
var req ClientRequest
|
|
|
|
err := c.BindUri(&req)
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
|
c.JSON(400, gin.H{
|
|
"status": 400,
|
|
"message": "Bad Request",
|
|
})
|
|
return
|
|
}
|
|
|
|
var client *config.OIDCClientConfig
|
|
|
|
// Inefficient yeah, but it will be good until we have thousands of clients
|
|
for _, clientCfg := range controller.config.Clients {
|
|
if clientCfg.ClientID == req.ClientID {
|
|
client = &clientCfg
|
|
break
|
|
}
|
|
}
|
|
|
|
if client == nil {
|
|
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
|
c.JSON(404, gin.H{
|
|
"status": 404,
|
|
"message": "Client not found",
|
|
})
|
|
return
|
|
}
|
|
|
|
c.JSON(200, gin.H{
|
|
"status": 200,
|
|
"client": &client.ClientID,
|
|
"name": &client.Name,
|
|
})
|
|
}
|
|
|
|
func (controller *OIDCController) Authorize(c *gin.Context) {
|
|
// Check if we are logged in
|
|
userContext, err := utils.GetContext(c)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
|
c.JSON(401, gin.H{
|
|
"status": 401,
|
|
"message": "Unauthorized",
|
|
})
|
|
return
|
|
}
|
|
|
|
// OIDC stuff
|
|
var req AuthorizeRequest
|
|
|
|
err = c.BindJSON(&req)
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
|
c.JSON(400, gin.H{
|
|
"status": 400,
|
|
"message": "Bad Request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// TODO: All these errors should redirect to the error page with an explanation
|
|
|
|
// Validate client ID
|
|
var client *config.OIDCClientConfig
|
|
|
|
for _, clientCfg := range controller.config.Clients {
|
|
if clientCfg.ClientID == req.ClientID {
|
|
client = &clientCfg
|
|
break
|
|
}
|
|
}
|
|
|
|
if client == nil {
|
|
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
|
c.JSON(404, gin.H{
|
|
"status": 404,
|
|
"message": "Client not found",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Validate redirect URI
|
|
if !slices.Contains(client.TrustedRedirectURLs, req.RedirectURI) {
|
|
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI not trusted")
|
|
c.JSON(400, gin.H{
|
|
"status": 400,
|
|
"message": "Bad Request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Validate scopes
|
|
reqScopes := strings.Split(req.Scope, " ")
|
|
keptScopes := make([]string, 0)
|
|
|
|
if len(reqScopes) == 0 || strings.TrimSpace(req.Scope) == "" {
|
|
queries, err := query.Values(CallbackError{
|
|
Error: "invalid_request",
|
|
ErrorDescription: "Missing scope parameter",
|
|
State: req.State,
|
|
})
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to build query")
|
|
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
return
|
|
}
|
|
|
|
c.Redirect(302, fmt.Sprintf("%s/callback?%s", req.RedirectURI, queries.Encode()))
|
|
return
|
|
}
|
|
|
|
for _, scope := range reqScopes {
|
|
if slices.Contains(SupportedScopes, scope) {
|
|
keptScopes = append(keptScopes, scope)
|
|
continue
|
|
}
|
|
tlog.App.Warn().Str("scope", scope).Msg("Scope not supported, ignoring")
|
|
}
|
|
|
|
// Generate a code and a sub
|
|
code, err := utils.GetRandomString(32)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to generate random string")
|
|
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
return
|
|
}
|
|
|
|
sub, err := utils.GetRandomInt(10)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to generate random integer")
|
|
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
return
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
|
|
|
|
// Insert the code into the database
|
|
_, err = controller.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
|
Code: code,
|
|
Sub: strconv.Itoa(int(sub)),
|
|
Scope: strings.Join(keptScopes, ","),
|
|
RedirectURI: req.RedirectURI,
|
|
ClientID: client.ClientID,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to insert code into database")
|
|
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
return
|
|
}
|
|
|
|
// We also need a snapshot of the user that authorized this
|
|
userInfoParams := repository.CreateOidcUserInfoParams{
|
|
Sub: strconv.Itoa(int(sub)),
|
|
Name: userContext.Name,
|
|
Email: userContext.Email,
|
|
PreferredUsername: userContext.Username,
|
|
UpdatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
if userContext.Provider == "ldap" {
|
|
userInfoParams.Groups = userContext.LdapGroups
|
|
}
|
|
|
|
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
|
|
userInfoParams.Groups = userContext.OAuthGroups
|
|
}
|
|
|
|
_, err = controller.queries.CreateOidcUserInfo(c, userInfoParams)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
|
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
return
|
|
}
|
|
|
|
// Return code and done
|
|
c.JSON(200, gin.H{
|
|
"status": 200,
|
|
"message": "Authorized",
|
|
"code": code,
|
|
"state": req.State,
|
|
"redirect_uri": req.RedirectURI,
|
|
})
|
|
}
|
|
|
|
func (controller *OIDCController) Token(c *gin.Context) {
|
|
// Get basic auth
|
|
clientId, clientSecret, ok := c.Request.BasicAuth()
|
|
|
|
if !ok {
|
|
tlog.App.Error().Msg("Missing token verifier")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Ensure client exists
|
|
var client *config.OIDCClientConfig
|
|
|
|
for _, clientCfg := range controller.config.Clients {
|
|
if clientCfg.ClientID == clientId {
|
|
client = &clientCfg
|
|
break
|
|
}
|
|
}
|
|
|
|
if client == nil {
|
|
tlog.App.Warn().Str("client_id", clientId).Msg("Client not found")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
if client.ClientSecret != clientSecret {
|
|
tlog.App.Warn().Str("client_id", clientId).Msg("Invalid client secret")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_client",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get token
|
|
var req TokenRequest
|
|
|
|
err := c.Bind(&req)
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to bind token request")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Validate grant type
|
|
if !slices.Contains(SupportedGrantTypes, req.GrantType) {
|
|
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
|
c.JSON(400, gin.H{
|
|
"error": "unsupported_grant_type",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Find pending code entry
|
|
entry, err := controller.queries.GetOidcCode(c, req.Code)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to find code in database")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Ensure redirect URIs match
|
|
if entry.RedirectURI != req.RedirectURI {
|
|
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Generate access token
|
|
genToken, err := utils.GetRandomString(29)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Add tinyauth prefix
|
|
token := fmt.Sprintf("ta-%s", genToken)
|
|
|
|
// TODO: either add a refresh token or customize token expiry
|
|
expiresAt := time.Now().Add(time.Duration(3600) * time.Second).Unix()
|
|
|
|
// Create token entry
|
|
_, err = controller.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
|
Sub: entry.Sub,
|
|
AccessToken: token,
|
|
Scope: entry.Scope,
|
|
ClientID: client.ClientID,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to create token in database")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Delete code entry
|
|
err = controller.queries.DeleteOidcCode(c, entry.Code)
|
|
|
|
if err != nil {
|
|
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
|
c.JSON(400, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Respond with token
|
|
c.JSON(200, gin.H{
|
|
"access_token": token,
|
|
"token_type": "bearer",
|
|
"expires_in": 3600,
|
|
})
|
|
}
|
|
|
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|
// Get bearer
|
|
authorizationHeader := c.GetHeader("Authorization")
|
|
|
|
tokenType, token, ok := strings.Cut(authorizationHeader, " ")
|
|
|
|
if !ok {
|
|
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
|
c.JSON(401, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
if strings.ToLower(tokenType) != "bearer" {
|
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
|
c.JSON(401, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get token entry
|
|
entry, err := controller.queries.GetOidcToken(c, token)
|
|
|
|
if err != nil {
|
|
tlog.App.Err(err).Msg("Failed to get token entry")
|
|
c.JSON(401, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get scopes
|
|
scopes := strings.Split(entry.Scope, ",")
|
|
|
|
// Check if token is expired
|
|
if time.Now().Unix() > entry.ExpiresAt {
|
|
tlog.App.Warn().Msg("OIDC userinfo accessed with expired token")
|
|
|
|
err = controller.queries.DeleteOidcToken(c, entry.AccessToken)
|
|
if err != nil {
|
|
tlog.App.Err(err).Msg("Failed to delete expired token")
|
|
}
|
|
|
|
err = controller.queries.DeleteOidcUserInfo(c, entry.Sub)
|
|
if err != nil {
|
|
tlog.App.Err(err).Msg("Failed to delete oidc user info")
|
|
}
|
|
|
|
c.JSON(401, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get user info
|
|
user, err := controller.queries.GetOidcUserInfo(c, entry.Sub)
|
|
|
|
if err != nil {
|
|
tlog.App.Err(err).Msg("Failed to get user entry")
|
|
c.JSON(401, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// If we don't have the openid scope, return an error
|
|
if !slices.Contains(scopes, "openid") {
|
|
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
|
c.JSON(401, gin.H{
|
|
"error": "invalid_request",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Let's build the response
|
|
res := map[string]any{
|
|
"sub": user.Sub,
|
|
"updated_at": user.UpdatedAt,
|
|
}
|
|
|
|
// If we have the profile scope, add the profile stuff
|
|
if slices.Contains(scopes, "profile") {
|
|
res["name"] = user.Name
|
|
res["preferred_username"] = user.PreferredUsername
|
|
}
|
|
|
|
// If we have the email scope, add the email stuff
|
|
if slices.Contains(scopes, "email") {
|
|
res["email"] = user.Email
|
|
}
|
|
|
|
// If we have the groups scope, add the groups stuff
|
|
if slices.Contains(scopes, "groups") {
|
|
res["groups"] = user.Groups
|
|
}
|
|
|
|
c.JSON(200, res)
|
|
}
|