From e7d26f497d731d3fa757a21b2e5183abaeae5e86 Mon Sep 17 00:00:00 2001 From: Stavros Date: Wed, 17 Jun 2026 12:33:09 +0300 Subject: [PATCH] fix: use runtime trusted uris in oauth controller --- internal/controller/oauth_controller.go | 41 +++++++++- internal/controller/oauth_controller_test.go | 83 ++++++++++++++++++++ internal/test/test.go | 4 + internal/utils/app_utils.go | 21 ----- internal/utils/app_utils_test.go | 55 ------------- 5 files changed, 125 insertions(+), 79 deletions(-) create mode 100644 internal/controller/oauth_controller_test.go diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 21877705..75962673 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "net/http" + "net/url" "strings" "time" @@ -80,9 +81,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { } if !controller.isOidcRequest(reqParams) { - isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) - - if !isRedirectSafe { + if !controller.isRedirectSafe(reqParams.RedirectURI) { controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") reqParams.RedirectURI = "" } @@ -310,3 +309,39 @@ func (controller *OAuthController) getCookieDomain() string { } return controller.runtime.CookieDomain } + +func (controller *OAuthController) isRedirectSafe(redirectURI string) bool { + u, err := url.Parse(redirectURI) + + if err != nil || u.Host == "" || u.Scheme == "" { + return false + } + + for _, allowed := range controller.runtime.TrustedDomains { + tu, err := url.Parse(allowed) + if err != nil { + controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain") + continue + } + + if tu.Scheme != u.Scheme { + continue + } + + // exact match + if u.Host == tu.Host { + return true + } + + // subdomain match (trim the tinyauth part) + _, root, ok := strings.Cut(tu.Host, ".") + if !ok { + continue + } + if strings.HasSuffix(u.Host, "."+root) { + return true + } + } + + return false +} diff --git a/internal/controller/oauth_controller_test.go b/internal/controller/oauth_controller_test.go new file mode 100644 index 00000000..355f9c8f --- /dev/null +++ b/internal/controller/oauth_controller_test.go @@ -0,0 +1,83 @@ +package controller + +import ( + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestOAuthController(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, runtime := test.CreateTestConfigs(t) + + type testCase struct { + description string + run func(ctrl *OAuthController) + } + + tests := []testCase{ + { + description: "Test exact match of redirect URI", + run: func(ctrl *OAuthController) { + redirectUri := "https://tinyauth.example.com" + assert.True(t, ctrl.isRedirectSafe(redirectUri)) + }, + }, + { + description: "Test subdomain match of redirect URI", + run: func(ctrl *OAuthController) { + redirectUri := "https://sub.example.com" + assert.True(t, ctrl.isRedirectSafe(redirectUri)) + }, + }, + { + description: "Test different trusted domain", + run: func(ctrl *OAuthController) { + redirectUri := "https://app.foo.com" + assert.True(t, ctrl.isRedirectSafe(redirectUri)) + }, + }, + { + description: "Test invalid redirect URI", + run: func(ctrl *OAuthController) { + redirectUri := "https://malicious.com" + assert.False(t, ctrl.isRedirectSafe(redirectUri)) + }, + }, + { + description: "Test empty redirect URI", + run: func(ctrl *OAuthController) { + redirectUri := "" + assert.False(t, ctrl.isRedirectSafe(redirectUri)) + }, + }, + { + description: "Test redirect URI with different scheme", + run: func(ctrl *OAuthController) { + redirectUri := "http://tinyauth.example.com" + assert.False(t, ctrl.isRedirectSafe(redirectUri)) + }, + }, + } + + // TODO: add auth service + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + router := gin.Default() + group := router.Group("/api") + gin.SetMode(gin.TestMode) + ctrl := NewOAuthController(OAuthControllerInput{ + Log: log, + Config: &cfg, + RuntimeConfig: &runtime, + RouterGroup: group, + }) + tc.run(ctrl) + }) + } +} diff --git a/internal/test/test.go b/internal/test/test.go index 76c31a27..df10f2b4 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -121,6 +121,10 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { CookieDomain: "example.com", AppURL: "https://tinyauth.example.com", SessionCookieName: "tinyauth-session", + TrustedDomains: []string{ + "https://tinyauth.example.com", + "https://tinyauth.foo.com", + }, } return config, runtime diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 6413755b..777e380d 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -2,7 +2,6 @@ package utils import ( "errors" - "fmt" "net" "net/url" "strings" @@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) { } return res } - -func IsRedirectSafe(redirectURL string, domain string) bool { - if redirectURL == "" { - return false - } - - parsed, err := url.Parse(redirectURL) - - if err != nil { - return false - } - - hostname := parsed.Hostname() - - if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) { - return true - } - - return hostname == domain -} diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index 6554fad8..f0c3625c 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -126,61 +126,6 @@ func TestFilter(t *testing.T) { assert.Equal(t, expectedStr, resultStr) } -func TestIsRedirectSafe(t *testing.T) { - // Setup - domain := "example.com" - - // Case with no subdomain - redirectURL := "http://example.com/welcome" - result := utils.IsRedirectSafe(redirectURL, domain) - assert.True(t, result) - - // Case with different domain - redirectURL = "http://malicious.com/phishing" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.False(t, result) - - // Case with subdomain - redirectURL = "http://sub.example.com/page" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.True(t, result) - - // Case with sub-subdomain - redirectURL = "http://a.b.example.com/home" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.True(t, result) - - // Case with empty redirect URL - redirectURL = "" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.False(t, result) - - // Case with invalid URL - redirectURL = "http://[::1]:namedport" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.False(t, result) - - // Case with URL having port - redirectURL = "http://sub.example.com:8080/page" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.True(t, result) - - // Case with URL having different subdomain - redirectURL = "http://another.example.com/page" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.True(t, result) - - // Case with URL having different TLD - redirectURL = "http://example.org/page" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.False(t, result) - - // Case with malicious domain - redirectURL = "https://malicious-example.com/yoyo" - result = utils.IsRedirectSafe(redirectURL, domain) - assert.False(t, result) -} - func TestGetStandaloneCookieDomain(t *testing.T) { // Normal case domain := "http://tinyauth.app"