From 39beed706b6002fcc1cb748214854f1af6ca8bb4 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 28 Mar 2026 20:26:47 +0200 Subject: [PATCH] tests: add tests for user controller --- internal/controller/oidc_controller_test.go | 10 +- internal/controller/user_controller_test.go | 559 +++++++++++--------- internal/service/auth_service.go | 7 + 3 files changed, 309 insertions(+), 267 deletions(-) diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 18e5749..ac8a4cf 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -431,18 +431,12 @@ func TestOIDCController(t *testing.T) { app := bootstrap.NewBootstrapApp(config.Config{}) db, err := app.SetupDatabase("/tmp/tinyauth_test.db") - - if err != nil { - t.Fatalf("Failed to set up database: %v", err) - } + assert.NoError(t, err) queries := repository.New(db) oidcService := service.NewOIDCService(oidcServiceCfg, queries) err = oidcService.Init() - - if err != nil { - t.Fatalf("Failed to initialize OIDC service: %v", err) - } + assert.NoError(t, err) for _, test := range tests { t.Run(test.description, func(t *testing.T) { diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 672740c..06686e6 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -2,305 +2,346 @@ package controller_test import ( "encoding/json" - "net/http" "net/http/httptest" + "slices" "strings" "testing" "time" + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" "github.com/steveiliop56/tinyauth/internal/bootstrap" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/utils/tlog" - - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" ) -var cookieValue string -var totpSecret = "6WFZXPEZRK5MZHHYAFW4DAOUYQMCASBJ" - -func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) { - tlog.NewSimpleLogger().Init() - - // Setup - gin.SetMode(gin.TestMode) - router := gin.Default() - - if middlewares != nil { - for _, m := range *middlewares { - router.Use(m) - } - } - - group := router.Group("/api") - recorder := httptest.NewRecorder() - - // Mock app - app := bootstrap.NewBootstrapApp(config.Config{}) - - // Database - db, err := app.SetupDatabase(":memory:") - - assert.NilError(t, err) - - // Queries - queries := repository.New(db) - - // Auth service - authService := service.NewAuthService(service.AuthServiceConfig{ +func TestUserController(t *testing.T) { + authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ { Username: "testuser", - Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", }, { Username: "totpuser", - Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test - TotpSecret: totpSecret, + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", + TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", }, }, - OauthWhitelist: []string{}, - SessionExpiry: 3600, - SessionMaxLifetime: 0, - SecureCookie: false, - CookieDomain: "localhost", - LoginTimeout: 300, - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - }, nil, nil, queries, &service.OAuthBrokerService{}) - - // Controller - ctrl := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: "localhost", - }, group, authService) - ctrl.SetupRoutes() - - return router, recorder -} - -func TestLoginHandler(t *testing.T) { - // Setup - router, recorder := setupUserController(t, nil) - - loginReq := controller.LoginRequest{ - Username: "testuser", - Password: "test", + SessionExpiry: 10, // 10 seconds, useful for testing + CookieDomain: "example.com", + LoginTimeout: 10, // 10 seconds, useful for testing + LoginMaxRetries: 3, + SessionCookieName: "tinyauth-session", } - loginReqJson, err := json.Marshal(loginReq) - assert.NilError(t, err) - - // Test - req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - cookie := recorder.Result().Cookies()[0] - - assert.Equal(t, "tinyauth-session", cookie.Name) - assert.Assert(t, cookie.Value != "") - - cookieValue = cookie.Value - - // Test invalid credentials - loginReq = controller.LoginRequest{ - Username: "testuser", - Password: "invalid", + userControllerCfg := controller.UserControllerConfig{ + CookieDomain: "example.com", } - loginReqJson, err = json.Marshal(loginReq) - assert.NilError(t, err) - - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 401, recorder.Code) - - // Test totp required - loginReq = controller.LoginRequest{ - Username: "totpuser", - Password: "test", + type testCase struct { + description string + middlewares []gin.HandlerFunc + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } - loginReqJson, err = json.Marshal(loginReq) - assert.NilError(t, err) + tests := []testCase{ + { + description: "Should be able to login with valid credentials", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - assert.Equal(t, 200, recorder.Code) + router.ServeHTTP(recorder, req) - loginResJson, err := json.Marshal(map[string]any{ - "message": "TOTP required", - "status": 200, - "totpPending": true, - }) + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) - assert.NilError(t, err) - assert.Equal(t, string(loginResJson), recorder.Body.String()) - - // Test invalid json - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader("{invalid json}")) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 400, recorder.Code) - - // Test rate limiting - loginReq = controller.LoginRequest{ - Username: "testuser", - Password: "invalid", - } - - loginReqJson, err = json.Marshal(loginReq) - assert.NilError(t, err) - - for range 5 { - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) - } - - assert.Equal(t, 429, recorder.Code) -} - -func TestLogoutHandler(t *testing.T) { - // Setup - router, recorder := setupUserController(t, nil) - - // Test - req := httptest.NewRequest("POST", "/api/user/logout", nil) - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookieValue, - }) - - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - cookie := recorder.Result().Cookies()[0] - - assert.Equal(t, "tinyauth-session", cookie.Name) - assert.Equal(t, "", cookie.Value) - assert.Equal(t, -1, cookie.MaxAge) -} - -func TestTotpHandler(t *testing.T) { - // Setup - router, recorder := setupUserController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "totpuser", - Email: "totpuser@example.com", - IsLoggedIn: false, - OAuth: false, - Provider: "local", - TotpPending: true, - OAuthGroups: "", - TotpEnabled: true, - }) - c.Next() + cookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, "example.com", cookie.Domain) + assert.Equal(t, cookie.MaxAge, 10) + }, }, - }) + { + description: "Should reject login with invalid credentials", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "wrongpassword", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - // Test - code, err := totp.GenerateCode(totpSecret, time.Now()) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - assert.NilError(t, err) + router.ServeHTTP(recorder, req) - totpReq := controller.TotpRequest{ - Code: code, - } - - totpReqJson, err := json.Marshal(totpReq) - assert.NilError(t, err) - - req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - cookie := recorder.Result().Cookies()[0] - - assert.Equal(t, "tinyauth-session", cookie.Name) - assert.Assert(t, cookie.Value != "") - - // Test invalid json - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader("{invalid json}")) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 400, recorder.Code) - - // Test rate limiting - totpReq = controller.TotpRequest{ - Code: "000000", - } - - totpReqJson, err = json.Marshal(totpReq) - assert.NilError(t, err) - - for range 5 { - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) - } - - assert.Equal(t, 429, recorder.Code) - - // Test invalid code - router, recorder = setupUserController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "totpuser", - Email: "totpuser@example.com", - IsLoggedIn: false, - OAuth: false, - Provider: "local", - TotpPending: true, - OAuthGroups: "", - TotpEnabled: true, - }) - c.Next() + assert.Equal(t, 401, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 0) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + }, }, - }) + { + description: "Should rate limit on 3 invalid attempts", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "wrongpassword", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) + for range 3 { + recorder := httptest.NewRecorder() - assert.Equal(t, 401, recorder.Code) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - // Test no totp pending - router, recorder = setupUserController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "totpuser", - Email: "totpuser@example.com", - IsLoggedIn: false, - OAuth: false, - Provider: "local", - TotpPending: false, - OAuthGroups: "", - TotpEnabled: false, - }) - c.Next() + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 0) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + } + + // 4th attempt should be rate limited + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 429, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Too many failed login attempts.") + }, }, - }) + { + description: "Should not allow full login with totp", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "totpuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - assert.Equal(t, 401, recorder.Code) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + decodedBody := make(map[string]any) + err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + assert.NoError(t, err) + + assert.Equal(t, decodedBody["totpPending"], true) + + // should set the session cookie + assert.Len(t, recorder.Result().Cookies(), 1) + cookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, "example.com", cookie.Domain) + assert.Equal(t, cookie.MaxAge, 3600) // 1 hour, default for totp pending sessions + }, + }, + { + description: "Should be able to logout", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + // First login to get a session cookie + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) + + cookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + + // Now logout using the session cookie + recorder = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/user/logout", nil) + req.AddCookie(cookie) + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) + + logoutCookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", logoutCookie.Name) + assert.Equal(t, "", logoutCookie.Value) + assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie + }, + }, + { + description: "Should be able to login with totp", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) + assert.NoError(t, err) + + totpReq := controller.TotpRequest{ + Code: code, + } + + totpReqBody, err := json.Marshal(totpReq) + assert.NoError(t, err) + + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) + + // should set a new session cookie with totp pending removed + totpCookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", totpCookie.Name) + assert.True(t, totpCookie.HttpOnly) + assert.Equal(t, "example.com", totpCookie.Domain) + assert.Equal(t, totpCookie.MaxAge, 10) // should use the regular session expiry time + }, + }, + { + description: "Totp should rate limit on multiple invalid attempts", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + for range 3 { + totpReq := controller.TotpRequest{ + Code: "000000", // invalid code + } + + totpReqBody, err := json.Marshal(totpReq) + assert.NoError(t, err) + + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + } + + // 4th attempt should be rate limited + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(`{"code":"000000"}`))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 429, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Too many failed TOTP attempts.") + }, + }, + } + + tlog.NewSimpleLogger().Init() + + oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + + app := bootstrap.NewBootstrapApp(config.Config{}) + + db, err := app.SetupDatabase("/tmp/tinyauth_test.db") + assert.NoError(t, err) + + queries := repository.New(db) + + docker := service.NewDockerService() + err = docker.Init() + assert.NoError(t, err) + + ldap := service.NewLdapService(service.LdapServiceConfig{}) + err = ldap.Init() + assert.NoError(t, err) + + broker := service.NewOAuthBrokerService(oauthBrokerCfgs) + err = broker.Init() + assert.NoError(t, err) + + authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker) + err = authService.Init() + assert.NoError(t, err) + + beforeEach := func() { + // Clear failed login attempts before each test + authService.ClearRateLimitsTestingOnly() + } + + setTotpMiddlewareOverrides := []string{ + "Should be able to login with totp", + "Totp should rate limit on multiple invalid attempts", + } + + for _, test := range tests { + beforeEach() + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + + for _, middleware := range test.middlewares { + router.Use(middleware) + } + + // Gin is stupid and doesn't allow setting a middleware after the groups + // so we need to do some stupid overrides here + if slices.Contains(setTotpMiddlewareOverrides, test.description) { + // Assuming the cookie is set, it should be picked up by the + // context middleware + router.Use(func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", + Provider: "local", + TotpPending: true, + TotpEnabled: true, + }) + }) + } + + group := router.Group("/api") + gin.SetMode(gin.TestMode) + + userController := controller.NewUserController(userControllerCfg, group, authService) + userController.SetupRoutes() + + recorder := httptest.NewRecorder() + + test.run(t, router, recorder) + }) + } } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 53c879d..0d1a598 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -746,3 +746,10 @@ func (auth *AuthService) ensureOAuthSessionLimit() { } } } + +// Function only used for testing - do not use in prod! +func (auth *AuthService) ClearRateLimitsTestingOnly() { + auth.loginMutex.Lock() + auth.loginAttempts = make(map[string]*LoginAttempt) + auth.loginMutex.Unlock() +}