package service import ( "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "database/sql" "encoding/pem" "errors" "fmt" "net/url" "os" "strings" "time" "github.com/gin-gonic/gin" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils/tlog" "golang.org/x/exp/slices" // Should probably switch to another package but for now this works "golang.org/x/oauth2/jws" ) var ( SupportedScopes = []string{"openid", "profile", "email", "groups"} SupportedResponseTypes = []string{"code"} SupportedGrantTypes = []string{"authorization_code"} ) var ( ErrCodeExpired = errors.New("code_expired") ErrCodeNotFound = errors.New("code_not_found") ErrTokenNotFound = errors.New("token_not_found") ErrTokenExpired = errors.New("token_expired") ) type UserinfoResponse struct { Sub string `json:"sub"` Name string `json:"name"` Email string `json:"email"` PreferredUsername string `json:"preferred_username"` Groups []string `json:"groups"` UpdatedAt int64 `json:"updated_at"` } type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int64 `json:"expires_in"` IDToken string `json:"id_token"` Scope string `json:"scope"` } type AuthorizeRequest struct { Scope string `json:"scope" binding:"required"` ResponseType string `json:"response_type" binding:"required"` ClientID string `json:"client_id" binding:"required"` RedirectURI string `json:"redirect_uri" binding:"required"` State string `json:"state" binding:"required"` } type OIDCServiceConfig struct { Clients map[string]config.OIDCClientConfig PrivateKeyPath string PublicKeyPath string Issuer string } type OIDCService struct { config OIDCServiceConfig queries *repository.Queries clients map[string]config.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey issuer string } func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { return &OIDCService{ config: config, queries: queries, } } // TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo func (service *OIDCService) Init() error { // Ensure issuer is https uissuer, err := url.Parse(service.config.Issuer) if err != nil { return err } if uissuer.Scheme != "https" { return errors.New("issuer must be https") } service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys if strings.TrimSpace(service.config.PrivateKeyPath) == "" || strings.TrimSpace(service.config.PublicKeyPath) == "" { return errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } if errors.Is(err, os.ErrNotExist) { privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { return err } der := x509.MarshalPKCS1PrivateKey(privateKey) encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }) err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) if err != nil { return err } service.privateKey = privateKey } else { block, _ := pem.Decode(fprivateKey) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return err } service.privateKey = privateKey } fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err } if errors.Is(err, os.ErrNotExist) { publicKey := service.privateKey.Public() der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: der, }) err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) if err != nil { return err } service.publicKey = publicKey } else { block, _ := pem.Decode(fpublicKey) publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { return err } service.publicKey = publicKey } // We will reorganize the client into a map with the client ID as the key service.clients = make(map[string]config.OIDCClientConfig) for id, client := range service.config.Clients { client.ID = id service.clients[client.ClientID] = client } // Load the client secrets from files if they exist for id, client := range service.clients { secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) if secret != "" { client.ClientSecret = secret } client.ClientSecretFile = "" service.clients[id] = client } return nil } func (service *OIDCService) GetIssuer() string { return service.config.Issuer } func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { client, ok := service.clients[id] return client, ok } func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error { // Validate client ID client, ok := service.GetClient(req.ClientID) if !ok { return errors.New("access_denied") } // Scopes scopes := strings.Split(req.Scope, " ") if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" { return errors.New("invalid_scope") } for _, scope := range scopes { if strings.TrimSpace(scope) == "" { return errors.New("invalid_scope") } if !slices.Contains(SupportedScopes, scope) { tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") } } // Response type if !slices.Contains(SupportedResponseTypes, req.ResponseType) { return errors.New("unsupported_response_type") } // Redirect URI if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) { return errors.New("invalid_request_uri") } return nil } func (service *OIDCService) filterScopes(scopes []string) []string { return utils.Filter(scopes, func(scope string) bool { return slices.Contains(SupportedScopes, scope) }) } func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error { // Fixed 10 minutes expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix() // Insert the code into the database _, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{ Sub: sub, Code: code, // Here it's safe to split and trust the output since, we validated the scopes before Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), RedirectURI: req.RedirectURI, ClientID: req.ClientID, ExpiresAt: expiresAt, }) return err } func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error { userInfoParams := repository.CreateOidcUserInfoParams{ Sub: sub, Name: userContext.Name, Email: userContext.Email, PreferredUsername: userContext.Username, UpdatedAt: time.Now().Unix(), } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server if userContext.Provider == "ldap" { userInfoParams.Groups = userContext.LdapGroups } if userContext.OAuth && len(userContext.OAuthGroups) > 0 { userInfoParams.Groups = userContext.OAuthGroups } _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) return err } func (service *OIDCService) ValidateGrantType(grantType string) error { if !slices.Contains(SupportedGrantTypes, grantType) { return errors.New("unsupported_response_type") } return nil } func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) { oidcCode, err := service.queries.GetOidcCode(c, code) if err != nil { if errors.Is(err, sql.ErrNoRows) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err } if time.Now().Unix() > oidcCode.ExpiresAt { err = service.queries.DeleteOidcCode(c, code) if err != nil { return repository.OidcCode{}, err } err = service.DeleteUserinfo(c, oidcCode.Sub) if err != nil { return repository.OidcCode{}, err } return repository.OidcCode{}, ErrCodeExpired } return oidcCode, nil } func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) { createdAt := time.Now().Unix() // TODO: This should probably be user-configured if refresh logic does not exist expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix() claims := jws.ClaimSet{ Iss: service.issuer, Aud: client.ClientID, Sub: sub, Iat: createdAt, Exp: expiresAt, } header := jws.Header{ Algorithm: "RS256", Typ: "JWT", } token, err := jws.Encode(&header, &claims, service.privateKey) if err != nil { return "", err } return token, nil } func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) { idToken, err := service.generateIDToken(client, sub) if err != nil { return TokenResponse{}, err } accessToken := rand.Text() expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: int64(time.Hour.Seconds()), IDToken: idToken, Scope: strings.ReplaceAll(scope, ",", " "), } _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ Sub: sub, AccessToken: accessToken, Scope: scope, ExpiresAt: expiresAt, }) if err != nil { return TokenResponse{}, err } return tokenResponse, nil } func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error { return service.queries.DeleteOidcCode(c, code) } func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { return service.queries.DeleteOidcUserInfo(c, sub) } func (service *OIDCService) DeleteToken(c *gin.Context, token string) error { return service.queries.DeleteOidcToken(c, token) } func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) { entry, err := service.queries.GetOidcToken(c, token) if err != nil { if err == sql.ErrNoRows { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err } if entry.ExpiresAt < time.Now().Unix() { err := service.DeleteToken(c, token) if err != nil { return repository.OidcToken{}, err } err = service.DeleteUserinfo(c, entry.Sub) if err != nil { return repository.OidcToken{}, err } return repository.OidcToken{}, ErrTokenExpired } return entry, nil } func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) { return service.queries.GetOidcUserInfo(c, sub) } func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse { scopes := strings.Split(scope, ",") // split by comma since it's a db entry userInfo := UserinfoResponse{ Sub: user.Sub, UpdatedAt: user.UpdatedAt, } if slices.Contains(scopes, "profile") { userInfo.Name = user.Name userInfo.PreferredUsername = user.PreferredUsername } if slices.Contains(scopes, "email") { userInfo.Email = user.Email } if slices.Contains(scopes, "groups") { userInfo.Groups = strings.Split(user.Groups, ",") } return userInfo }