fix: use runtime trusted uris in oauth controller

This commit is contained in:
Stavros
2026-06-17 12:33:09 +03:00
parent a9face749d
commit e7d26f497d
5 changed files with 125 additions and 79 deletions
+38 -3
View File
@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@@ -80,9 +81,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
if !controller.isRedirectSafe(reqParams.RedirectURI) {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
@@ -310,3 +309,39 @@ func (controller *OAuthController) getCookieDomain() string {
}
return controller.runtime.CookieDomain
}
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil || u.Host == "" || u.Scheme == "" {
return false
}
for _, allowed := range controller.runtime.TrustedDomains {
tu, err := url.Parse(allowed)
if err != nil {
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
continue
}
if tu.Scheme != u.Scheme {
continue
}
// exact match
if u.Host == tu.Host {
return true
}
// subdomain match (trim the tinyauth part)
_, root, ok := strings.Cut(tu.Host, ".")
if !ok {
continue
}
if strings.HasSuffix(u.Host, "."+root) {
return true
}
}
return false
}
@@ -0,0 +1,83 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthController(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
run func(ctrl *OAuthController)
}
tests := []testCase{
{
description: "Test exact match of redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain match of redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test different trusted domain",
run: func(ctrl *OAuthController) {
redirectUri := "https://app.foo.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test invalid redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https://malicious.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test empty redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := ""
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different scheme",
run: func(ctrl *OAuthController) {
redirectUri := "http://tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
}
// TODO: add auth service
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
tc.run(ctrl)
})
}
}
+4
View File
@@ -121,6 +121,10 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
TrustedDomains: []string{
"https://tinyauth.example.com",
"https://tinyauth.foo.com",
},
}
return config, runtime
-21
View File
@@ -2,7 +2,6 @@ package utils
import (
"errors"
"fmt"
"net"
"net/url"
"strings"
@@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
}
return res
}
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsed, err := url.Parse(redirectURL)
if err != nil {
return false
}
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}
return hostname == domain
}
-55
View File
@@ -126,61 +126,6 @@ func TestFilter(t *testing.T) {
assert.Equal(t, expectedStr, resultStr)
}
func TestIsRedirectSafe(t *testing.T) {
// Setup
domain := "example.com"
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.False(t, result)
}
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"