mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-14 10:42:03 +00:00
fix: validate client id on oidc token endpoint
This commit is contained in:
@@ -270,7 +270,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
switch req.GrantType {
|
switch req.GrantType {
|
||||||
case "authorization_code":
|
case "authorization_code":
|
||||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
tlog.App.Warn().Msg("Code not found")
|
tlog.App.Warn().Msg("Code not found")
|
||||||
@@ -286,6 +286,13 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
|
tlog.App.Warn().Msg("Invalid client ID")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_client",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
|
|||||||
@@ -185,11 +185,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
||||||
|
|
||||||
if userContext.IsBasicAuth && userContext.TotpEnabled {
|
|
||||||
tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access")
|
|
||||||
userContext.IsLoggedIn = false
|
|
||||||
}
|
|
||||||
|
|
||||||
if userContext.IsLoggedIn {
|
if userContext.IsLoggedIn {
|
||||||
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
|
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,11 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
|
|||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.",
|
||||||
|
TotpSecret: "foo",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
OauthWhitelist: []string{},
|
OauthWhitelist: []string{},
|
||||||
SessionExpiry: 3600,
|
SessionExpiry: 3600,
|
||||||
@@ -79,9 +84,11 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
|
|||||||
return router, recorder, authService
|
return router, recorder, authService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Needs tests for context middleware
|
||||||
|
|
||||||
func TestProxyHandler(t *testing.T) {
|
func TestProxyHandler(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
router, recorder, authService := setupProxyController(t, nil)
|
router, recorder, _ := setupProxyController(t, nil)
|
||||||
|
|
||||||
// Test invalid proxy
|
// Test invalid proxy
|
||||||
req := httptest.NewRequest("GET", "/api/auth/invalidproxy", nil)
|
req := httptest.NewRequest("GET", "/api/auth/invalidproxy", nil)
|
||||||
@@ -144,21 +151,6 @@ func TestProxyHandler(t *testing.T) {
|
|||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
|
||||||
// Test logged in user
|
// Test logged in user
|
||||||
c := gin.CreateTestContextOnly(recorder, router)
|
|
||||||
|
|
||||||
err := authService.CreateSessionCookie(c, &repository.Session{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: false,
|
|
||||||
OAuthGroups: "",
|
|
||||||
})
|
|
||||||
|
|
||||||
assert.NilError(t, err)
|
|
||||||
|
|
||||||
cookie := c.Writer.Header().Get("Set-Cookie")
|
|
||||||
|
|
||||||
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
|
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
|
||||||
func(c *gin.Context) {
|
func(c *gin.Context) {
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
@@ -177,44 +169,15 @@ func TestProxyHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
req.Header.Set("Cookie", cookie)
|
|
||||||
req.Header.Set("X-Forwarded-Proto", "https")
|
req.Header.Set("X-Forwarded-Proto", "https")
|
||||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||||
req.Header.Set("X-Forwarded-Uri", "/somepath")
|
req.Header.Set("X-Forwarded-Uri", "/somepath")
|
||||||
req.Header.Set("Accept", "text/html")
|
req.Header.Set("Accept", "text/html")
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("Remote-User"))
|
assert.Equal(t, "testuser", recorder.Header().Get("Remote-User"))
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("Remote-Name"))
|
assert.Equal(t, "testuser", recorder.Header().Get("Remote-Name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("Remote-Email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("Remote-Email"))
|
||||||
|
|
||||||
// Ensure basic auth is disabled for TOTP enabled users
|
|
||||||
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
|
|
||||||
func(c *gin.Context) {
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
IsLoggedIn: true,
|
|
||||||
IsBasicAuth: true,
|
|
||||||
OAuth: false,
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: false,
|
|
||||||
OAuthGroups: "",
|
|
||||||
TotpEnabled: true,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("X-Forwarded-Proto", "https")
|
|
||||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
|
||||||
req.Header.Set("X-Forwarded-Uri", "/somepath")
|
|
||||||
req.SetBasicAuth("testuser", "test")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -182,13 +182,17 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
user := m.auth.GetLocalUser(basic.Username)
|
user := m.auth.GetLocalUser(basic.Username)
|
||||||
|
|
||||||
|
if user.TotpSecret != "" {
|
||||||
|
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
TotpEnabled: user.TotpSecret != "",
|
|
||||||
IsBasicAuth: true,
|
IsBasicAuth: true,
|
||||||
})
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -352,7 +352,7 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
|
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
|
||||||
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -374,6 +374,10 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repos
|
|||||||
return repository.OidcCode{}, ErrCodeExpired
|
return repository.OidcCode{}, ErrCodeExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if oidcCode.ClientID != clientId {
|
||||||
|
return repository.OidcCode{}, ErrInvalidClient
|
||||||
|
}
|
||||||
|
|
||||||
return oidcCode, nil
|
return oidcCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user