diff --git a/go.mod b/go.mod index 9b924ab..f3f2fd0 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/mdp/qrterminal/v3 v3.2.1 github.com/pquerna/otp v1.5.0 github.com/rs/zerolog v1.34.0 + github.com/stretchr/testify v1.11.1 github.com/traefik/paerser v0.2.2 github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.49.0 @@ -52,6 +53,7 @@ require ( github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/log v0.1.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect @@ -96,6 +98,7 @@ require ( github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 5f5e6e9..9a1ba45 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -2,152 +2,131 @@ package controller_test import ( "encoding/json" + "net/http" "net/http/httptest" "testing" + "github.com/gin-gonic/gin" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" - "github.com/steveiliop56/tinyauth/internal/utils/tlog" - - "github.com/gin-gonic/gin" - "gotest.tools/v3/assert" + "github.com/steveiliop56/tinyauth/internal/utils" + "github.com/stretchr/testify/assert" ) -var contextControllerCfg = controller.ContextControllerConfig{ - Providers: []controller.Provider{ +func TestContextController(t *testing.T) { + controllerConfig := controller.ContextControllerConfig{ + Providers: []controller.Provider{ + { + Name: "Local", + ID: "local", + OAuth: false, + }, + }, + Title: "Tinyauth", + AppURL: "https://tinyauth.example.com", + CookieDomain: "example.com", + ForgotPasswordMessage: "foo", + BackgroundImage: "/background.jpg", + OAuthAutoRedirect: "none", + WarningsEnabled: true, + } + + tests := []struct { + description string + middlewares []gin.HandlerFunc + expected string + path string + }{ { - Name: "Local", - ID: "local", - OAuth: false, + description: "Ensure context controller returns app context", + middlewares: []gin.HandlerFunc{}, + path: "/api/context/app", + expected: func() string { + expectedAppContextResponse := controller.AppContextResponse{ + Status: 200, + Message: "Success", + Providers: controllerConfig.Providers, + Title: controllerConfig.Title, + AppURL: controllerConfig.AppURL, + CookieDomain: controllerConfig.CookieDomain, + ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage, + BackgroundImage: controllerConfig.BackgroundImage, + OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect, + WarningsEnabled: controllerConfig.WarningsEnabled, + } + bytes, err := json.Marshal(expectedAppContextResponse) + assert.NoError(t, err) + return string(bytes) + }(), }, { - Name: "Google", - ID: "google", - OAuth: true, + description: "Ensure user context returns 401 when unauthorized", + middlewares: []gin.HandlerFunc{}, + path: "/api/context/user", + expected: func() string { + expectedUserContextResponse := controller.UserContextResponse{ + Status: 401, + Message: "Unauthorized", + } + bytes, err := json.Marshal(expectedUserContextResponse) + assert.NoError(t, err) + return string(bytes) + }(), }, - }, - Title: "Test App", - AppURL: "http://localhost:8080", - CookieDomain: "localhost", - ForgotPasswordMessage: "Contact admin to reset your password.", - BackgroundImage: "/assets/bg.jpg", - OAuthAutoRedirect: "google", - WarningsEnabled: true, -} - -var contextCtrlTestContext = config.UserContext{ - Username: "testuser", - Name: "testuser", - Email: "test@example.com", - IsLoggedIn: true, - IsBasicAuth: false, - OAuth: false, - Provider: "local", - TotpPending: false, - OAuthGroups: "", - TotpEnabled: false, - OAuthSub: "", -} - -func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) { - tlog.NewSimpleLogger().Init() - - // Setup - gin.SetMode(gin.TestMode) - router := gin.Default() - recorder := httptest.NewRecorder() - - if middlewares != nil { - for _, m := range *middlewares { - router.Use(m) - } - } - - group := router.Group("/api") - - ctrl := controller.NewContextController(contextControllerCfg, group) - ctrl.SetupRoutes() - - return router, recorder -} - -func TestAppContextHandler(t *testing.T) { - expectedRes := controller.AppContextResponse{ - Status: 200, - Message: "Success", - Providers: contextControllerCfg.Providers, - Title: contextControllerCfg.Title, - AppURL: contextControllerCfg.AppURL, - CookieDomain: contextControllerCfg.CookieDomain, - ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage, - BackgroundImage: contextControllerCfg.BackgroundImage, - OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect, - WarningsEnabled: contextControllerCfg.WarningsEnabled, - } - - router, recorder := setupContextController(nil) - req := httptest.NewRequest("GET", "/api/context/app", nil) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - var ctrlRes controller.AppContextResponse - - err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes) - - assert.NilError(t, err) - assert.DeepEqual(t, expectedRes, ctrlRes) -} - -func TestUserContextHandler(t *testing.T) { - expectedRes := controller.UserContextResponse{ - Status: 200, - Message: "Success", - IsLoggedIn: contextCtrlTestContext.IsLoggedIn, - Username: contextCtrlTestContext.Username, - Name: contextCtrlTestContext.Name, - Email: contextCtrlTestContext.Email, - Provider: contextCtrlTestContext.Provider, - OAuth: contextCtrlTestContext.OAuth, - TotpPending: contextCtrlTestContext.TotpPending, - OAuthName: contextCtrlTestContext.OAuthName, - } - - // Test with context - router, recorder := setupContextController(&[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &contextCtrlTestContext) - c.Next() + { + description: "Ensure user context returns when authorized", + middlewares: []gin.HandlerFunc{ + func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "johndoe", + Name: "John Doe", + Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Provider: "local", + IsLoggedIn: true, + }) + }, + }, + path: "/api/context/user", + expected: func() string { + expectedUserContextResponse := controller.UserContextResponse{ + Status: 200, + Message: "Success", + IsLoggedIn: true, + Username: "johndoe", + Name: "John Doe", + Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Provider: "local", + } + bytes, err := json.Marshal(expectedUserContextResponse) + assert.NoError(t, err) + return string(bytes) + }(), }, - }) - - req := httptest.NewRequest("GET", "/api/context/user", nil) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - var ctrlRes controller.UserContextResponse - - err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes) - - assert.NilError(t, err) - assert.DeepEqual(t, expectedRes, ctrlRes) - - // Test no context - expectedRes = controller.UserContextResponse{ - Status: 401, - Message: "Unauthorized", - IsLoggedIn: false, } - router, recorder = setupContextController(nil) - req = httptest.NewRequest("GET", "/api/context/user", nil) - router.ServeHTTP(recorder, req) + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() - assert.Equal(t, 200, recorder.Code) + for _, middleware := range test.middlewares { + router.Use(middleware) + } - err = json.Unmarshal(recorder.Body.Bytes(), &ctrlRes) + group := router.Group("/api") + gin.SetMode(gin.TestMode) - assert.NilError(t, err) - assert.DeepEqual(t, expectedRes, ctrlRes) + contextController := controller.NewContextController(controllerConfig, group) + contextController.SetupRoutes() + + recorder := httptest.NewRecorder() + + request, err := http.NewRequest("GET", test.path, nil) + assert.NoError(t, err) + + router.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, test.expected, recorder.Body.String()) + }) + } } diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 8ad67b5..1b9adbf 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -19,7 +19,7 @@ func (controller *HealthController) SetupRoutes() { func (controller *HealthController) healthHandler(c *gin.Context) { c.JSON(200, gin.H{ - "status": "ok", + "status": 200, "message": "Healthy", }) } diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go new file mode 100644 index 0000000..c8cb36f --- /dev/null +++ b/internal/controller/health_controller_test.go @@ -0,0 +1,71 @@ +package controller_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/steveiliop56/tinyauth/internal/controller" + "github.com/stretchr/testify/assert" +) + +func TestHealthController(t *testing.T) { + tests := []struct { + description string + path string + method string + expected string + }{ + { + description: "Ensure health endpoint returns 200 OK", + path: "/api/healthz", + method: "GET", + expected: func() string { + expectedHealthResponse := map[string]any{ + "status": 200, + "message": "Healthy", + } + bytes, err := json.Marshal(expectedHealthResponse) + assert.NoError(t, err) + return string(bytes) + }(), + }, + { + description: "Ensure health endpoint returns 200 OK for HEAD request", + path: "/api/healthz", + method: "HEAD", + expected: func() string { + expectedHealthResponse := map[string]any{ + "status": 200, + "message": "Healthy", + } + bytes, err := json.Marshal(expectedHealthResponse) + assert.NoError(t, err) + return string(bytes) + }(), + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + group := router.Group("/api") + gin.SetMode(gin.TestMode) + + healthController := controller.NewHealthController(group) + healthController.SetupRoutes() + + recorder := httptest.NewRecorder() + + request, err := http.NewRequest(test.method, test.path, nil) + assert.NoError(t, err) + + router.ServeHTTP(recorder, request) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, test.expected, recorder.Body.String()) + }) + } +} diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 160ca2d..76b096d 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -235,7 +235,7 @@ func (controller *OIDCController) Token(c *gin.Context) { if !ok { tlog.App.Error().Msg("Missing authorization header") - c.Header("www-authenticate", "basic") + c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.JSON(400, gin.H{ "error": "invalid_client", }) diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index e6910a5..c3943f7 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -2,10 +2,9 @@ package controller_test import ( "encoding/json" - "fmt" - "net/http" "net/http/httptest" "net/url" + "path" "strings" "testing" @@ -16,266 +15,456 @@ import ( "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/service" - "github.com/steveiliop56/tinyauth/internal/utils/tlog" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var oidcServiceConfig = service.OIDCServiceConfig{ - Clients: map[string]config.OIDCClientConfig{ - "client1": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - ClientSecretFile: "", - TrustedRedirectURIs: []string{ - "https://example.com/oauth/callback", - }, - Name: "Client 1", - }, - }, - PrivateKeyPath: "/tmp/tinyauth_oidc_key", - PublicKeyPath: "/tmp/tinyauth_oidc_key.pub", - Issuer: "https://example.com", - SessionExpiry: 3600, -} - -var oidcCtrlTestContext = config.UserContext{ - Username: "test", - Name: "Test", - Email: "test@example.com", - IsLoggedIn: true, - IsBasicAuth: false, - OAuth: false, - Provider: "ldap", // ldap in order to test the groups - TotpPending: false, - OAuthGroups: "", - TotpEnabled: false, - OAuthName: "", - OAuthSub: "", - LdapGroups: "test1,test2", -} - -// Test is not amazing, but it will confirm the OIDC server works func TestOIDCController(t *testing.T) { - tlog.NewSimpleLogger().Init() + tempDir := t.TempDir() - // Create an app instance - app := bootstrap.NewBootstrapApp(config.Config{}) - - // Get db - db, err := app.SetupDatabase("/tmp/tinyauth.db") - assert.NilError(t, err) - - // Create queries - queries := repository.New(db) - - // Create a new OIDC Servicee - oidcService := service.NewOIDCService(oidcServiceConfig, queries) - err = oidcService.Init() - assert.NilError(t, err) - - // Create test router - gin.SetMode(gin.TestMode) - router := gin.Default() - - router.Use(func(c *gin.Context) { - c.Set("context", &oidcCtrlTestContext) - c.Next() - }) - - group := router.Group("/api") - - // Register oidc controller - oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group) - oidcController.SetupRoutes() - - // Get redirect URL test - recorder := httptest.NewRecorder() - - marshalled, err := json.Marshal(service.AuthorizeRequest{ - Scope: "openid profile email groups", - ResponseType: "code", - ClientID: "some-client-id", - RedirectURI: "https://example.com/oauth/callback", - State: "some-state", - }) - - assert.NilError(t, err) - - req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled))) - assert.NilError(t, err) - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson := map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - assert.NilError(t, err) - - redirect_uri, ok := resJson["redirect_uri"].(string) - assert.Assert(t, ok) - - u, err := url.Parse(redirect_uri) - assert.NilError(t, err) - - m, err := url.ParseQuery(u.RawQuery) - assert.NilError(t, err) - assert.Equal(t, m["state"][0], "some-state") - - code := m["code"][0] - - // Exchange code for token - recorder = httptest.NewRecorder() - - params, err := query.Values(controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://example.com/oauth/callback", - }) - - assert.NilError(t, err) - - req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) - - assert.NilError(t, err) - - req.Header.Set("content-type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson = map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - assert.NilError(t, err) - - accessToken, ok := resJson["access_token"].(string) - assert.Assert(t, ok) - - _, ok = resJson["id_token"].(string) - assert.Assert(t, ok) - - refreshToken, ok := resJson["refresh_token"].(string) - assert.Assert(t, ok) - - expires_in, ok := resJson["expires_in"].(float64) - assert.Assert(t, ok) - assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry)) - - // Ensure code is expired - recorder = httptest.NewRecorder() - - params, err = query.Values(controller.TokenRequest{ - GrantType: "authorization_code", - Code: code, - RedirectURI: "https://example.com/oauth/callback", - }) - - assert.NilError(t, err) - - req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) - - assert.NilError(t, err) - - req.Header.Set("content-type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusBadRequest, recorder.Code) - - // Test userinfo - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) - assert.NilError(t, err) - - req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson = map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - assert.NilError(t, err) - - _, ok = resJson["sub"].(string) - assert.Assert(t, ok) - - name, ok := resJson["name"].(string) - assert.Assert(t, ok) - assert.Equal(t, name, oidcCtrlTestContext.Name) - - email, ok := resJson["email"].(string) - assert.Assert(t, ok) - assert.Equal(t, email, oidcCtrlTestContext.Email) - - preferred_username, ok := resJson["preferred_username"].(string) - assert.Assert(t, ok) - assert.Equal(t, preferred_username, oidcCtrlTestContext.Username) - - // Not sure why this is failing, will look into it later - igroups, ok := resJson["groups"].([]any) - assert.Assert(t, ok) - - groups := make([]string, len(igroups)) - for i, group := range igroups { - groups[i], ok = group.(string) - assert.Assert(t, ok) + oidcServiceCfg := service.OIDCServiceConfig{ + Clients: map[string]config.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: path.Join(tempDir, "key.pem"), + PublicKeyPath: path.Join(tempDir, "key.pub"), + Issuer: "https://tinyauth.example.com", + SessionExpiry: 500, } - assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups) + controllerCfg := controller.OIDCControllerConfig{} - // Test refresh token - recorder = httptest.NewRecorder() + simpleCtx := func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "test", + Name: "Test User", + Email: "test@example.com", + IsLoggedIn: true, + Provider: "local", + }) + c.Next() + } - params, err = query.Values(controller.TokenRequest{ - GrantType: "refresh_token", - RefreshToken: refreshToken, + type testCase struct { + description string + middlewares []gin.HandlerFunc + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + } + + var tests []testCase + + getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) { + for _, test := range tests { + if test.description == description { + return test.run, true + } + } + return nil, false + } + + tests = []testCase{ + { + description: "Ensure we can fetch the client", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil) + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure API fails on non-existent client ID", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil) + router.ServeHTTP(recorder, req) + assert.Equal(t, 404, recorder.Code) + }, + }, + { + description: "Ensure authorize fails with empty context", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("POST", "/api/oidc/authorize", nil) + router.ServeHTTP(recorder, req) + + var res map[string]any + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") + }, + }, + { + description: "Ensure authorize fails with an invalid param", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "some_unsupported_response_type", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") + }, + }, + { + description: "Ensure authorize succeeds with valid params", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := service.AuthorizeRequest{ + Scope: "openid", + ResponseType: "code", + ClientID: "some-client-id", + RedirectURI: "https://test.example.com/callback", + State: "some-state", + Nonce: "some-nonce", + } + reqBodyBytes, err := json.Marshal(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + redirectURI := res["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + assert.Equal(t, queryParams.Get("state"), "some-state") + + code := queryParams.Get("code") + assert.NotEmpty(t, code) + }, + }, + { + description: "Ensure token request fails with invalid grant", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := controller.TokenRequest{ + GrantType: "invalid_grant", + Code: "", + RedirectURI: "https://test.example.com/callback", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + var res map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + assert.Equal(t, res["error"], "unsupported_grant_type") + }, + }, + { + description: "Ensure token endpoint accepts basic auth", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: "some-code", + RedirectURI: "https://test.example.com/callback", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Empty(t, recorder.Header().Get("www-authenticate")) + }, + }, + { + description: "Ensure token endpoint accepts form auth", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", "some-code") + form.Set("redirect_uri", "https://test.example.com/callback") + form.Set("client_id", "some-client-id") + form.Set("client_secret", "some-client-secret") + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.Empty(t, recorder.Header().Get("www-authenticate")) + }, + }, + { + description: "Ensure token endpoint sets authenticate header when no auth is available", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: "some-code", + RedirectURI: "https://test.example.com/callback", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + authHeader := recorder.Header().Get("www-authenticate") + assert.Contains(t, authHeader, "Basic") + }, + }, + { + description: "Ensure we can get a token with a valid request", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") + assert.True(t, found, "Authorize test not found") + authorizeTestRecorder := httptest.NewRecorder() + authorizeCodeTest(t, router, authorizeTestRecorder) + + var authorizeRes map[string]any + err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) + assert.NoError(t, err) + + redirectURI := authorizeRes["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure we can renew the access token with the refresh token", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") + assert.True(t, found, "Token test not found") + tokenRecorder := httptest.NewRecorder() + tokenTest(t, router, tokenRecorder) + + var tokenRes map[string]any + err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + assert.NoError(t, err) + + _, ok := tokenRes["refresh_token"] + assert.True(t, ok, "Expected refresh token in response") + refreshToken := tokenRes["refresh_token"].(string) + assert.NotEmpty(t, refreshToken) + + reqBody := controller.TokenRequest{ + GrantType: "refresh_token", + RefreshToken: refreshToken, + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + router.ServeHTTP(recorder, req) + + assert.NotEmpty(t, recorder.Header().Get("cache-control")) + assert.NotEmpty(t, recorder.Header().Get("pragma")) + + assert.Equal(t, 200, recorder.Code) + var refreshRes map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) + assert.NoError(t, err) + + _, ok = refreshRes["access_token"] + assert.True(t, ok, "Expected access token in refresh response") + assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string)) + assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string)) + }, + }, + { + description: "Ensure token endpoint deletes code after use", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params") + assert.True(t, found, "Authorize test not found") + authorizeTestRecorder := httptest.NewRecorder() + authorizeCodeTest(t, router, authorizeTestRecorder) + + var authorizeRes map[string]any + err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) + assert.NoError(t, err) + + redirectURI := authorizeRes["redirect_uri"].(string) + url, err := url.Parse(redirectURI) + assert.NoError(t, err) + + queryParams := url.Query() + code := queryParams.Get("code") + assert.NotEmpty(t, code) + + reqBody := controller.TokenRequest{ + GrantType: "authorization_code", + Code: code, + RedirectURI: "https://test.example.com/callback", + } + reqBodyEncoded, err := query.Values(reqBody) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth("some-client-id", "some-client-secret") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + // Try to use the same code again + secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) + secondReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + secondReq.SetBasicAuth("some-client-id", "some-client-secret") + secondRecorder := httptest.NewRecorder() + router.ServeHTTP(secondRecorder, secondReq) + + assert.Equal(t, 400, secondRecorder.Code) + + var secondRes map[string]any + err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) + assert.NoError(t, err) + + assert.Equal(t, secondRes["error"], "invalid_grant") + }, + }, + { + description: "Ensure userinfo forbids access with invalid access token", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + req.Header.Set("Authorization", "Bearer invalid-access-token") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + }, + }, + { + description: "Ensure access token can be used to access protected resources", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request") + assert.True(t, found, "Token test not found") + tokenRecorder := httptest.NewRecorder() + tokenTest(t, router, tokenRecorder) + + var tokenRes map[string]any + err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) + assert.NoError(t, err) + + accessToken := tokenRes["access_token"].(string) + assert.NotEmpty(t, accessToken) + + protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil) + protectedReq.Header.Set("Authorization", "Bearer "+accessToken) + router.ServeHTTP(recorder, protectedReq) + assert.Equal(t, 200, recorder.Code) + + var userInfoRes map[string]any + err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) + assert.NoError(t, err) + + _, ok := userInfoRes["sub"] + assert.True(t, ok, "Expected sub claim in userinfo response") + + // We should not have an email claim since we didn't request it in the scope + _, ok = userInfoRes["email"] + assert.False(t, ok, "Did not expect email claim in userinfo response") + }, + }, + } + + app := bootstrap.NewBootstrapApp(config.Config{}) + + db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + require.NoError(t, err) + + queries := repository.New(db) + oidcService := service.NewOIDCService(oidcServiceCfg, queries) + err = oidcService.Init() + require.NoError(t, err) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + + for _, middleware := range test.middlewares { + router.Use(middleware) + } + + group := router.Group("/api") + gin.SetMode(gin.TestMode) + + oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) + oidcController.SetupRoutes() + + recorder := httptest.NewRecorder() + + test.run(t, router, recorder) + }) + } + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) }) - - assert.NilError(t, err) - - req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode())) - - assert.NilError(t, err) - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth("some-client-id", "some-client-secret") - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) - - resJson = map[string]any{} - - err = json.Unmarshal(recorder.Body.Bytes(), &resJson) - - assert.NilError(t, err) - - newToken, ok := resJson["access_token"].(string) - assert.Assert(t, ok) - assert.Assert(t, newToken != accessToken) - - // Ensure old token is invalid - recorder = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) - - assert.NilError(t, err) - - req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken)) - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusUnauthorized, recorder.Code) - - // Test new token - recorder = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil) - - assert.NilError(t, err) - - req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken)) - - router.ServeHTTP(recorder, req) - assert.Equal(t, http.StatusOK, recorder.Code) } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index f7e73ec..5d3169c 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -1,302 +1,373 @@ package controller_test import ( - "net/http" "net/http/httptest" + "path" "testing" + "github.com/gin-gonic/gin" "github.com/steveiliop56/tinyauth/internal/bootstrap" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/service" - - "github.com/gin-gonic/gin" - "gotest.tools/v3/assert" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var loggedInCtx = config.UserContext{ - Username: "test", - Name: "Test", - Email: "test@example.com", - IsLoggedIn: true, - Provider: "local", -} +func TestProxyController(t *testing.T) { + tempDir := t.TempDir() -func setupProxyController(t *testing.T, middlewares []gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) { - // Setup - gin.SetMode(gin.TestMode) - router := gin.Default() - - if len(middlewares) > 0 { - for _, m := range middlewares { - router.Use(m) - } - } - - group := router.Group("/api") - recorder := httptest.NewRecorder() - - // Mock app - app := bootstrap.NewBootstrapApp(config.Config{}) - - // Database - db, err := app.SetupDatabase(":memory:") - - assert.NilError(t, err) - - // Queries - queries := repository.New(db) - - // Docker - dockerService := service.NewDockerService() - - assert.NilError(t, dockerService.Init()) - - // Access controls - accessControlsService := service.NewAccessControlsService(dockerService, map[string]config.App{ - "whoami": { - Path: config.AppPath{ - Allow: "/allow", - }, - }, - }) - - assert.NilError(t, accessControlsService.Init()) - - // Auth service - authService := service.NewAuthService(service.AuthServiceConfig{ + authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ { Username: "testuser", - Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password }, { Username: "totpuser", - Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", - TotpSecret: "foo", + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password + TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", }, }, - OauthWhitelist: []string{}, - SessionExpiry: 3600, - SessionMaxLifetime: 0, - SecureCookie: false, - CookieDomain: "localhost", - LoginTimeout: 300, - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - }, dockerService, nil, queries, &service.OAuthBrokerService{}) + SessionExpiry: 10, // 10 seconds, useful for testing + CookieDomain: "example.com", + LoginTimeout: 10, // 10 seconds, useful for testing + LoginMaxRetries: 3, + SessionCookieName: "tinyauth-session", + } - // Controller - ctrl := controller.NewProxyController(controller.ProxyControllerConfig{ - AppURL: "http://tinyauth.example.com", - }, group, accessControlsService, authService) - ctrl.SetupRoutes() + controllerCfg := controller.ProxyControllerConfig{ + AppURL: "https://tinyauth.example.com", + } - return router, recorder -} - -// TODO: Needs tests for context middleware - -func TestProxyHandler(t *testing.T) { - // Test logged out user traefik/caddy (forward_auth) - router, recorder := setupProxyController(t, nil) - - req, err := http.NewRequest("GET", "/api/auth/traefik", nil) - assert.NilError(t, err) - - req.Header.Set("x-forwarded-host", "whoami.example.com") - req.Header.Set("x-forwarded-proto", "http") - req.Header.Set("x-forwarded-uri", "/") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - // Test logged out user nginx (auth_request) - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - assert.NilError(t, err) - - req.Header.Set("x-original-url", "http://whoami.example.com/") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - // Test logged out user envoy (ext_authz) - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/envoy?path=/", nil) - assert.NilError(t, err) - - req.Host = "whoami.example.com" - req.Header.Set("x-forwarded-proto", "http") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - // Test logged in user traefik/caddy (forward_auth) - router, recorder = setupProxyController(t, []gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &loggedInCtx) - c.Next() - }, - }) - - req, err = http.NewRequest("GET", "/api/auth/traefik", nil) - assert.NilError(t, err) - - req.Header.Set("x-forwarded-host", "whoami.example.com") - req.Header.Set("x-forwarded-proto", "http") - req.Header.Set("x-forwarded-uri", "/") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test logged in user nginx (auth_request) - router, recorder = setupProxyController(t, []gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &loggedInCtx) - c.Next() - }, - }) - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - assert.NilError(t, err) - - req.Header.Set("x-original-url", "http://whoami.example.com/") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test logged in user envoy (ext_authz) - router, recorder = setupProxyController(t, []gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &loggedInCtx) - c.Next() - }, - }) - - req, err = http.NewRequest("GET", "/api/auth/envoy?path=/", nil) - assert.NilError(t, err) - - req.Host = "whoami.example.com" - req.Header.Set("x-forwarded-proto", "http") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test ACL allow caddy/traefik (forward_auth) - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/traefik", nil) - assert.NilError(t, err) - - req.Header.Set("x-forwarded-host", "whoami.example.com") - req.Header.Set("x-forwarded-proto", "http") - req.Header.Set("x-forwarded-uri", "/allow") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test ACL allow nginx - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - assert.NilError(t, err) - - req.Header.Set("x-original-url", "http://whoami.example.com/allow") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test ACL allow envoy - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/envoy?path=/allow", nil) - assert.NilError(t, err) - - req.Host = "whoami.example.com" - req.Header.Set("x-forwarded-proto", "http") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test traefik/caddy (forward_auth) without required headers - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/traefik", nil) - assert.NilError(t, err) - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusBadRequest) - - // Test nginx (forward_auth) without required headers - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - assert.NilError(t, err) - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusBadRequest) - - // Test envoy (forward_auth) without required headers - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/envoy", nil) - assert.NilError(t, err) - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusBadRequest) - - // Test nginx (auth_request) with forward_auth fallback with ACLs - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - assert.NilError(t, err) - - req.Header.Set("x-forwarded-host", "whoami.example.com") - req.Header.Set("x-forwarded-proto", "http") - req.Header.Set("x-forwarded-uri", "/allow") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test envoy (ext_authz) with forward_auth fallback with ACLs - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/envoy", nil) - assert.NilError(t, err) - - req.Header.Set("x-forwarded-host", "whoami.example.com") - req.Header.Set("x-forwarded-proto", "http") - req.Header.Set("x-forwarded-uri", "/allow") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Test envoy (ext_authz) with empty path - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/envoy", nil) - assert.NilError(t, err) - - req.Host = "whoami.example.com" - req.Header.Set("x-forwarded-proto", "http") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - // Ensure forward_auth fallback works with path (should ignore) - router, recorder = setupProxyController(t, nil) - - req, err = http.NewRequest("GET", "/api/auth/traefik?path=/allow", nil) - assert.NilError(t, err) - - req.Header.Set("x-forwarded-proto", "http") - req.Header.Set("x-forwarded-host", "whoami.example.com") - req.Header.Set("x-forwarded-uri", "/allow") - - router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) + acls := map[string]config.App{ + "app_path_allow": { + Config: config.AppConfig{ + Domain: "path-allow.example.com", + }, + Path: config.AppPath{ + Allow: "/allowed", + }, + }, + "app_user_allow": { + Config: config.AppConfig{ + Domain: "user-allow.example.com", + }, + Users: config.AppUsers{ + Allow: "testuser", + }, + }, + "ip_bypass": { + Config: config.AppConfig{ + Domain: "ip-bypass.example.com", + }, + IP: config.AppIP{ + Bypass: []string{"10.10.10.10"}, + }, + }, + } + + const browserUserAgent = ` + Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` + + simpleCtx := func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "testuser", + Name: "Testuser", + Email: "testuser@example.com", + IsLoggedIn: true, + Provider: "local", + }) + c.Next() + } + + simpleCtxTotp := func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", + IsLoggedIn: true, + Provider: "local", + TotpEnabled: true, + }) + c.Next() + } + + type testCase struct { + description string + middlewares []gin.HandlerFunc + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + } + + tests := []testCase{ + { + description: "Default forward auth should be detected and used", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "test.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("user-agent", browserUserAgent) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 307, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") + assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2F") + }, + }, + { + description: "Auth request (nginx) should be detected and used", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/nginx", nil) + req.Header.Set("x-original-url", "https://test.example.com/") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + }, + }, + { + description: "Ext authz (envoy) should be detected and used", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) // test a different method for envoy + req.Host = "test.example.com" + req.Header.Set("x-forwarded-proto", "https") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + }, + }, + { + description: "Ensure forward auth fallback for nginx", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/nginx", nil) + req.Header.Set("x-forwarded-host", "test.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + }, + }, + { + description: "Ensure forward auth fallback for envoy", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) + req.Header.Set("x-forwarded-host", "test.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/hello") + router.ServeHTTP(recorder, req) + assert.Equal(t, 401, recorder.Code) + }, + }, + { + description: "Ensure normal authentication flow for forward auth", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "test.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) + assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) + assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) + }, + }, + { + description: "Ensure normal authentication flow for nginx auth request", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/nginx", nil) + req.Header.Set("x-original-url", "https://test.example.com/") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) + assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) + assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) + }, + }, + { + description: "Ensure normal authentication flow for envoy ext authz", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) + req.Host = "test.example.com" + req.Header.Set("x-forwarded-proto", "https") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "testuser", recorder.Header().Get("remote-user")) + assert.Equal(t, "Testuser", recorder.Header().Get("remote-name")) + assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email")) + }, + }, + { + description: "Ensure path allow ACL works on forward auth", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "path-allow.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/allowed") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure path allow ACL works on nginx auth request", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/nginx", nil) + req.Header.Set("x-original-url", "https://path-allow.example.com/allowed") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure path allow ACL works on envoy ext authz", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/allowed", nil) + req.Host = "path-allow.example.com" + req.Header.Set("x-forwarded-proto", "https") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure ip bypass ACL works on forward auth", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "ip-bypass.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("x-forwarded-for", "10.10.10.10") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure ip bypass ACL works on nginx auth request", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/nginx", nil) + req.Header.Set("x-original-url", "https://ip-bypass.example.com/") + req.Header.Set("x-forwarded-for", "10.10.10.10") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure ip bypass ACL works on envoy ext authz", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) + req.Host = "ip-bypass.example.com" + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-for", "10.10.10.10") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure user allow ACL allows correct user (should allow testuser)", + middlewares: []gin.HandlerFunc{ + simpleCtx, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "user-allow.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + }, + }, + { + description: "Ensure user allow ACL blocks incorrect user (should block totpuser)", + middlewares: []gin.HandlerFunc{ + simpleCtxTotp, + }, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req.Header.Set("x-forwarded-host", "user-allow.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/") + router.ServeHTTP(recorder, req) + assert.Equal(t, 403, recorder.Code) + assert.Equal(t, "", recorder.Header().Get("remote-user")) + assert.Equal(t, "", recorder.Header().Get("remote-name")) + assert.Equal(t, "", recorder.Header().Get("remote-email")) + }, + }, + } + + tlog.NewSimpleLogger().Init() + + oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + + app := bootstrap.NewBootstrapApp(config.Config{}) + + db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + require.NoError(t, err) + + queries := repository.New(db) + + docker := service.NewDockerService() + err = docker.Init() + require.NoError(t, err) + + ldap := service.NewLdapService(service.LdapServiceConfig{}) + err = ldap.Init() + require.NoError(t, err) + + broker := service.NewOAuthBrokerService(oauthBrokerCfgs) + err = broker.Init() + require.NoError(t, err) + + authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker) + err = authService.Init() + require.NoError(t, err) + + aclsService := service.NewAccessControlsService(docker, acls) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + + for _, m := range test.middlewares { + router.Use(m) + } + + group := router.Group("/api") + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + + proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService) + proxyController.SetupRoutes() + + test.run(t, router, recorder) + }) + } + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) + }) } diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index 5f38528..2376aa9 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -3,57 +3,81 @@ package controller_test import ( "net/http/httptest" "os" + "path" "testing" - "github.com/steveiliop56/tinyauth/internal/controller" - "github.com/gin-gonic/gin" - "gotest.tools/v3/assert" + "github.com/steveiliop56/tinyauth/internal/controller" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestResourcesHandler(t *testing.T) { - // Setup - gin.SetMode(gin.TestMode) - router := gin.New() - group := router.Group("/") +func TestResourcesController(t *testing.T) { + tempDir := t.TempDir() - ctrl := controller.NewResourcesController(controller.ResourcesControllerConfig{ - Path: "/tmp/tinyauth", + resourcesControllerCfg := controller.ResourcesControllerConfig{ + Path: path.Join(tempDir, "resources"), Enabled: true, - }, group) - ctrl.SetupRoutes() + } - // Create test data - err := os.Mkdir("/tmp/tinyauth", 0755) - assert.NilError(t, err) - defer os.RemoveAll("/tmp/tinyauth") + err := os.Mkdir(resourcesControllerCfg.Path, 0777) + require.NoError(t, err) - file, err := os.Create("/tmp/tinyauth/test.txt") - assert.NilError(t, err) + type testCase struct { + description string + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + } - _, err = file.WriteString("This is a test file.") - assert.NilError(t, err) - file.Close() + tests := []testCase{ + { + description: "Ensure resources endpoint returns 200 OK for existing file", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/resources/testfile.txt", nil) + router.ServeHTTP(recorder, req) - // Test existing file - req := httptest.NewRequest("GET", "/resources/test.txt", nil) - recorder := httptest.NewRecorder() - router.ServeHTTP(recorder, req) + assert.Equal(t, 200, recorder.Code) + assert.Equal(t, "This is a test file.", recorder.Body.String()) + }, + }, + { + description: "Ensure resources endpoint returns 404 Not Found for non-existing file", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/resources/nonexistent.txt", nil) + router.ServeHTTP(recorder, req) - assert.Equal(t, 200, recorder.Code) - assert.Equal(t, "This is a test file.", recorder.Body.String()) + assert.Equal(t, 404, recorder.Code) + }, + }, + { + description: "Ensure resources controller denies path traversal", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/resources/../somefile.txt", nil) + router.ServeHTTP(recorder, req) - // Test non-existing file - req = httptest.NewRequest("GET", "/resources/nonexistent.txt", nil) - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) + assert.Equal(t, 404, recorder.Code) + }, + }, + } - assert.Equal(t, 404, recorder.Code) + testFilePath := resourcesControllerCfg.Path + "/testfile.txt" + err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777) + require.NoError(t, err) - // Test directory traversal attack - req = httptest.NewRequest("GET", "/resources/../etc/passwd", nil) - recorder = httptest.NewRecorder() - router.ServeHTTP(recorder, req) + testFilePathParent := tempDir + "/somefile.txt" + err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777) + require.NoError(t, err) - assert.Equal(t, 404, recorder.Code) + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + group := router.Group("/") + gin.SetMode(gin.TestMode) + + resourcesController := controller.NewResourcesController(resourcesControllerCfg, group) + resourcesController.SetupRoutes() + + recorder := httptest.NewRecorder() + test.run(t, router, recorder) + }) + } } diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 672740c..b4cff86 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -2,305 +2,355 @@ package controller_test import ( "encoding/json" - "net/http" "net/http/httptest" + "path" + "slices" "strings" "testing" "time" + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" "github.com/steveiliop56/tinyauth/internal/bootstrap" "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/utils/tlog" - - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "gotest.tools/v3/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var cookieValue string -var totpSecret = "6WFZXPEZRK5MZHHYAFW4DAOUYQMCASBJ" +func TestUserController(t *testing.T) { + tempDir := t.TempDir() -func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) { - tlog.NewSimpleLogger().Init() - - // Setup - gin.SetMode(gin.TestMode) - router := gin.Default() - - if middlewares != nil { - for _, m := range *middlewares { - router.Use(m) - } - } - - group := router.Group("/api") - recorder := httptest.NewRecorder() - - // Mock app - app := bootstrap.NewBootstrapApp(config.Config{}) - - // Database - db, err := app.SetupDatabase(":memory:") - - assert.NilError(t, err) - - // Queries - queries := repository.New(db) - - // Auth service - authService := service.NewAuthService(service.AuthServiceConfig{ + authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ { Username: "testuser", - Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password }, { Username: "totpuser", - Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test - TotpSecret: totpSecret, + Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password + TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", }, }, - OauthWhitelist: []string{}, - SessionExpiry: 3600, - SessionMaxLifetime: 0, - SecureCookie: false, - CookieDomain: "localhost", - LoginTimeout: 300, - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - }, nil, nil, queries, &service.OAuthBrokerService{}) - - // Controller - ctrl := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: "localhost", - }, group, authService) - ctrl.SetupRoutes() - - return router, recorder -} - -func TestLoginHandler(t *testing.T) { - // Setup - router, recorder := setupUserController(t, nil) - - loginReq := controller.LoginRequest{ - Username: "testuser", - Password: "test", + SessionExpiry: 10, // 10 seconds, useful for testing + CookieDomain: "example.com", + LoginTimeout: 10, // 10 seconds, useful for testing + LoginMaxRetries: 3, + SessionCookieName: "tinyauth-session", } - loginReqJson, err := json.Marshal(loginReq) - assert.NilError(t, err) - - // Test - req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - cookie := recorder.Result().Cookies()[0] - - assert.Equal(t, "tinyauth-session", cookie.Name) - assert.Assert(t, cookie.Value != "") - - cookieValue = cookie.Value - - // Test invalid credentials - loginReq = controller.LoginRequest{ - Username: "testuser", - Password: "invalid", + userControllerCfg := controller.UserControllerConfig{ + CookieDomain: "example.com", } - loginReqJson, err = json.Marshal(loginReq) - assert.NilError(t, err) - - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 401, recorder.Code) - - // Test totp required - loginReq = controller.LoginRequest{ - Username: "totpuser", - Password: "test", + type testCase struct { + description string + middlewares []gin.HandlerFunc + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } - loginReqJson, err = json.Marshal(loginReq) - assert.NilError(t, err) + tests := []testCase{ + { + description: "Should be able to login with valid credentials", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - assert.Equal(t, 200, recorder.Code) + router.ServeHTTP(recorder, req) - loginResJson, err := json.Marshal(map[string]any{ - "message": "TOTP required", - "status": 200, - "totpPending": true, - }) + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) - assert.NilError(t, err) - assert.Equal(t, string(loginResJson), recorder.Body.String()) - - // Test invalid json - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader("{invalid json}")) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 400, recorder.Code) - - // Test rate limiting - loginReq = controller.LoginRequest{ - Username: "testuser", - Password: "invalid", - } - - loginReqJson, err = json.Marshal(loginReq) - assert.NilError(t, err) - - for range 5 { - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson))) - router.ServeHTTP(recorder, req) - } - - assert.Equal(t, 429, recorder.Code) -} - -func TestLogoutHandler(t *testing.T) { - // Setup - router, recorder := setupUserController(t, nil) - - // Test - req := httptest.NewRequest("POST", "/api/user/logout", nil) - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookieValue, - }) - - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - cookie := recorder.Result().Cookies()[0] - - assert.Equal(t, "tinyauth-session", cookie.Name) - assert.Equal(t, "", cookie.Value) - assert.Equal(t, -1, cookie.MaxAge) -} - -func TestTotpHandler(t *testing.T) { - // Setup - router, recorder := setupUserController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "totpuser", - Email: "totpuser@example.com", - IsLoggedIn: false, - OAuth: false, - Provider: "local", - TotpPending: true, - OAuthGroups: "", - TotpEnabled: true, - }) - c.Next() + cookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, "example.com", cookie.Domain) + assert.Equal(t, 10, cookie.MaxAge) + }, }, - }) + { + description: "Should reject login with invalid credentials", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "wrongpassword", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - // Test - code, err := totp.GenerateCode(totpSecret, time.Now()) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - assert.NilError(t, err) + router.ServeHTTP(recorder, req) - totpReq := controller.TotpRequest{ - Code: code, - } - - totpReqJson, err := json.Marshal(totpReq) - assert.NilError(t, err) - - req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 200, recorder.Code) - - cookie := recorder.Result().Cookies()[0] - - assert.Equal(t, "tinyauth-session", cookie.Name) - assert.Assert(t, cookie.Value != "") - - // Test invalid json - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader("{invalid json}")) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 400, recorder.Code) - - // Test rate limiting - totpReq = controller.TotpRequest{ - Code: "000000", - } - - totpReqJson, err = json.Marshal(totpReq) - assert.NilError(t, err) - - for range 5 { - recorder = httptest.NewRecorder() - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) - } - - assert.Equal(t, 429, recorder.Code) - - // Test invalid code - router, recorder = setupUserController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "totpuser", - Email: "totpuser@example.com", - IsLoggedIn: false, - OAuth: false, - Provider: "local", - TotpPending: true, - OAuthGroups: "", - TotpEnabled: true, - }) - c.Next() + assert.Equal(t, 401, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 0) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + }, }, - }) + { + description: "Should rate limit on 3 invalid attempts", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "wrongpassword", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) + for range 3 { + recorder := httptest.NewRecorder() - assert.Equal(t, 401, recorder.Code) + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") - // Test no totp pending - router, recorder = setupUserController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "totpuser", - Name: "totpuser", - Email: "totpuser@example.com", - IsLoggedIn: false, - OAuth: false, - Provider: "local", - TotpPending: false, - OAuthGroups: "", - TotpEnabled: false, - }) - c.Next() + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 0) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + } + + // 4th attempt should be rate limited + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 429, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Too many failed login attempts.") + }, }, + { + description: "Should not allow full login with totp", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + loginReq := controller.LoginRequest{ + Username: "totpuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + decodedBody := make(map[string]any) + err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + assert.NoError(t, err) + + assert.Equal(t, decodedBody["totpPending"], true) + + // should set the session cookie + assert.Len(t, recorder.Result().Cookies(), 1) + cookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + assert.True(t, cookie.HttpOnly) + assert.Equal(t, "example.com", cookie.Domain) + assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions + }, + }, + { + description: "Should be able to logout", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + // First login to get a session cookie + loginReq := controller.LoginRequest{ + Username: "testuser", + Password: "password", + } + loginReqBody, err := json.Marshal(loginReq) + assert.NoError(t, err) + + req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) + + cookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", cookie.Name) + + // Now logout using the session cookie + recorder = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/user/logout", nil) + req.AddCookie(cookie) + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) + + logoutCookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", logoutCookie.Name) + assert.Equal(t, "", logoutCookie.Value) + assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie + }, + }, + { + description: "Should be able to login with totp", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) + assert.NoError(t, err) + + totpReq := controller.TotpRequest{ + Code: code, + } + + totpReqBody, err := json.Marshal(totpReq) + assert.NoError(t, err) + + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + assert.Len(t, recorder.Result().Cookies(), 1) + + // should set a new session cookie with totp pending removed + totpCookie := recorder.Result().Cookies()[0] + assert.Equal(t, "tinyauth-session", totpCookie.Name) + assert.True(t, totpCookie.HttpOnly) + assert.Equal(t, "example.com", totpCookie.Domain) + assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time + }, + }, + { + description: "Totp should rate limit on multiple invalid attempts", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + for range 3 { + totpReq := controller.TotpRequest{ + Code: "000000", // invalid code + } + + totpReqBody, err := json.Marshal(totpReq) + assert.NoError(t, err) + + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Unauthorized") + } + + // 4th attempt should be rate limited + recorder = httptest.NewRecorder() + req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(`{"code":"000000"}`))) + req.Header.Set("Content-Type", "application/json") + + router.ServeHTTP(recorder, req) + + assert.Equal(t, 429, recorder.Code) + assert.Contains(t, recorder.Body.String(), "Too many failed TOTP attempts.") + }, + }, + } + + tlog.NewSimpleLogger().Init() + + oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) + + app := bootstrap.NewBootstrapApp(config.Config{}) + + db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + require.NoError(t, err) + + queries := repository.New(db) + + docker := service.NewDockerService() + err = docker.Init() + require.NoError(t, err) + + ldap := service.NewLdapService(service.LdapServiceConfig{}) + err = ldap.Init() + require.NoError(t, err) + + broker := service.NewOAuthBrokerService(oauthBrokerCfgs) + err = broker.Init() + require.NoError(t, err) + + authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker) + err = authService.Init() + require.NoError(t, err) + + beforeEach := func() { + // Clear failed login attempts before each test + authService.ClearRateLimitsTestingOnly() + } + + setTotpMiddlewareOverrides := []string{ + "Should be able to login with totp", + "Totp should rate limit on multiple invalid attempts", + } + + for _, test := range tests { + beforeEach() + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + + for _, middleware := range test.middlewares { + router.Use(middleware) + } + + // Gin is stupid and doesn't allow setting a middleware after the groups + // so we need to do some stupid overrides here + if slices.Contains(setTotpMiddlewareOverrides, test.description) { + // Assuming the cookie is set, it should be picked up by the + // context middleware + router.Use(func(c *gin.Context) { + c.Set("context", &config.UserContext{ + Username: "totpuser", + Name: "Totpuser", + Email: "totpuser@example.com", + Provider: "local", + TotpPending: true, + TotpEnabled: true, + }) + }) + } + + group := router.Group("/api") + gin.SetMode(gin.TestMode) + + userController := controller.NewUserController(userControllerCfg, group, authService) + userController.SetupRoutes() + + recorder := httptest.NewRecorder() + + test.run(t, router, recorder) + }) + } + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) }) - - req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson))) - router.ServeHTTP(recorder, req) - - assert.Equal(t, 401, recorder.Code) } diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go new file mode 100644 index 0000000..20dc2a1 --- /dev/null +++ b/internal/controller/well_known_controller_test.go @@ -0,0 +1,129 @@ +package controller_test + +import ( + "encoding/json" + "fmt" + "net/http/httptest" + "path" + "testing" + + "github.com/gin-gonic/gin" + "github.com/steveiliop56/tinyauth/internal/bootstrap" + "github.com/steveiliop56/tinyauth/internal/config" + "github.com/steveiliop56/tinyauth/internal/controller" + "github.com/steveiliop56/tinyauth/internal/repository" + "github.com/steveiliop56/tinyauth/internal/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWellKnownController(t *testing.T) { + tempDir := t.TempDir() + + oidcServiceCfg := service.OIDCServiceConfig{ + Clients: map[string]config.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: path.Join(tempDir, "key.pem"), + PublicKeyPath: path.Join(tempDir, "key.pub"), + Issuer: "https://tinyauth.example.com", + SessionExpiry: 500, + } + + type testCase struct { + description string + run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) + } + + tests := []testCase{ + { + description: "Ensure well-known endpoint returns correct OIDC configuration", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + res := controller.OpenIDConnectConfiguration{} + err := json.Unmarshal(recorder.Body.Bytes(), &res) + assert.NoError(t, err) + + expected := controller.OpenIDConnectConfiguration{ + Issuer: oidcServiceCfg.Issuer, + AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer), + TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer), + UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer), + JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer), + ScopesSupported: service.SupportedScopes, + ResponseTypesSupported: service.SupportedResponseTypes, + GrantTypesSupported: service.SupportedGrantTypes, + SubjectTypesSupported: []string{"pairwise"}, + IDTokenSigningAlgValuesSupported: []string{"RS256"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, + ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups"}, + ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc", + } + + assert.Equal(t, expected, res) + }, + }, + { + description: "Ensure well-known endpoint returns correct JWKS", + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 200, recorder.Code) + + decodedBody := make(map[string]any) + err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody) + assert.NoError(t, err) + + keys, ok := decodedBody["keys"].([]any) + assert.True(t, ok) + assert.Len(t, keys, 1) + + keyData, ok := keys[0].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "RSA", keyData["kty"]) + assert.Equal(t, "sig", keyData["use"]) + assert.Equal(t, "RS256", keyData["alg"]) + }, + }, + } + + app := bootstrap.NewBootstrapApp(config.Config{}) + + db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + require.NoError(t, err) + + queries := repository.New(db) + + oidcService := service.NewOIDCService(oidcServiceCfg, queries) + err = oidcService.Init() + require.NoError(t, err) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + router := gin.Default() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + + wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router) + wellKnownController.SetupRoutes() + + test.run(t, router, recorder) + }) + } + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) + }) +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index e81e7c5..6540fe8 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -1,6 +1,7 @@ package service import ( + "context" "database/sql" "errors" "fmt" @@ -78,6 +79,8 @@ type AuthService struct { queries *repository.Queries oauthBroker *OAuthBrokerService lockdown *Lockdown + lockdownCtx context.Context + lockdownCancelFunc context.CancelFunc } func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { @@ -770,6 +773,11 @@ func (auth *AuthService) ensureOAuthSessionLimit() { } func (auth *AuthService) lockdownMode() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + auth.lockdownCtx = ctx + auth.lockdownCancelFunc = cancel + auth.loginMutex.Lock() tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") @@ -788,7 +796,12 @@ func (auth *AuthService) lockdownMode() { auth.loginMutex.Unlock() - <-timer.C + select { + case <-timer.C: + // Timer expired, end lockdown + case <-ctx.Done(): + // Context cancelled, end lockdown + } auth.loginMutex.Lock() @@ -796,3 +809,13 @@ func (auth *AuthService) lockdownMode() { auth.lockdown = nil auth.loginMutex.Unlock() } + +// Function only used for testing - do not use in prod! +func (auth *AuthService) ClearRateLimitsTestingOnly() { + auth.loginMutex.Lock() + auth.loginAttempts = make(map[string]*LoginAttempt) + if auth.lockdown != nil { + auth.lockdownCancelFunc() + } + auth.loginMutex.Unlock() +} diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index 2af078e..5ca545d 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -25,12 +25,6 @@ func TestGetRootDomain(t *testing.T) { assert.NilError(t, err) assert.Equal(t, expected, result) - // Domain with no subdomain - domain = "http://tinyauth.app" - expected = "tinyauth.app" - _, err = utils.GetCookieDomain(domain) - assert.Error(t, err, "invalid app url, must be at least second level domain") - // Invalid domain (only TLD) domain = "com" _, err = utils.GetCookieDomain(domain)