From b2a1bfb1f532e87f205fa3afa3fc9f148c53ab89 Mon Sep 17 00:00:00 2001 From: Stavros Date: Wed, 11 Mar 2026 16:48:04 +0200 Subject: [PATCH] fix: validate client id on oidc token endpoint --- internal/controller/oidc_controller.go | 9 +++- internal/controller/proxy_controller.go | 5 -- internal/controller/proxy_controller_test.go | 55 ++++---------------- internal/middleware/context_middleware.go | 6 ++- internal/service/oidc_service.go | 6 ++- 5 files changed, 27 insertions(+), 54 deletions(-) diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 08205b1..160ca2d 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -270,7 +270,7 @@ func (controller *OIDCController) Token(c *gin.Context) { switch req.GrantType { case "authorization_code": - entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code)) + entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) if err != nil { if errors.Is(err, service.ErrCodeNotFound) { tlog.App.Warn().Msg("Code not found") @@ -286,6 +286,13 @@ func (controller *OIDCController) Token(c *gin.Context) { }) return } + if errors.Is(err, service.ErrInvalidClient) { + tlog.App.Warn().Msg("Invalid client ID") + c.JSON(400, gin.H{ + "error": "invalid_client", + }) + return + } tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") c.JSON(400, gin.H{ "error": "server_error", diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 9b01b7d..3993473 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -185,11 +185,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { tlog.App.Trace().Interface("context", userContext).Msg("User context from request") - if userContext.IsBasicAuth && userContext.TotpEnabled { - tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access") - userContext.IsLoggedIn = false - } - if userContext.IsLoggedIn { userAllowed := controller.auth.IsUserAllowed(c, userContext, acls) diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index f23fcec..e2c020f 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -59,6 +59,11 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En Username: "testuser", Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test }, + { + Username: "totpuser", + Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", + TotpSecret: "foo", + }, }, OauthWhitelist: []string{}, SessionExpiry: 3600, @@ -79,9 +84,11 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En return router, recorder, authService } +// TODO: Needs tests for context middleware + func TestProxyHandler(t *testing.T) { // Setup - router, recorder, authService := setupProxyController(t, nil) + router, recorder, _ := setupProxyController(t, nil) // Test invalid proxy req := httptest.NewRequest("GET", "/api/auth/invalidproxy", nil) @@ -144,21 +151,6 @@ func TestProxyHandler(t *testing.T) { assert.Equal(t, 401, recorder.Code) // Test logged in user - c := gin.CreateTestContextOnly(recorder, router) - - err := authService.CreateSessionCookie(c, &repository.Session{ - Username: "testuser", - Name: "testuser", - Email: "testuser@example.com", - Provider: "local", - TotpPending: false, - OAuthGroups: "", - }) - - assert.NilError(t, err) - - cookie := c.Writer.Header().Get("Set-Cookie") - router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{ func(c *gin.Context) { c.Set("context", &config.UserContext{ @@ -177,44 +169,15 @@ func TestProxyHandler(t *testing.T) { }) req = httptest.NewRequest("GET", "/api/auth/traefik", nil) - req.Header.Set("Cookie", cookie) req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Host", "example.com") req.Header.Set("X-Forwarded-Uri", "/somepath") req.Header.Set("Accept", "text/html") - router.ServeHTTP(recorder, req) + router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) assert.Equal(t, "testuser", recorder.Header().Get("Remote-User")) assert.Equal(t, "testuser", recorder.Header().Get("Remote-Name")) assert.Equal(t, "testuser@example.com", recorder.Header().Get("Remote-Email")) - - // Ensure basic auth is disabled for TOTP enabled users - router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{ - func(c *gin.Context) { - c.Set("context", &config.UserContext{ - Username: "testuser", - Name: "testuser", - Email: "testuser@example.com", - IsLoggedIn: true, - IsBasicAuth: true, - OAuth: false, - Provider: "local", - TotpPending: false, - OAuthGroups: "", - TotpEnabled: true, - }) - c.Next() - }, - }) - - req = httptest.NewRequest("GET", "/api/auth/traefik", nil) - req.Header.Set("X-Forwarded-Proto", "https") - req.Header.Set("X-Forwarded-Host", "example.com") - req.Header.Set("X-Forwarded-Uri", "/somepath") - req.SetBasicAuth("testuser", "test") - router.ServeHTTP(recorder, req) - - assert.Equal(t, 401, recorder.Code) } diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 2067d82..f317b15 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -182,13 +182,17 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { user := m.auth.GetLocalUser(basic.Username) + if user.TotpSecret != "" { + tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth") + return + } + c.Set("context", &config.UserContext{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), Provider: "local", IsLoggedIn: true, - TotpEnabled: user.TotpSecret != "", IsBasicAuth: true, }) c.Next() diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 22050f3..f732d4d 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -352,7 +352,7 @@ func (service *OIDCService) ValidateGrantType(grantType string) error { return nil } -func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) { +func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) { oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { @@ -374,6 +374,10 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repos return repository.OidcCode{}, ErrCodeExpired } + if oidcCode.ClientID != clientId { + return repository.OidcCode{}, ErrInvalidClient + } + return oidcCode, nil }