From dbadb096b4a863c64e53aee6978211fd9fb15246 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 19:33:52 +0300 Subject: [PATCH] feat: create oauth broker service --- internal/config/config.go | 11 - internal/controller/oauth_controller.go | 45 +- internal/controller/proxy_controller.go | 23 +- internal/controller/user_controller.go | 14 +- internal/middleware/context_middleware.go | 33 +- internal/service/auth_service.go | 60 +-- internal/service/generic_oauth_service.go | 6 - internal/service/github_oauth_service.go | 4 - internal/service/google_oauth_service.go | 4 - internal/service/oauth_broker_service.go | 76 +++ internal/utils/utils.go | 43 +- internal/utils/utils_test.go | 548 ---------------------- 12 files changed, 184 insertions(+), 683 deletions(-) create mode 100644 internal/service/oauth_broker_service.go delete mode 100644 internal/utils/utils_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 48dc47f..5584d0e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,5 @@ package config -import "time" - type Claims struct { Name string `json:"name"` Email string `json:"email"` @@ -100,7 +98,6 @@ type OAuthServiceConfig struct { TokenURL string UserinfoURL string InsecureSkipVerify bool - Name string } type User struct { @@ -114,8 +111,6 @@ type UserSearch struct { Type string // local, ldap or unknown } -type Users []User - type SessionCookie struct { Username string Name string @@ -137,12 +132,6 @@ type UserContext struct { TotpEnabled bool } -type LoginAttempt struct { - FailedAttempts int - LastAttempt time.Time - LockedUntil time.Time -} - type UnauthorizedQuery struct { Username string `url:"username"` Resource string `url:"resource"` diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 63b6322..0178af6 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -5,9 +5,8 @@ import ( "net/http" "strings" "time" - "tinyauth/internal/auth" - "tinyauth/internal/providers" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -26,18 +25,18 @@ type OAuthControllerConfig struct { } type OAuthController struct { - Config OAuthControllerConfig - Router *gin.RouterGroup - Auth *auth.Auth - Providers *providers.Providers + Config OAuthControllerConfig + Router *gin.RouterGroup + Auth *service.AuthService + Broker *service.OAuthBrokerService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController { +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { return &OAuthController{ - Config: config, - Router: router, - Auth: auth, - Providers: providers, + Config: config, + Router: router, + Auth: auth, + Broker: broker, } } @@ -59,9 +58,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - provider := controller.Providers.GetProvider(req.Provider) + service, exists := controller.Broker.GetService(req.Provider) - if provider == nil { + if !exists { c.JSON(404, gin.H{ "status": 404, "message": "Not Found", @@ -69,8 +68,8 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - state := provider.GenerateState() - authURL := provider.GetAuthURL(state) + state := service.GenerateState() + authURL := service.GetAuthURL(state) c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) redirectURI := c.Query("redirect_uri") @@ -109,20 +108,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) code := c.Query("code") - provider := controller.Providers.GetProvider(req.Provider) + service, exists := controller.Broker.GetService(req.Provider) - if provider == nil { + if !exists { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } - _, err = provider.ExchangeToken(code) + err = service.VerifyCode(code) if err != nil { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) return } - user, err := controller.Providers.GetUser(req.Provider) + user, err := controller.Broker.GetUser(req.Provider) if err != nil { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) @@ -135,7 +134,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } if !controller.Auth.EmailWhitelisted(user.Email) { - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Username: user.Email, }) @@ -156,7 +155,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Email, Name: name, Email: user.Email, @@ -171,7 +170,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - queries, err := query.Values(types.RedirectQuery{ + queries, err := query.Values(config.RedirectQuery{ RedirectURI: redirectURI, }) diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index f8476f0..ced09bf 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -4,9 +4,8 @@ import ( "fmt" "net/http" "strings" - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -24,11 +23,11 @@ type ProxyControllerConfig struct { type ProxyController struct { Config ProxyControllerConfig Router *gin.RouterGroup - Docker *docker.Docker - Auth *auth.Auth + Docker *service.DockerService + Auth *service.AuthService } -func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController { +func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController { return &ProxyController{ Config: config, Router: router, @@ -109,7 +108,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(host, ".")[0], IP: clientIP, }) @@ -157,12 +156,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - var userContext types.UserContext + var userContext config.UserContext context, err := utils.GetContext(c) if err != nil { - userContext = types.UserContext{ + userContext = config.UserContext{ IsLoggedIn: false, } } else { @@ -185,7 +184,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(host, ".")[0], }) @@ -216,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.UnauthorizedQuery{ + queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(host, ".")[0], GroupErr: true, }) @@ -268,7 +267,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(types.RedirectQuery{ + queries, err := query.Values(config.RedirectQuery{ RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), }) diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index e017826..77bb6f3 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -3,8 +3,8 @@ package controller import ( "fmt" "strings" - "tinyauth/internal/auth" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -27,10 +27,10 @@ type UserControllerConfig struct { type UserController struct { Config UserControllerConfig Router *gin.RouterGroup - Auth *auth.Auth + Auth *service.AuthService } -func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController { +func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { return &UserController{ Config: config, Router: router, @@ -101,7 +101,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { user := controller.Auth.GetLocalUser(userSearch.Username) if user.TotpSecret != "" { - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), @@ -118,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: req.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), @@ -202,7 +202,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { controller.Auth.RecordLoginAttempt(rateIdentifier, true) - controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index ead4879..a83d465 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -3,9 +3,8 @@ package middleware import ( "fmt" "strings" - "tinyauth/internal/auth" - "tinyauth/internal/providers" - "tinyauth/internal/types" + "tinyauth/internal/config" + "tinyauth/internal/service" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -16,16 +15,16 @@ type ContextMiddlewareConfig struct { } type ContextMiddleware struct { - Config ContextMiddlewareConfig - Auth *auth.Auth - Providers *providers.Providers + Config ContextMiddlewareConfig + Auth *service.AuthService + Broker *service.OAuthBrokerService } -func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware { +func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { return &ContextMiddleware{ - Config: config, - Auth: auth, - Providers: providers, + Config: config, + Auth: auth, + Broker: broker, } } @@ -46,7 +45,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { } if cookie.TotpPending { - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: cookie.Username, Name: cookie.Name, Email: cookie.Email, @@ -66,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { goto basic } - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: cookie.Username, Name: cookie.Name, Email: cookie.Email, @@ -76,9 +75,9 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return default: - provider := m.Providers.GetProvider(cookie.Provider) + _, exists := m.Broker.GetService(cookie.Provider) - if provider == nil { + if !exists { goto basic } @@ -87,7 +86,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { goto basic } - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: cookie.Username, Name: cookie.Name, Email: cookie.Email, @@ -124,7 +123,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { case "local": user := m.Auth.GetLocalUser(basic.Username) - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain), @@ -135,7 +134,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return case "ldap": - c.Set("context", &types.UserContext{ + c.Set("context", &config.UserContext{ Username: basic.Username, Name: utils.Capitalize(basic.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain), diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ebbd1ad..46bad06 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -6,7 +6,7 @@ import ( "strings" "sync" "time" - "tinyauth/internal/types" + "tinyauth/internal/config" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -15,8 +15,14 @@ import ( "golang.org/x/crypto/bcrypt" ) +type LoginAttempt struct { + FailedAttempts int + LastAttempt time.Time + LockedUntil time.Time +} + type AuthServiceConfig struct { - Users types.Users + Users []config.User OauthWhitelist string SessionExpiry int CookieSecure bool @@ -31,7 +37,7 @@ type AuthServiceConfig struct { type AuthService struct { Config AuthServiceConfig Docker *DockerService - LoginAttempts map[string]*types.LoginAttempt + LoginAttempts map[string]*LoginAttempt LoginMutex sync.RWMutex Store *sessions.CookieStore LDAP *LdapService @@ -41,7 +47,7 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS return &AuthService{ Config: config, Docker: docker, - LoginAttempts: make(map[string]*types.LoginAttempt), + LoginAttempts: make(map[string]*LoginAttempt), LDAP: ldap, } } @@ -75,13 +81,13 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { return session, nil } -func (auth *AuthService) SearchUser(username string) types.UserSearch { +func (auth *AuthService) SearchUser(username string) config.UserSearch { log.Debug().Str("username", username).Msg("Searching for user") // Check local users first if auth.GetLocalUser(username).Username != "" { log.Debug().Str("username", username).Msg("Found local user") - return types.UserSearch{ + return config.UserSearch{ Username: username, Type: "local", } @@ -93,20 +99,20 @@ func (auth *AuthService) SearchUser(username string) types.UserSearch { userDN, err := auth.LDAP.Search(username) if err != nil { log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") - return types.UserSearch{} + return config.UserSearch{} } - return types.UserSearch{ + return config.UserSearch{ Username: userDN, Type: "ldap", } } - return types.UserSearch{ + return config.UserSearch{ Type: "unknown", } } -func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { +func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { // Authenticate the user based on the type switch search.Type { case "local": @@ -144,7 +150,7 @@ func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bo return false } -func (auth *AuthService) GetLocalUser(username string) types.User { +func (auth *AuthService) GetLocalUser(username string) config.User { // Loop through users and return the user if the username matches log.Debug().Str("username", username).Msg("Searching for local user") @@ -156,10 +162,10 @@ func (auth *AuthService) GetLocalUser(username string) types.User { // If no user found, return an empty user log.Warn().Str("username", username).Msg("Local user not found") - return types.User{} + return config.User{} } -func (auth *AuthService) CheckPassword(user types.User, password string) bool { +func (auth *AuthService) CheckPassword(user config.User, password string) bool { return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } @@ -201,7 +207,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { // Get current attempt record or create a new one attempt, exists := auth.LoginAttempts[identifier] if !exists { - attempt = &types.LoginAttempt{} + attempt = &LoginAttempt{} auth.LoginAttempts[identifier] = attempt } @@ -229,7 +235,7 @@ func (auth *AuthService) EmailWhitelisted(email string) bool { return utils.CheckFilter(auth.Config.OauthWhitelist, email) } -func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { +func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error { log.Debug().Msg("Creating session cookie") session, err := auth.GetSession(c) @@ -288,13 +294,13 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { return nil } -func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { +func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) { log.Debug().Msg("Getting session cookie") session, err := auth.GetSession(c) if err != nil { log.Error().Err(err).Msg("Failed to get session") - return types.SessionCookie{}, err + return config.SessionCookie{}, err } log.Debug().Msg("Got session") @@ -311,18 +317,18 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { log.Warn().Msg("Session cookie is invalid") auth.DeleteSessionCookie(c) - return types.SessionCookie{}, nil + return config.SessionCookie{}, nil } // If the session cookie has expired, delete it if time.Now().Unix() > expiry { log.Warn().Msg("Session cookie expired") auth.DeleteSessionCookie(c) - return types.SessionCookie{}, nil + return config.SessionCookie{}, nil } log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") - return types.SessionCookie{ + return config.SessionCookie{ Username: username, Name: name, Email: email, @@ -337,7 +343,7 @@ func (auth *AuthService) UserAuthConfigured() bool { return len(auth.Config.Users) > 0 || auth.LDAP != nil } -func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool { if context.OAuth { log.Debug().Msg("Checking OAuth whitelist") return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) @@ -347,7 +353,7 @@ func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserConte return utils.CheckFilter(labels.Users, context.Username) } -func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool { if labels.OAuth.Groups == "" { return true } @@ -374,7 +380,7 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, l return false } -func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { +func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) { // If the label is empty, auth is enabled if labels.Allowed == "" { return true, nil @@ -398,18 +404,18 @@ func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, err return true, nil } -func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { +func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { username, password, ok := c.Request.BasicAuth() if !ok { return nil } - return &types.User{ + return &config.User{ Username: username, Password: password, } } -func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { +func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { // Check if the IP is in block list for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) @@ -446,7 +452,7 @@ func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { return true } -func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { +func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool { // For every IP in the bypass list, check if the IP matches for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index 9bd6a8e..c68d150 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -19,7 +19,6 @@ type GenericOAuthService struct { Token *oauth2.Token Verifier string InsecureSkipVerify bool - ServiceName string UserinfoURL string } @@ -36,7 +35,6 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi }, }, InsecureSkipVerify: config.InsecureSkipVerify, - ServiceName: config.Name, UserinfoURL: config.UserinfoURL, } } @@ -63,10 +61,6 @@ func (generic *GenericOAuthService) Init() error { return nil } -func (generic *GenericOAuthService) Name() string { - return generic.ServiceName -} - func (generic *GenericOAuthService) GenerateState() string { b := make([]byte, 128) rand.Read(b) diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 57d8391..a8c1334 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -54,10 +54,6 @@ func (github *GithubOAuthService) Init() error { return nil } -func (github *GithubOAuthService) Name() string { - return "github" -} - func (github *GithubOAuthService) GenerateState() string { b := make([]byte, 128) rand.Read(b) diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 2d86a56..6d9eaed 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -49,10 +49,6 @@ func (google *GoogleOAuthService) Init() error { return nil } -func (google *GoogleOAuthService) Name() string { - return "google" -} - func (oauth *GoogleOAuthService) GenerateState() string { b := make([]byte, 128) rand.Read(b) diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go new file mode 100644 index 0000000..6b5b1e6 --- /dev/null +++ b/internal/service/oauth_broker_service.go @@ -0,0 +1,76 @@ +package service + +import ( + "errors" + "tinyauth/internal/config" + + "github.com/rs/zerolog/log" +) + +type OAuthService interface { + Init() error + GenerateState() string + GetAuthURL(state string) string + VerifyCode(code string) error + Userinfo() (config.Claims, error) +} + +type OAuthBrokerService struct { + Services map[string]OAuthService + Configs map[string]config.OAuthServiceConfig +} + +func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { + return &OAuthBrokerService{ + Services: make(map[string]OAuthService), + Configs: configs, + } +} + +func (broker *OAuthBrokerService) Init() error { + for name, cfg := range broker.Configs { + switch name { + case "github": + service := NewGithubOAuthService(cfg) + broker.Services[name] = service + case "google": + service := NewGoogleOAuthService(cfg) + broker.Services[name] = service + default: + service := NewGenericOAuthService(cfg) + broker.Services[name] = service + } + } + + for name, service := range broker.Services { + err := service.Init() + if err != nil { + log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name) + return err + } + log.Info().Msgf("Initialized OAuth service: %s", name) + } + + return nil +} + +func (broker *OAuthBrokerService) GetConfiguredServices() []string { + services := make([]string, 0, len(broker.Services)) + for name := range broker.Services { + services = append(services, name) + } + return services +} + +func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) { + service, exists := broker.Services[name] + return service, exists +} + +func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) { + oauthService, exists := broker.Services[service] + if !exists { + return config.Claims{}, errors.New("oauth service not found") + } + return oauthService.Userinfo() +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 8c2f4ea..67b904f 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -11,7 +11,7 @@ import ( "os" "regexp" "strings" - "tinyauth/internal/types" + "tinyauth/internal/config" "github.com/gin-gonic/gin" "github.com/traefik/paerser/parser" @@ -22,21 +22,21 @@ import ( ) // Parses a list of comma separated users in a struct -func ParseUsers(users string) (types.Users, error) { +func ParseUsers(users string) ([]config.User, error) { log.Debug().Msg("Parsing users") - var usersParsed types.Users + var usersParsed []config.User userList := strings.Split(users, ",") if len(userList) == 0 { - return types.Users{}, errors.New("invalid user format") + return []config.User{}, errors.New("invalid user format") } for _, user := range userList { parsed, err := ParseUser(user) if err != nil { - return types.Users{}, err + return []config.User{}, err } usersParsed = append(usersParsed, parsed) } @@ -107,11 +107,11 @@ func GetSecret(conf string, file string) string { } // Get the users from the config or file -func GetUsers(conf string, file string) (types.Users, error) { +func GetUsers(conf string, file string) ([]config.User, error) { var users string if conf == "" && file == "" { - return types.Users{}, nil + return []config.User{}, nil } if conf != "" { @@ -152,23 +152,18 @@ func ParseHeaders(headers []string) map[string]string { } // Get labels parses a map of labels into a struct with only the needed labels -func GetLabels(labels map[string]string) (types.Labels, error) { - var labelsParsed types.Labels +func GetLabels(labels map[string]string) (config.Labels, error) { + var labelsParsed config.Labels err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") if err != nil { log.Error().Err(err).Msg("Error parsing labels") - return types.Labels{}, err + return config.Labels{}, err } return labelsParsed, nil } -// Check if any of the OAuth providers are configured based on the client id and secret -func OAuthConfigured(config types.Config) bool { - return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") -} - // Filter helper function func Filter[T any](slice []T, test func(T) bool) (res []T) { for _, value := range slice { @@ -180,7 +175,7 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) { } // Parse user -func ParseUser(user string) (types.User, error) { +func ParseUser(user string) (config.User, error) { if strings.Contains(user, "$$") { user = strings.ReplaceAll(user, "$$", "$") } @@ -188,23 +183,23 @@ func ParseUser(user string) (types.User, error) { userSplit := strings.Split(user, ":") if len(userSplit) < 2 || len(userSplit) > 3 { - return types.User{}, errors.New("invalid user format") + return config.User{}, errors.New("invalid user format") } for _, userPart := range userSplit { if strings.TrimSpace(userPart) == "" { - return types.User{}, errors.New("invalid user format") + return config.User{}, errors.New("invalid user format") } } if len(userSplit) == 2 { - return types.User{ + return config.User{ Username: strings.TrimSpace(userSplit[0]), Password: strings.TrimSpace(userSplit[1]), }, nil } - return types.User{ + return config.User{ Username: strings.TrimSpace(userSplit[0]), Password: strings.TrimSpace(userSplit[1]), TotpSecret: strings.TrimSpace(userSplit[2]), @@ -350,17 +345,17 @@ func CoalesceToString(value any) string { } } -func GetContext(c *gin.Context) (types.UserContext, error) { +func GetContext(c *gin.Context) (config.UserContext, error) { userContextValue, exists := c.Get("context") if !exists { - return types.UserContext{}, errors.New("no user context in request") + return config.UserContext{}, errors.New("no user context in request") } - userContext, ok := userContextValue.(*types.UserContext) + userContext, ok := userContextValue.(*config.UserContext) if !ok { - return types.UserContext{}, errors.New("invalid user context in request") + return config.UserContext{}, errors.New("invalid user context in request") } return *userContext, nil diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go deleted file mode 100644 index 5ae7e89..0000000 --- a/internal/utils/utils_test.go +++ /dev/null @@ -1,548 +0,0 @@ -package utils_test - -import ( - "fmt" - "os" - "reflect" - "testing" - "tinyauth/internal/types" - "tinyauth/internal/utils" -) - -func TestParseUsers(t *testing.T) { - t.Log("Testing parse users with a valid string") - - users := "user1:pass1,user2:pass2" - expected := types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err := utils.ParseUsers(users) - if err != nil { - t.Fatalf("Error parsing users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestGetUpperDomain(t *testing.T) { - t.Log("Testing get upper domain with a valid url") - - url := "https://sub1.sub2.domain.com:8080" - expected := "sub2.domain.com" - - result, err := utils.GetUpperDomain(url) - if err != nil { - t.Fatalf("Error getting root url: %v", err) - } - - if expected != result { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestReadFile(t *testing.T) { - t.Log("Creating a test file") - - err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - t.Log("Testing read file with a valid file") - - data, err := utils.ReadFile("/tmp/test.txt") - if err != nil { - t.Fatalf("Error reading file: %v", err) - } - - if data != "test" { - t.Fatalf("Expected test, got %v", data) - } - - t.Log("Cleaning up test file") - - err = os.Remove("/tmp/test.txt") - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestParseFileToLine(t *testing.T) { - t.Log("Testing parse file to line with a valid string") - - content := "\nuser1:pass1\nuser2:pass2\n" - expected := "user1:pass1,user2:pass2" - - result := utils.ParseFileToLine(content) - - if expected != result { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestGetSecret(t *testing.T) { - t.Log("Testing get secret with an empty config and file") - - conf := "" - file := "/tmp/test.txt" - expected := "test" - - err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - result := utils.GetSecret(conf, file) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get secret with an empty file and a valid config") - - result = utils.GetSecret(expected, "") - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get secret with both a valid config and file") - - result = utils.GetSecret(expected, file) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Cleaning up test file") - - err = os.Remove(file) - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestGetUsers(t *testing.T) { - t.Log("Testing get users with a config and no file") - - conf := "user1:pass1,user2:pass2" - file := "" - expected := types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err := utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get users with a file and no config") - - conf = "" - file = "/tmp/test.txt" - expected = types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - result, err = utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get users with both a config and file") - - conf = "user3:pass3" - expected = types.Users{ - { - Username: "user3", - Password: "pass3", - }, - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err = utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Cleaning up test file") - - err = os.Remove(file) - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestGetLabels(t *testing.T) { - t.Log("Testing get labels with a valid map") - - labels := map[string]string{ - "tinyauth.users": "user1,user2", - "tinyauth.oauth.whitelist": "/regex/", - "tinyauth.allowed": "random", - "tinyauth.headers": "X-Header=value", - "tinyauth.oauth.groups": "group1,group2", - } - - expected := types.Labels{ - Users: "user1,user2", - Allowed: "random", - Headers: []string{"X-Header=value"}, - OAuth: types.OAuthLabels{ - Whitelist: "/regex/", - Groups: "group1,group2", - }, - } - - result, err := utils.GetLabels(labels) - if err != nil { - t.Fatalf("Error getting labels: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseUser(t *testing.T) { - t.Log("Testing parse user with a valid user") - - user := "user:pass:secret" - expected := types.User{ - Username: "user", - Password: "pass", - TotpSecret: "secret", - } - - result, err := utils.ParseUser(user) - if err != nil { - t.Fatalf("Error parsing user: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse user with an escaped user") - - user = "user:p$$ass$$:secret" - expected = types.User{ - Username: "user", - Password: "p$ass$", - TotpSecret: "secret", - } - - result, err = utils.ParseUser(user) - if err != nil { - t.Fatalf("Error parsing user: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse user with an invalid user") - - user = "user::pass" - - _, err = utils.ParseUser(user) - if err == nil { - t.Fatalf("Expected error parsing user") - } -} - -func TestCheckFilter(t *testing.T) { - t.Log("Testing check filter with a comma separated list") - - filter := "user1,user2,user3" - str := "user1" - expected := true - - result := utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with a regex filter") - - filter = "/^user[0-9]+$/" - str = "user1" - expected = true - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with an empty filter") - - filter = "" - str = "user1" - expected = true - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with an invalid regex filter") - - filter = "/^user[0-9+$/" - str = "user1" - expected = false - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with a non matching list") - - filter = "user1,user2,user3" - str = "user4" - expected = false - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestSanitizeHeader(t *testing.T) { - t.Log("Testing sanitize header with a valid string") - - str := "X-Header=value" - expected := "X-Header=value" - - result := utils.SanitizeHeader(str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing sanitize header with an invalid string") - - str = "X-Header=val\nue" - expected = "X-Header=value" - - result = utils.SanitizeHeader(str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseHeaders(t *testing.T) { - t.Log("Testing parse headers with a valid string") - - headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"} - expected := map[string]string{ - "X-Header1": "value1", - "X-Header2": "value2", - } - - result := utils.ParseHeaders(headers) - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse headers with an invalid string") - - headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"} - expected = map[string]string{"X-Header3": "value3"} - - result = utils.ParseHeaders(headers) - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseSecretFile(t *testing.T) { - t.Log("Testing parse secret file with a valid file") - - content := "\n\n \n\n\n secret \n\n \n " - expected := "secret" - - result := utils.ParseSecretFile(content) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestFilterIP(t *testing.T) { - t.Log("Testing filter IP with an IP and a valid CIDR") - - ip := "10.10.10.10" - filter := "10.10.10.0/24" - expected := true - - result, err := utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and a valid IP") - - filter = "10.10.10.10" - expected = true - - result, err = utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and an non matching CIDR") - - filter = "10.10.15.0/24" - expected = false - - result, err = utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with a non matching IP and a valid CIDR") - - filter = "10.10.10.11" - expected = false - - result, err = utils.FilterIP(filter, ip) - - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and an invalid CIDR") - - filter = "10.../83" - - _, err = utils.FilterIP(filter, ip) - if err == nil { - t.Fatalf("Expected error filtering IP") - } -} - -func TestDeriveKey(t *testing.T) { - t.Log("Testing the derive key function") - - master := "master" - info := "info" - expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl" - - result, err := utils.DeriveKey(master, info) - - if err != nil { - t.Fatalf("Error deriving key: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestCoalesceToString(t *testing.T) { - t.Log("Testing coalesce to string with a string") - - value := any("test") - expected := "test" - - result := utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing coalesce to string with a slice of strings") - - value = []any{any("test1"), any("test2"), any(123)} - expected = "test1,test2" - - result = utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing coalesce to string with an unsupported type") - - value = 12345 - expected = "" - - result = utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -}