diff --git a/frontend/src/pages/authorize-page.tsx b/frontend/src/pages/authorize-page.tsx
index 6befa96..7ada730 100644
--- a/frontend/src/pages/authorize-page.tsx
+++ b/frontend/src/pages/authorize-page.tsx
@@ -1,6 +1,6 @@
import { useUserContext } from "@/context/user-context";
-import { useQuery } from "@tanstack/react-query";
-import { Navigate } from "react-router";
+import { useMutation, useQuery } from "@tanstack/react-query";
+import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router";
import {
Card,
@@ -11,6 +11,8 @@ import {
} from "@/components/ui/card";
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button";
+import axios from "axios";
+import { toast } from "sonner";
type AuthorizePageProps = {
scope: string;
@@ -25,6 +27,7 @@ const optionalAuthorizeProps = ["state"];
export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext();
const { search } = useLocation();
+ const navigate = useNavigate();
const searchParams = new URLSearchParams(search);
@@ -46,12 +49,38 @@ export const AuthorizePage = () => {
},
});
+ const authorizeMutation = useMutation({
+ mutationFn: () => {
+ return axios.post("/api/oidc/authorize", {
+ scope: props.scope,
+ response_type: props.responseType,
+ client_id: props.clientId,
+ redirect_uri: props.redirectUri,
+ state: props.state,
+ });
+ },
+ mutationKey: ["authorize", props.clientId],
+ onSuccess: (data) => {
+ toast.info("Authorized", {
+ description: "You will be soon redirected to your application",
+ });
+ window.location.replace(
+ `${data.data.redirect_uri}?code=${encodeURIComponent(data.data.code)}&state=${encodeURIComponent(data.data.state)}`,
+ );
+ },
+ onError: (error) => {
+ window.location.replace(
+ `/error?error=${encodeURIComponent(error.message)}`,
+ );
+ },
+ });
+
if (!isLoggedIn) {
// TODO: Pass the params to the login page, so user can login -> authorize
return ;
}
- for (const key in Object.keys(props)) {
+ Object.keys(props).forEach((key) => {
if (
!props[key as keyof AuthorizePageProps] &&
!optionalAuthorizeProps.includes(key)
@@ -59,7 +88,7 @@ export const AuthorizePage = () => {
// TODO: Add reason for error
return ;
}
- }
+ });
if (getClientInfo.isLoading) {
return (
@@ -91,8 +120,19 @@ export const AuthorizePage = () => {
-
-
+
+
);
diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/000005_oidc_session.down.sql
new file mode 100644
index 0000000..68a3248
--- /dev/null
+++ b/internal/assets/migrations/000005_oidc_session.down.sql
@@ -0,0 +1,3 @@
+DROP TABLE IF EXISTS "oidc_tokens";
+DROP TABLE IF EXISTS "oidc_userinfo";
+DROP TABLE IF EXISTS "oidc_codes";
diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/000005_oidc_session.up.sql
new file mode 100644
index 0000000..01fa8a3
--- /dev/null
+++ b/internal/assets/migrations/000005_oidc_session.up.sql
@@ -0,0 +1,25 @@
+CREATE TABLE IF NOT EXISTS "oidc_codes" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "code" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "scope" TEXT NOT NULL,
+ "redirect_uri" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_tokens" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "scope" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
+ "sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
+ "name" TEXT NOT NULL,
+ "preferred_username" TEXT NOT NULL,
+ "email" TEXT NOT NULL,
+ "groups" TEXT NOT NULL,
+ "updated_at" INTEGER NOT NULL
+);
diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go
index e9cdd5a..31473c9 100644
--- a/internal/bootstrap/app_bootstrap.go
+++ b/internal/bootstrap/app_bootstrap.go
@@ -176,7 +176,7 @@ func (app *BootstrapApp) Setup() error {
app.context.configuredProviders = configuredProviders
// Setup router
- router, err := app.setupRouter()
+ router, err := app.setupRouter(queries)
if err != nil {
return fmt.Errorf("failed to setup routes: %w", err)
diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go
index c854c45..f6747c8 100644
--- a/internal/bootstrap/router_bootstrap.go
+++ b/internal/bootstrap/router_bootstrap.go
@@ -7,13 +7,14 @@ import (
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/middleware"
+ "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/gin-gonic/gin"
)
var DEV_MODES = []string{"main", "test", "development"}
-func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
+func (app *BootstrapApp) setupRouter(queries *repository.Queries) (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, config.Version) {
gin.SetMode(gin.ReleaseMode)
}
@@ -88,7 +89,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{
Clients: app.context.oidcClients,
- }, apiRouter)
+ AppURL: app.config.AppURL,
+ }, apiRouter, queries)
oidcController.SetupRoutes()
diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go
index 8fbf2ce..26b6966 100644
--- a/internal/controller/oidc_controller.go
+++ b/internal/controller/oidc_controller.go
@@ -1,30 +1,71 @@
package controller
import (
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+ "time"
+
"github.com/gin-gonic/gin"
+ "github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/config"
+ "github.com/steveiliop56/tinyauth/internal/repository"
+ "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
+var (
+ SupportedResponseTypes = []string{"code"}
+ SupportedScopes = []string{"openid", "profile", "email", "groups"}
+ SupportedGrantTypes = []string{"authorization_code"}
+)
+
type OIDCControllerConfig struct {
Clients []config.OIDCClientConfig
+ AppURL string
}
type OIDCController struct {
- clients []config.OIDCClientConfig
+ config OIDCControllerConfig
router *gin.RouterGroup
+ queries *repository.Queries
}
-func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup) *OIDCController {
+type AuthorizeRequest struct {
+ Scope string `json:"scope" binding:"required"`
+ ResponseType string `json:"response_type" binding:"required"`
+ ClientID string `json:"client_id" binding:"required"`
+ RedirectURI string `json:"redirect_uri" binding:"required"`
+ State string `json:"state" binding:"required"`
+}
+
+type TokenRequest struct {
+ GrantType string `form:"grant_type" binding:"required"`
+ Code string `form:"code" binding:"required"`
+ RedirectURI string `form:"redirect_uri" binding:"required"`
+}
+
+type CallbackError struct {
+ Error string `url:"error"`
+ ErrorDescription string `url:"error_description"`
+ State string `url:"state"`
+}
+
+func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, queries *repository.Queries) *OIDCController {
return &OIDCController{
- clients: config.Clients,
+ config: config,
router: router,
+ queries: queries,
}
}
func (controller *OIDCController) SetupRoutes() {
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
+ oidcGroup.POST("/authorize", controller.Authorize)
+ oidcGroup.POST("/token", controller.Token)
+ oidcGroup.GET("/userinfo", controller.Userinfo)
}
type ClientRequest struct {
@@ -47,7 +88,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
var client *config.OIDCClientConfig
// Inefficient yeah, but it will be good until we have thousands of clients
- for _, clientCfg := range controller.clients {
+ for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == req.ClientID {
client = &clientCfg
break
@@ -69,3 +110,392 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
"name": &client.Name,
})
}
+
+func (controller *OIDCController) Authorize(c *gin.Context) {
+ // Check if we are logged in
+ userContext, err := utils.GetContext(c)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to get user context")
+ c.JSON(401, gin.H{
+ "status": 401,
+ "message": "Unauthorized",
+ })
+ return
+ }
+
+ // OIDC stuff
+ var req AuthorizeRequest
+
+ err = c.BindJSON(&req)
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to bind JSON")
+ c.JSON(400, gin.H{
+ "status": 400,
+ "message": "Bad Request",
+ })
+ return
+ }
+
+ // TODO: All these errors should redirect to the error page with an explanation
+
+ // Validate client ID
+ var client *config.OIDCClientConfig
+
+ for _, clientCfg := range controller.config.Clients {
+ if clientCfg.ClientID == req.ClientID {
+ client = &clientCfg
+ break
+ }
+ }
+
+ if client == nil {
+ tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
+ c.JSON(404, gin.H{
+ "status": 404,
+ "message": "Client not found",
+ })
+ return
+ }
+
+ // Validate redirect URI
+ if !slices.Contains(client.TrustedRedirectURLs, req.RedirectURI) {
+ tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI not trusted")
+ c.JSON(400, gin.H{
+ "status": 400,
+ "message": "Bad Request",
+ })
+ return
+ }
+
+ // Validate scopes
+ reqScopes := strings.Split(req.Scope, " ")
+ keptScopes := make([]string, 0)
+
+ if len(reqScopes) == 0 || strings.TrimSpace(req.Scope) == "" {
+ queries, err := query.Values(CallbackError{
+ Error: "invalid_request",
+ ErrorDescription: "Missing scope parameter",
+ State: req.State,
+ })
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to build query")
+ c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
+ return
+ }
+
+ c.Redirect(302, fmt.Sprintf("%s/callback?%s", req.RedirectURI, queries.Encode()))
+ return
+ }
+
+ for _, scope := range reqScopes {
+ if slices.Contains(SupportedScopes, scope) {
+ keptScopes = append(keptScopes, scope)
+ continue
+ }
+ tlog.App.Warn().Str("scope", scope).Msg("Scope not supported, ignoring")
+ }
+
+ // Generate a code and a sub
+ code, err := utils.GetRandomString(32)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to generate random string")
+ c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
+ return
+ }
+
+ sub, err := utils.GetRandomInt(10)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to generate random integer")
+ c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
+ return
+ }
+
+ expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
+
+ // Insert the code into the database
+ _, err = controller.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
+ Code: code,
+ Sub: strconv.Itoa(int(sub)),
+ Scope: strings.Join(keptScopes, ","),
+ RedirectURI: req.RedirectURI,
+ ClientID: client.ClientID,
+ ExpiresAt: expiresAt,
+ })
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to insert code into database")
+ c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
+ return
+ }
+
+ // We also need a snapshot of the user that authorized this
+ userInfoParams := repository.CreateOidcUserInfoParams{
+ Sub: strconv.Itoa(int(sub)),
+ Name: userContext.Name,
+ Email: userContext.Email,
+ PreferredUsername: userContext.Username,
+ UpdatedAt: time.Now().Unix(),
+ }
+
+ if userContext.Provider == "ldap" {
+ userInfoParams.Groups = userContext.LdapGroups
+ }
+
+ if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
+ userInfoParams.Groups = userContext.OAuthGroups
+ }
+
+ _, err = controller.queries.CreateOidcUserInfo(c, userInfoParams)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
+ c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
+ return
+ }
+
+ // Return code and done
+ c.JSON(200, gin.H{
+ "status": 200,
+ "message": "Authorized",
+ "code": code,
+ "state": req.State,
+ "redirect_uri": req.RedirectURI,
+ })
+}
+
+func (controller *OIDCController) Token(c *gin.Context) {
+ // Get basic auth
+ clientId, clientSecret, ok := c.Request.BasicAuth()
+
+ if !ok {
+ tlog.App.Error().Msg("Missing token verifier")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Ensure client exists
+ var client *config.OIDCClientConfig
+
+ for _, clientCfg := range controller.config.Clients {
+ if clientCfg.ClientID == clientId {
+ client = &clientCfg
+ break
+ }
+ }
+
+ if client == nil {
+ tlog.App.Warn().Str("client_id", clientId).Msg("Client not found")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ if client.ClientSecret != clientSecret {
+ tlog.App.Warn().Str("client_id", clientId).Msg("Invalid client secret")
+ c.JSON(400, gin.H{
+ "error": "invalid_client",
+ })
+ return
+ }
+
+ // Get token
+ var req TokenRequest
+
+ err := c.Bind(&req)
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to bind token request")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Validate grant type
+ if !slices.Contains(SupportedGrantTypes, req.GrantType) {
+ tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
+ c.JSON(400, gin.H{
+ "error": "unsupported_grant_type",
+ })
+ return
+ }
+
+ // Find pending code entry
+ entry, err := controller.queries.GetOidcCode(c, req.Code)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to find code in database")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Ensure redirect URIs match
+ if entry.RedirectURI != req.RedirectURI {
+ tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Generate access token
+ genToken, err := utils.GetRandomString(29)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to generate access token")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Add tinyauth prefix
+ token := fmt.Sprintf("ta-%s", genToken)
+
+ // TODO: either add a refresh token or customize token expiry
+ expiresAt := time.Now().Add(time.Duration(3600) * time.Second).Unix()
+
+ // Create token entry
+ _, err = controller.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
+ Sub: entry.Sub,
+ AccessToken: token,
+ Scope: entry.Scope,
+ ClientID: client.ClientID,
+ ExpiresAt: expiresAt,
+ })
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to create token in database")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Delete code entry
+ err = controller.queries.DeleteOidcCode(c, entry.Code)
+
+ if err != nil {
+ tlog.App.Error().Err(err).Msg("Failed to delete code in database")
+ c.JSON(400, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Respond with token
+ c.JSON(200, gin.H{
+ "access_token": token,
+ "token_type": "bearer",
+ "expires_in": 3600,
+ })
+}
+
+func (controller *OIDCController) Userinfo(c *gin.Context) {
+ // Get bearer
+ authorizationHeader := c.GetHeader("Authorization")
+
+ tokenType, token, ok := strings.Cut(authorizationHeader, " ")
+
+ if !ok {
+ tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
+ c.JSON(401, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ if strings.ToLower(tokenType) != "bearer" {
+ tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
+ c.JSON(401, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Get token entry
+ entry, err := controller.queries.GetOidcToken(c, token)
+
+ if err != nil {
+ tlog.App.Err(err).Msg("Failed to get token entry")
+ c.JSON(401, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Get scopes
+ scopes := strings.Split(entry.Scope, ",")
+
+ // Check if token is expired
+ if time.Now().Unix() > entry.ExpiresAt {
+ tlog.App.Warn().Msg("OIDC userinfo accessed with expired token")
+
+ err = controller.queries.DeleteOidcToken(c, entry.AccessToken)
+ if err != nil {
+ tlog.App.Err(err).Msg("Failed to delete expired token")
+ }
+
+ err = controller.queries.DeleteOidcUserInfo(c, entry.Sub)
+ if err != nil {
+ tlog.App.Err(err).Msg("Failed to delete oidc user info")
+ }
+
+ c.JSON(401, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Get user info
+ user, err := controller.queries.GetOidcUserInfo(c, entry.Sub)
+
+ if err != nil {
+ tlog.App.Err(err).Msg("Failed to get user entry")
+ c.JSON(401, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // If we don't have the openid scope, return an error
+ if !slices.Contains(scopes, "openid") {
+ tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
+ c.JSON(401, gin.H{
+ "error": "invalid_request",
+ })
+ return
+ }
+
+ // Let's build the response
+ res := map[string]any{
+ "sub": user.Sub,
+ "updated_at": user.UpdatedAt,
+ }
+
+ // If we have the profile scope, add the profile stuff
+ if slices.Contains(scopes, "profile") {
+ res["name"] = user.Name
+ res["preferred_username"] = user.PreferredUsername
+ }
+
+ // If we have the email scope, add the email stuff
+ if slices.Contains(scopes, "email") {
+ res["email"] = user.Email
+ }
+
+ // If we have the groups scope, add the groups stuff
+ if slices.Contains(scopes, "groups") {
+ res["groups"] = user.Groups
+ }
+
+ c.JSON(200, res)
+}
diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go
index 4d392c8..fc71c05 100644
--- a/internal/middleware/context_middleware.go
+++ b/internal/middleware/context_middleware.go
@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
+ "slices"
"strings"
"time"
@@ -13,6 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
+var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"}
+
type ContextMiddlewareConfig struct {
CookieDomain string
}
@@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error {
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
+ // There is no point in trying to get credentials if it's an OIDC endpoint
+ path := c.Request.URL.Path
+ if slices.Contains(OIDCIgnorePaths, path) {
+ c.Next()
+ return
+ }
+
cookie, err := m.auth.GetSessionCookie(c)
if err != nil {
diff --git a/internal/repository/models.go b/internal/repository/models.go
index 61f7f80..3380645 100644
--- a/internal/repository/models.go
+++ b/internal/repository/models.go
@@ -4,6 +4,32 @@
package repository
+type OidcCode struct {
+ Sub string
+ Code string
+ Scope string
+ RedirectURI string
+ ClientID string
+ ExpiresAt int64
+}
+
+type OidcToken struct {
+ Sub string
+ AccessToken string
+ Scope string
+ ClientID string
+ ExpiresAt int64
+}
+
+type OidcUserinfo struct {
+ Sub string
+ Name string
+ PreferredUsername string
+ Email string
+ Groups string
+ UpdatedAt int64
+}
+
type Session struct {
UUID string
Username string
diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go
new file mode 100644
index 0000000..510981f
--- /dev/null
+++ b/internal/repository/oidc_queries.sql.go
@@ -0,0 +1,224 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.30.0
+// source: oidc_queries.sql
+
+package repository
+
+import (
+ "context"
+)
+
+const createOidcCode = `-- name: CreateOidcCode :one
+INSERT INTO "oidc_codes" (
+ "sub",
+ "code",
+ "scope",
+ "redirect_uri",
+ "client_id",
+ "expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING sub, code, scope, redirect_uri, client_id, expires_at
+`
+
+type CreateOidcCodeParams struct {
+ Sub string
+ Code string
+ Scope string
+ RedirectURI string
+ ClientID string
+ ExpiresAt int64
+}
+
+func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, createOidcCode,
+ arg.Sub,
+ arg.Code,
+ arg.Scope,
+ arg.RedirectURI,
+ arg.ClientID,
+ arg.ExpiresAt,
+ )
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.Code,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const createOidcToken = `-- name: CreateOidcToken :one
+INSERT INTO "oidc_tokens" (
+ "sub",
+ "access_token",
+ "scope",
+ "client_id",
+ "expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?
+)
+RETURNING sub, access_token, scope, client_id, expires_at
+`
+
+type CreateOidcTokenParams struct {
+ Sub string
+ AccessToken string
+ Scope string
+ ClientID string
+ ExpiresAt int64
+}
+
+func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, createOidcToken,
+ arg.Sub,
+ arg.AccessToken,
+ arg.Scope,
+ arg.ClientID,
+ arg.ExpiresAt,
+ )
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessToken,
+ &i.Scope,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
+INSERT INTO "oidc_userinfo" (
+ "sub",
+ "name",
+ "preferred_username",
+ "email",
+ "groups",
+ "updated_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING sub, name, preferred_username, email, "groups", updated_at
+`
+
+type CreateOidcUserInfoParams struct {
+ Sub string
+ Name string
+ PreferredUsername string
+ Email string
+ Groups string
+ UpdatedAt int64
+}
+
+func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
+ row := q.db.QueryRowContext(ctx, createOidcUserInfo,
+ arg.Sub,
+ arg.Name,
+ arg.PreferredUsername,
+ arg.Email,
+ arg.Groups,
+ arg.UpdatedAt,
+ )
+ var i OidcUserinfo
+ err := row.Scan(
+ &i.Sub,
+ &i.Name,
+ &i.PreferredUsername,
+ &i.Email,
+ &i.Groups,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
+
+const deleteOidcCode = `-- name: DeleteOidcCode :exec
+DELETE FROM "oidc_codes"
+WHERE "code" = ?
+`
+
+func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcCode, code)
+ return err
+}
+
+const deleteOidcToken = `-- name: DeleteOidcToken :exec
+DELETE FROM "oidc_tokens"
+WHERE "access_token" = ?
+`
+
+func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken)
+ return err
+}
+
+const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
+DELETE FROM "oidc_userinfo"
+WHERE "sub" = ?
+`
+
+func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
+ _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
+ return err
+}
+
+const getOidcCode = `-- name: GetOidcCode :one
+SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
+WHERE "code" = ?
+`
+
+func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) {
+ row := q.db.QueryRowContext(ctx, getOidcCode, code)
+ var i OidcCode
+ err := row.Scan(
+ &i.Sub,
+ &i.Code,
+ &i.Scope,
+ &i.RedirectURI,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcToken = `-- name: GetOidcToken :one
+SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens"
+WHERE "access_token" = ?
+`
+
+func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) {
+ row := q.db.QueryRowContext(ctx, getOidcToken, accessToken)
+ var i OidcToken
+ err := row.Scan(
+ &i.Sub,
+ &i.AccessToken,
+ &i.Scope,
+ &i.ClientID,
+ &i.ExpiresAt,
+ )
+ return i, err
+}
+
+const getOidcUserInfo = `-- name: GetOidcUserInfo :one
+SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo"
+WHERE "sub" = ?
+`
+
+func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
+ row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
+ var i OidcUserinfo
+ err := row.Scan(
+ &i.Sub,
+ &i.Name,
+ &i.PreferredUsername,
+ &i.Email,
+ &i.Groups,
+ &i.UpdatedAt,
+ )
+ return i, err
+}
diff --git a/internal/repository/queries.sql.go b/internal/repository/session_queries.sql.go
similarity index 98%
rename from internal/repository/queries.sql.go
rename to internal/repository/session_queries.sql.go
index e171b7a..c846c3f 100644
--- a/internal/repository/queries.sql.go
+++ b/internal/repository/session_queries.sql.go
@@ -1,7 +1,7 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
-// source: queries.sql
+// source: session_queries.sql
package repository
@@ -10,7 +10,7 @@ import (
)
const createSession = `-- name: CreateSession :one
-INSERT INTO sessions (
+INSERT INTO "sessions" (
"uuid",
"username",
"email",
diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go
index 40fe713..0cc539d 100644
--- a/internal/utils/security_utils.go
+++ b/internal/utils/security_utils.go
@@ -1,8 +1,11 @@
package utils
import (
+ "crypto/rand"
"encoding/base64"
"errors"
+ "math"
+ "math/big"
"net"
"regexp"
"strings"
@@ -105,3 +108,28 @@ func GenerateUUID(str string) string {
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
return uuid.String()
}
+
+// These could definitely be improved A LOT but at least they are cryptographically secure
+func GetRandomString(length int) (string, error) {
+ if length < 1 {
+ return "", errors.New("length must be greater than 0")
+ }
+ b := make([]byte, length)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ state := base64.RawURLEncoding.EncodeToString(b)
+ return state[:length], nil
+}
+
+func GetRandomInt(length int) (int64, error) {
+ if length < 1 {
+ return 0, errors.New("length must be greater than 0")
+ }
+ a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
+ if err != nil {
+ return 0, err
+ }
+ return a.Int64(), nil
+}
diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go
index 3ebd681..6e74c99 100644
--- a/internal/utils/security_utils_test.go
+++ b/internal/utils/security_utils_test.go
@@ -2,6 +2,7 @@ package utils_test
import (
"os"
+ "strconv"
"testing"
"github.com/steveiliop56/tinyauth/internal/utils"
@@ -147,3 +148,25 @@ func TestGenerateUUID(t *testing.T) {
id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3)
}
+
+func TestGetRandomString(t *testing.T) {
+ // Test with normal length
+ state, err := utils.GetRandomString(16)
+ assert.NilError(t, err)
+ assert.Equal(t, 16, len(state))
+
+ // Test with zero length
+ state, err = utils.GetRandomString(0)
+ assert.Error(t, err, "length must be greater than 0")
+}
+
+func TestGetRandomInt(t *testing.T) {
+ // Test with normal length
+ state, err := utils.GetRandomInt(16)
+ assert.NilError(t, err)
+ assert.Equal(t, 16, len(strconv.Itoa(int(state))))
+
+ // Test with zero length
+ state, err = utils.GetRandomInt(0)
+ assert.Error(t, err, "length must be greater than 0")
+}
diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql
new file mode 100644
index 0000000..c99c788
--- /dev/null
+++ b/sql/oidc_queries.sql
@@ -0,0 +1,61 @@
+-- name: CreateOidcCode :one
+INSERT INTO "oidc_codes" (
+ "sub",
+ "code",
+ "scope",
+ "redirect_uri",
+ "client_id",
+ "expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING *;
+
+-- name: DeleteOidcCode :exec
+DELETE FROM "oidc_codes"
+WHERE "code" = ?;
+
+-- name: GetOidcCode :one
+SELECT * FROM "oidc_codes"
+WHERE "code" = ?;
+
+-- name: CreateOidcToken :one
+INSERT INTO "oidc_tokens" (
+ "sub",
+ "access_token",
+ "scope",
+ "client_id",
+ "expires_at"
+) VALUES (
+ ?, ?, ?, ?, ?
+)
+RETURNING *;
+
+-- name: DeleteOidcToken :exec
+DELETE FROM "oidc_tokens"
+WHERE "access_token" = ?;
+
+-- name: GetOidcToken :one
+SELECT * FROM "oidc_tokens"
+WHERE "access_token" = ?;
+
+-- name: CreateOidcUserInfo :one
+INSERT INTO "oidc_userinfo" (
+ "sub",
+ "name",
+ "preferred_username",
+ "email",
+ "groups",
+ "updated_at"
+) VALUES (
+ ?, ?, ?, ?, ?, ?
+)
+RETURNING *;
+
+-- name: DeleteOidcUserInfo :exec
+DELETE FROM "oidc_userinfo"
+WHERE "sub" = ?;
+
+-- name: GetOidcUserInfo :one
+SELECT * FROM "oidc_userinfo"
+WHERE "sub" = ?;
diff --git a/sql/oidc_schemas.sql b/sql/oidc_schemas.sql
new file mode 100644
index 0000000..01fa8a3
--- /dev/null
+++ b/sql/oidc_schemas.sql
@@ -0,0 +1,25 @@
+CREATE TABLE IF NOT EXISTS "oidc_codes" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "code" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "scope" TEXT NOT NULL,
+ "redirect_uri" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_tokens" (
+ "sub" TEXT NOT NULL UNIQUE,
+ "access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
+ "scope" TEXT NOT NULL,
+ "client_id" TEXT NOT NULL,
+ "expires_at" INTEGER NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
+ "sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
+ "name" TEXT NOT NULL,
+ "preferred_username" TEXT NOT NULL,
+ "email" TEXT NOT NULL,
+ "groups" TEXT NOT NULL,
+ "updated_at" INTEGER NOT NULL
+);
diff --git a/sql/queries.sql b/sql/session_queries.sql
similarity index 96%
rename from sql/queries.sql
rename to sql/session_queries.sql
index 9fde4e2..da93126 100644
--- a/sql/queries.sql
+++ b/sql/session_queries.sql
@@ -1,5 +1,5 @@
-- name: CreateSession :one
-INSERT INTO sessions (
+INSERT INTO "sessions" (
"uuid",
"username",
"email",
diff --git a/sql/schema.sql b/sql/session_schemas.sql
similarity index 100%
rename from sql/schema.sql
rename to sql/session_schemas.sql
diff --git a/sqlc.yml b/sqlc.yml
index b9cf1ea..2c0f170 100644
--- a/sqlc.yml
+++ b/sqlc.yml
@@ -1,8 +1,8 @@
version: "2"
sql:
- engine: "sqlite"
- queries: "sql/queries.sql"
- schema: "sql/schema.sql"
+ queries: "sql/*_queries.sql"
+ schema: "sql/*_schemas.sql"
gen:
go:
package: "repository"
@@ -12,6 +12,7 @@ sql:
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
+ redirect_uri: "RedirectURI"
overrides:
- column: "sessions.oauth_groups"
go_type: "string"