package service import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "database/sql" "encoding/json" "encoding/pem" "errors" "fmt" "net/url" "os" "strings" "time" "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v4" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils/tlog" "golang.org/x/exp/slices" ) var ( SupportedScopes = []string{"openid", "profile", "email", "groups"} SupportedResponseTypes = []string{"code"} SupportedGrantTypes = []string{"authorization_code", "refresh_token"} ) var ( ErrCodeExpired = errors.New("code_expired") ErrCodeNotFound = errors.New("code_not_found") ErrTokenNotFound = errors.New("token_not_found") ErrTokenExpired = errors.New("token_expired") ErrInvalidClient = errors.New("invalid_client") ) type ClaimSet struct { Iss string `json:"iss"` Aud string `json:"aud"` Sub string `json:"sub"` Iat int64 `json:"iat"` Exp int64 `json:"exp"` } type UserinfoResponse struct { Sub string `json:"sub"` Name string `json:"name"` Email string `json:"email"` PreferredUsername string `json:"preferred_username"` Groups []string `json:"groups"` UpdatedAt int64 `json:"updated_at"` } type TokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` Scope string `json:"scope"` } type AuthorizeRequest struct { Scope string `json:"scope" binding:"required"` ResponseType string `json:"response_type" binding:"required"` ClientID string `json:"client_id" binding:"required"` RedirectURI string `json:"redirect_uri" binding:"required"` State string `json:"state" binding:"required"` } type OIDCServiceConfig struct { Clients map[string]config.OIDCClientConfig PrivateKeyPath string PublicKeyPath string Issuer string SessionExpiry int } type OIDCService struct { config OIDCServiceConfig queries *repository.Queries clients map[string]config.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey issuer string isConfigured bool } func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { return &OIDCService{ config: config, queries: queries, } } func (service *OIDCService) IsConfigured() bool { return service.isConfigured } func (service *OIDCService) Init() error { // If not configured, skip init if len(service.config.Clients) == 0 { service.isConfigured = false return nil } service.isConfigured = true // Ensure issuer is https uissuer, err := url.Parse(service.config.Issuer) if err != nil { return err } if uissuer.Scheme != "https" { return errors.New("issuer must be https") } service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys if strings.TrimSpace(service.config.PrivateKeyPath) == "" || strings.TrimSpace(service.config.PublicKeyPath) == "" { return errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } if errors.Is(err, os.ErrNotExist) { privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { return err } der := x509.MarshalPKCS1PrivateKey(privateKey) if der == nil { return errors.New("failed to marshal private key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }) err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) if err != nil { return err } service.privateKey = privateKey } else { block, _ := pem.Decode(fprivateKey) if block == nil { return errors.New("failed to decode private key") } privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return err } service.privateKey = privateKey } fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } if errors.Is(err, os.ErrNotExist) { publicKey := service.privateKey.Public() der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) if der == nil { return errors.New("failed to marshal public key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: der, }) err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) if err != nil { return err } service.publicKey = publicKey } else { block, _ := pem.Decode(fpublicKey) if block == nil { return errors.New("failed to decode public key") } publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { return err } service.publicKey = publicKey } // We will reorganize the client into a map with the client ID as the key service.clients = make(map[string]config.OIDCClientConfig) for id, client := range service.config.Clients { client.ID = id service.clients[client.ClientID] = client } // Load the client secrets from files if they exist for id, client := range service.clients { secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) if secret != "" { client.ClientSecret = secret } client.ClientSecretFile = "" service.clients[id] = client tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") } return nil } func (service *OIDCService) GetIssuer() string { return service.issuer } func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { client, ok := service.clients[id] return client, ok } func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error { // Validate client ID client, ok := service.GetClient(req.ClientID) if !ok { return errors.New("access_denied") } // Scopes scopes := strings.Split(req.Scope, " ") if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" { return errors.New("invalid_scope") } for _, scope := range scopes { if strings.TrimSpace(scope) == "" { return errors.New("invalid_scope") } if !slices.Contains(SupportedScopes, scope) { tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") } } // Response type if !slices.Contains(SupportedResponseTypes, req.ResponseType) { return errors.New("unsupported_response_type") } // Redirect URI if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) { return errors.New("invalid_request_uri") } return nil } func (service *OIDCService) filterScopes(scopes []string) []string { return utils.Filter(scopes, func(scope string) bool { return slices.Contains(SupportedScopes, scope) }) } func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error { // Fixed 10 minutes expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix() // Insert the code into the database _, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{ Sub: sub, CodeHash: service.Hash(code), // Here it's safe to split and trust the output since, we validated the scopes before Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), RedirectURI: req.RedirectURI, ClientID: req.ClientID, ExpiresAt: expiresAt, }) return err } func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error { userInfoParams := repository.CreateOidcUserInfoParams{ Sub: sub, Name: userContext.Name, Email: userContext.Email, PreferredUsername: userContext.Username, UpdatedAt: time.Now().Unix(), } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server if userContext.Provider == "ldap" { userInfoParams.Groups = userContext.LdapGroups } if userContext.OAuth && len(userContext.OAuthGroups) > 0 { userInfoParams.Groups = userContext.OAuthGroups } _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) return err } func (service *OIDCService) ValidateGrantType(grantType string) error { if !slices.Contains(SupportedGrantTypes, grantType) { return errors.New("unsupported_grant_type") } return nil } func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) { oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { if errors.Is(err, sql.ErrNoRows) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err } if time.Now().Unix() > oidcCode.ExpiresAt { err = service.queries.DeleteOidcCode(c, codeHash) if err != nil { return repository.OidcCode{}, err } err = service.DeleteUserinfo(c, oidcCode.Sub) if err != nil { return repository.OidcCode{}, err } return repository.OidcCode{}, ErrCodeExpired } return oidcCode, nil } func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) { createdAt := time.Now().Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.RS256, Key: service.privateKey, }, &jose.SignerOptions{ ExtraHeaders: map[jose.HeaderKey]any{ "typ": "jwt", "jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer), }, }) if err != nil { return "", err } claims := ClaimSet{ Iss: service.issuer, Aud: client.ClientID, Sub: sub, Iat: createdAt, Exp: expiresAt, } payload, err := json.Marshal(claims) if err != nil { return "", err } object, err := signer.Sign(payload) if err != nil { return "", err } token, err := object.CompactSerialize() if err != nil { return "", err } return token, nil } func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) { idToken, err := service.generateIDToken(client, sub) if err != nil { return TokenResponse{}, err } accessToken := rand.Text() refreshToken := rand.Text() tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() // Refresh token lives double the time of an access token but can't be used to access userinfo refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", ExpiresIn: int64(service.config.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(scope, ",", " "), } _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ Sub: sub, AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(refreshToken), ClientID: client.ClientID, Scope: scope, TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt, }) if err != nil { return TokenResponse{}, err } return tokenResponse, nil } func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) { entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { if err == sql.ErrNoRows { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err } if entry.RefreshTokenExpiresAt < time.Now().Unix() { return TokenResponse{}, ErrTokenExpired } // Ensure the client ID in the request matches the client ID in the token if entry.ClientID != reqClientId { return TokenResponse{}, ErrInvalidClient } idToken, err := service.generateIDToken(config.OIDCClientConfig{ ClientID: entry.ClientID, }, entry.Sub) if err != nil { return TokenResponse{}, err } accessToken := rand.Text() newRefreshToken := rand.Text() tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, TokenType: "Bearer", ExpiresIn: int64(service.config.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(entry.Scope, ",", " "), } _, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{ AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(newRefreshToken), TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refrshTokenExpiresAt, RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db }) if err != nil { return TokenResponse{}, err } return tokenResponse, nil } func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error { return service.queries.DeleteOidcCode(c, codeHash) } func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { return service.queries.DeleteOidcUserInfo(c, sub) } func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error { return service.queries.DeleteOidcToken(c, tokenHash) } func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { if err == sql.ErrNoRows { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err } if entry.TokenExpiresAt < time.Now().Unix() { // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore if entry.RefreshTokenExpiresAt < time.Now().Unix() { err := service.DeleteToken(c, tokenHash) if err != nil { return repository.OidcToken{}, err } err = service.DeleteUserinfo(c, entry.Sub) if err != nil { return repository.OidcToken{}, err } } return repository.OidcToken{}, ErrTokenExpired } return entry, nil } func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) { return service.queries.GetOidcUserInfo(c, sub) } func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse { scopes := strings.Split(scope, ",") // split by comma since it's a db entry userInfo := UserinfoResponse{ Sub: user.Sub, UpdatedAt: user.UpdatedAt, } if slices.Contains(scopes, "profile") { userInfo.Name = user.Name userInfo.PreferredUsername = user.PreferredUsername } if slices.Contains(scopes, "email") { userInfo.Email = user.Email } if slices.Contains(scopes, "groups") { if user.Groups != "" { userInfo.Groups = strings.Split(user.Groups, ",") } else { userInfo.Groups = []string{} } } return userInfo } func (service *OIDCService) Hash(token string) string { hasher := sha256.New() hasher.Write([]byte(token)) return fmt.Sprintf("%x", hasher.Sum(nil)) } func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } return nil } // Cleanup routine - Resource heavy due to the linked tables func (service *OIDCService) Cleanup() { // We need a context for the routine ctx := context.Background() ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() for range ticker.C { currentTime := time.Now().Unix() // For the OIDC tokens, if they are expired we delete the userinfo and codes expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ TokenExpiresAt: currentTime, RefreshTokenExpiresAt: currentTime, }) if err != nil { tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") } for _, expiredToken := range expiredTokens { err := service.DeleteOldSession(ctx, expiredToken.Sub) if err != nil { tlog.App.Warn().Err(err).Msg("Failed to delete old session") } } // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) if err != nil { tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") } for _, expiredCode := range expiredCodes { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { if err == sql.ErrNoRows { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") } if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { err := service.DeleteOldSession(ctx, expiredCode.Sub) if err != nil { tlog.App.Warn().Err(err).Msg("Failed to delete session") } } } } } func (service *OIDCService) GetJWK() ([]byte, error) { jwk := jose.JSONWebKey{ Key: service.privateKey, Algorithm: string(jose.RS256), Use: "sig", } return jwk.Public().MarshalJSON() }