fix: review comments batch 2

This commit is contained in:
Stavros
2026-05-05 18:54:45 +03:00
parent d47e4d3d79
commit e04980468f
9 changed files with 78 additions and 23 deletions
+42 -6
View File
@@ -56,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool {
}
func (c *UserContext) IsLocal() bool {
return c.Provider == ProviderLocal
return c.Provider == ProviderLocal && c.Local != nil
}
func (c *UserContext) IsOAuth() bool {
return c.Provider == ProviderOAuth
return c.Provider == ProviderOAuth && c.OAuth != nil
}
func (c *UserContext) IsLDAP() bool {
return c.Provider == ProviderLDAP
return c.Provider == ProviderLDAP && c.LDAP != nil
}
func (c *UserContext) IsBasicAuth() bool {
return c.Provider == ProviderBasicAuth
return c.Provider == ProviderBasicAuth && c.Local != nil
}
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
@@ -145,12 +145,24 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
func (c *UserContext) GetUsername() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Username
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Username
default:
return ""
@@ -160,12 +172,24 @@ func (c *UserContext) GetUsername() string {
func (c *UserContext) GetEmail() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Email
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Email
default:
return ""
@@ -175,12 +199,24 @@ func (c *UserContext) GetEmail() string {
func (c *UserContext) GetName() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Name
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Name
default:
return ""
@@ -201,14 +237,14 @@ func (c *UserContext) ProviderName() string {
}
func (c *UserContext) TOTPPending() bool {
if c.Provider == ProviderLocal {
if c.Provider == ProviderLocal && c.Local != nil {
return c.Local.TOTPPending
}
return false
}
func (c *UserContext) OAuthName() string {
if c.Provider == ProviderOAuth {
if c.Provider == ProviderOAuth && c.OAuth != nil {
return c.OAuth.DisplayName
}
return ""
+12 -4
View File
@@ -34,25 +34,25 @@ func TestContext(t *testing.T) {
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth},
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
@@ -258,6 +258,14 @@ func TestContext(t *testing.T) {
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
},
}
for _, test := range tests {