diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 2ed63545..014c9db6 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -107,7 +107,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { clientIP := c.ClientIP() - if controller.auth.IsBypassedIP(&acls.IP, clientIP) { + if controller.auth.IsBypassedIP(clientIP, acls) { controller.setHeaders(c, *acls) c.JSON(200, gin.H{ "status": 200, @@ -116,7 +116,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, &acls.Path) + authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) if err != nil { tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") @@ -134,7 +134,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if !controller.auth.CheckIP(&acls.IP, clientIP) { + if !controller.auth.CheckIP(clientIP, acls) { queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], IP: clientIP, @@ -213,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { var groupOK bool if userContext.IsOAuth() { - groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls.OAuth.Groups) + groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls) } else { - groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls.LDAP.Groups) + groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls) } if !groupOK { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index ad162a90..168b1eea 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -86,10 +86,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - basic, err := m.auth.GetBasicAuth(c.Request) + username, password, ok := c.Request.BasicAuth() - if err == nil { - userContext, headers, err := m.basicAuth(c.Request.Context(), basic) + if ok { + userContext, headers, err := m.basicAuth(username, password) if err != nil { tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) @@ -188,39 +188,39 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model return userContext, cookie, nil } -func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) { +func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) { headers := make(map[string]string) userContext := new(model.UserContext) - locked, remaining := m.auth.IsAccountLocked(basic.Username) + locked, remaining := m.auth.IsAccountLocked(username) if locked { - tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) + tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) headers["x-tinyauth-lock-locked"] = "true" headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) return nil, headers, nil } - search, err := m.auth.SearchUser(basic.Username) + search, err := m.auth.SearchUser(username) if err != nil { return nil, nil, fmt.Errorf("error searching for user: %w", err) } - err = m.auth.CheckUserPassword(*search, basic.Password) + err = m.auth.CheckUserPassword(*search, password) if err != nil { - m.auth.RecordLoginAttempt(basic.Username, false) + m.auth.RecordLoginAttempt(username, false) return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) } - m.auth.RecordLoginAttempt(basic.Username, true) + m.auth.RecordLoginAttempt(username, true) switch search.Type { case model.UserLocal: - user := m.auth.GetLocalUser(basic.Username) + user := m.auth.GetLocalUser(username) if user.TOTPSecret != "" { - return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username) + return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username) } userContext.Local = &model.LocalContext{ @@ -233,7 +233,7 @@ func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUse } userContext.Provider = model.ProviderLocal case model.UserLDAP: - user, err := m.auth.GetLDAPUser(basic.Username) + user, err := m.auth.GetLDAPUser(username) if err != nil { return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) @@ -241,9 +241,9 @@ func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUse userContext.LDAP = &model.LDAPContext{ BaseContext: model.BaseContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), + Username: username, + Name: utils.Capitalize(username), + Email: utils.CompileUserEmail(username, m.config.CookieDomain), }, Groups: user.Groups, } diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index 065117ec..dedd6dd3 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -1,7 +1,6 @@ package service import ( - "errors" "strings" "github.com/tinyauthapp/tinyauth/internal/model" @@ -28,26 +27,26 @@ func (acls *AccessControlsService) Init() error { return nil // No initialization needed } -func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) { +func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { for app, config := range acls.static { if config.Config.Domain == domain { tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") - return &config, nil + return &config } if strings.SplitN(domain, ".", 2)[0] == app { tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") - return &config, nil + return &config } } - return nil, errors.New("no results") + return nil } func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { // First check in the static config - app, err := acls.lookupStaticACLs(domain) + app := acls.lookupStaticACLs(domain) - if err == nil { + if app != nil { tlog.App.Debug().Msg("Using ACls from static configuration") return app, nil } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index be01ccd6..3c5946b1 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -464,8 +464,8 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } -func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool { - if requiredGroups == "" { +func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { return true } @@ -480,8 +480,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex } for _, userGroup := range context.OAuth.Groups { - if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") + if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { + tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") return true } } @@ -490,8 +490,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex return false } -func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool { - if requiredGroups == "" { +func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { + if acls == nil { return true } @@ -501,8 +501,8 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext } for _, userGroup := range context.LDAP.Groups { - if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") + if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { + tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") return true } } @@ -511,14 +511,14 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext return false } -func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, error) { - if path == nil { +func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) { + if acls == nil { return true, nil } // Check for block list - if path.Block != "" { - regex, err := regexp.Compile(path.Block) + if acls.Path.Block != "" { + regex, err := regexp.Compile(acls.Path.Block) if err != nil { return true, err @@ -530,8 +530,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e } // Check for allow list - if path.Allow != "" { - regex, err := regexp.Compile(path.Allow) + if acls.Path.Allow != "" { + regex, err := regexp.Compile(acls.Path.Allow) if err != nil { return true, err @@ -545,29 +545,14 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e return true, nil } -// local user is used only as a medium to pass the basic auth credentials, user can be ldap too -func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) { - if req == nil { - return nil, errors.New("request is nil") - } - username, password, ok := req.BasicAuth() - if !ok { - return nil, errors.New("no basic auth credentials provided") - } - return &model.LocalUser{ - Username: username, - Password: password, - }, nil -} - -func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool { +func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { if acls == nil { - acls = &model.AppIP{} + return true } // Merge the global and app IP filter - blockedIps := append(auth.config.IP.Block, acls.Block...) - allowedIPs := append(auth.config.IP.Allow, acls.Allow...) + blockedIps := append(auth.config.IP.Block, acls.IP.Block...) + allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) for _, blocked := range blockedIps { res, err := utils.FilterIP(blocked, ip) @@ -602,12 +587,12 @@ func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool { return true } -func (auth *AuthService) IsBypassedIP(acls *model.AppIP, ip string) bool { +func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { if acls == nil { return false } - for _, bypassed := range acls.Bypass { + for _, bypassed := range acls.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index f47cd10e..c5f95dd4 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -51,19 +51,11 @@ func (docker *DockerService) Init() error { } func (docker *DockerService) getContainers() ([]container.Summary, error) { - containers, err := docker.client.ContainerList(docker.context, container.ListOptions{}) - if err != nil { - return nil, err - } - return containers, nil + return docker.client.ContainerList(docker.context, container.ListOptions{}) } func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) { - inspect, err := docker.client.ContainerInspect(docker.context, containerId) - if err != nil { - return container.InspectResponse{}, err - } - return inspect, nil + return docker.client.ContainerInspect(docker.context, containerId) } func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index a3358ed6..11a60100 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -89,7 +89,7 @@ func (k *KubernetesService) removeIngress(namespace, name string) { } } -func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) { +func (k *KubernetesService) getByDomain(domain string) *model.App { k.mu.RLock() defer k.mu.RUnlock() @@ -97,15 +97,15 @@ func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { for _, app := range apps { if app.domain == domain && app.appName == appKey.appName { - return &app.app, true + return &app.app } } } } - return nil, false + return nil } -func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) { +func (k *KubernetesService) getByAppName(appName string) *model.App { k.mu.RLock() defer k.mu.RUnlock() @@ -113,12 +113,12 @@ func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { for _, app := range apps { if app.appName == appName { - return &app.app, true + return &app.app } } } } - return nil, false + return nil } func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { @@ -287,12 +287,14 @@ func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { } // First check cache - if app, found := k.getByDomain(appDomain); found { + app := k.getByDomain(appDomain) + if app != nil { tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") return app, nil } appName := strings.SplitN(appDomain, ".", 2)[0] - if app, found := k.getByAppName(appName); found { + app = k.getByAppName(appName) + if app != nil { tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") return app, nil }