fix: review comments batch 1

This commit is contained in:
Stavros
2026-05-05 18:43:22 +03:00
parent f3965a7470
commit d47e4d3d79
10 changed files with 131 additions and 88 deletions
+46 -34
View File
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
@@ -22,47 +23,48 @@ func TestContext(t *testing.T) {
tests := []struct {
description string
context *model.UserContext
run func(*model.UserContext) any
run func(*testing.T, *model.UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.IsLocal() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth},
run: func(c *model.UserContext) any { return c.IsOAuth() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(c *model.UserContext) any { return c.IsLDAP() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
})
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
expected: [2]any{model.ProviderLocal, true},
@@ -70,10 +72,11 @@ func TestContext(t *testing.T) {
{
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
require.NoError(t, err)
return got.Authenticated
},
expected: false,
@@ -81,10 +84,11 @@ func TestContext(t *testing.T) {
{
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
expected: model.ProviderLDAP,
@@ -92,11 +96,12 @@ func TestContext(t *testing.T) {
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
})
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
@@ -107,7 +112,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
@@ -118,7 +123,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
@@ -129,7 +134,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
@@ -140,7 +145,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
@@ -148,19 +153,19 @@ func TestContext(t *testing.T) {
{
description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "ldap",
},
{
@@ -169,7 +174,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "GitHub",
},
{
@@ -178,7 +183,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
},
run: func(c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: true,
},
{
@@ -187,13 +192,13 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
},
run: func(c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
@@ -202,28 +207,26 @@ func TestContext(t *testing.T) {
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"},
},
run: func(c *model.UserContext) any { return c.OAuthName() },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(c *model.UserContext) any { return c.OAuthName() },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
stored := &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
if err != nil {
return err.Error()
}
require.NoError(t, err)
return [2]any{got.Authenticated, got.GetUsername()}
},
expected: [2]any{true, "alice"},
@@ -231,7 +234,7 @@ func TestContext(t *testing.T) {
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
@@ -240,17 +243,26 @@ func TestContext(t *testing.T) {
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
expected: "invalid user context type",
},
{
description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
assert.Equal(t, test.expected, test.run(test.context))
assert.Equal(t, test.expected, test.run(t, test.context))
})
}
}