diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 3362d0de..f939ba55 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -95,7 +95,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { Username: context.GetUsername(), Name: context.GetName(), Email: context.GetEmail(), - Provider: context.ProviderName(), + Provider: context.GetProviderID(), OAuth: context.IsOAuth(), TOTPPending: context.TOTPPending(), OAuthName: context.OAuthName(), diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 66c24a5e..7b2e3202 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -98,7 +98,6 @@ func TestProxyController(t *testing.T) { Name: "Totpuser", Email: "totpuser@example.com", }, - TOTPEnabled: true, }, }) c.Next() diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 6599a965..cb6d5e6f 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -249,7 +249,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err == nil { - tlog.AuditLogout(c, context.GetUsername(), context.ProviderName()) + tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID()) } else { tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") tlog.AuditLogout(c, "unknown", "unknown") diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index a39b71e0..4863c16e 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -78,7 +78,6 @@ func TestUserController(t *testing.T) { Email: "totpuser@example.com", }, TOTPPending: true, - TOTPEnabled: true, }, }) } @@ -94,7 +93,6 @@ func TestUserController(t *testing.T) { Email: "bob@example.com", }, TOTPPending: true, - TOTPEnabled: true, }, }) } @@ -152,7 +150,9 @@ func TestUserController(t *testing.T) { assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) assert.Equal(t, "example.com", cookie.Domain) - assert.Equal(t, 9, cookie.MaxAge) + // 3 seconds should be more than enough for even slow test environments + assert.GreaterOrEqual(t, cookie.MaxAge, 7) + assert.LessOrEqual(t, cookie.MaxAge, 10) }, }, { @@ -241,7 +241,8 @@ func TestUserController(t *testing.T) { assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) assert.Equal(t, "example.com", cookie.Domain) - assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions + assert.GreaterOrEqual(t, cookie.MaxAge, 3597) + assert.LessOrEqual(t, cookie.MaxAge, 3600) }, }, { @@ -335,7 +336,8 @@ func TestUserController(t *testing.T) { assert.Equal(t, "tinyauth-session", totpCookie.Name) assert.True(t, totpCookie.HttpOnly) assert.Equal(t, "example.com", totpCookie.Domain) - assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time + assert.GreaterOrEqual(t, totpCookie.MaxAge, 7) + assert.LessOrEqual(t, totpCookie.MaxAge, 10) }, }, { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index a5773dbd..88e96462 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -123,7 +123,6 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model if userContext.Provider == model.ProviderLocal && userContext.Local.TOTPPending { - userContext.Local.TOTPEnabled = true return userContext, nil, nil } diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 6e91a585..5dfde3b4 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -109,7 +109,6 @@ func TestContextMiddleware(t *testing.T) { assert.Equal(t, "testuser", userCtx.GetUsername()) assert.True(t, userCtx.Authenticated) require.NotNil(t, userCtx.Local) - assert.False(t, userCtx.Local.TOTPEnabled) }, }, { @@ -134,7 +133,6 @@ func TestContextMiddleware(t *testing.T) { assert.False(t, userCtx.Authenticated) require.NotNil(t, userCtx.Local) assert.True(t, userCtx.Local.TOTPPending) - assert.True(t, userCtx.Local.TOTPEnabled) }, }, { diff --git a/internal/model/context.go b/internal/model/context.go index 7df204de..7384ebe8 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -34,7 +34,6 @@ type BaseContext struct { type LocalContext struct { BaseContext TOTPPending bool - TOTPEnabled bool Attributes UserAttributes } @@ -223,14 +222,14 @@ func (c *UserContext) GetName() string { } } -func (c *UserContext) ProviderName() string { +func (c *UserContext) GetProviderID() string { switch c.Provider { case ProviderBasicAuth, ProviderLocal: return "local" case ProviderLDAP: return "ldap" case ProviderOAuth: - return c.OAuth.DisplayName // compatability + return c.OAuth.ID default: return "unknown" } diff --git a/internal/model/context_test.go b/internal/model/context_test.go index b45b9210..733805a7 100644 --- a/internal/model/context_test.go +++ b/internal/model/context_test.go @@ -153,29 +153,29 @@ func TestContext(t *testing.T) { { description: "ProviderName returns 'local' for ProviderLocal", context: &model.UserContext{Provider: model.ProviderLocal}, - run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, expected: "local", }, { description: "ProviderName returns 'local' for ProviderBasicAuth", context: &model.UserContext{Provider: model.ProviderBasicAuth}, - run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, expected: "local", }, { description: "ProviderName returns 'ldap' for ProviderLDAP", context: &model.UserContext{Provider: model.ProviderLDAP}, - run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, expected: "ldap", }, { - description: "ProviderName returns OAuth DisplayName for ProviderOAuth", + description: "ProviderName returns OAuth provider ID for ProviderOAuth", context: &model.UserContext{ Provider: model.ProviderOAuth, - OAuth: &model.OAuthContext{DisplayName: "GitHub"}, + OAuth: &model.OAuthContext{ID: "github"}, }, - run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() }, - expected: "GitHub", + run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() }, + expected: "github", }, { description: "TOTPPending returns true when local context is pending",