diff --git a/frontend/src/lib/hooks/redirect-uri.ts b/frontend/src/lib/hooks/redirect-uri.ts index 38e8b5c5..99b14f07 100644 --- a/frontend/src/lib/hooks/redirect-uri.ts +++ b/frontend/src/lib/hooks/redirect-uri.ts @@ -9,6 +9,7 @@ type IuseRedirectUri = { export const useRedirectUri = ( redirect_uri: string | undefined, cookieDomain: string, + subdomainsEnabled: boolean, ): IuseRedirectUri => { let isValid = false; let isTrusted = false; @@ -39,10 +40,11 @@ export const useRedirectUri = ( isValid = true; - if ( - url.hostname == cookieDomain || - url.hostname.endsWith(`.${cookieDomain}`) - ) { + if (url.hostname == cookieDomain) { + isTrusted = true; + } + + if (subdomainsEnabled && url.hostname.endsWith("." + cookieDomain)) { isTrusted = true; } diff --git a/frontend/src/pages/continue-page.tsx b/frontend/src/pages/continue-page.tsx index 3220ac99..e63bb7e0 100644 --- a/frontend/src/pages/continue-page.tsx +++ b/frontend/src/pages/continue-page.tsx @@ -37,6 +37,7 @@ export const ContinuePage = () => { const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( redirectUri, app.cookieDomain, + app.subdomainsEnabled, ); const urlHref = url?.href; @@ -108,7 +109,11 @@ export const ContinuePage = () => { components={{ code: , }} - values={{ cookieDomain: app.cookieDomain }} + values={{ + cookieDomain: app.subdomainsEnabled + ? `.${app.cookieDomain}` + : app.cookieDomain, + }} shouldUnescape={true} /> diff --git a/frontend/src/schemas/app-context-schema.ts b/frontend/src/schemas/app-context-schema.ts index 4ad64940..f8740a70 100644 --- a/frontend/src/schemas/app-context-schema.ts +++ b/frontend/src/schemas/app-context-schema.ts @@ -24,6 +24,7 @@ const uiSchema = z.object({ const appSchema = z.object({ appUrl: z.string(), cookieDomain: z.string(), + subdomainsEnabled: z.boolean(), }); export const appContextSchema = z.object({ diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index fc8bba18..b163f61e 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -97,7 +97,7 @@ func (app *BootstrapApp) Setup() error { return fmt.Errorf("failed to parse app url: %w", err) } - app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host + app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host) // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { @@ -162,14 +162,11 @@ func (app *BootstrapApp) Setup() error { } // cookie domain - cookieDomainResolver := utils.GetCookieDomain - if !app.config.Auth.SubdomainsEnabled { - app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") - cookieDomainResolver = utils.GetStandaloneCookieDomain + app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only") } - cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) + cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL) if err != nil { return fmt.Errorf("failed to get cookie domain: %w", err) @@ -290,6 +287,14 @@ func (app *BootstrapApp) Setup() error { if tailscaleUrl != app.runtime.AppURL { app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname") app.runtime.AppURL = tailscaleUrl + // also update cookie domain + cookieDomain, err := utils.GetCookieDomain(tailscaleUrl) + + if err != nil { + return fmt.Errorf("failed to get cookie domain: %w", err) + } + + app.runtime.CookieDomain = cookieDomain } } diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 5579c923..84f52cc3 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -58,8 +58,9 @@ type ACRUI struct { } type ACRApp struct { - AppURL string `json:"appUrl"` - CookieDomain string `json:"cookieDomain"` + AppURL string `json:"appUrl"` + CookieDomain string `json:"cookieDomain"` + SubdomainsEnabled bool `json:"subdomainsEnabled"` } type AppContextResponse struct { @@ -159,8 +160,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { WarningsEnabled: controller.config.UI.WarningsEnabled, }, App: ACRApp{ - AppURL: controller.runtime.AppURL, - CookieDomain: controller.runtime.CookieDomain, + AppURL: controller.runtime.AppURL, + CookieDomain: controller.runtime.CookieDomain, + SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled, }, }) } diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 777e380d..5544b3b8 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -1,7 +1,7 @@ package utils import ( - "errors" + "fmt" "net" "net/url" "strings" @@ -10,58 +10,44 @@ import ( ) // Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetCookieDomain(u string) (string, error) { - parsed, err := url.Parse(u) +func GetCookieDomain(appUrl string) (string, error) { + u, err := url.Parse(appUrl) + if err != nil { - return "", err + return "", fmt.Errorf("invalid app url: %w", err) } - host := parsed.Hostname() + hostname := strings.ToLower(u.Hostname()) - if netIP := net.ParseIP(host); netIP != nil { - return "", errors.New("ip addresses not allowed") + if netIP := net.ParseIP(hostname); netIP != nil { + return "", fmt.Errorf("ip addresses not allowed") } - parts := strings.Split(host, ".") + parts := strings.Split(hostname, ".") + + if len(parts) < 2 { + return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld") + } if len(parts) == 2 { - return host, nil + return strings.ToLower(u.Host), nil } - if len(parts) < 3 { - return "", errors.New("invalid app url, must be at least second level domain") - } + // parts > 3 domain := strings.Join(parts[1:], ".") _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil) if err != nil { - return "", errors.New("domain in public suffix list, cannot set cookies") + return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err) } - return domain, nil -} + // now that we validated the domain, return with the port + parts = strings.Split(strings.ToLower(u.Host), ":") + domainWithPort := strings.Join(parts[1:], ":") -func GetStandaloneCookieDomain(u string) (string, error) { - parsed, err := url.Parse(u) - if err != nil { - return "", err - } - - host := parsed.Hostname() - - if netIP := net.ParseIP(host); netIP != nil { - return "", errors.New("ip addresses not allowed") - } - - parts := strings.Split(host, ".") - - if len(parts) < 2 { - return "", errors.New("invalid app url") - } - - return host, nil + return domainWithPort, nil } func ParseFileToLine(content string) string { diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index f0c3625c..c01c88e1 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -125,48 +125,3 @@ func TestFilter(t *testing.T) { resultStr := utils.Filter(sliceStr, testFuncStr) assert.Equal(t, expectedStr, resultStr) } - -func TestGetStandaloneCookieDomain(t *testing.T) { - // Normal case - domain := "http://tinyauth.app" - expected := "tinyauth.app" - result, err := utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // URL with subdomain (full hostname is returned, no subdomain stripping) - domain = "http://sub.tinyauth.app" - expected = "sub.tinyauth.app" - result, err = utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // URL with port (port should be stripped) - domain = "http://tinyauth.app:8080" - expected = "tinyauth.app" - result, err = utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // URL with path - domain = "https://tinyauth.app/some/path" - expected = "tinyauth.app" - result, err = utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // IP address - domain = "http://10.10.10.10" - _, err = utils.GetStandaloneCookieDomain(domain) - assert.ErrorContains(t, err, "ip addresses not allowed") - - // Invalid domain (only TLD) - domain = "com" - _, err = utils.GetStandaloneCookieDomain(domain) - assert.ErrorContains(t, err, "invalid app url") - - // Invalid URL - domain = "http://[::1]:namedport" - _, err = utils.GetStandaloneCookieDomain(domain) - assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") -}