fix: review comments batch 2

This commit is contained in:
Stavros
2026-05-05 18:54:45 +03:00
parent d47e4d3d79
commit e04980468f
9 changed files with 78 additions and 23 deletions
+2 -2
View File
@@ -29,7 +29,7 @@ type BootstrapApp struct {
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
localUsers []model.LocalUser
localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
return err
}
app.context.localUsers = *users
app.context.localUsers = users
// Setup OAuth providers
app.context.oauthProviders = app.config.OAuth.Providers
+1 -1
View File
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: []model.LocalUser{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
+1 -1
View File
@@ -25,7 +25,7 @@ func TestUserController(t *testing.T) {
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: []model.LocalUser{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
@@ -25,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: []model.LocalUser{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
+42 -6
View File
@@ -56,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool {
}
func (c *UserContext) IsLocal() bool {
return c.Provider == ProviderLocal
return c.Provider == ProviderLocal && c.Local != nil
}
func (c *UserContext) IsOAuth() bool {
return c.Provider == ProviderOAuth
return c.Provider == ProviderOAuth && c.OAuth != nil
}
func (c *UserContext) IsLDAP() bool {
return c.Provider == ProviderLDAP
return c.Provider == ProviderLDAP && c.LDAP != nil
}
func (c *UserContext) IsBasicAuth() bool {
return c.Provider == ProviderBasicAuth
return c.Provider == ProviderBasicAuth && c.Local != nil
}
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
@@ -145,12 +145,24 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
func (c *UserContext) GetUsername() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Username
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Username
default:
return ""
@@ -160,12 +172,24 @@ func (c *UserContext) GetUsername() string {
func (c *UserContext) GetEmail() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Email
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Email
default:
return ""
@@ -175,12 +199,24 @@ func (c *UserContext) GetEmail() string {
func (c *UserContext) GetName() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Name
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Name
default:
return ""
@@ -201,14 +237,14 @@ func (c *UserContext) ProviderName() string {
}
func (c *UserContext) TOTPPending() bool {
if c.Provider == ProviderLocal {
if c.Provider == ProviderLocal && c.Local != nil {
return c.Local.TOTPPending
}
return false
}
func (c *UserContext) OAuthName() string {
if c.Provider == ProviderOAuth {
if c.Provider == ProviderOAuth && c.OAuth != nil {
return c.OAuth.DisplayName
}
return ""
+12 -4
View File
@@ -34,25 +34,25 @@ func TestContext(t *testing.T) {
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth},
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
@@ -258,6 +258,14 @@ func TestContext(t *testing.T) {
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
},
}
for _, test := range tests {
+6 -3
View File
@@ -28,18 +28,21 @@ func (acls *AccessControlsService) Init() error {
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
return &config
appAcls = &config
break // If we find a match by domain, we can stop searching
}
if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
return &config
appAcls = &config
break // If we find a match by app name, we can stop searching
}
}
return nil
return appAcls
}
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
+9 -3
View File
@@ -73,7 +73,7 @@ type Lockdown struct {
}
type AuthServiceConfig struct {
LocalUsers []model.LocalUser
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
@@ -147,6 +147,9 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
switch search.Type {
case model.UserLocal:
user := auth.GetLocalUser(search.Username)
if user == nil {
return ErrUserNotFound
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap.IsConfigured() {
@@ -169,7 +172,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
for _, user := range auth.config.LocalUsers {
if auth.config.LocalUsers == nil {
return nil
}
for _, user := range *auth.config.LocalUsers {
if user.Username == username {
return &user
}
@@ -438,7 +444,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
}
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.LocalUsers) > 0
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
}
func (auth *AuthService) LDAPAuthConfigured() bool {
+4 -2
View File
@@ -95,7 +95,8 @@ func (k *KubernetesService) getByDomain(domain string) *model.App {
if appKey, ok := k.domainIndex[domain]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps {
for i := range apps {
app := &apps[i]
if app.domain == domain && app.appName == appKey.appName {
return &app.app
}
@@ -111,7 +112,8 @@ func (k *KubernetesService) getByAppName(appName string) *model.App {
if appKey, ok := k.appNameIndex[appName]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps {
for i := range apps {
app := &apps[i]
if app.appName == appName {
return &app.app
}