refactor: implement oidc following tinyauth patterns

This commit is contained in:
Stavros
2026-01-24 14:31:03 +02:00
parent 97e90ea560
commit c817e353f6
8 changed files with 621 additions and 290 deletions

View File

@@ -1,43 +1,31 @@
package controller
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
var (
SupportedResponseTypes = []string{"code"}
SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedGrantTypes = []string{"authorization_code"}
)
type OIDCControllerConfig struct {
Clients []config.OIDCClientConfig
AppURL string
}
type OIDCControllerConfig struct{}
type OIDCController struct {
config OIDCControllerConfig
router *gin.RouterGroup
queries *repository.Queries
config OIDCControllerConfig
router *gin.RouterGroup
oidc *service.OIDCService
}
type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state" binding:"required"`
type AuthorizeCallback struct {
Code string `url:"code"`
State string `url:"state"`
}
type TokenRequest struct {
@@ -52,11 +40,19 @@ type CallbackError struct {
State string `url:"state"`
}
func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, queries *repository.Queries) *OIDCController {
type ErrorScreen struct {
Error string `url:"error"`
}
type ClientRequest struct {
ClientID string `uri:"id" binding:"required"`
}
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
return &OIDCController{
config: config,
router: router,
queries: queries,
config: config,
oidc: oidcService,
router: router,
}
}
@@ -68,10 +64,6 @@ func (controller *OIDCController) SetupRoutes() {
oidcGroup.GET("/userinfo", controller.Userinfo)
}
type ClientRequest struct {
ClientID string `uri:"id" binding:"required"`
}
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
var req ClientRequest
@@ -85,17 +77,9 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
return
}
var client *config.OIDCClientConfig
client, ok := controller.oidc.GetClient(req.ClientID)
// Inefficient yeah, but it will be good until we have thousands of clients
for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == req.ClientID {
client = &clientCfg
break
}
}
if client == nil {
if !ok {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
@@ -106,206 +90,111 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"client": &client.ClientID,
"name": &client.Name,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
// Check if we are logged in
userContext, err := utils.GetContext(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
// OIDC stuff
var req AuthorizeRequest
var req service.AuthorizeRequest
err = c.BindJSON(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "")
return
}
// TODO: All these errors should redirect to the error page with an explanation
_, ok := controller.oidc.GetClient(req.ClientID)
// Validate client ID
var client *config.OIDCClientConfig
for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == req.ClientID {
client = &clientCfg
break
}
}
if client == nil {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
if !ok {
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
return
}
// Validate redirect URI
if !slices.Contains(client.TrustedRedirectURLs, req.RedirectURI) {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI not trusted")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
err = controller.oidc.ValidateAuthorizeParams(req)
// Validate scopes
reqScopes := strings.Split(req.Scope, " ")
keptScopes := make([]string, 0)
if len(reqScopes) == 0 || strings.TrimSpace(req.Scope) == "" {
queries, err := query.Values(CallbackError{
Error: "invalid_request",
ErrorDescription: "Missing scope parameter",
State: req.State,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to build query")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return
}
c.Redirect(302, fmt.Sprintf("%s/callback?%s", req.RedirectURI, queries.Encode()))
controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "")
return
}
for _, scope := range reqScopes {
if slices.Contains(SupportedScopes, scope) {
keptScopes = append(keptScopes, scope)
continue
}
tlog.App.Warn().Str("scope", scope).Msg("Scope not supported, ignoring")
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too.
sub := utils.GenerateUUID(userContext.Username)
code := rand.Text()
// Generate a code and a sub
code, err := utils.GetRandomString(32)
err = controller.oidc.StoreCode(c, sub, code, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate random string")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
sub, err := utils.GetRandomInt(10)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate random integer")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
// Insert the code into the database
_, err = controller.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Code: code,
Sub: strconv.Itoa(int(sub)),
Scope: strings.Join(keptScopes, ","),
RedirectURI: req.RedirectURI,
ClientID: client.ClientID,
ExpiresAt: expiresAt,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert code into database")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State)
return
}
// We also need a snapshot of the user that authorized this
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: strconv.Itoa(int(sub)),
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
UpdatedAt: time.Now().Unix(),
}
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
}
_, err = controller.queries.CreateOidcUserInfo(c, userInfoParams)
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return
}
queries, err := query.Values(AuthorizeCallback{
Code: code,
State: req.State,
})
if err != nil {
controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State)
return
}
// Return code and done
c.JSON(200, gin.H{
"status": 200,
"message": "Authorized",
"code": code,
"state": req.State,
"redirect_uri": req.RedirectURI,
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
})
}
func (controller *OIDCController) Token(c *gin.Context) {
// Get basic auth
clientId, clientSecret, ok := c.Request.BasicAuth()
rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing token verifier")
tlog.App.Error().Msg("Missing authorization header")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Ensure client exists
var client *config.OIDCClientConfig
client, ok := controller.oidc.GetClient(rclientId)
for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == clientId {
client = &clientCfg
break
}
}
if client == nil {
tlog.App.Warn().Str("client_id", clientId).Msg("Client not found")
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_request",
"error": "access_denied",
})
return
}
if client.ClientSecret != clientSecret {
tlog.App.Warn().Str("client_id", clientId).Msg("Invalid client secret")
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
"error": "access_denied",
})
return
}
// Get token
var req TokenRequest
err := c.Bind(&req)
@@ -317,93 +206,73 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
// Validate grant type
if !slices.Contains(SupportedGrantTypes, req.GrantType) {
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{
"error": "unsupported_grant_type",
"error": err.Error(),
})
return
}
// Find pending code entry
entry, err := controller.queries.GetOidcCode(c, req.Code)
entry, err := controller.oidc.GetCodeEntry(c, req.Code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to find code in database")
if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
c.JSON(400, gin.H{
"error": "access_denied",
})
return
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{
"error": "invalid_request",
"error": "server_error",
})
return
}
// Ensure redirect URIs match
if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{
"error": "invalid_request",
"error": "invalid_request_uri",
})
return
}
// Generate access token
genToken, err := utils.GetRandomString(29)
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "invalid_request",
"error": "server_error",
})
return
}
// Add tinyauth prefix
token := fmt.Sprintf("ta-%s", genToken)
// TODO: either add a refresh token or customize token expiry
expiresAt := time.Now().Add(time.Duration(3600) * time.Second).Unix()
// Create token entry
_, err = controller.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: entry.Sub,
AccessToken: token,
Scope: entry.Scope,
ClientID: client.ClientID,
ExpiresAt: expiresAt,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create token in database")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Delete code entry
err = controller.queries.DeleteOidcCode(c, entry.Code)
err = controller.oidc.DeleteCodeEntry(c, entry.Code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
c.JSON(400, gin.H{
"error": "invalid_request",
"error": "server_error",
})
return
}
// Respond with token
c.JSON(200, gin.H{
"access_token": token,
"token_type": "bearer",
"expires_in": 3600,
})
c.JSON(200, accessToken)
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
// Get bearer
authorizationHeader := c.GetHeader("Authorization")
authorization := c.GetHeader("Authorization")
tokenType, token, ok := strings.Cut(authorizationHeader, " ")
tokenType, token, ok := strings.Cut(authorization, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
@@ -421,53 +290,36 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return
}
// Get token entry
entry, err := controller.queries.GetOidcToken(c, token)
entry, err := controller.oidc.GetAccessToken(c, token)
if err != nil {
if err == service.ErrTokenNotFound {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
tlog.App.Err(err).Msg("Failed to get token entry")
c.JSON(401, gin.H{
"error": "invalid_request",
"error": "server_error",
})
return
}
// Get scopes
scopes := strings.Split(entry.Scope, ",")
// Check if token is expired
if time.Now().Unix() > entry.ExpiresAt {
tlog.App.Warn().Msg("OIDC userinfo accessed with expired token")
err = controller.queries.DeleteOidcToken(c, entry.AccessToken)
if err != nil {
tlog.App.Err(err).Msg("Failed to delete expired token")
}
err = controller.queries.DeleteOidcUserInfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to delete oidc user info")
}
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
// Get user info
user, err := controller.queries.GetOidcUserInfo(c, entry.Sub)
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry")
c.JSON(401, gin.H{
"error": "invalid_request",
"error": "server_error",
})
return
}
// If we don't have the openid scope, return an error
if !slices.Contains(scopes, "openid") {
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_request",
@@ -475,27 +327,52 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return
}
// Let's build the response
res := map[string]any{
"sub": user.Sub,
"updated_at": user.UpdatedAt,
}
// If we have the profile scope, add the profile stuff
if slices.Contains(scopes, "profile") {
res["name"] = user.Name
res["preferred_username"] = user.PreferredUsername
}
// If we have the email scope, add the email stuff
if slices.Contains(scopes, "email") {
res["email"] = user.Email
}
// If we have the groups scope, add the groups stuff
if slices.Contains(scopes, "groups") {
res["groups"] = user.Groups
}
c.JSON(200, res)
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
}
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
tlog.App.Error().Err(err).Msg(reason)
if callback != "" {
errorQueries := CallbackError{
Error: callbackError,
}
if reasonUser != "" {
errorQueries.ErrorDescription = reasonUser
}
if state != "" {
errorQueries.State = state
}
queries, err := query.Values(errorQueries)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s/?%s", callback, queries.Encode()),
})
return
}
errorQueries := ErrorScreen{
Error: reasonUser,
}
queries, err := query.Values(errorQueries)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
})
}