diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 2b6738a..eed127e 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "net/http" + "slices" "strings" "tinyauth/internal/config" "tinyauth/internal/service" @@ -13,6 +14,8 @@ import ( "github.com/rs/zerolog/log" ) +var SupportedProxies = []string{"nginx", "traefik", "caddy", "envoy"} + type Proxy struct { Proxy string `uri:"proxy" binding:"required"` } @@ -39,7 +42,7 @@ func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, a func (controller *ProxyController) SetupRoutes() { proxyGroup := controller.router.Group("/auth") - proxyGroup.GET("/:proxy", controller.proxyHandler) + proxyGroup.Any("/:proxy", controller.proxyHandler) } func (controller *ProxyController) proxyHandler(c *gin.Context) { @@ -55,7 +58,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if req.Proxy != "nginx" && req.Proxy != "traefik" && req.Proxy != "caddy" { + if !slices.Contains(SupportedProxies, req.Proxy) { log.Warn().Str("proxy", req.Proxy).Msg("Invalid proxy") c.JSON(400, gin.H{ "status": 400, @@ -64,6 +67,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } + if req.Proxy != "envoy" && c.Request.Method != http.MethodGet { + log.Warn().Str("method", c.Request.Method).Msg("Invalid method for proxy") + c.JSON(405, gin.H{ + "status": 405, + "message": "Method Not Allowed", + }) + return + } + isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") if isBrowser { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index e7e27cf..452155f 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -80,6 +80,13 @@ func TestProxyHandler(t *testing.T) { assert.Equal(t, 400, recorder.Code) + // Test invalid method + recorder = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/auth/traefik", nil) + router.ServeHTTP(recorder, req) + + assert.Equal(t, 405, recorder.Code) + // Test logged out user (traefik/caddy) recorder = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/auth/traefik", nil) @@ -92,6 +99,18 @@ func TestProxyHandler(t *testing.T) { assert.Equal(t, 307, recorder.Code) assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location")) + // Test logged out user (envoy) + recorder = httptest.NewRecorder() + req = httptest.NewRequest("POST", "/api/auth/envoy", nil) + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "example.com") + req.Header.Set("X-Forwarded-Uri", "/somepath") + req.Header.Set("Accept", "text/html") + router.ServeHTTP(recorder, req) + + assert.Equal(t, 307, recorder.Code) + assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location")) + // Test logged out user (nginx) recorder = httptest.NewRecorder() req = httptest.NewRequest("GET", "/api/auth/nginx", nil)