mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 06:18:11 +00:00
ef8bbd8c9f
removes the sqlite dependency for tests, also brings back the option for users to run zero persistence instances of tinyauth. adds new mapErr fn for sqlc wrapper gen to prevent sql errors from leaking out of the store implementation.
317 lines
9.7 KiB
Go
317 lines
9.7 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
|
)
|
|
|
|
func TestContextMiddleware(t *testing.T) {
|
|
tlog.NewTestLogger().Init()
|
|
|
|
authServiceCfg := service.AuthServiceConfig{
|
|
LocalUsers: &[]model.LocalUser{
|
|
{
|
|
Username: "testuser",
|
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
|
},
|
|
{
|
|
Username: "totpuser",
|
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
|
},
|
|
},
|
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
|
CookieDomain: "example.com",
|
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
|
LoginMaxRetries: 3,
|
|
SessionCookieName: "tinyauth-session",
|
|
}
|
|
|
|
middlewareCfg := middleware.ContextMiddlewareConfig{
|
|
CookieDomain: "example.com",
|
|
SessionCookieName: "tinyauth-session",
|
|
}
|
|
|
|
basicAuthHeader := func(username, password string) string {
|
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
|
}
|
|
|
|
seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) {
|
|
t.Helper()
|
|
_, err := queries.CreateSession(context.Background(), params)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type runArgs struct {
|
|
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
|
queries repository.Store
|
|
}
|
|
|
|
type testCase struct {
|
|
description string
|
|
run func(t *testing.T, args runArgs)
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
description: "Skip path bypasses auth processing",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/healthz", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "No credentials yields no context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "Valid session cookie sets authenticated local context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
uuid := "session-valid-local"
|
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
|
UUID: uuid,
|
|
Username: "testuser",
|
|
Provider: "local",
|
|
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
|
userCtx, _ := args.do(req)
|
|
|
|
require.NotNil(t, userCtx)
|
|
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
|
assert.True(t, userCtx.Authenticated)
|
|
require.NotNil(t, userCtx.Local)
|
|
},
|
|
},
|
|
{
|
|
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
|
|
run: func(t *testing.T, args runArgs) {
|
|
uuid := "session-totp-pending"
|
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
|
UUID: uuid,
|
|
Username: "totpuser",
|
|
Provider: "local",
|
|
TotpPending: true,
|
|
Expiry: time.Now().Add(60 * time.Second).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
|
userCtx, _ := args.do(req)
|
|
|
|
require.NotNil(t, userCtx)
|
|
assert.Equal(t, "totpuser", userCtx.GetUsername())
|
|
assert.False(t, userCtx.Authenticated)
|
|
require.NotNil(t, userCtx.Local)
|
|
assert.True(t, userCtx.Local.TOTPPending)
|
|
},
|
|
},
|
|
{
|
|
description: "Unknown session cookie yields no context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "Session for missing local user yields no context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
uuid := "session-deleted-user"
|
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
|
UUID: uuid,
|
|
Username: "ghostuser",
|
|
Provider: "local",
|
|
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "Expired session cookie yields no context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
uuid := "session-expired"
|
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
|
UUID: uuid,
|
|
Username: "testuser",
|
|
Provider: "local",
|
|
Expiry: time.Now().Add(-1 * time.Second).Unix(),
|
|
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "Valid basic auth sets authenticated local context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
|
userCtx, _ := args.do(req)
|
|
|
|
require.NotNil(t, userCtx)
|
|
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
|
assert.True(t, userCtx.Authenticated)
|
|
},
|
|
},
|
|
{
|
|
description: "Invalid basic auth password yields no context",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "Basic auth is rejected for users with totp",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
|
userCtx, _ := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
},
|
|
},
|
|
{
|
|
description: "Locked account on basic auth sets lock headers",
|
|
run: func(t *testing.T, args runArgs) {
|
|
for range 3 {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
|
args.do(req)
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
|
userCtx, recorder := args.do(req)
|
|
|
|
assert.Nil(t, userCtx)
|
|
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
|
|
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
|
|
},
|
|
},
|
|
{
|
|
description: "Cookie auth takes precedence over basic auth",
|
|
run: func(t *testing.T, args runArgs) {
|
|
uuid := "session-precedence"
|
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
|
UUID: uuid,
|
|
Username: "testuser",
|
|
Provider: "local",
|
|
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
|
CreatedAt: time.Now().Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
|
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
|
userCtx, _ := args.do(req)
|
|
|
|
require.NotNil(t, userCtx)
|
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
|
assert.True(t, userCtx.Authenticated)
|
|
},
|
|
},
|
|
{
|
|
description: "Ensure fallback to basic auth when cookie is missing",
|
|
run: func(t *testing.T, args runArgs) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
|
userCtx, _ := args.do(req)
|
|
|
|
require.NotNil(t, userCtx)
|
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
|
assert.True(t, userCtx.Authenticated)
|
|
},
|
|
},
|
|
}
|
|
|
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
|
|
|
store := memory.New()
|
|
|
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
|
err := ldap.Init()
|
|
require.NoError(t, err)
|
|
|
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
|
err = broker.Init()
|
|
require.NoError(t, err)
|
|
|
|
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
|
|
err = authService.Init()
|
|
require.NoError(t, err)
|
|
|
|
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
|
err = contextMiddleware.Init()
|
|
require.NoError(t, err)
|
|
|
|
for _, test := range tests {
|
|
authService.ClearRateLimitsTestingOnly()
|
|
t.Run(test.description, func(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
|
|
var captured *model.UserContext
|
|
router := gin.New()
|
|
router.Use(contextMiddleware.Middleware())
|
|
handler := func(c *gin.Context) {
|
|
if val, exists := c.Get("context"); exists {
|
|
captured, _ = val.(*model.UserContext)
|
|
}
|
|
}
|
|
router.GET("/api/test", handler)
|
|
router.GET("/api/healthz", handler)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
router.ServeHTTP(recorder, req)
|
|
return captured, recorder
|
|
}
|
|
|
|
test.run(t, runArgs{do: do, queries: store})
|
|
})
|
|
}
|
|
}
|