From 97e90ea56028cf6ec70b9daf4c1510280da0a27d Mon Sep 17 00:00:00 2001 From: Stavros Date: Thu, 22 Jan 2026 22:30:23 +0200 Subject: [PATCH] feat: implement basic oidc functionality --- frontend/src/pages/authorize-page.tsx | 52 ++- .../migrations/000005_oidc_session.down.sql | 3 + .../migrations/000005_oidc_session.up.sql | 25 + internal/bootstrap/app_bootstrap.go | 2 +- internal/bootstrap/router_bootstrap.go | 6 +- internal/controller/oidc_controller.go | 438 +++++++++++++++++- internal/middleware/context_middleware.go | 10 + internal/repository/models.go | 26 ++ internal/repository/oidc_queries.sql.go | 224 +++++++++ ...{queries.sql.go => session_queries.sql.go} | 4 +- internal/utils/security_utils.go | 28 ++ internal/utils/security_utils_test.go | 23 + sql/oidc_queries.sql | 61 +++ sql/oidc_schemas.sql | 25 + sql/{queries.sql => session_queries.sql} | 2 +- sql/{schema.sql => session_schemas.sql} | 0 sqlc.yml | 5 +- 17 files changed, 916 insertions(+), 18 deletions(-) create mode 100644 internal/assets/migrations/000005_oidc_session.down.sql create mode 100644 internal/assets/migrations/000005_oidc_session.up.sql create mode 100644 internal/repository/oidc_queries.sql.go rename internal/repository/{queries.sql.go => session_queries.sql.go} (98%) create mode 100644 sql/oidc_queries.sql create mode 100644 sql/oidc_schemas.sql rename sql/{queries.sql => session_queries.sql} (96%) rename sql/{schema.sql => session_schemas.sql} (100%) diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx index 6befa96..7ada730 100644 --- a/frontend/src/pages/authorize-page.tsx +++ b/frontend/src/pages/authorize-page.tsx @@ -1,6 +1,6 @@ import { useUserContext } from "@/context/user-context"; -import { useQuery } from "@tanstack/react-query"; -import { Navigate } from "react-router"; +import { useMutation, useQuery } from "@tanstack/react-query"; +import { Navigate, useNavigate } from "react-router"; import { useLocation } from "react-router"; import { Card, @@ -11,6 +11,8 @@ import { } from "@/components/ui/card"; import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas"; import { Button } from "@/components/ui/button"; +import axios from "axios"; +import { toast } from "sonner"; type AuthorizePageProps = { scope: string; @@ -25,6 +27,7 @@ const optionalAuthorizeProps = ["state"]; export const AuthorizePage = () => { const { isLoggedIn } = useUserContext(); const { search } = useLocation(); + const navigate = useNavigate(); const searchParams = new URLSearchParams(search); @@ -46,12 +49,38 @@ export const AuthorizePage = () => { }, }); + const authorizeMutation = useMutation({ + mutationFn: () => { + return axios.post("/api/oidc/authorize", { + scope: props.scope, + response_type: props.responseType, + client_id: props.clientId, + redirect_uri: props.redirectUri, + state: props.state, + }); + }, + mutationKey: ["authorize", props.clientId], + onSuccess: (data) => { + toast.info("Authorized", { + description: "You will be soon redirected to your application", + }); + window.location.replace( + `${data.data.redirect_uri}?code=${encodeURIComponent(data.data.code)}&state=${encodeURIComponent(data.data.state)}`, + ); + }, + onError: (error) => { + window.location.replace( + `/error?error=${encodeURIComponent(error.message)}`, + ); + }, + }); + if (!isLoggedIn) { // TODO: Pass the params to the login page, so user can login -> authorize return ; } - for (const key in Object.keys(props)) { + Object.keys(props).forEach((key) => { if ( !props[key as keyof AuthorizePageProps] && !optionalAuthorizeProps.includes(key) @@ -59,7 +88,7 @@ export const AuthorizePage = () => { // TODO: Add reason for error return ; } - } + }); if (getClientInfo.isLoading) { return ( @@ -91,8 +120,19 @@ export const AuthorizePage = () => { - - + + ); diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/000005_oidc_session.down.sql new file mode 100644 index 0000000..68a3248 --- /dev/null +++ b/internal/assets/migrations/000005_oidc_session.down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/000005_oidc_session.up.sql new file mode 100644 index 0000000..01fa8a3 --- /dev/null +++ b/internal/assets/migrations/000005_oidc_session.up.sql @@ -0,0 +1,25 @@ +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" INTEGER NOT NULL +); diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index e9cdd5a..31473c9 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -176,7 +176,7 @@ func (app *BootstrapApp) Setup() error { app.context.configuredProviders = configuredProviders // Setup router - router, err := app.setupRouter() + router, err := app.setupRouter(queries) if err != nil { return fmt.Errorf("failed to setup routes: %w", err) diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index c854c45..f6747c8 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -7,13 +7,14 @@ import ( "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/middleware" + "github.com/steveiliop56/tinyauth/internal/repository" "github.com/gin-gonic/gin" ) var DEV_MODES = []string{"main", "test", "development"} -func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { +func (app *BootstrapApp) setupRouter(queries *repository.Queries) (*gin.Engine, error) { if !slices.Contains(DEV_MODES, config.Version) { gin.SetMode(gin.ReleaseMode) } @@ -88,7 +89,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{ Clients: app.context.oidcClients, - }, apiRouter) + AppURL: app.config.AppURL, + }, apiRouter, queries) oidcController.SetupRoutes() diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 8fbf2ce..26b6966 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -1,30 +1,71 @@ package controller import ( + "fmt" + "slices" + "strconv" + "strings" + "time" + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" "github.com/steveiliop56/tinyauth/internal/config" + "github.com/steveiliop56/tinyauth/internal/repository" + "github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils/tlog" ) +var ( + SupportedResponseTypes = []string{"code"} + SupportedScopes = []string{"openid", "profile", "email", "groups"} + SupportedGrantTypes = []string{"authorization_code"} +) + type OIDCControllerConfig struct { Clients []config.OIDCClientConfig + AppURL string } type OIDCController struct { - clients []config.OIDCClientConfig + config OIDCControllerConfig router *gin.RouterGroup + queries *repository.Queries } -func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup) *OIDCController { +type AuthorizeRequest struct { + Scope string `json:"scope" binding:"required"` + ResponseType string `json:"response_type" binding:"required"` + ClientID string `json:"client_id" binding:"required"` + RedirectURI string `json:"redirect_uri" binding:"required"` + State string `json:"state" binding:"required"` +} + +type TokenRequest struct { + GrantType string `form:"grant_type" binding:"required"` + Code string `form:"code" binding:"required"` + RedirectURI string `form:"redirect_uri" binding:"required"` +} + +type CallbackError struct { + Error string `url:"error"` + ErrorDescription string `url:"error_description"` + State string `url:"state"` +} + +func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, queries *repository.Queries) *OIDCController { return &OIDCController{ - clients: config.Clients, + config: config, router: router, + queries: queries, } } func (controller *OIDCController) SetupRoutes() { oidcGroup := controller.router.Group("/oidc") oidcGroup.GET("/clients/:id", controller.GetClientInfo) + oidcGroup.POST("/authorize", controller.Authorize) + oidcGroup.POST("/token", controller.Token) + oidcGroup.GET("/userinfo", controller.Userinfo) } type ClientRequest struct { @@ -47,7 +88,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { var client *config.OIDCClientConfig // Inefficient yeah, but it will be good until we have thousands of clients - for _, clientCfg := range controller.clients { + for _, clientCfg := range controller.config.Clients { if clientCfg.ClientID == req.ClientID { client = &clientCfg break @@ -69,3 +110,392 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { "name": &client.Name, }) } + +func (controller *OIDCController) Authorize(c *gin.Context) { + // Check if we are logged in + userContext, err := utils.GetContext(c) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // OIDC stuff + var req AuthorizeRequest + + err = c.BindJSON(&req) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + // TODO: All these errors should redirect to the error page with an explanation + + // Validate client ID + var client *config.OIDCClientConfig + + for _, clientCfg := range controller.config.Clients { + if clientCfg.ClientID == req.ClientID { + client = &clientCfg + break + } + } + + if client == nil { + tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") + c.JSON(404, gin.H{ + "status": 404, + "message": "Client not found", + }) + return + } + + // Validate redirect URI + if !slices.Contains(client.TrustedRedirectURLs, req.RedirectURI) { + tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI not trusted") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + // Validate scopes + reqScopes := strings.Split(req.Scope, " ") + keptScopes := make([]string, 0) + + if len(reqScopes) == 0 || strings.TrimSpace(req.Scope) == "" { + queries, err := query.Values(CallbackError{ + Error: "invalid_request", + ErrorDescription: "Missing scope parameter", + State: req.State, + }) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to build query") + c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + c.Redirect(302, fmt.Sprintf("%s/callback?%s", req.RedirectURI, queries.Encode())) + return + } + + for _, scope := range reqScopes { + if slices.Contains(SupportedScopes, scope) { + keptScopes = append(keptScopes, scope) + continue + } + tlog.App.Warn().Str("scope", scope).Msg("Scope not supported, ignoring") + } + + // Generate a code and a sub + code, err := utils.GetRandomString(32) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to generate random string") + c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + sub, err := utils.GetRandomInt(10) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to generate random integer") + c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix() + + // Insert the code into the database + _, err = controller.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{ + Code: code, + Sub: strconv.Itoa(int(sub)), + Scope: strings.Join(keptScopes, ","), + RedirectURI: req.RedirectURI, + ClientID: client.ClientID, + ExpiresAt: expiresAt, + }) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to insert code into database") + c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + // We also need a snapshot of the user that authorized this + userInfoParams := repository.CreateOidcUserInfoParams{ + Sub: strconv.Itoa(int(sub)), + Name: userContext.Name, + Email: userContext.Email, + PreferredUsername: userContext.Username, + UpdatedAt: time.Now().Unix(), + } + + if userContext.Provider == "ldap" { + userInfoParams.Groups = userContext.LdapGroups + } + + if userContext.OAuth && len(userContext.OAuthGroups) > 0 { + userInfoParams.Groups = userContext.OAuthGroups + } + + _, err = controller.queries.CreateOidcUserInfo(c, userInfoParams) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to insert user info into database") + c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + // Return code and done + c.JSON(200, gin.H{ + "status": 200, + "message": "Authorized", + "code": code, + "state": req.State, + "redirect_uri": req.RedirectURI, + }) +} + +func (controller *OIDCController) Token(c *gin.Context) { + // Get basic auth + clientId, clientSecret, ok := c.Request.BasicAuth() + + if !ok { + tlog.App.Error().Msg("Missing token verifier") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Ensure client exists + var client *config.OIDCClientConfig + + for _, clientCfg := range controller.config.Clients { + if clientCfg.ClientID == clientId { + client = &clientCfg + break + } + } + + if client == nil { + tlog.App.Warn().Str("client_id", clientId).Msg("Client not found") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + if client.ClientSecret != clientSecret { + tlog.App.Warn().Str("client_id", clientId).Msg("Invalid client secret") + c.JSON(400, gin.H{ + "error": "invalid_client", + }) + return + } + + // Get token + var req TokenRequest + + err := c.Bind(&req) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to bind token request") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Validate grant type + if !slices.Contains(SupportedGrantTypes, req.GrantType) { + tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") + c.JSON(400, gin.H{ + "error": "unsupported_grant_type", + }) + return + } + + // Find pending code entry + entry, err := controller.queries.GetOidcCode(c, req.Code) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to find code in database") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Ensure redirect URIs match + if entry.RedirectURI != req.RedirectURI { + tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Generate access token + genToken, err := utils.GetRandomString(29) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to generate access token") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Add tinyauth prefix + token := fmt.Sprintf("ta-%s", genToken) + + // TODO: either add a refresh token or customize token expiry + expiresAt := time.Now().Add(time.Duration(3600) * time.Second).Unix() + + // Create token entry + _, err = controller.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ + Sub: entry.Sub, + AccessToken: token, + Scope: entry.Scope, + ClientID: client.ClientID, + ExpiresAt: expiresAt, + }) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to create token in database") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Delete code entry + err = controller.queries.DeleteOidcCode(c, entry.Code) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to delete code in database") + c.JSON(400, gin.H{ + "error": "invalid_request", + }) + return + } + + // Respond with token + c.JSON(200, gin.H{ + "access_token": token, + "token_type": "bearer", + "expires_in": 3600, + }) +} + +func (controller *OIDCController) Userinfo(c *gin.Context) { + // Get bearer + authorizationHeader := c.GetHeader("Authorization") + + tokenType, token, ok := strings.Cut(authorizationHeader, " ") + + if !ok { + tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + + if strings.ToLower(tokenType) != "bearer" { + tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + + // Get token entry + entry, err := controller.queries.GetOidcToken(c, token) + + if err != nil { + tlog.App.Err(err).Msg("Failed to get token entry") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + + // Get scopes + scopes := strings.Split(entry.Scope, ",") + + // Check if token is expired + if time.Now().Unix() > entry.ExpiresAt { + tlog.App.Warn().Msg("OIDC userinfo accessed with expired token") + + err = controller.queries.DeleteOidcToken(c, entry.AccessToken) + if err != nil { + tlog.App.Err(err).Msg("Failed to delete expired token") + } + + err = controller.queries.DeleteOidcUserInfo(c, entry.Sub) + if err != nil { + tlog.App.Err(err).Msg("Failed to delete oidc user info") + } + + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + + // Get user info + user, err := controller.queries.GetOidcUserInfo(c, entry.Sub) + + if err != nil { + tlog.App.Err(err).Msg("Failed to get user entry") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + + // If we don't have the openid scope, return an error + if !slices.Contains(scopes, "openid") { + tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") + c.JSON(401, gin.H{ + "error": "invalid_request", + }) + return + } + + // Let's build the response + res := map[string]any{ + "sub": user.Sub, + "updated_at": user.UpdatedAt, + } + + // If we have the profile scope, add the profile stuff + if slices.Contains(scopes, "profile") { + res["name"] = user.Name + res["preferred_username"] = user.PreferredUsername + } + + // If we have the email scope, add the email stuff + if slices.Contains(scopes, "email") { + res["email"] = user.Email + } + + // If we have the groups scope, add the groups stuff + if slices.Contains(scopes, "groups") { + res["groups"] = user.Groups + } + + c.JSON(200, res) +} diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 4d392c8..fc71c05 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "slices" "strings" "time" @@ -13,6 +14,8 @@ import ( "github.com/gin-gonic/gin" ) +var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"} + type ContextMiddlewareConfig struct { CookieDomain string } @@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error { func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { + // There is no point in trying to get credentials if it's an OIDC endpoint + path := c.Request.URL.Path + if slices.Contains(OIDCIgnorePaths, path) { + c.Next() + return + } + cookie, err := m.auth.GetSessionCookie(c) if err != nil { diff --git a/internal/repository/models.go b/internal/repository/models.go index 61f7f80..3380645 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -4,6 +4,32 @@ package repository +type OidcCode struct { + Sub string + Code string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 +} + +type OidcToken struct { + Sub string + AccessToken string + Scope string + ClientID string + ExpiresAt int64 +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 +} + type Session struct { UUID string Username string diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go new file mode 100644 index 0000000..510981f --- /dev/null +++ b/internal/repository/oidc_queries.sql.go @@ -0,0 +1,224 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: oidc_queries.sql + +package repository + +import ( + "context" +) + +const createOidcCode = `-- name: CreateOidcCode :one +INSERT INTO "oidc_codes" ( + "sub", + "code", + "scope", + "redirect_uri", + "client_id", + "expires_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING sub, code, scope, redirect_uri, client_id, expires_at +` + +type CreateOidcCodeParams struct { + Sub string + Code string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 +} + +func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, createOidcCode, + arg.Sub, + arg.Code, + arg.Scope, + arg.RedirectURI, + arg.ClientID, + arg.ExpiresAt, + ) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.Code, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const createOidcToken = `-- name: CreateOidcToken :one +INSERT INTO "oidc_tokens" ( + "sub", + "access_token", + "scope", + "client_id", + "expires_at" +) VALUES ( + ?, ?, ?, ?, ? +) +RETURNING sub, access_token, scope, client_id, expires_at +` + +type CreateOidcTokenParams struct { + Sub string + AccessToken string + Scope string + ClientID string + ExpiresAt int64 +} + +func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, createOidcToken, + arg.Sub, + arg.AccessToken, + arg.Scope, + arg.ClientID, + arg.ExpiresAt, + ) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessToken, + &i.Scope, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const createOidcUserInfo = `-- name: CreateOidcUserInfo :one +INSERT INTO "oidc_userinfo" ( + "sub", + "name", + "preferred_username", + "email", + "groups", + "updated_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING sub, name, preferred_username, email, "groups", updated_at +` + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 +} + +func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { + row := q.db.QueryRowContext(ctx, createOidcUserInfo, + arg.Sub, + arg.Name, + arg.PreferredUsername, + arg.Email, + arg.Groups, + arg.UpdatedAt, + ) + var i OidcUserinfo + err := row.Scan( + &i.Sub, + &i.Name, + &i.PreferredUsername, + &i.Email, + &i.Groups, + &i.UpdatedAt, + ) + return i, err +} + +const deleteOidcCode = `-- name: DeleteOidcCode :exec +DELETE FROM "oidc_codes" +WHERE "code" = ? +` + +func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error { + _, err := q.db.ExecContext(ctx, deleteOidcCode, code) + return err +} + +const deleteOidcToken = `-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token" = ? +` + +func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error { + _, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken) + return err +} + +const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec +DELETE FROM "oidc_userinfo" +WHERE "sub" = ? +` + +func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) + return err +} + +const getOidcCode = `-- name: GetOidcCode :one +SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" +WHERE "code" = ? +` + +func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCode, code) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.Code, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const getOidcToken = `-- name: GetOidcToken :one +SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens" +WHERE "access_token" = ? +` + +func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcToken, accessToken) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessToken, + &i.Scope, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + +const getOidcUserInfo = `-- name: GetOidcUserInfo :one +SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo" +WHERE "sub" = ? +` + +func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { + row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) + var i OidcUserinfo + err := row.Scan( + &i.Sub, + &i.Name, + &i.PreferredUsername, + &i.Email, + &i.Groups, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/repository/queries.sql.go b/internal/repository/session_queries.sql.go similarity index 98% rename from internal/repository/queries.sql.go rename to internal/repository/session_queries.sql.go index e171b7a..c846c3f 100644 --- a/internal/repository/queries.sql.go +++ b/internal/repository/session_queries.sql.go @@ -1,7 +1,7 @@ // Code generated by sqlc. DO NOT EDIT. // versions: // sqlc v1.30.0 -// source: queries.sql +// source: session_queries.sql package repository @@ -10,7 +10,7 @@ import ( ) const createSession = `-- name: CreateSession :one -INSERT INTO sessions ( +INSERT INTO "sessions" ( "uuid", "username", "email", diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 40fe713..0cc539d 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -1,8 +1,11 @@ package utils import ( + "crypto/rand" "encoding/base64" "errors" + "math" + "math/big" "net" "regexp" "strings" @@ -105,3 +108,28 @@ func GenerateUUID(str string) string { uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) return uuid.String() } + +// These could definitely be improved A LOT but at least they are cryptographically secure +func GetRandomString(length int) (string, error) { + if length < 1 { + return "", errors.New("length must be greater than 0") + } + b := make([]byte, length) + _, err := rand.Read(b) + if err != nil { + return "", err + } + state := base64.RawURLEncoding.EncodeToString(b) + return state[:length], nil +} + +func GetRandomInt(length int) (int64, error) { + if length < 1 { + return 0, errors.New("length must be greater than 0") + } + a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length))))) + if err != nil { + return 0, err + } + return a.Int64(), nil +} diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go index 3ebd681..6e74c99 100644 --- a/internal/utils/security_utils_test.go +++ b/internal/utils/security_utils_test.go @@ -2,6 +2,7 @@ package utils_test import ( "os" + "strconv" "testing" "github.com/steveiliop56/tinyauth/internal/utils" @@ -147,3 +148,25 @@ func TestGenerateUUID(t *testing.T) { id3 := utils.GenerateUUID("differentstring") assert.Assert(t, id1 != id3) } + +func TestGetRandomString(t *testing.T) { + // Test with normal length + state, err := utils.GetRandomString(16) + assert.NilError(t, err) + assert.Equal(t, 16, len(state)) + + // Test with zero length + state, err = utils.GetRandomString(0) + assert.Error(t, err, "length must be greater than 0") +} + +func TestGetRandomInt(t *testing.T) { + // Test with normal length + state, err := utils.GetRandomInt(16) + assert.NilError(t, err) + assert.Equal(t, 16, len(strconv.Itoa(int(state)))) + + // Test with zero length + state, err = utils.GetRandomInt(0) + assert.Error(t, err, "length must be greater than 0") +} diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql new file mode 100644 index 0000000..c99c788 --- /dev/null +++ b/sql/oidc_queries.sql @@ -0,0 +1,61 @@ +-- name: CreateOidcCode :one +INSERT INTO "oidc_codes" ( + "sub", + "code", + "scope", + "redirect_uri", + "client_id", + "expires_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: DeleteOidcCode :exec +DELETE FROM "oidc_codes" +WHERE "code" = ?; + +-- name: GetOidcCode :one +SELECT * FROM "oidc_codes" +WHERE "code" = ?; + +-- name: CreateOidcToken :one +INSERT INTO "oidc_tokens" ( + "sub", + "access_token", + "scope", + "client_id", + "expires_at" +) VALUES ( + ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token" = ?; + +-- name: GetOidcToken :one +SELECT * FROM "oidc_tokens" +WHERE "access_token" = ?; + +-- name: CreateOidcUserInfo :one +INSERT INTO "oidc_userinfo" ( + "sub", + "name", + "preferred_username", + "email", + "groups", + "updated_at" +) VALUES ( + ?, ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: DeleteOidcUserInfo :exec +DELETE FROM "oidc_userinfo" +WHERE "sub" = ?; + +-- name: GetOidcUserInfo :one +SELECT * FROM "oidc_userinfo" +WHERE "sub" = ?; diff --git a/sql/oidc_schemas.sql b/sql/oidc_schemas.sql new file mode 100644 index 0000000..01fa8a3 --- /dev/null +++ b/sql/oidc_schemas.sql @@ -0,0 +1,25 @@ +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" INTEGER NOT NULL +); diff --git a/sql/queries.sql b/sql/session_queries.sql similarity index 96% rename from sql/queries.sql rename to sql/session_queries.sql index 9fde4e2..da93126 100644 --- a/sql/queries.sql +++ b/sql/session_queries.sql @@ -1,5 +1,5 @@ -- name: CreateSession :one -INSERT INTO sessions ( +INSERT INTO "sessions" ( "uuid", "username", "email", diff --git a/sql/schema.sql b/sql/session_schemas.sql similarity index 100% rename from sql/schema.sql rename to sql/session_schemas.sql diff --git a/sqlc.yml b/sqlc.yml index b9cf1ea..2c0f170 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -1,8 +1,8 @@ version: "2" sql: - engine: "sqlite" - queries: "sql/queries.sql" - schema: "sql/schema.sql" + queries: "sql/*_queries.sql" + schema: "sql/*_schemas.sql" gen: go: package: "repository" @@ -12,6 +12,7 @@ sql: oauth_groups: "OAuthGroups" oauth_name: "OAuthName" oauth_sub: "OAuthSub" + redirect_uri: "RedirectURI" overrides: - column: "sessions.oauth_groups" go_type: "string"