diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index a471216..404fc51 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -323,12 +323,12 @@ func (controller *ProxyController) getHeader(c *gin.Context, header string) (str } func (controller *ProxyController) useBrowserResponse(proxyCtx ProxyContext) bool { - // If it's nginx or envoy we need non-browser response - if proxyCtx.ProxyType == Nginx || proxyCtx.ProxyType == Envoy { + // If it's nginx we need non-browser response + if proxyCtx.ProxyType == Nginx { return false } - // For other proxies (traefik or caddy) we can check + // For other proxies (traefik/caddy/envoy) we can check // the user agent to determine if it's a browser or not if proxyCtx.IsBrowser { return true diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 89e94de..5434aae 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -104,7 +104,7 @@ func TestProxyController(t *testing.T) { tests := []testCase{ { - description: "Default forward auth should be detected and used", + description: "Default forward auth should be detected and used for traefik", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/auth/traefik", nil) @@ -126,6 +126,7 @@ func TestProxyController(t *testing.T) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/auth/nginx", nil) req.Header.Set("x-original-url", "https://test.example.com/") + req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) assert.Equal(t, 401, recorder.Code) }, @@ -137,8 +138,12 @@ func TestProxyController(t *testing.T) { req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) // test a different method for envoy req.Host = "test.example.com" req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, 307, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") + assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2Fhello") }, }, { @@ -149,6 +154,7 @@ func TestProxyController(t *testing.T) { req.Header.Set("x-forwarded-host", "test.example.com") req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-uri", "/") + req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) assert.Equal(t, 401, recorder.Code) }, @@ -158,41 +164,20 @@ func TestProxyController(t *testing.T) { middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) - req.Header.Set("x-forwarded-host", "test.example.com") - req.Header.Set("x-forwarded-proto", "https") - req.Header.Set("x-forwarded-uri", "/hello") - router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - }, - }, - { - description: "Ensure forward auth fallback for nginx with browser user agent", - middlewares: []gin.HandlerFunc{}, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/auth/nginx", nil) - req.Header.Set("x-forwarded-host", "test.example.com") - req.Header.Set("x-forwarded-proto", "https") - req.Header.Set("x-forwarded-uri", "/") - req.Header.Set("user-agent", browserUserAgent) - router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) - }, - }, - { - description: "Ensure forward auth fallback for envoy with browser user agent", - middlewares: []gin.HandlerFunc{}, - run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) + req.Host = "" req.Header.Set("x-forwarded-host", "test.example.com") req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-uri", "/hello") req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) - assert.Equal(t, 401, recorder.Code) + assert.Equal(t, 307, recorder.Code) + location := recorder.Header().Get("Location") + assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") + assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2Fhello") }, }, { - description: "Ensure forward auth with is browser false returns json", + description: "Ensure forward auth with non browser returns json", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { req := httptest.NewRequest("GET", "/api/auth/traefik", nil) @@ -210,7 +195,7 @@ func TestProxyController(t *testing.T) { description: "Ensure forward auth with caddy and browser user agent returns redirect", middlewares: []gin.HandlerFunc{}, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - req := httptest.NewRequest("GET", "/api/auth/traefik", nil) + req := httptest.NewRequest("GET", "/api/auth/caddy", nil) req.Header.Set("x-forwarded-host", "test.example.com") req.Header.Set("x-forwarded-proto", "https") req.Header.Set("x-forwarded-uri", "/") @@ -238,6 +223,21 @@ func TestProxyController(t *testing.T) { assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) }, }, + { + description: "Ensure envoy non browser returns json", + middlewares: []gin.HandlerFunc{}, + run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { + req := httptest.NewRequest("HEAD", "/api/auth/envoy?path=/hello", nil) + req.Header.Set("x-forwarded-host", "test.example.com") + req.Header.Set("x-forwarded-proto", "https") + req.Header.Set("x-forwarded-uri", "/hello") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 401, recorder.Code) + assert.Contains(t, recorder.Body.String(), `"status":401`) + assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`) + }, + }, { description: "Ensure normal authentication flow for forward auth", middlewares: []gin.HandlerFunc{