mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-26 09:12:30 +00:00
439 lines
11 KiB
Go
439 lines
11 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/steveiliop56/tinyauth/internal/config"
|
|
"github.com/steveiliop56/tinyauth/internal/repository"
|
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
|
"golang.org/x/exp/slices"
|
|
|
|
// Should probably switch to another package but for now this works
|
|
"golang.org/x/oauth2/jws"
|
|
)
|
|
|
|
var (
|
|
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
|
SupportedResponseTypes = []string{"code"}
|
|
SupportedGrantTypes = []string{"authorization_code"}
|
|
)
|
|
|
|
var (
|
|
ErrCodeExpired = errors.New("code_expired")
|
|
ErrCodeNotFound = errors.New("code_not_found")
|
|
ErrTokenNotFound = errors.New("token_not_found")
|
|
ErrTokenExpired = errors.New("token_expired")
|
|
)
|
|
|
|
type UserinfoResponse struct {
|
|
Sub string `json:"sub"`
|
|
Name string `json:"name"`
|
|
Email string `json:"email"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
Groups []string `json:"groups"`
|
|
UpdatedAt int64 `json:"updated_at"`
|
|
}
|
|
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
IDToken string `json:"id_token"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
type AuthorizeRequest struct {
|
|
Scope string `json:"scope" binding:"required"`
|
|
ResponseType string `json:"response_type" binding:"required"`
|
|
ClientID string `json:"client_id" binding:"required"`
|
|
RedirectURI string `json:"redirect_uri" binding:"required"`
|
|
State string `json:"state" binding:"required"`
|
|
}
|
|
|
|
type OIDCServiceConfig struct {
|
|
Clients map[string]config.OIDCClientConfig
|
|
PrivateKeyPath string
|
|
PublicKeyPath string
|
|
Issuer string
|
|
}
|
|
|
|
type OIDCService struct {
|
|
config OIDCServiceConfig
|
|
queries *repository.Queries
|
|
clients map[string]config.OIDCClientConfig
|
|
privateKey *rsa.PrivateKey
|
|
publicKey crypto.PublicKey
|
|
issuer string
|
|
}
|
|
|
|
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
|
|
return &OIDCService{
|
|
config: config,
|
|
queries: queries,
|
|
}
|
|
}
|
|
|
|
// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
|
|
|
|
func (service *OIDCService) Init() error {
|
|
// Ensure issuer is https
|
|
uissuer, err := url.Parse(service.config.Issuer)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if uissuer.Scheme != "https" {
|
|
return errors.New("issuer must be https")
|
|
}
|
|
|
|
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
|
|
|
// Create/load private and public keys
|
|
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
|
|
strings.TrimSpace(service.config.PublicKeyPath) == "" {
|
|
return errors.New("private key path and public key path are required")
|
|
}
|
|
|
|
var privateKey *rsa.PrivateKey
|
|
|
|
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
|
|
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return err
|
|
}
|
|
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
|
encoded := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: der,
|
|
})
|
|
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
service.privateKey = privateKey
|
|
} else {
|
|
block, _ := pem.Decode(fprivateKey)
|
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
service.privateKey = privateKey
|
|
}
|
|
|
|
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
|
|
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return err
|
|
}
|
|
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
publicKey := service.privateKey.Public()
|
|
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
|
encoded := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PUBLIC KEY",
|
|
Bytes: der,
|
|
})
|
|
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
service.publicKey = publicKey
|
|
} else {
|
|
block, _ := pem.Decode(fpublicKey)
|
|
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
service.publicKey = publicKey
|
|
}
|
|
|
|
// We will reorganize the client into a map with the client ID as the key
|
|
service.clients = make(map[string]config.OIDCClientConfig)
|
|
|
|
for id, client := range service.config.Clients {
|
|
client.ID = id
|
|
service.clients[client.ClientID] = client
|
|
}
|
|
|
|
// Load the client secrets from files if they exist
|
|
for id, client := range service.clients {
|
|
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
|
if secret != "" {
|
|
client.ClientSecret = secret
|
|
}
|
|
client.ClientSecretFile = ""
|
|
service.clients[id] = client
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (service *OIDCService) GetIssuer() string {
|
|
return service.config.Issuer
|
|
}
|
|
|
|
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
|
client, ok := service.clients[id]
|
|
return client, ok
|
|
}
|
|
|
|
func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error {
|
|
// Validate client ID
|
|
client, ok := service.GetClient(req.ClientID)
|
|
if !ok {
|
|
return errors.New("access_denied")
|
|
}
|
|
|
|
// Scopes
|
|
scopes := strings.Split(req.Scope, " ")
|
|
|
|
if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" {
|
|
return errors.New("invalid_scope")
|
|
}
|
|
|
|
for _, scope := range scopes {
|
|
if strings.TrimSpace(scope) == "" {
|
|
return errors.New("invalid_scope")
|
|
}
|
|
if !slices.Contains(SupportedScopes, scope) {
|
|
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
|
|
}
|
|
}
|
|
|
|
// Response type
|
|
if !slices.Contains(SupportedResponseTypes, req.ResponseType) {
|
|
return errors.New("unsupported_response_type")
|
|
}
|
|
|
|
// Redirect URI
|
|
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
|
|
return errors.New("invalid_request_uri")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (service *OIDCService) filterScopes(scopes []string) []string {
|
|
return utils.Filter(scopes, func(scope string) bool {
|
|
return slices.Contains(SupportedScopes, scope)
|
|
})
|
|
}
|
|
|
|
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
|
|
// Fixed 10 minutes
|
|
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
|
|
|
|
// Insert the code into the database
|
|
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
|
Sub: sub,
|
|
Code: code,
|
|
// Here it's safe to split and trust the output since, we validated the scopes before
|
|
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
|
RedirectURI: req.RedirectURI,
|
|
ClientID: req.ClientID,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
|
|
return err
|
|
}
|
|
|
|
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
|
|
userInfoParams := repository.CreateOidcUserInfoParams{
|
|
Sub: sub,
|
|
Name: userContext.Name,
|
|
Email: userContext.Email,
|
|
PreferredUsername: userContext.Username,
|
|
UpdatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
|
if userContext.Provider == "ldap" {
|
|
userInfoParams.Groups = userContext.LdapGroups
|
|
}
|
|
|
|
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
|
|
userInfoParams.Groups = userContext.OAuthGroups
|
|
}
|
|
|
|
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
|
|
|
|
return err
|
|
}
|
|
|
|
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
|
if !slices.Contains(SupportedGrantTypes, grantType) {
|
|
return errors.New("unsupported_response_type")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
|
|
oidcCode, err := service.queries.GetOidcCode(c, code)
|
|
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return repository.OidcCode{}, ErrCodeNotFound
|
|
}
|
|
return repository.OidcCode{}, err
|
|
}
|
|
|
|
if time.Now().Unix() > oidcCode.ExpiresAt {
|
|
err = service.queries.DeleteOidcCode(c, code)
|
|
if err != nil {
|
|
return repository.OidcCode{}, err
|
|
}
|
|
err = service.DeleteUserinfo(c, oidcCode.Sub)
|
|
if err != nil {
|
|
return repository.OidcCode{}, err
|
|
}
|
|
return repository.OidcCode{}, ErrCodeExpired
|
|
}
|
|
|
|
return oidcCode, nil
|
|
}
|
|
|
|
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
|
|
createdAt := time.Now().Unix()
|
|
|
|
// TODO: This should probably be user-configured if refresh logic does not exist
|
|
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
|
|
|
claims := jws.ClaimSet{
|
|
Iss: service.issuer,
|
|
Aud: client.ClientID,
|
|
Sub: sub,
|
|
Iat: createdAt,
|
|
Exp: expiresAt,
|
|
}
|
|
|
|
header := jws.Header{
|
|
Algorithm: "RS256",
|
|
Typ: "JWT",
|
|
}
|
|
|
|
token, err := jws.Encode(&header, &claims, service.privateKey)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) {
|
|
idToken, err := service.generateIDToken(client, sub)
|
|
|
|
if err != nil {
|
|
return TokenResponse{}, err
|
|
}
|
|
|
|
accessToken := rand.Text()
|
|
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
|
|
|
tokenResponse := TokenResponse{
|
|
AccessToken: accessToken,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: int64(time.Hour.Seconds()),
|
|
IDToken: idToken,
|
|
Scope: strings.ReplaceAll(scope, ",", " "),
|
|
}
|
|
|
|
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
|
Sub: sub,
|
|
AccessToken: accessToken,
|
|
Scope: scope,
|
|
ExpiresAt: expiresAt,
|
|
})
|
|
|
|
if err != nil {
|
|
return TokenResponse{}, err
|
|
}
|
|
|
|
return tokenResponse, nil
|
|
}
|
|
|
|
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
|
|
return service.queries.DeleteOidcCode(c, code)
|
|
}
|
|
|
|
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
|
return service.queries.DeleteOidcUserInfo(c, sub)
|
|
}
|
|
|
|
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
|
|
return service.queries.DeleteOidcToken(c, token)
|
|
}
|
|
|
|
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
|
|
entry, err := service.queries.GetOidcToken(c, token)
|
|
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return repository.OidcToken{}, ErrTokenNotFound
|
|
}
|
|
return repository.OidcToken{}, err
|
|
}
|
|
|
|
if entry.ExpiresAt < time.Now().Unix() {
|
|
err := service.DeleteToken(c, token)
|
|
if err != nil {
|
|
return repository.OidcToken{}, err
|
|
}
|
|
err = service.DeleteUserinfo(c, entry.Sub)
|
|
if err != nil {
|
|
return repository.OidcToken{}, err
|
|
}
|
|
return repository.OidcToken{}, ErrTokenExpired
|
|
}
|
|
|
|
return entry, nil
|
|
}
|
|
|
|
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
|
|
return service.queries.GetOidcUserInfo(c, sub)
|
|
}
|
|
|
|
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
|
|
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
|
|
userInfo := UserinfoResponse{
|
|
Sub: user.Sub,
|
|
UpdatedAt: user.UpdatedAt,
|
|
}
|
|
|
|
if slices.Contains(scopes, "profile") {
|
|
userInfo.Name = user.Name
|
|
userInfo.PreferredUsername = user.PreferredUsername
|
|
}
|
|
|
|
if slices.Contains(scopes, "email") {
|
|
userInfo.Email = user.Email
|
|
}
|
|
|
|
if slices.Contains(scopes, "groups") {
|
|
userInfo.Groups = strings.Split(user.Groups, ",")
|
|
}
|
|
|
|
return userInfo
|
|
}
|