fix: review comments batch 2

This commit is contained in:
Stavros
2026-05-05 18:54:45 +03:00
parent d47e4d3d79
commit e04980468f
9 changed files with 78 additions and 23 deletions
+2 -2
View File
@@ -29,7 +29,7 @@ type BootstrapApp struct {
csrfCookieName string csrfCookieName string
redirectCookieName string redirectCookieName string
oauthSessionCookieName string oauthSessionCookieName string
localUsers []model.LocalUser localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig oauthProviders map[string]model.OAuthServiceConfig
configuredProviders []controller.Provider configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig oidcClients []model.OIDCClientConfig
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
return err return err
} }
app.context.localUsers = *users app.context.localUsers = users
// Setup OAuth providers // Setup OAuth providers
app.context.oauthProviders = app.config.OAuth.Providers app.context.oauthProviders = app.config.OAuth.Providers
+1 -1
View File
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
LocalUsers: []model.LocalUser{ LocalUsers: &[]model.LocalUser{
{ {
Username: "testuser", Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
+1 -1
View File
@@ -25,7 +25,7 @@ func TestUserController(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
LocalUsers: []model.LocalUser{ LocalUsers: &[]model.LocalUser{
{ {
Username: "testuser", Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
@@ -25,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
LocalUsers: []model.LocalUser{ LocalUsers: &[]model.LocalUser{
{ {
Username: "testuser", Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
+42 -6
View File
@@ -56,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool {
} }
func (c *UserContext) IsLocal() bool { func (c *UserContext) IsLocal() bool {
return c.Provider == ProviderLocal return c.Provider == ProviderLocal && c.Local != nil
} }
func (c *UserContext) IsOAuth() bool { func (c *UserContext) IsOAuth() bool {
return c.Provider == ProviderOAuth return c.Provider == ProviderOAuth && c.OAuth != nil
} }
func (c *UserContext) IsLDAP() bool { func (c *UserContext) IsLDAP() bool {
return c.Provider == ProviderLDAP return c.Provider == ProviderLDAP && c.LDAP != nil
} }
func (c *UserContext) IsBasicAuth() bool { func (c *UserContext) IsBasicAuth() bool {
return c.Provider == ProviderBasicAuth return c.Provider == ProviderBasicAuth && c.Local != nil
} }
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
@@ -145,12 +145,24 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
func (c *UserContext) GetUsername() string { func (c *UserContext) GetUsername() string {
switch c.Provider { switch c.Provider {
case ProviderLocal: case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Username return c.Local.Username
case ProviderLDAP: case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Username return c.LDAP.Username
case ProviderBasicAuth: case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Username return c.Local.Username
case ProviderOAuth: case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Username return c.OAuth.Username
default: default:
return "" return ""
@@ -160,12 +172,24 @@ func (c *UserContext) GetUsername() string {
func (c *UserContext) GetEmail() string { func (c *UserContext) GetEmail() string {
switch c.Provider { switch c.Provider {
case ProviderLocal: case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Email return c.Local.Email
case ProviderLDAP: case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Email return c.LDAP.Email
case ProviderBasicAuth: case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Email return c.Local.Email
case ProviderOAuth: case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Email return c.OAuth.Email
default: default:
return "" return ""
@@ -175,12 +199,24 @@ func (c *UserContext) GetEmail() string {
func (c *UserContext) GetName() string { func (c *UserContext) GetName() string {
switch c.Provider { switch c.Provider {
case ProviderLocal: case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Name return c.Local.Name
case ProviderLDAP: case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Name return c.LDAP.Name
case ProviderBasicAuth: case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Name return c.Local.Name
case ProviderOAuth: case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Name return c.OAuth.Name
default: default:
return "" return ""
@@ -201,14 +237,14 @@ func (c *UserContext) ProviderName() string {
} }
func (c *UserContext) TOTPPending() bool { func (c *UserContext) TOTPPending() bool {
if c.Provider == ProviderLocal { if c.Provider == ProviderLocal && c.Local != nil {
return c.Local.TOTPPending return c.Local.TOTPPending
} }
return false return false
} }
func (c *UserContext) OAuthName() string { func (c *UserContext) OAuthName() string {
if c.Provider == ProviderOAuth { if c.Provider == ProviderOAuth && c.OAuth != nil {
return c.OAuth.DisplayName return c.OAuth.DisplayName
} }
return "" return ""
+12 -4
View File
@@ -34,25 +34,25 @@ func TestContext(t *testing.T) {
}, },
{ {
description: "IsLocal returns true for ProviderLocal", description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true, expected: true,
}, },
{ {
description: "IsOAuth returns true for ProviderOAuth", description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth}, context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true, expected: true,
}, },
{ {
description: "IsLDAP returns true for ProviderLDAP", description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP}, context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true, expected: true,
}, },
{ {
description: "IsBasicAuth returns true for ProviderBasicAuth", description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth}, context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() }, run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true, expected: true,
}, },
@@ -258,6 +258,14 @@ func TestContext(t *testing.T) {
}, },
expected: "incomplete user context", expected: "incomplete user context",
}, },
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
},
} }
for _, test := range tests { for _, test := range tests {
+6 -3
View File
@@ -28,18 +28,21 @@ func (acls *AccessControlsService) Init() error {
} }
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static { for app, config := range acls.static {
if config.Config.Domain == domain { if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
return &config appAcls = &config
break // If we find a match by domain, we can stop searching
} }
if strings.SplitN(domain, ".", 2)[0] == app { if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
return &config appAcls = &config
break // If we find a match by app name, we can stop searching
} }
} }
return nil return appAcls
} }
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
+9 -3
View File
@@ -73,7 +73,7 @@ type Lockdown struct {
} }
type AuthServiceConfig struct { type AuthServiceConfig struct {
LocalUsers []model.LocalUser LocalUsers *[]model.LocalUser
OauthWhitelist []string OauthWhitelist []string
SessionExpiry int SessionExpiry int
SessionMaxLifetime int SessionMaxLifetime int
@@ -147,6 +147,9 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
switch search.Type { switch search.Type {
case model.UserLocal: case model.UserLocal:
user := auth.GetLocalUser(search.Username) user := auth.GetLocalUser(search.Username)
if user == nil {
return ErrUserNotFound
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP: case model.UserLDAP:
if auth.ldap.IsConfigured() { if auth.ldap.IsConfigured() {
@@ -169,7 +172,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
} }
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
for _, user := range auth.config.LocalUsers { if auth.config.LocalUsers == nil {
return nil
}
for _, user := range *auth.config.LocalUsers {
if user.Username == username { if user.Username == username {
return &user return &user
} }
@@ -438,7 +444,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
} }
func (auth *AuthService) LocalAuthConfigured() bool { func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.LocalUsers) > 0 return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LDAPAuthConfigured() bool {
+4 -2
View File
@@ -95,7 +95,8 @@ func (k *KubernetesService) getByDomain(domain string) *model.App {
if appKey, ok := k.domainIndex[domain]; ok { if appKey, ok := k.domainIndex[domain]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps { for i := range apps {
app := &apps[i]
if app.domain == domain && app.appName == appKey.appName { if app.domain == domain && app.appName == appKey.appName {
return &app.app return &app.app
} }
@@ -111,7 +112,8 @@ func (k *KubernetesService) getByAppName(appName string) *model.App {
if appKey, ok := k.appNameIndex[appName]; ok { if appKey, ok := k.appNameIndex[appName]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps { for i := range apps {
app := &apps[i]
if app.appName == appName { if app.appName == appName {
return &app.app return &app.app
} }