From e04980468ffa071feac75ec6cc9d47f42c94405f Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 5 May 2026 18:54:45 +0300 Subject: [PATCH] fix: review comments batch 2 --- internal/bootstrap/app_bootstrap.go | 4 +- internal/controller/proxy_controller_test.go | 2 +- internal/controller/user_controller_test.go | 2 +- .../middleware/context_middleware_test.go | 2 +- internal/model/context.go | 48 ++++++++++++++++--- internal/model/context_test.go | 16 +++++-- internal/service/access_controls_service.go | 9 ++-- internal/service/auth_service.go | 12 +++-- internal/service/kubernetes_service.go | 6 ++- 9 files changed, 78 insertions(+), 23 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 252d75bf..fc86a7ab 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -29,7 +29,7 @@ type BootstrapApp struct { csrfCookieName string redirectCookieName string oauthSessionCookieName string - localUsers []model.LocalUser + localUsers *[]model.LocalUser oauthProviders map[string]model.OAuthServiceConfig configuredProviders []controller.Provider oidcClients []model.OIDCClientConfig @@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error { return err } - app.context.localUsers = *users + app.context.localUsers = users // Setup OAuth providers app.context.oauthProviders = app.config.OAuth.Providers diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 6e4e4c0e..66c24a5e 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) { tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ - LocalUsers: []model.LocalUser{ + LocalUsers: &[]model.LocalUser{ { Username: "testuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 1667036e..18544c43 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -25,7 +25,7 @@ func TestUserController(t *testing.T) { tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ - LocalUsers: []model.LocalUser{ + LocalUsers: &[]model.LocalUser{ { Username: "testuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 4eac53ef..6e91a585 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -25,7 +25,7 @@ func TestContextMiddleware(t *testing.T) { tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ - LocalUsers: []model.LocalUser{ + LocalUsers: &[]model.LocalUser{ { Username: "testuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password diff --git a/internal/model/context.go b/internal/model/context.go index b268f362..7df204de 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -56,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool { } func (c *UserContext) IsLocal() bool { - return c.Provider == ProviderLocal + return c.Provider == ProviderLocal && c.Local != nil } func (c *UserContext) IsOAuth() bool { - return c.Provider == ProviderOAuth + return c.Provider == ProviderOAuth && c.OAuth != nil } func (c *UserContext) IsLDAP() bool { - return c.Provider == ProviderLDAP + return c.Provider == ProviderLDAP && c.LDAP != nil } func (c *UserContext) IsBasicAuth() bool { - return c.Provider == ProviderBasicAuth + return c.Provider == ProviderBasicAuth && c.Local != nil } func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { @@ -145,12 +145,24 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, func (c *UserContext) GetUsername() string { switch c.Provider { case ProviderLocal: + if c.Local == nil { + return "" + } return c.Local.Username case ProviderLDAP: + if c.LDAP == nil { + return "" + } return c.LDAP.Username case ProviderBasicAuth: + if c.Local == nil { + return "" + } return c.Local.Username case ProviderOAuth: + if c.OAuth == nil { + return "" + } return c.OAuth.Username default: return "" @@ -160,12 +172,24 @@ func (c *UserContext) GetUsername() string { func (c *UserContext) GetEmail() string { switch c.Provider { case ProviderLocal: + if c.Local == nil { + return "" + } return c.Local.Email case ProviderLDAP: + if c.LDAP == nil { + return "" + } return c.LDAP.Email case ProviderBasicAuth: + if c.Local == nil { + return "" + } return c.Local.Email case ProviderOAuth: + if c.OAuth == nil { + return "" + } return c.OAuth.Email default: return "" @@ -175,12 +199,24 @@ func (c *UserContext) GetEmail() string { func (c *UserContext) GetName() string { switch c.Provider { case ProviderLocal: + if c.Local == nil { + return "" + } return c.Local.Name case ProviderLDAP: + if c.LDAP == nil { + return "" + } return c.LDAP.Name case ProviderBasicAuth: + if c.Local == nil { + return "" + } return c.Local.Name case ProviderOAuth: + if c.OAuth == nil { + return "" + } return c.OAuth.Name default: return "" @@ -201,14 +237,14 @@ func (c *UserContext) ProviderName() string { } func (c *UserContext) TOTPPending() bool { - if c.Provider == ProviderLocal { + if c.Provider == ProviderLocal && c.Local != nil { return c.Local.TOTPPending } return false } func (c *UserContext) OAuthName() string { - if c.Provider == ProviderOAuth { + if c.Provider == ProviderOAuth && c.OAuth != nil { return c.OAuth.DisplayName } return "" diff --git a/internal/model/context_test.go b/internal/model/context_test.go index ad370af3..b45b9210 100644 --- a/internal/model/context_test.go +++ b/internal/model/context_test.go @@ -34,25 +34,25 @@ func TestContext(t *testing.T) { }, { description: "IsLocal returns true for ProviderLocal", - context: &model.UserContext{Provider: model.ProviderLocal}, + context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}}, run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, expected: true, }, { description: "IsOAuth returns true for ProviderOAuth", - context: &model.UserContext{Provider: model.ProviderOAuth}, + context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}}, run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, expected: true, }, { description: "IsLDAP returns true for ProviderLDAP", - context: &model.UserContext{Provider: model.ProviderLDAP}, + context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}}, run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, expected: true, }, { description: "IsBasicAuth returns true for ProviderBasicAuth", - context: &model.UserContext{Provider: model.ProviderBasicAuth}, + context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}}, run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, expected: true, }, @@ -258,6 +258,14 @@ func TestContext(t *testing.T) { }, expected: "incomplete user context", }, + { + description: "Getters should not panic if provider context is empty", + context: &model.UserContext{Provider: model.ProviderLocal}, + run: func(t *testing.T, c *model.UserContext) any { + return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()} + }, + expected: [3]string{"", "", ""}, + }, } for _, test := range tests { diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index dedd6dd3..fd57bf39 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -28,18 +28,21 @@ func (acls *AccessControlsService) Init() error { } func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { + var appAcls *model.App for app, config := range acls.static { if config.Config.Domain == domain { tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") - return &config + appAcls = &config + break // If we find a match by domain, we can stop searching } if strings.SplitN(domain, ".", 2)[0] == app { tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") - return &config + appAcls = &config + break // If we find a match by app name, we can stop searching } } - return nil + return appAcls } func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ca64dbdf..cad25608 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -73,7 +73,7 @@ type Lockdown struct { } type AuthServiceConfig struct { - LocalUsers []model.LocalUser + LocalUsers *[]model.LocalUser OauthWhitelist []string SessionExpiry int SessionMaxLifetime int @@ -147,6 +147,9 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str switch search.Type { case model.UserLocal: user := auth.GetLocalUser(search.Username) + if user == nil { + return ErrUserNotFound + } return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) case model.UserLDAP: if auth.ldap.IsConfigured() { @@ -169,7 +172,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { - for _, user := range auth.config.LocalUsers { + if auth.config.LocalUsers == nil { + return nil + } + for _, user := range *auth.config.LocalUsers { if user.Username == username { return &user } @@ -438,7 +444,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito } func (auth *AuthService) LocalAuthConfigured() bool { - return len(auth.config.LocalUsers) > 0 + return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 } func (auth *AuthService) LDAPAuthConfigured() bool { diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 11a60100..9c5ad427 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -95,7 +95,8 @@ func (k *KubernetesService) getByDomain(domain string) *model.App { if appKey, ok := k.domainIndex[domain]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { - for _, app := range apps { + for i := range apps { + app := &apps[i] if app.domain == domain && app.appName == appKey.appName { return &app.app } @@ -111,7 +112,8 @@ func (k *KubernetesService) getByAppName(appName string) *model.App { if appKey, ok := k.appNameIndex[appName]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok { - for _, app := range apps { + for i := range apps { + app := &apps[i] if app.appName == appName { return &app.app }