Files
tinyauth/internal/service/oidc_service.go

439 lines
11 KiB
Go

package service
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices"
// Should probably switch to another package but for now this works
"golang.org/x/oauth2/jws"
)
var (
SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedResponseTypes = []string{"code"}
SupportedGrantTypes = []string{"authorization_code"}
)
var (
ErrCodeExpired = errors.New("code_expired")
ErrCodeNotFound = errors.New("code_not_found")
ErrTokenNotFound = errors.New("token_not_found")
ErrTokenExpired = errors.New("token_expired")
)
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"`
UpdatedAt int64 `json:"updated_at"`
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
Scope string `json:"scope"`
}
type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state" binding:"required"`
}
type OIDCServiceConfig struct {
Clients map[string]config.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
}
}
// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
func (service *OIDCService) Init() error {
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
if err != nil {
return err
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
if err != nil {
return err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
}
service.publicKey = publicKey
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]config.OIDCClientConfig)
for id, client := range service.config.Clients {
client.ID = id
service.clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
}
return nil
}
func (service *OIDCService) GetIssuer() string {
return service.config.Issuer
}
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
client, ok := service.clients[id]
return client, ok
}
func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error {
// Validate client ID
client, ok := service.GetClient(req.ClientID)
if !ok {
return errors.New("access_denied")
}
// Scopes
scopes := strings.Split(req.Scope, " ")
if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" {
return errors.New("invalid_scope")
}
for _, scope := range scopes {
if strings.TrimSpace(scope) == "" {
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
}
}
// Response type
if !slices.Contains(SupportedResponseTypes, req.ResponseType) {
return errors.New("unsupported_response_type")
}
// Redirect URI
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
return errors.New("invalid_request_uri")
}
return nil
}
func (service *OIDCService) filterScopes(scopes []string) []string {
return utils.Filter(scopes, func(scope string) bool {
return slices.Contains(SupportedScopes, scope)
})
}
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
// Fixed 10 minutes
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
// Insert the code into the database
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Sub: sub,
Code: code,
// Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI,
ClientID: req.ClientID,
ExpiresAt: expiresAt,
})
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
UpdatedAt: time.Now().Unix(),
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
}
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
func (service *OIDCService) ValidateGrantType(grantType string) error {
if !slices.Contains(SupportedGrantTypes, grantType) {
return errors.New("unsupported_response_type")
}
return nil
}
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, code)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
}
if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, code)
if err != nil {
return repository.OidcCode{}, err
}
err = service.DeleteUserinfo(c, oidcCode.Sub)
if err != nil {
return repository.OidcCode{}, err
}
return repository.OidcCode{}, ErrCodeExpired
}
return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
createdAt := time.Now().Unix()
// TODO: This should probably be user-configured if refresh logic does not exist
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
claims := jws.ClaimSet{
Iss: service.issuer,
Aud: client.ClientID,
Sub: sub,
Iat: createdAt,
Exp: expiresAt,
}
header := jws.Header{
Algorithm: "RS256",
Typ: "JWT",
}
token, err := jws.Encode(&header, &claims, service.privateKey)
if err != nil {
return "", err
}
return token, nil
}
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) {
idToken, err := service.generateIDToken(client, sub)
if err != nil {
return TokenResponse{}, err
}
accessToken := rand.Text()
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: int64(time.Hour.Seconds()),
IDToken: idToken,
Scope: strings.ReplaceAll(scope, ",", " "),
}
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub,
AccessToken: accessToken,
Scope: scope,
ExpiresAt: expiresAt,
})
if err != nil {
return TokenResponse{}, err
}
return tokenResponse, nil
}
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
return service.queries.DeleteOidcCode(c, code)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
return service.queries.DeleteOidcToken(c, token)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, token)
if err != nil {
if err == sql.ErrNoRows {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
}
if entry.ExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, token)
if err != nil {
return repository.OidcToken{}, err
}
err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
}
return repository.OidcToken{}, ErrTokenExpired
}
return entry, nil
}
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
return service.queries.GetOidcUserInfo(c, sub)
}
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
userInfo := UserinfoResponse{
Sub: user.Sub,
UpdatedAt: user.UpdatedAt,
}
if slices.Contains(scopes, "profile") {
userInfo.Name = user.Name
userInfo.PreferredUsername = user.PreferredUsername
}
if slices.Contains(scopes, "email") {
userInfo.Email = user.Email
}
if slices.Contains(scopes, "groups") {
userInfo.Groups = strings.Split(user.Groups, ",")
}
return userInfo
}