package middleware_test import ( "context" "encoding/base64" "net/http" "net/http/httptest" "path" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestContextMiddleware(t *testing.T) { tlog.NewTestLogger().Init() tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ { Username: "testuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password }, { Username: "totpuser", Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", }, }, SessionExpiry: 10, // 10 seconds, useful for testing CookieDomain: "example.com", LoginTimeout: 10, // 10 seconds, useful for testing LoginMaxRetries: 3, SessionCookieName: "tinyauth-session", } middlewareCfg := middleware.ContextMiddlewareConfig{ CookieDomain: "example.com", SessionCookieName: "tinyauth-session", } basicAuthHeader := func(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) } seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { t.Helper() _, err := queries.CreateSession(context.Background(), params) require.NoError(t, err) } type runArgs struct { do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) queries *repository.Queries } type testCase struct { description string run func(t *testing.T, args runArgs) } tests := []testCase{ { description: "Skip path bypasses auth processing", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/healthz", nil) req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "No credentials yields no context", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/test", nil) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "Valid session cookie sets authenticated local context", run: func(t *testing.T, args runArgs) { uuid := "session-valid-local" seedSession(t, args.queries, repository.CreateSessionParams{ UUID: uuid, Username: "testuser", Provider: "local", Expiry: time.Now().Add(10 * time.Second).Unix(), CreatedAt: time.Now().Unix(), }) req := httptest.NewRequest("GET", "/api/test", nil) req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) userCtx, _ := args.do(req) require.NotNil(t, userCtx) assert.Equal(t, model.ProviderLocal, userCtx.Provider) assert.Equal(t, "testuser", userCtx.GetUsername()) assert.True(t, userCtx.Authenticated) require.NotNil(t, userCtx.Local) assert.False(t, userCtx.Local.TOTPEnabled) }, }, { description: "Session cookie with totp pending sets unauthenticated context with totp enabled", run: func(t *testing.T, args runArgs) { uuid := "session-totp-pending" seedSession(t, args.queries, repository.CreateSessionParams{ UUID: uuid, Username: "totpuser", Provider: "local", TotpPending: true, Expiry: time.Now().Add(60 * time.Second).Unix(), CreatedAt: time.Now().Unix(), }) req := httptest.NewRequest("GET", "/api/test", nil) req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) userCtx, _ := args.do(req) require.NotNil(t, userCtx) assert.Equal(t, "totpuser", userCtx.GetUsername()) assert.False(t, userCtx.Authenticated) require.NotNil(t, userCtx.Local) assert.True(t, userCtx.Local.TOTPPending) assert.True(t, userCtx.Local.TOTPEnabled) }, }, { description: "Unknown session cookie yields no context", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/test", nil) req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"}) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "Session for missing local user yields no context", run: func(t *testing.T, args runArgs) { uuid := "session-deleted-user" seedSession(t, args.queries, repository.CreateSessionParams{ UUID: uuid, Username: "ghostuser", Provider: "local", Expiry: time.Now().Add(10 * time.Second).Unix(), CreatedAt: time.Now().Unix(), }) req := httptest.NewRequest("GET", "/api/test", nil) req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "Expired session cookie yields no context", run: func(t *testing.T, args runArgs) { uuid := "session-expired" seedSession(t, args.queries, repository.CreateSessionParams{ UUID: uuid, Username: "testuser", Provider: "local", Expiry: time.Now().Add(-1 * time.Second).Unix(), CreatedAt: time.Now().Add(-10 * time.Second).Unix(), }) req := httptest.NewRequest("GET", "/api/test", nil) req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "Valid basic auth sets authenticated local context", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) userCtx, _ := args.do(req) require.NotNil(t, userCtx) assert.Equal(t, model.ProviderLocal, userCtx.Provider) assert.Equal(t, "testuser", userCtx.GetUsername()) assert.True(t, userCtx.Authenticated) }, }, { description: "Invalid basic auth password yields no context", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword")) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "Basic auth is rejected for users with totp", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", basicAuthHeader("totpuser", "password")) userCtx, _ := args.do(req) assert.Nil(t, userCtx) }, }, { description: "Locked account on basic auth sets lock headers", run: func(t *testing.T, args runArgs) { for range 3 { req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword")) args.do(req) } req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) userCtx, recorder := args.do(req) assert.Nil(t, userCtx) assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked")) assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset")) }, }, { description: "Cookie auth takes precedence over basic auth", run: func(t *testing.T, args runArgs) { uuid := "session-precedence" seedSession(t, args.queries, repository.CreateSessionParams{ UUID: uuid, Username: "testuser", Provider: "local", Expiry: time.Now().Add(10 * time.Second).Unix(), CreatedAt: time.Now().Unix(), }) req := httptest.NewRequest("GET", "/api/test", nil) req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid}) req.Header.Set("Authorization", basicAuthHeader("totpuser", "password")) userCtx, _ := args.do(req) require.NotNil(t, userCtx) assert.Equal(t, "testuser", userCtx.GetUsername()) assert.True(t, userCtx.Authenticated) }, }, { description: "Ensure fallback to basic auth when cookie is missing", run: func(t *testing.T, args runArgs) { req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", basicAuthHeader("testuser", "password")) userCtx, _ := args.do(req) require.NotNil(t, userCtx) assert.Equal(t, "testuser", userCtx.GetUsername()) assert.True(t, userCtx.Authenticated) }, }, } oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) app := bootstrap.NewBootstrapApp(model.Config{}) db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) queries := repository.New(db) ldap := service.NewLdapService(service.LdapServiceConfig{}) err = ldap.Init() require.NoError(t, err) broker := service.NewOAuthBrokerService(oauthBrokerCfgs) err = broker.Init() require.NoError(t, err) authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) err = authService.Init() require.NoError(t, err) contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker) err = contextMiddleware.Init() require.NoError(t, err) for _, test := range tests { authService.ClearRateLimitsTestingOnly() t.Run(test.description, func(t *testing.T) { gin.SetMode(gin.TestMode) do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) { var captured *model.UserContext router := gin.New() router.Use(contextMiddleware.Middleware()) handler := func(c *gin.Context) { if val, exists := c.Get("context"); exists { captured, _ = val.(*model.UserContext) } } router.GET("/api/test", handler) router.GET("/api/healthz", handler) recorder := httptest.NewRecorder() router.ServeHTTP(recorder, req) return captured, recorder } test.run(t, runArgs{do: do, queries: queries}) }) } t.Cleanup(func() { err = db.Close() require.NoError(t, err) }) }