tests: add tests for context middleware

This commit is contained in:
Stavros
2026-05-04 20:52:59 +03:00
parent 4d3860f860
commit e13598bf3c
2 changed files with 417 additions and 166 deletions
+99 -166
View File
@@ -1,14 +1,31 @@
package model_test
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func TestContext(t *testing.T) {
errMsg := func(err error) string {
if err == nil {
return ""
}
return err.Error()
}
newGinCtx := func(value any, set bool) *gin.Context {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
if set {
c.Set("context", value)
}
return c
}
tests := []struct {
description string
context *model.UserContext
@@ -16,79 +33,49 @@ func TestContext(t *testing.T) {
expected any
}{
{
description: "IsAuthenticated returns true when Authenticated is true",
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsAuthenticated returns false when Authenticated is false",
context: &model.UserContext{Authenticated: false},
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
expected: false,
},
{
description: "IsLocal returns true when Provider is ProviderLocal",
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsLocal returns false when Provider is not ProviderLocal",
context: &model.UserContext{Provider: model.ProviderOAuth},
run: func(c *model.UserContext) any { return c.IsLocal() },
expected: false,
},
{
description: "IsOAuth returns true when Provider is ProviderOAuth",
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth},
run: func(c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsOAuth returns false when Provider is ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.IsOAuth() },
expected: false,
},
{
description: "IsLDAP returns true when Provider is ProviderLDAP",
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsLDAP returns false when Provider is ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth},
run: func(c *model.UserContext) any { return c.IsLDAP() },
expected: false,
},
{
description: "IsBasicAuth returns true when Provider is ProviderBasicAuth",
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "IsBasicAuth returns false when Provider is ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
expected: false,
},
{
description: "NewFromSession local session without TOTP sets ProviderLocal and is authenticated",
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local", TotpPending: false,
Provider: "local",
})
return got.Provider == model.ProviderLocal && got.Authenticated
return [2]any{got.Provider, got.Authenticated}
},
expected: true,
expected: [2]any{model.ProviderLocal, true},
},
{
description: "NewFromSession local session with TOTP pending is not authenticated",
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
@@ -99,136 +86,71 @@ func TestContext(t *testing.T) {
expected: false,
},
{
description: "NewFromSession ldap session sets ProviderLDAP and is authenticated",
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
Username: "carol", Email: "carol@example.com", Name: "Carol",
Provider: "ldap",
Username: "carol", Provider: "ldap",
})
return got.Provider == model.ProviderLDAP && got.Authenticated
return got.Provider
},
expected: true,
expected: model.ProviderLDAP,
},
{
description: "NewFromSession unknown provider defaults to ProviderOAuth",
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
})
return got.Provider
return [4]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName}
},
expected: model.ProviderOAuth,
expected: [4]any{model.ProviderOAuth, "github", "sub-123", "GitHub"},
},
{
description: "GetUsername returns local username for ProviderLocal",
description: "Local getters return BaseContext fields",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(c *model.UserContext) any { return c.GetUsername() },
expected: "alice",
run: func(c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
},
{
description: "GetUsername returns local username for ProviderBasicAuth",
description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob"}},
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(c *model.UserContext) any { return c.GetUsername() },
expected: "bob",
run: func(c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
},
{
description: "GetUsername returns LDAP username for ProviderLDAP",
description: "LDAP getters return LDAP fields",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol"}},
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(c *model.UserContext) any { return c.GetUsername() },
expected: "carol",
run: func(c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
},
{
description: "GetUsername returns OAuth username for ProviderOAuth",
description: "OAuth getters return OAuth fields",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave"}},
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(c *model.UserContext) any { return c.GetUsername() },
expected: "dave",
},
{
description: "GetEmail returns local email for ProviderLocal",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Email: "alice@example.com"}},
run: func(c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
run: func(c *model.UserContext) any { return c.GetEmail() },
expected: "alice@example.com",
},
{
description: "GetEmail returns local email for ProviderBasicAuth",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Email: "bob@example.com"}},
},
run: func(c *model.UserContext) any { return c.GetEmail() },
expected: "bob@example.com",
},
{
description: "GetEmail returns LDAP email for ProviderLDAP",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Email: "carol@example.com"}},
},
run: func(c *model.UserContext) any { return c.GetEmail() },
expected: "carol@example.com",
},
{
description: "GetEmail returns OAuth email for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Email: "dave@example.com"}},
},
run: func(c *model.UserContext) any { return c.GetEmail() },
expected: "dave@example.com",
},
{
description: "GetName returns local name for ProviderLocal",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Name: "Alice"}},
},
run: func(c *model.UserContext) any { return c.GetName() },
expected: "Alice",
},
{
description: "GetName returns local name for ProviderBasicAuth",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Name: "Bob"}},
},
run: func(c *model.UserContext) any { return c.GetName() },
expected: "Bob",
},
{
description: "GetName returns LDAP name for ProviderLDAP",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Name: "Carol"}},
},
run: func(c *model.UserContext) any { return c.GetName() },
expected: "Carol",
},
{
description: "GetName returns OAuth name for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Name: "Dave"}},
},
run: func(c *model.UserContext) any { return c.GetName() },
expected: "Dave",
expected: [3]string{"dave", "dave@example.com", "Dave"},
},
{
description: "ProviderName returns 'local' for ProviderLocal",
@@ -258,7 +180,7 @@ func TestContext(t *testing.T) {
expected: "GitHub",
},
{
description: "TOTPPending returns true for ProviderLocal when TOTPPending is true",
description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
@@ -267,7 +189,7 @@ func TestContext(t *testing.T) {
expected: true,
},
{
description: "TOTPPending returns false for ProviderLocal when TOTPPending is false",
description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
@@ -276,22 +198,10 @@ func TestContext(t *testing.T) {
expected: false,
},
{
description: "TOTPPending returns false for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{},
},
run: func(c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for ProviderLDAP",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{},
},
run: func(c *model.UserContext) any { return c.TOTPPending() },
expected: false,
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "OAuthName returns DisplayName for ProviderOAuth",
@@ -303,22 +213,45 @@ func TestContext(t *testing.T) {
expected: "Google",
},
{
description: "OAuthName returns empty string for ProviderLocal",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{},
},
run: func(c *model.UserContext) any { return c.OAuthName() },
expected: "",
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(c *model.UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "OAuthName returns empty string for ProviderLDAP",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{},
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
stored := &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
if err != nil {
return err.Error()
}
return [2]any{got.Authenticated, got.GetUsername()}
},
run: func(c *model.UserContext) any { return c.OAuthName() },
expected: "",
expected: [2]any{true, "alice"},
},
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return errMsg(err)
},
expected: "failed to get user context",
},
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return errMsg(err)
},
expected: "invalid user context type",
},
}