fix: own comments

This commit is contained in:
Stavros
2026-05-06 23:39:07 +03:00
parent c6d36673eb
commit e718471ad3
8 changed files with 18 additions and 21 deletions
+1 -1
View File
@@ -95,7 +95,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
Username: context.GetUsername(), Username: context.GetUsername(),
Name: context.GetName(), Name: context.GetName(),
Email: context.GetEmail(), Email: context.GetEmail(),
Provider: context.ProviderName(), Provider: context.GetProviderID(),
OAuth: context.IsOAuth(), OAuth: context.IsOAuth(),
TOTPPending: context.TOTPPending(), TOTPPending: context.TOTPPending(),
OAuthName: context.OAuthName(), OAuthName: context.OAuthName(),
@@ -98,7 +98,6 @@ func TestProxyController(t *testing.T) {
Name: "Totpuser", Name: "Totpuser",
Email: "totpuser@example.com", Email: "totpuser@example.com",
}, },
TOTPEnabled: true,
}, },
}) })
c.Next() c.Next()
+1 -1
View File
@@ -249,7 +249,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err == nil { if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName()) tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
} else { } else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
tlog.AuditLogout(c, "unknown", "unknown") tlog.AuditLogout(c, "unknown", "unknown")
+7 -5
View File
@@ -78,7 +78,6 @@ func TestUserController(t *testing.T) {
Email: "totpuser@example.com", Email: "totpuser@example.com",
}, },
TOTPPending: true, TOTPPending: true,
TOTPEnabled: true,
}, },
}) })
} }
@@ -94,7 +93,6 @@ func TestUserController(t *testing.T) {
Email: "bob@example.com", Email: "bob@example.com",
}, },
TOTPPending: true, TOTPPending: true,
TOTPEnabled: true,
}, },
}) })
} }
@@ -152,7 +150,9 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly) assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain) assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 9, cookie.MaxAge) // 3 seconds should be more than enough for even slow test environments
assert.GreaterOrEqual(t, cookie.MaxAge, 7)
assert.LessOrEqual(t, cookie.MaxAge, 10)
}, },
}, },
{ {
@@ -241,7 +241,8 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly) assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain) assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions assert.GreaterOrEqual(t, cookie.MaxAge, 3597)
assert.LessOrEqual(t, cookie.MaxAge, 3600)
}, },
}, },
{ {
@@ -335,7 +336,8 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", totpCookie.Name) assert.Equal(t, "tinyauth-session", totpCookie.Name)
assert.True(t, totpCookie.HttpOnly) assert.True(t, totpCookie.HttpOnly)
assert.Equal(t, "example.com", totpCookie.Domain) assert.Equal(t, "example.com", totpCookie.Domain)
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time assert.GreaterOrEqual(t, totpCookie.MaxAge, 7)
assert.LessOrEqual(t, totpCookie.MaxAge, 10)
}, },
}, },
{ {
@@ -123,7 +123,6 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
if userContext.Provider == model.ProviderLocal && if userContext.Provider == model.ProviderLocal &&
userContext.Local.TOTPPending { userContext.Local.TOTPPending {
userContext.Local.TOTPEnabled = true
return userContext, nil, nil return userContext, nil, nil
} }
@@ -109,7 +109,6 @@ func TestContextMiddleware(t *testing.T) {
assert.Equal(t, "testuser", userCtx.GetUsername()) assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated) assert.True(t, userCtx.Authenticated)
require.NotNil(t, userCtx.Local) require.NotNil(t, userCtx.Local)
assert.False(t, userCtx.Local.TOTPEnabled)
}, },
}, },
{ {
@@ -134,7 +133,6 @@ func TestContextMiddleware(t *testing.T) {
assert.False(t, userCtx.Authenticated) assert.False(t, userCtx.Authenticated)
require.NotNil(t, userCtx.Local) require.NotNil(t, userCtx.Local)
assert.True(t, userCtx.Local.TOTPPending) assert.True(t, userCtx.Local.TOTPPending)
assert.True(t, userCtx.Local.TOTPEnabled)
}, },
}, },
{ {
+2 -3
View File
@@ -34,7 +34,6 @@ type BaseContext struct {
type LocalContext struct { type LocalContext struct {
BaseContext BaseContext
TOTPPending bool TOTPPending bool
TOTPEnabled bool
Attributes UserAttributes Attributes UserAttributes
} }
@@ -223,14 +222,14 @@ func (c *UserContext) GetName() string {
} }
} }
func (c *UserContext) ProviderName() string { func (c *UserContext) GetProviderID() string {
switch c.Provider { switch c.Provider {
case ProviderBasicAuth, ProviderLocal: case ProviderBasicAuth, ProviderLocal:
return "local" return "local"
case ProviderLDAP: case ProviderLDAP:
return "ldap" return "ldap"
case ProviderOAuth: case ProviderOAuth:
return c.OAuth.DisplayName // compatability return c.OAuth.ID
default: default:
return "unknown" return "unknown"
} }
+7 -7
View File
@@ -153,29 +153,29 @@ func TestContext(t *testing.T) {
{ {
description: "ProviderName returns 'local' for ProviderLocal", description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal}, context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'local' for ProviderBasicAuth", description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth}, context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "local", expected: "local",
}, },
{ {
description: "ProviderName returns 'ldap' for ProviderLDAP", description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP}, context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "ldap", expected: "ldap",
}, },
{ {
description: "ProviderName returns OAuth DisplayName for ProviderOAuth", description: "ProviderName returns OAuth provider ID for ProviderOAuth",
context: &model.UserContext{ context: &model.UserContext{
Provider: model.ProviderOAuth, Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "GitHub"}, OAuth: &model.OAuthContext{ID: "github"},
}, },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
expected: "GitHub", expected: "github",
}, },
{ {
description: "TOTPPending returns true when local context is pending", description: "TOTPPending returns true when local context is pending",