diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 404fc51..9790daa 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -131,14 +131,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if !controller.auth.CheckIP(acls.IP, clientIP) { - if !controller.useBrowserResponse(proxyCtx) { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], IP: clientIP, @@ -146,11 +138,22 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.handleError(c, proxyCtx) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + + if !controller.useBrowserResponse(proxyCtx) { + c.Header("x-tinyauth-location", redirectURL) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, redirectURL) return } @@ -175,21 +178,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if !userAllowed { tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") - if !controller.useBrowserResponse(proxyCtx) { - c.JSON(403, gin.H{ - "status": 403, - "message": "Forbidden", - }) - return - } - queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], }) if err != nil { tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.handleError(c, proxyCtx) return } @@ -199,7 +194,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { queries.Set("username", userContext.Username) } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + + if !controller.useBrowserResponse(proxyCtx) { + c.Header("x-tinyauth-location", redirectURL) + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, redirectURL) return } @@ -215,14 +221,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if !groupOK { tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") - if !controller.useBrowserResponse(proxyCtx) { - c.JSON(403, gin.H{ - "status": 403, - "message": "Forbidden", - }) - return - } - queries, err := query.Values(config.UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], GroupErr: true, @@ -230,7 +228,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.handleError(c, proxyCtx) return } @@ -240,7 +238,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { queries.Set("username", userContext.Username) } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + + if !controller.useBrowserResponse(proxyCtx) { + c.Header("x-tinyauth-location", redirectURL) + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, redirectURL) return } } @@ -266,7 +275,20 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } + queries, err := query.Values(config.RedirectQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), + }) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") + controller.handleError(c, proxyCtx) + return + } + + redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()) + if !controller.useBrowserResponse(proxyCtx) { + c.Header("x-tinyauth-location", redirectURL) c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -274,17 +296,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - queries, err := query.Values(config.RedirectQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), - }) - - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) } func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) { @@ -306,7 +318,10 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) { } func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { + redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL) + if !controller.useBrowserResponse(proxyCtx) { + c.Header("x-tinyauth-location", redirectURL) c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -314,7 +329,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) } func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 35d5d6a..f485560 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -116,8 +116,7 @@ func TestProxyController(t *testing.T) { assert.Equal(t, 307, recorder.Code) location := recorder.Header().Get("Location") - assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") - assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2F") + assert.Equal(t, location, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F") }, }, { @@ -129,6 +128,8 @@ func TestProxyController(t *testing.T) { req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) assert.Equal(t, 401, recorder.Code) + location := recorder.Header().Get("x-tinyauth-location") + assert.Equal(t, location, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F") }, }, { @@ -142,8 +143,7 @@ func TestProxyController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 307, recorder.Code) location := recorder.Header().Get("Location") - assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") - assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2Fhello") + assert.Equal(t, location, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello") }, }, { @@ -159,8 +159,7 @@ func TestProxyController(t *testing.T) { assert.Equal(t, 307, recorder.Code) location := recorder.Header().Get("Location") - assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") - assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2F") + assert.Equal(t, location, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F") }, }, { @@ -174,6 +173,8 @@ func TestProxyController(t *testing.T) { req.Header.Set("user-agent", browserUserAgent) router.ServeHTTP(recorder, req) assert.Equal(t, 401, recorder.Code) + location := recorder.Header().Get("x-tinyauth-location") + assert.Equal(t, location, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2F") }, }, { @@ -189,8 +190,7 @@ func TestProxyController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 307, recorder.Code) location := recorder.Header().Get("Location") - assert.Contains(t, location, "https://tinyauth.example.com/login?redirect_uri=") - assert.Contains(t, location, "https%3A%2F%2Ftest.example.com%2Fhello") + assert.Equal(t, location, "https://tinyauth.example.com/login?redirect_uri=https%3A%2F%2Ftest.example.com%2Fhello") }, }, {