diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 651d9d8..5dd98f9 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -1,10 +1,13 @@ package middleware import ( + "context" + "fmt" + "net/http" "strings" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -33,7 +36,8 @@ var ( ) type ContextMiddlewareConfig struct { - CookieDomain string + CookieDomain string + SessionCookieName string } type ContextMiddleware struct { @@ -61,194 +65,42 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - cookie, err := m.auth.GetSessionCookie(c) + uuid, err := c.Cookie(m.config.SessionCookieName) - if err != nil { - tlog.App.Debug().Err(err).Msg("No valid session cookie found") - goto basic - } - - if cookie.TotpPending { - c.Set("context", &config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: "local", - TotpPending: true, - TotpEnabled: true, - }) - c.Next() - return - } - - switch cookie.Provider { - case "local", "ldap": - userSearch := m.auth.SearchUser(cookie.Username) - - if userSearch.Type == "unknown" { - tlog.App.Debug().Msg("User from session cookie not found") - m.auth.DeleteSessionCookie(c) - goto basic - } - - if userSearch.Type != cookie.Provider { - tlog.App.Warn().Msg("User type from session cookie does not match user search type") - m.auth.DeleteSessionCookie(c) - c.Next() - return - } - - var ldapGroups []string - var localAttributes config.UserAttributes - - if cookie.Provider == "ldap" { - ldapUser, err := m.auth.GetLdapUser(userSearch.Username) - - if err != nil { - tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details") - c.Next() - return - } - - ldapGroups = ldapUser.Groups - } - - if cookie.Provider == "local" { - localUser := m.auth.GetLocalUser(cookie.Username) - localAttributes = localUser.Attributes - } - - m.auth.RefreshSessionCookie(c) - c.Set("context", &config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - IsLoggedIn: true, - LdapGroups: strings.Join(ldapGroups, ","), - Attributes: localAttributes, - }) - c.Next() - return - default: - _, exists := m.broker.GetService(cookie.Provider) - - if !exists { - tlog.App.Debug().Msg("OAuth provider from session cookie not found") - m.auth.DeleteSessionCookie(c) - goto basic - } - - if !m.auth.IsEmailWhitelisted(cookie.Email) { - tlog.App.Debug().Msg("Email from session cookie not whitelisted") - m.auth.DeleteSessionCookie(c) - goto basic - } - - m.auth.RefreshSessionCookie(c) - c.Set("context", &config.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - OAuthGroups: cookie.OAuthGroups, - OAuthName: cookie.OAuthName, - OAuthSub: cookie.OAuthSub, - IsLoggedIn: true, - OAuth: true, - }) - c.Next() - return - } - - basic: - basic := m.auth.GetBasicAuth(c) - - if basic == nil { - tlog.App.Debug().Msg("No basic auth provided") - c.Next() - return - } - - locked, remaining := m.auth.IsAccountLocked(basic.Username) - - if locked { - tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) - c.Writer.Header().Add("x-tinyauth-lock-locked", "true") - c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) - c.Next() - return - } - - userSearch := m.auth.SearchUser(basic.Username) - - if userSearch.Type == "unknown" || userSearch.Type == "error" { - m.auth.RecordLoginAttempt(basic.Username, false) - tlog.App.Debug().Msg("User from basic auth not found") - c.Next() - return - } - - if !m.auth.VerifyUser(userSearch, basic.Password) { - m.auth.RecordLoginAttempt(basic.Username, false) - tlog.App.Debug().Msg("Invalid password for basic auth user") - c.Next() - return - } - - m.auth.RecordLoginAttempt(basic.Username, true) - - switch userSearch.Type { - case "local": - tlog.App.Debug().Msg("Basic auth user is local") - - user := m.auth.GetLocalUser(basic.Username) - - if user.TotpSecret != "" { - tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth") - return - } - - name := utils.Capitalize(user.Username) - if user.Attributes.Name != "" { - name = user.Attributes.Name - } - email := utils.CompileUserEmail(user.Username, m.config.CookieDomain) - if user.Attributes.Email != "" { - email = user.Attributes.Email - } - - c.Set("context", &config.UserContext{ - Username: user.Username, - Name: name, - Email: email, - Provider: "local", - IsLoggedIn: true, - IsBasicAuth: true, - Attributes: user.Attributes, - }) - c.Next() - return - case "ldap": - tlog.App.Debug().Msg("Basic auth user is LDAP") - - ldapUser, err := m.auth.GetLdapUser(basic.Username) + if err == nil { + userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) if err != nil { - tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details") + tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) c.Next() return } - c.Set("context", &config.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), - Provider: "ldap", - IsLoggedIn: true, - LdapGroups: strings.Join(ldapUser.Groups, ","), - IsBasicAuth: true, - }) + if cookie != nil { + http.SetCookie(c.Writer, cookie) + } + + c.Set("context", userContext) + c.Next() + return + } + + basic, err := m.auth.GetBasicAuth(c.Request) + + if err == nil { + userContext, headers, err := m.basicAuth(c.Request.Context(), basic) + + if err != nil { + tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) + c.Next() + return + } + + for k, v := range headers { + c.Header(k, v) + } + + c.Set("context", userContext) c.Next() return } @@ -257,6 +109,150 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { } } +func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) { + session, err := m.auth.GetSession(ctx, uuid) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving session: %w", err) + } + + userContext, err := new(model.UserContext).NewFromSession(session) + + if err != nil { + return nil, nil, fmt.Errorf("error creating user context from session: %w", err) + } + + if userContext.Provider == model.ProviderLocal && + userContext.Local.TOTPPending { + userContext.Local.TOTPEnabled = true + return userContext, nil, nil + } + + switch userContext.Provider { + case model.ProviderLocal: + user := m.auth.GetLocalUser(userContext.Local.Username) + + if user == nil { + return nil, nil, fmt.Errorf("local user not found") + } + + userContext.Local.Attributes = user.Attributes + + if userContext.Local.Attributes.Name == "" { + userContext.Local.Attributes.Name = utils.Capitalize(user.Username) + } + + if userContext.Local.Attributes.Email == "" { + userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain) + } + case model.ProviderLDAP: + search, err := m.auth.SearchUser(userContext.LDAP.Username) + + if err != nil { + return nil, nil, fmt.Errorf("error searching for ldap user: %w", err) + } + + if search.Type != model.UserLDAP { + return nil, nil, fmt.Errorf("user from session cookie is not ldap") + } + + user, err := m.auth.GetLDAPUser(search.Username) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) + } + + userContext.LDAP.Groups = user.Groups + userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) + userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain) + case model.ProviderOAuth: + _, exists := m.broker.GetService(userContext.OAuth.ID) + + if !exists { + return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) + } + + if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { + m.auth.DeleteSession(ctx, uuid) + return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) + } + } + + cookie, err := m.auth.RefreshSession(ctx, uuid) + + if err != nil { + return nil, nil, fmt.Errorf("error refreshing session: %w", err) + } + + return userContext, cookie, nil +} + +func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) { + headers := make(map[string]string) + userContext := new(model.UserContext) + locked, remaining := m.auth.IsAccountLocked(basic.Username) + + if locked { + tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) + headers["x-tinyauth-lock-locked"] = "true" + headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) + return nil, headers, nil + } + + search, err := m.auth.SearchUser(basic.Username) + + if err != nil { + return nil, nil, fmt.Errorf("error searching for user: %w", err) + } + + err = m.auth.CheckUserPassword(*search, basic.Password) + + if err != nil { + m.auth.RecordLoginAttempt(basic.Username, false) + return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) + } + + m.auth.RecordLoginAttempt(basic.Username, true) + + switch search.Type { + case model.UserLocal: + user := m.auth.GetLocalUser(basic.Username) + + if user.TOTPSecret != "" { + return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username) + } + + userContext.Local = &model.LocalContext{ + BaseContext: model.BaseContext{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), + }, + Attributes: user.Attributes, + } + userContext.Provider = model.ProviderLocal + case model.UserLDAP: + user, err := m.auth.GetLDAPUser(basic.Username) + + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) + } + + userContext.LDAP = &model.LDAPContext{ + BaseContext: model.BaseContext{ + Username: basic.Username, + Name: utils.Capitalize(basic.Username), + Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), + }, + Groups: user.Groups, + } + userContext.Provider = model.ProviderLDAP + } + + userContext.Authenticated = true + return userContext, nil, nil +} + func (m *ContextMiddleware) isIgnorePath(path string) bool { for _, prefix := range contextSkipPathsPrefix { if strings.HasPrefix(path, prefix) { diff --git a/internal/config/config.go b/internal/model/config.go similarity index 89% rename from internal/config/config.go rename to internal/model/config.go index e364b45..28b0881 100644 --- a/internal/config/config.go +++ b/internal/model/config.go @@ -1,4 +1,4 @@ -package config +package model // Default configuration func NewDefaultConfiguration() *Config { @@ -29,7 +29,7 @@ func NewDefaultConfiguration() *Config { BackgroundImage: "/background.jpg", WarningsEnabled: true, }, - Ldap: LdapConfig{ + LDAP: LDAPConfig{ Insecure: false, SearchFilter: "(uid=%s)", GroupCacheTTL: 900, // 15 minutes @@ -63,20 +63,6 @@ func NewDefaultConfiguration() *Config { } } -// Version information, set at build time - -var Version = "development" -var CommitHash = "development" -var BuildTimestamp = "0000-00-00T00:00:00Z" - -// Cookie name templates - -var SessionCookieName = "tinyauth-session" -var CSRFCookieName = "tinyauth-csrf" -var RedirectCookieName = "tinyauth-redirect" -var OAuthSessionCookieName = "tinyauth-oauth" - -// Main app config type Config struct { AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"` Database DatabaseConfig `description:"Database configuration." yaml:"database"` @@ -88,7 +74,7 @@ type Config struct { OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"` OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"` UI UIConfig `description:"UI customization." yaml:"ui"` - Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"` + LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"` Log LogConfig `description:"Logging configuration." yaml:"log"` @@ -177,7 +163,7 @@ type UIConfig struct { WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"` } -type LdapConfig struct { +type LDAPConfig struct { Address string `description:"LDAP server address." yaml:"address"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` @@ -210,20 +196,6 @@ type ExperimentalConfig struct { ConfigFile string `description:"Path to config file." yaml:"-"` } -// Config loader options - -const DefaultNamePrefix = "TINYAUTH_" - -// OAuth/OIDC config - -type Claims struct { - Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Groups any `json:"groups"` -} - type OAuthServiceConfig struct { ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` @@ -246,47 +218,6 @@ type OIDCClientConfig struct { Name string `description:"Client name in UI." yaml:"name"` } -var OverrideProviders = map[string]string{ - "google": "Google", - "github": "GitHub", -} - -// User/session related stuff - -type User struct { - Username string - Password string - TotpSecret string - Attributes UserAttributes -} - -type LdapUser struct { - DN string - Groups []string -} - -type UserSearch struct { - Username string - Type string // local, ldap or unknown -} - -type UserContext struct { - Username string - Name string - Email string - IsLoggedIn bool - IsBasicAuth bool - OAuth bool - Provider string - TotpPending bool - OAuthGroups string - TotpEnabled bool - OAuthName string - OAuthSub string - LdapGroups string - Attributes UserAttributes -} - // API responses and queries type UnauthorizedQuery struct { @@ -355,7 +286,3 @@ type AppPath struct { Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"` Block string `description:"Comma-separated list of blocked paths." yaml:"block"` } - -// API server - -var ApiServer = "https://api.tinyauth.app" diff --git a/internal/model/constants.go b/internal/model/constants.go new file mode 100644 index 0000000..d9e85e5 --- /dev/null +++ b/internal/model/constants.go @@ -0,0 +1,23 @@ +package model + +const DefaultNamePrefix = "TINYAUTH_" + +const APIServer = "https://api.tinyauth.app" + +type Claims struct { + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` +} + +var OverrideProviders = map[string]string{ + "google": "Google", + "github": "GitHub", +} + +const SessionCookieName = "tinyauth-session" +const CSRFCookieName = "tinyauth-csrf" +const RedirectCookieName = "tinyauth-redirect" +const OAuthSessionCookieName = "tinyauth-oauth" diff --git a/internal/model/context.go b/internal/model/context.go new file mode 100644 index 0000000..ad75b4f --- /dev/null +++ b/internal/model/context.go @@ -0,0 +1,179 @@ +package model + +import ( + "errors" + "strings" + + "github.com/gin-gonic/gin" + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +type ProviderType int + +const ( + ProviderLocal ProviderType = iota + ProviderBasicAuth + ProviderOAuth + ProviderLDAP +) + +type UserContext struct { + Authenticated bool + Provider ProviderType + Local *LocalContext + OAuth *OAuthContext + LDAP *LDAPContext +} + +type BaseContext struct { + Username string + Name string + Email string +} + +type LocalContext struct { + BaseContext + TOTPPending bool + TOTPEnabled bool + Attributes UserAttributes +} + +type OAuthContext struct { + BaseContext + Groups []string + Sub string + DisplayName string + ID string +} + +type LDAPContext struct { + BaseContext + Groups []string +} + +func (c *UserContext) IsAuthenticated() bool { + return c.Authenticated +} + +func (c *UserContext) IsLocal() bool { + return c.Provider == ProviderLocal +} + +func (c *UserContext) IsOAuth() bool { + return c.Provider == ProviderOAuth +} + +func (c *UserContext) IsLDAP() bool { + return c.Provider == ProviderLDAP +} + +func (c *UserContext) IsBasicAuth() bool { + return c.Provider == ProviderBasicAuth +} + +func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { + userContextValue, exists := ginctx.Get("context") + + if !exists { + return nil, errors.New("failed to get user context") + } + + userContext, ok := userContextValue.(*UserContext) + + if !ok { + return nil, errors.New("invalid user context type") + } + + *c = *userContext + return c, nil +} + +// Compatability layer until we get an excuse to drop in database migrations +func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) { + switch session.Provider { + case "local": + c.Provider = ProviderLocal + c.Local = &LocalContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + TOTPPending: session.TotpPending, + } + case "ldap": + c.Provider = ProviderLDAP + c.LDAP = &LDAPContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + } + // By default we assume an unkown name which is oauth + default: + c.Provider = ProviderOAuth + c.OAuth = &OAuthContext{ + BaseContext: BaseContext{ + Username: session.Username, + Name: session.Name, + Email: session.Email, + }, + Groups: strings.Split(session.OAuthGroups, ","), + Sub: session.OAuthSub, + DisplayName: session.OAuthName, + ID: session.Provider, + } + } + + if !session.TotpPending { + c.Authenticated = true + } + + return c, nil +} + +func (c *UserContext) GetUsername() string { + switch c.Provider { + case ProviderLocal: + return c.Local.Username + case ProviderLDAP: + return c.LDAP.Username + case ProviderBasicAuth: + return c.Local.Username + case ProviderOAuth: + return c.OAuth.Username + default: + return "" + } +} + +func (c *UserContext) GetEmail() string { + switch c.Provider { + case ProviderLocal: + return c.Local.Email + case ProviderLDAP: + return c.LDAP.Email + case ProviderBasicAuth: + return c.Local.Email + case ProviderOAuth: + return c.OAuth.Email + default: + return "" + } +} + +func (c *UserContext) GetName() string { + switch c.Provider { + case ProviderLocal: + return c.Local.Name + case ProviderLDAP: + return c.LDAP.Name + case ProviderBasicAuth: + return c.Local.Name + case ProviderOAuth: + return c.OAuth.Name + default: + return "" + } +} diff --git a/internal/model/users.go b/internal/model/users.go new file mode 100644 index 0000000..48826fd --- /dev/null +++ b/internal/model/users.go @@ -0,0 +1,25 @@ +package model + +type UserSearchType int + +const ( + UserLocal UserSearchType = iota + UserLDAP +) + +type LDAPUser struct { + DN string + Groups []string +} + +type LocalUser struct { + Username string + Password string + TOTPSecret string + Attributes UserAttributes +} + +type UserSearch struct { + Username string + Type UserSearchType +} diff --git a/internal/model/version.go b/internal/model/version.go new file mode 100644 index 0000000..cd8bc13 --- /dev/null +++ b/internal/model/version.go @@ -0,0 +1,5 @@ +package model + +var Version = "development" +var CommitHash = "development" +var BuildTimestamp = "0000-00-00T00:00:00Z" diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index d054b5f..065117e 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -4,20 +4,20 @@ import ( "errors" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) type LabelProvider interface { - GetLabels(appDomain string) (config.App, error) + GetLabels(appDomain string) (*model.App, error) } type AccessControlsService struct { labelProvider LabelProvider - static map[string]config.App + static map[string]model.App } -func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService { +func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { return &AccessControlsService{ labelProvider: labelProvider, static: static, @@ -28,22 +28,22 @@ func (acls *AccessControlsService) Init() error { return nil // No initialization needed } -func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) { +func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) { for app, config := range acls.static { if config.Config.Domain == domain { tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") - return config, nil + return &config, nil } if strings.SplitN(domain, ".", 2)[0] == app { tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") - return config, nil + return &config, nil } } - return config.App{}, errors.New("no results") + return nil, errors.New("no results") } -func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) { +func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { // First check in the static config app, err := acls.lookupStaticACLs(domain) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 0311229..148340f 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -5,12 +5,13 @@ import ( "database/sql" "errors" "fmt" + "net/http" "regexp" "strings" "sync" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -68,7 +69,7 @@ type Lockdown struct { } type AuthServiceConfig struct { - Users []config.User + LocalUsers []model.LocalUser OauthWhitelist []string SessionExpiry int SessionMaxLifetime int @@ -77,7 +78,7 @@ type AuthServiceConfig struct { LoginTimeout int LoginMaxRetries int SessionCookieName string - IP config.IPConfig + IP model.IPConfig LDAPGroupsCacheTTL int } @@ -106,7 +107,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi ldap: ldap, queries: queries, oauthBroker: oauthBroker, -} + } } func (auth *AuthService) Init() error { @@ -114,79 +115,67 @@ func (auth *AuthService) Init() error { return nil } -func (auth *AuthService) SearchUser(username string) config.UserSearch { +func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { if auth.GetLocalUser(username).Username != "" { - return config.UserSearch{ + return &model.UserSearch{ Username: username, - Type: "local", - } + Type: model.UserLocal, + }, nil } if auth.ldap.IsConfigured() { userDN, err := auth.ldap.GetUserDN(username) if err != nil { - tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") - return config.UserSearch{ - Type: "unknown", - } + return nil, fmt.Errorf("failed to get ldap user: %w", err) } - return config.UserSearch{ + return &model.UserSearch{ Username: userDN, - Type: "ldap", - } + Type: model.UserLDAP, + }, nil } - return config.UserSearch{ - Type: "unknown", - } + return nil, fmt.Errorf("user not found") } -func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { +func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error { switch search.Type { - case "local": + case model.UserLocal: user := auth.GetLocalUser(search.Username) - return auth.CheckPassword(user, password) - case "ldap": + return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) + case model.UserLDAP: if auth.ldap.IsConfigured() { err := auth.ldap.Bind(search.Username, password) if err != nil { - tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") - return false + return fmt.Errorf("failed to bind to ldap user: %w", err) } err = auth.ldap.BindService(true) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication") - return false + return fmt.Errorf("failed to bind to ldap service account: %w", err) } - return true + return nil } default: - tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication") - return false + return errors.New("unknown user search type") } - - tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed") - return false + return errors.New("user authentication failed") } -func (auth *AuthService) GetLocalUser(username string) config.User { - for _, user := range auth.config.Users { +func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { + for _, user := range auth.config.LocalUsers { if user.Username == username { - return user + return &user } } - - tlog.App.Warn().Str("username", username).Msg("Local user not found") - return config.User{} + return nil } -func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { +func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { if !auth.ldap.IsConfigured() { - return config.LdapUser{}, errors.New("LDAP service not initialized") + return nil, errors.New("ldap service not configured") } auth.ldapGroupsMutex.RLock() @@ -194,7 +183,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { auth.ldapGroupsMutex.RUnlock() if exists && time.Now().Before(entry.Expires) { - return config.LdapUser{ + return &model.LDAPUser{ DN: userDN, Groups: entry.Groups, }, nil @@ -203,7 +192,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { groups, err := auth.ldap.GetUserGroups(userDN) if err != nil { - return config.LdapUser{}, err + return nil, fmt.Errorf("failed to get ldap groups: %w", err) } auth.ldapGroupsMutex.Lock() @@ -213,16 +202,12 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { } auth.ldapGroupsMutex.Unlock() - return config.LdapUser{ + return &model.LDAPUser{ DN: userDN, Groups: groups, }, nil } -func (auth *AuthService) CheckPassword(user config.User, password string) bool { - return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil -} - func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { auth.loginMutex.RLock() defer auth.loginMutex.RUnlock() @@ -291,11 +276,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool { return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) } -func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error { +func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { uuid, err := uuid.NewRandom() if err != nil { - return err + return nil, fmt.Errorf("failed to generate session uuid: %w", err) } var expiry int @@ -320,28 +305,30 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se OAuthSub: data.OAuthSub, } - _, err = auth.queries.CreateSession(c, session) + _, err = auth.queries.CreateSession(ctx, session) if err != nil { - return err + return nil, fmt.Errorf("failed to create session entry: %w", err) } - c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - - return nil + return &http.Cookie{ + Name: auth.config.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Expires: time.Now().Add(time.Duration(expiry) * time.Second), + MaxAge: expiry, + Secure: auth.config.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } -func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.config.SessionCookieName) +func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { + session, err := auth.queries.GetSession(ctx, uuid) if err != nil { - return err - } - - session, err := auth.queries.GetSession(c, cookie) - - if err != nil { - return err + return nil, fmt.Errorf("failed to retrieve session: %w", err) } currentTime := time.Now().Unix() @@ -355,12 +342,12 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { } if session.Expiry-currentTime > refreshThreshold { - return nil + return nil, fmt.Errorf("session not eligible for refresh yet") } newExpiry := session.Expiry + refreshThreshold - _, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{ + _, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{ Username: session.Username, Email: session.Email, Name: session.Name, @@ -374,120 +361,117 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { }) if err != nil { - return err + return nil, fmt.Errorf("failed to update session expiry: %w", err) } - c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed") + return &http.Cookie{ + Name: auth.config.SessionCookieName, + Value: session.UUID, + Path: "/", + Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), + MaxAge: auth.config.SessionExpiry, + Secure: auth.config.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil - return nil } -func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.config.SessionCookieName) +func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { + err := auth.queries.DeleteSession(ctx, uuid) if err != nil { - return err + tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") } - err = auth.queries.DeleteSession(c, cookie) - - if err != nil { - return err - } - - c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true) - - return nil + return &http.Cookie{ + Name: auth.config.SessionCookieName, + Value: "", + Path: "/", + Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Expires: time.Now(), + MaxAge: -1, + Secure: auth.config.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }, nil } -func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) { - cookie, err := c.Cookie(auth.config.SessionCookieName) - - if err != nil { - return repository.Session{}, err - } - - session, err := auth.queries.GetSession(c, cookie) +func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) { + session, err := auth.queries.GetSession(ctx, uuid) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return repository.Session{}, fmt.Errorf("session not found") + return nil, errors.New("session not found") } - return repository.Session{}, err + return nil, err } currentTime := time.Now().Unix() if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { - err = auth.queries.DeleteSession(c, cookie) + err = auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime") + return nil, fmt.Errorf("failed to delete expired session: %w", err) } - return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded") + return nil, fmt.Errorf("session max lifetime exceeded") } } if currentTime > session.Expiry { - err = auth.queries.DeleteSession(c, cookie) + err = auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete expired session") + return nil, fmt.Errorf("failed to delete expired session: %w", err) } - return repository.Session{}, fmt.Errorf("session expired") + return nil, fmt.Errorf("session expired") } - return repository.Session{ - UUID: session.UUID, - Username: session.Username, - Email: session.Email, - Name: session.Name, - Provider: session.Provider, - TotpPending: session.TotpPending, - OAuthGroups: session.OAuthGroups, - OAuthName: session.OAuthName, - OAuthSub: session.OAuthSub, - }, nil + return &session, nil } func (auth *AuthService) LocalAuthConfigured() bool { - return len(auth.config.Users) > 0 + return len(auth.config.LocalUsers) > 0 } -func (auth *AuthService) LdapAuthConfigured() bool { +func (auth *AuthService) LDAPAuthConfigured() bool { return auth.ldap.IsConfigured() } -func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool { - if context.OAuth { +func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls model.App) bool { + if context.Provider == model.ProviderOAuth { tlog.App.Debug().Msg("Checking OAuth whitelist") - return utils.CheckFilter(acls.OAuth.Whitelist, context.Email) + return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) } if acls.Users.Block != "" { tlog.App.Debug().Msg("Checking blocked users") - if utils.CheckFilter(acls.Users.Block, context.Username) { + if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { return false } } tlog.App.Debug().Msg("Checking users") - return utils.CheckFilter(acls.Users.Allow, context.Username) + return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } -func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool { +func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool { if requiredGroups == "" { return true } - for id := range config.OverrideProviders { - if context.Provider == id { - tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider") - return true - } + if !context.IsOAuth() { + tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + return false } - for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") { + if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { + tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") + return true + } + + for _, userGroup := range context.OAuth.Groups { if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") return true @@ -498,12 +482,17 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte return false } -func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool { +func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool { if requiredGroups == "" { return true } - for userGroup := range strings.SplitSeq(context.LdapGroups, ",") { + if !context.IsLDAP() { + tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + return false + } + + for _, userGroup := range context.LDAP.Groups { if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") return true @@ -514,7 +503,7 @@ func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContex return false } -func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) { +func (auth *AuthService) IsAuthEnabled(uri string, path model.AppPath) (bool, error) { // Check for block list if path.Block != "" { regex, err := regexp.Compile(path.Block) @@ -544,19 +533,22 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e return true, nil } -func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { - username, password, ok := c.Request.BasicAuth() - if !ok { - tlog.App.Debug().Msg("No basic auth provided") - return nil +// local user is used only as a medium to pass the basic auth credentials, user can be ldap too +func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) { + if req == nil { + return nil, errors.New("request is nil") } - return &config.User{ + username, password, ok := req.BasicAuth() + if !ok { + return nil, errors.New("no basic auth credentials provided") + } + return &model.LocalUser{ Username: username, Password: password, - } + }, nil } -func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool { +func (auth *AuthService) CheckIP(acls model.AppIP, ip string) bool { // Merge the global and app IP filter blockedIps := append(auth.config.IP.Block, acls.Block...) allowedIPs := append(auth.config.IP.Allow, acls.Allow...) @@ -594,7 +586,7 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool { return true } -func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { +func (auth *AuthService) IsBypassedIP(acls model.AppIP, ip string) bool { for _, bypassed := range acls.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { @@ -674,21 +666,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return token, nil } -func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { +func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) { session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { - return config.Claims{}, err + return nil, err } if session.Token == nil { - return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId) + return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) } userinfo, err := (*session.Service).GetUserinfo(session.Token) if err != nil { - return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err) + return nil, fmt.Errorf("failed to get userinfo: %w", err) } return userinfo, nil diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 9717924..f47cd10 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -4,7 +4,7 @@ import ( "context" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -66,41 +66,41 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins return inspect, nil } -func (docker *DockerService) GetLabels(appDomain string) (config.App, error) { +func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { if !docker.isConnected { tlog.App.Debug().Msg("Docker not connected, returning empty labels") - return config.App{}, nil + return nil, nil } containers, err := docker.getContainers() if err != nil { - return config.App{}, err + return nil, err } for _, ctr := range containers { inspect, err := docker.inspectContainer(ctr.ID) if err != nil { - return config.App{}, err + return nil, err } - labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps") + labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps") if err != nil { - return config.App{}, err + return nil, err } for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") - return appLabels, nil + return &appLabels, nil } if strings.SplitN(appDomain, ".", 2)[0] == appName { tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") - return appLabels, nil + return &appLabels, nil } } } tlog.App.Debug().Msg("No matching container found, returning empty labels") - return config.App{}, nil + return nil, nil } diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 6e11eac..a3358ed 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -32,7 +32,7 @@ type ingressAppKey struct { type ingressApp struct { domain string appName string - app config.App + app model.App } type KubernetesService struct { @@ -89,7 +89,7 @@ func (k *KubernetesService) removeIngress(namespace, name string) { } } -func (k *KubernetesService) getByDomain(domain string) (config.App, bool) { +func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) { k.mu.RLock() defer k.mu.RUnlock() @@ -97,15 +97,15 @@ func (k *KubernetesService) getByDomain(domain string) (config.App, bool) { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { for _, app := range apps { if app.domain == domain && app.appName == appKey.appName { - return app.app, true + return &app.app, true } } } } - return config.App{}, false + return nil, false } -func (k *KubernetesService) getByAppName(appName string) (config.App, bool) { +func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) { k.mu.RLock() defer k.mu.RUnlock() @@ -113,12 +113,12 @@ func (k *KubernetesService) getByAppName(appName string) (config.App, bool) { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { for _, app := range apps { if app.appName == appName { - return app.app, true + return &app.app, true } } } } - return config.App{}, false + return nil, false } func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { @@ -129,7 +129,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { k.removeIngress(namespace, name) return } - labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps") + labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") if err != nil { tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") k.removeIngress(namespace, name) @@ -280,10 +280,10 @@ func (k *KubernetesService) Init() error { return nil } -func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) { +func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") - return config.App{}, nil + return nil, nil } // First check cache @@ -298,6 +298,5 @@ func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) { } tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") - return config.App{}, nil + return nil, nil } - diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 610a882..15823c4 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,7 +1,7 @@ package service import ( - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "slices" @@ -15,20 +15,20 @@ type OAuthServiceImpl interface { NewRandom() string GetAuthURL(state string, verifier string) string GetToken(code string, verifier string) (*oauth2.Token, error) - GetUserinfo(token *oauth2.Token) (config.Claims, error) + GetUserinfo(token *oauth2.Token) (*model.Claims, error) } type OAuthBrokerService struct { services map[string]OAuthServiceImpl - configs map[string]config.OAuthServiceConfig + configs map[string]model.OAuthServiceConfig } -var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{ +var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ "github": newGitHubOAuthService, "google": newGoogleOAuthService, } -func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { +func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { return &OAuthBrokerService{ services: make(map[string]OAuthServiceImpl), configs: configs, diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go index 45d03f7..09515e2 100644 --- a/internal/service/oauth_extractors.go +++ b/internal/service/oauth_extractors.go @@ -8,7 +8,7 @@ import ( "net/http" "strconv" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" ) type GithubEmailResponse []struct { @@ -22,32 +22,32 @@ type GithubUserInfoResponse struct { ID int `json:"id"` } -func defaultExtractor(client *http.Client, url string) (config.Claims, error) { - return simpleReq[config.Claims](client, url, nil) +func defaultExtractor(client *http.Client, url string) (*model.Claims, error) { + return simpleReq[model.Claims](client, url, nil) } -func githubExtractor(client *http.Client, url string) (config.Claims, error) { - var user config.Claims +func githubExtractor(client *http.Client, url string) (*model.Claims, error) { + var user model.Claims userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { - return config.Claims{}, err + return nil, err } userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ "accept": "application/vnd.github+json", }) if err != nil { - return config.Claims{}, err + return nil, err } - if len(userEmails) == 0 { - return user, errors.New("no emails found") + if len(*userEmails) == 0 { + return nil, errors.New("no emails found") } - for _, email := range userEmails { + for _, email := range *userEmails { if email.Primary { user.Email = email.Email break @@ -56,22 +56,22 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) { // Use first available email if no primary email was found if user.Email == "" { - user.Email = userEmails[0].Email + user.Email = (*userEmails)[0].Email } user.PreferredUsername = userInfo.Login user.Name = userInfo.Name user.Sub = strconv.Itoa(userInfo.ID) - return user, nil + return &user, nil } -func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) { +func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) { var decodedRes T req, err := http.NewRequest("GET", url, nil) if err != nil { - return decodedRes, err + return nil, err } for key, value := range headers { @@ -80,23 +80,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string res, err := client.Do(req) if err != nil { - return decodedRes, err + return nil, err } defer res.Body.Close() if res.StatusCode < 200 || res.StatusCode >= 300 { - return decodedRes, fmt.Errorf("request failed with status: %s", res.Status) + return nil, fmt.Errorf("request failed with status: %s", res.Status) } body, err := io.ReadAll(res.Body) if err != nil { - return decodedRes, err + return nil, err } err = json.Unmarshal(body, &decodedRes) if err != nil { - return decodedRes, err + return nil, err } - return decodedRes, nil + return &decodedRes, nil } diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index df23be5..ef21fa6 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -1,11 +1,11 @@ package service import ( - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2/endpoints" ) -func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { +func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { scopes := []string{"openid", "email", "profile"} config.Scopes = scopes config.AuthURL = endpoints.Google.AuthURL @@ -14,7 +14,7 @@ func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { return NewOAuthService(config, "google") } -func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { +func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { scopes := []string{"read:user", "user:email"} config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 4ef118e..11b0be9 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -6,21 +6,21 @@ import ( "net/http" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2" ) -type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error) +type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type OAuthService struct { - serviceCfg config.OAuthServiceConfig + serviceCfg model.OAuthServiceConfig config *oauth2.Config ctx context.Context userinfoExtractor UserinfoExtractor id string } -func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService { +func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -78,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) } -func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) { +func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1ac138a..1e1c198 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -22,7 +22,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v4" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" @@ -68,27 +68,27 @@ type ClaimSet struct { } type UserinfoResponse struct { - Sub string `json:"sub"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - MiddleName string `json:"middle_name,omitempty"` - Nickname string `json:"nickname,omitempty"` - Profile string `json:"profile,omitempty"` - Picture string `json:"picture,omitempty"` - Website string `json:"website,omitempty"` - Gender string `json:"gender,omitempty"` - Birthdate string `json:"birthdate,omitempty"` - Zoneinfo string `json:"zoneinfo,omitempty"` - Locale string `json:"locale,omitempty"` - Email string `json:"email,omitempty"` - PreferredUsername string `json:"preferred_username,omitempty"` - Groups []string `json:"groups,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - PhoneNumber string `json:"phone_number,omitempty"` - PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` - Address *config.AddressClaim `json:"address,omitempty"` - UpdatedAt int64 `json:"updated_at"` + Sub string `json:"sub"` + Name string `json:"name,omitempty"` + GivenName string `json:"given_name,omitempty"` + FamilyName string `json:"family_name,omitempty"` + MiddleName string `json:"middle_name,omitempty"` + Nickname string `json:"nickname,omitempty"` + Profile string `json:"profile,omitempty"` + Picture string `json:"picture,omitempty"` + Website string `json:"website,omitempty"` + Gender string `json:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty"` + Zoneinfo string `json:"zoneinfo,omitempty"` + Locale string `json:"locale,omitempty"` + Email string `json:"email,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Groups []string `json:"groups,omitempty"` + EmailVerified bool `json:"email_verified,omitempty"` + PhoneNumber string `json:"phone_number,omitempty"` + PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` + Address *model.AddressClaim `json:"address,omitempty"` + UpdatedAt int64 `json:"updated_at"` } type TokenResponse struct { @@ -112,7 +112,7 @@ type AuthorizeRequest struct { } type OIDCServiceConfig struct { - Clients map[string]config.OIDCClientConfig + Clients map[string]model.OIDCClientConfig PrivateKeyPath string PublicKeyPath string Issuer string @@ -122,7 +122,7 @@ type OIDCServiceConfig struct { type OIDCService struct { config OIDCServiceConfig queries *repository.Queries - clients map[string]config.OIDCClientConfig + clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey issuer string @@ -255,7 +255,7 @@ func (service *OIDCService) Init() error { } // We will reorganize the client into a map with the client ID as the key - service.clients = make(map[string]config.OIDCClientConfig) + service.clients = make(map[string]model.OIDCClientConfig) for id, client := range service.config.Clients { client.ID = id @@ -283,7 +283,7 @@ func (service *OIDCService) GetIssuer() string { return service.issuer } -func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) { +func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) { client, ok := service.clients[id] return client, ok } @@ -367,43 +367,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r return err } -func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error { - addressJSON, err := json.Marshal(userContext.Attributes.Address) - if err != nil { - return err - } - +func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { userInfoParams := repository.CreateOidcUserInfoParams{ Sub: sub, - Name: userContext.Name, - Email: userContext.Email, - PreferredUsername: userContext.Username, + Name: userContext.GetName(), + Email: userContext.GetEmail(), + PreferredUsername: userContext.GetUsername(), UpdatedAt: time.Now().Unix(), - GivenName: userContext.Attributes.GivenName, - FamilyName: userContext.Attributes.FamilyName, - MiddleName: userContext.Attributes.MiddleName, - Nickname: userContext.Attributes.Nickname, - Profile: userContext.Attributes.Profile, - Picture: userContext.Attributes.Picture, - Website: userContext.Attributes.Website, - Gender: userContext.Attributes.Gender, - Birthdate: userContext.Attributes.Birthdate, - Zoneinfo: userContext.Attributes.Zoneinfo, - Locale: userContext.Attributes.Locale, - PhoneNumber: userContext.Attributes.PhoneNumber, - Address: string(addressJSON), + } + + if userContext.IsLocal() { + addressJSON, err := json.Marshal(userContext.Local.Attributes.Address) + if err != nil { + return err + } + userInfoParams.GivenName = userContext.Local.Attributes.GivenName + userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName + userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName + userInfoParams.Nickname = userContext.Local.Attributes.Nickname + userInfoParams.Profile = userContext.Local.Attributes.Profile + userInfoParams.Picture = userContext.Local.Attributes.Picture + userInfoParams.Website = userContext.Local.Attributes.Website + userInfoParams.Gender = userContext.Local.Attributes.Gender + userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate + userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo + userInfoParams.Locale = userContext.Local.Attributes.Locale + userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber + userInfoParams.Address = string(addressJSON) } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server - if userContext.Provider == "ldap" { - userInfoParams.Groups = userContext.LdapGroups + if userContext.IsLDAP() { + userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") } - if userContext.OAuth && len(userContext.OAuthGroups) > 0 { - userInfoParams.Groups = userContext.OAuthGroups + if userContext.IsOAuth() { + userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") } - _, err = service.queries.CreateOidcUserInfo(c, userInfoParams) + _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) return err } @@ -445,7 +447,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client return oidcCode, nil } -func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { +func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() @@ -511,7 +513,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user return token, nil } -func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { +func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { user, err := service.GetUserinfo(c, codeEntry.Sub) if err != nil { @@ -585,7 +587,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri return TokenResponse{}, err } - idToken, err := service.generateIDToken(config.OIDCClientConfig{ + idToken, err := service.generateIDToken(model.OIDCClientConfig{ ClientID: entry.ClientID, }, user, entry.Scope, entry.Nonce) @@ -714,7 +716,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "address") { - var addr config.AddressClaim + var addr model.AddressClaim if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { userInfo.Address = &addr } diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 55665ee..e7206bd 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -7,10 +7,8 @@ import ( "net/url" "strings" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/gin-gonic/gin" "github.com/weppos/publicsuffix-go/publicsuffix" ) @@ -73,22 +71,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) { return res } -func GetContext(c *gin.Context) (config.UserContext, error) { - userContextValue, exists := c.Get("context") - - if !exists { - return config.UserContext{}, errors.New("no user context in request") - } - - userContext, ok := userContextValue.(*config.UserContext) - - if !ok { - return config.UserContext{}, errors.New("invalid user context in request") - } - - return *userContext, nil -} - func IsRedirectSafe(redirectURL string, domain string) bool { if redirectURL == "" { return false diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index a44c08d..9328593 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -3,10 +3,8 @@ package utils_test import ( "testing" - "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/gin-gonic/gin" "gotest.tools/v3/assert" ) @@ -129,28 +127,6 @@ func TestFilter(t *testing.T) { assert.DeepEqual(t, expectedStr, resultStr) } -func TestGetContext(t *testing.T) { - // Setup - gin.SetMode(gin.TestMode) - c, _ := gin.CreateTestContext(nil) - - // Normal case - c.Set("context", &config.UserContext{Username: "testuser"}) - result, err := utils.GetContext(c) - assert.NilError(t, err) - assert.Equal(t, "testuser", result.Username) - - // Case with no context - c.Set("context", nil) - _, err = utils.GetContext(c) - assert.Error(t, err, "invalid user context in request") - - // Case with invalid context type - c.Set("context", "invalid type") - _, err = utils.GetContext(c) - assert.Error(t, err, "invalid user context in request") -} - func TestIsRedirectSafe(t *testing.T) { // Setup domain := "example.com"