mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-05 03:48:14 +00:00
tests: add tests for context middleware
This commit is contained in:
@@ -0,0 +1,318 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
func TestContextMiddleware(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: []model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
},
|
||||
},
|
||||
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||
CookieDomain: "example.com",
|
||||
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||
LoginMaxRetries: 3,
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}
|
||||
|
||||
middlewareCfg := middleware.ContextMiddlewareConfig{
|
||||
CookieDomain: "example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}
|
||||
|
||||
basicAuthHeader := func(username, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
||||
t.Helper()
|
||||
_, err := queries.CreateSession(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runArgs struct {
|
||||
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||
queries *repository.Queries
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(t *testing.T, args runArgs)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Skip path bypasses auth processing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/healthz", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "No credentials yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Valid session cookie sets authenticated local context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-valid-local"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
require.NotNil(t, userCtx.Local)
|
||||
assert.False(t, userCtx.Local.TOTPEnabled)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-totp-pending"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "totpuser",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
Expiry: time.Now().Add(60 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "totpuser", userCtx.GetUsername())
|
||||
assert.False(t, userCtx.Authenticated)
|
||||
require.NotNil(t, userCtx.Local)
|
||||
assert.True(t, userCtx.Local.TOTPPending)
|
||||
assert.True(t, userCtx.Local.TOTPEnabled)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Unknown session cookie yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Session for missing local user yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-deleted-user"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "ghostuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Expired session cookie yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-expired"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(-1 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Valid basic auth sets authenticated local context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Invalid basic auth password yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Basic auth is rejected for users with totp",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Locked account on basic auth sets lock headers",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
for range 3 {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||
args.do(req)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, recorder := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
|
||||
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Cookie auth takes precedence over basic auth",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-precedence"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
err = ldap.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||
err = broker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||
err = authService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
||||
err = contextMiddleware.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
authService.ClearRateLimitsTestingOnly()
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
|
||||
var captured *model.UserContext
|
||||
router := gin.New()
|
||||
router.Use(contextMiddleware.Middleware())
|
||||
handler := func(c *gin.Context) {
|
||||
if val, exists := c.Get("context"); exists {
|
||||
captured, _ = val.(*model.UserContext)
|
||||
}
|
||||
}
|
||||
router.GET("/api/test", handler)
|
||||
router.GET("/api/healthz", handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
return captured, recorder
|
||||
}
|
||||
|
||||
test.run(t, runArgs{do: do, queries: queries})
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
+99
-166
@@ -1,14 +1,31 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
errMsg := func(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
newGinCtx := func(value any, set bool) *gin.Context {
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
if set {
|
||||
c.Set("context", value)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
context *model.UserContext
|
||||
@@ -16,79 +33,49 @@ func TestContext(t *testing.T) {
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
description: "IsAuthenticated returns true when Authenticated is true",
|
||||
description: "IsAuthenticated reflects Authenticated field",
|
||||
context: &model.UserContext{Authenticated: true},
|
||||
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsAuthenticated returns false when Authenticated is false",
|
||||
context: &model.UserContext{Authenticated: false},
|
||||
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "IsLocal returns true when Provider is ProviderLocal",
|
||||
description: "IsLocal returns true for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.IsLocal() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLocal returns false when Provider is not ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsLocal() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "IsOAuth returns true when Provider is ProviderOAuth",
|
||||
description: "IsOAuth returns true for ProviderOAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsOAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsOAuth returns false when Provider is ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.IsOAuth() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "IsLDAP returns true when Provider is ProviderLDAP",
|
||||
description: "IsLDAP returns true for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(c *model.UserContext) any { return c.IsLDAP() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLDAP returns false when Provider is ProviderOAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsLDAP() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "IsBasicAuth returns true when Provider is ProviderBasicAuth",
|
||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsBasicAuth returns false when Provider is ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session without TOTP sets ProviderLocal and is authenticated",
|
||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||
Provider: "local", TotpPending: false,
|
||||
Provider: "local",
|
||||
})
|
||||
return got.Provider == model.ProviderLocal && got.Authenticated
|
||||
return [2]any{got.Provider, got.Authenticated}
|
||||
},
|
||||
expected: true,
|
||||
expected: [2]any{model.ProviderLocal, true},
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session with TOTP pending is not authenticated",
|
||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
@@ -99,136 +86,71 @@ func TestContext(t *testing.T) {
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession ldap session sets ProviderLDAP and is authenticated",
|
||||
description: "NewFromSession ldap session is ProviderLDAP",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
Username: "carol", Email: "carol@example.com", Name: "Carol",
|
||||
Provider: "ldap",
|
||||
Username: "carol", Provider: "ldap",
|
||||
})
|
||||
return got.Provider == model.ProviderLDAP && got.Authenticated
|
||||
return got.Provider
|
||||
},
|
||||
expected: true,
|
||||
expected: model.ProviderLDAP,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession unknown provider defaults to ProviderOAuth",
|
||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
Username: "dave", Provider: "github",
|
||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||
})
|
||||
return got.Provider
|
||||
return [4]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName}
|
||||
},
|
||||
expected: model.ProviderOAuth,
|
||||
expected: [4]any{model.ProviderOAuth, "github", "sub-123", "GitHub"},
|
||||
},
|
||||
{
|
||||
description: "GetUsername returns local username for ProviderLocal",
|
||||
description: "Local getters return BaseContext fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetUsername() },
|
||||
expected: "alice",
|
||||
run: func(c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||
},
|
||||
{
|
||||
description: "GetUsername returns local username for ProviderBasicAuth",
|
||||
description: "BasicAuth getters fall back to local fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob"}},
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetUsername() },
|
||||
expected: "bob",
|
||||
run: func(c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||
},
|
||||
{
|
||||
description: "GetUsername returns LDAP username for ProviderLDAP",
|
||||
description: "LDAP getters return LDAP fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol"}},
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetUsername() },
|
||||
expected: "carol",
|
||||
run: func(c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||
},
|
||||
{
|
||||
description: "GetUsername returns OAuth username for ProviderOAuth",
|
||||
description: "OAuth getters return OAuth fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave"}},
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetUsername() },
|
||||
expected: "dave",
|
||||
},
|
||||
{
|
||||
description: "GetEmail returns local email for ProviderLocal",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Email: "alice@example.com"}},
|
||||
run: func(c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetEmail() },
|
||||
expected: "alice@example.com",
|
||||
},
|
||||
{
|
||||
description: "GetEmail returns local email for ProviderBasicAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Email: "bob@example.com"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetEmail() },
|
||||
expected: "bob@example.com",
|
||||
},
|
||||
{
|
||||
description: "GetEmail returns LDAP email for ProviderLDAP",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Email: "carol@example.com"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetEmail() },
|
||||
expected: "carol@example.com",
|
||||
},
|
||||
{
|
||||
description: "GetEmail returns OAuth email for ProviderOAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Email: "dave@example.com"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetEmail() },
|
||||
expected: "dave@example.com",
|
||||
},
|
||||
{
|
||||
description: "GetName returns local name for ProviderLocal",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Name: "Alice"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetName() },
|
||||
expected: "Alice",
|
||||
},
|
||||
{
|
||||
description: "GetName returns local name for ProviderBasicAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Name: "Bob"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetName() },
|
||||
expected: "Bob",
|
||||
},
|
||||
{
|
||||
description: "GetName returns LDAP name for ProviderLDAP",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Name: "Carol"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetName() },
|
||||
expected: "Carol",
|
||||
},
|
||||
{
|
||||
description: "GetName returns OAuth name for ProviderOAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Name: "Dave"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.GetName() },
|
||||
expected: "Dave",
|
||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderLocal",
|
||||
@@ -258,7 +180,7 @@ func TestContext(t *testing.T) {
|
||||
expected: "GitHub",
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns true for ProviderLocal when TOTPPending is true",
|
||||
description: "TOTPPending returns true when local context is pending",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: true},
|
||||
@@ -267,7 +189,7 @@ func TestContext(t *testing.T) {
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for ProviderLocal when TOTPPending is false",
|
||||
description: "TOTPPending returns false when local context is not pending",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: false},
|
||||
@@ -276,22 +198,10 @@ func TestContext(t *testing.T) {
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for ProviderOAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for ProviderLDAP",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
description: "TOTPPending returns false for non-local providers",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||
@@ -303,22 +213,45 @@ func TestContext(t *testing.T) {
|
||||
expected: "Google",
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns empty string for ProviderLocal",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
description: "OAuthName returns empty string for non-oauth providers",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns empty string for ProviderLDAP",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{},
|
||||
description: "NewFromGin populates context from gin value",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
stored := &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||
}
|
||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return [2]any{got.Authenticated, got.GetUsername()}
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
expected: [2]any{true, "alice"},
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns error when context value is missing",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||
return errMsg(err)
|
||||
},
|
||||
expected: "failed to get user context",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns error when context value has wrong type",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||
return errMsg(err)
|
||||
},
|
||||
expected: "invalid user context type",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user