refactor: rework cookie domain logic

This commit is contained in:
Stavros
2026-06-21 16:51:39 +03:00
parent 8c739c68e3
commit 21877190e4
7 changed files with 50 additions and 94 deletions
+6 -4
View File
@@ -9,6 +9,7 @@ type IuseRedirectUri = {
export const useRedirectUri = ( export const useRedirectUri = (
redirect_uri: string | undefined, redirect_uri: string | undefined,
cookieDomain: string, cookieDomain: string,
subdomainsEnabled: boolean,
): IuseRedirectUri => { ): IuseRedirectUri => {
let isValid = false; let isValid = false;
let isTrusted = false; let isTrusted = false;
@@ -39,10 +40,11 @@ export const useRedirectUri = (
isValid = true; isValid = true;
if ( if (url.hostname == cookieDomain) {
url.hostname == cookieDomain || isTrusted = true;
url.hostname.endsWith(`.${cookieDomain}`) }
) {
if (subdomainsEnabled && url.hostname.endsWith("." + cookieDomain)) {
isTrusted = true; isTrusted = true;
} }
+6 -1
View File
@@ -37,6 +37,7 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri, redirectUri,
app.cookieDomain, app.cookieDomain,
app.subdomainsEnabled,
); );
const urlHref = url?.href; const urlHref = url?.href;
@@ -108,7 +109,11 @@ export const ContinuePage = () => {
components={{ components={{
code: <code />, code: <code />,
}} }}
values={{ cookieDomain: app.cookieDomain }} values={{
cookieDomain: app.subdomainsEnabled
? `.${app.cookieDomain}`
: app.cookieDomain,
}}
shouldUnescape={true} shouldUnescape={true}
/> />
</CardDescription> </CardDescription>
@@ -24,6 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({ const appSchema = z.object({
appUrl: z.string(), appUrl: z.string(),
cookieDomain: z.string(), cookieDomain: z.string(),
subdomainsEnabled: z.boolean(),
}); });
export const appContextSchema = z.object({ export const appContextSchema = z.object({
+11 -6
View File
@@ -97,7 +97,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err) return fmt.Errorf("failed to parse app url: %w", err)
} }
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
@@ -162,14 +162,11 @@ func (app *BootstrapApp) Setup() error {
} }
// cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled { if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err) return fmt.Errorf("failed to get cookie domain: %w", err)
@@ -290,6 +287,14 @@ func (app *BootstrapApp) Setup() error {
if tailscaleUrl != app.runtime.AppURL { if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname") app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")
app.runtime.AppURL = tailscaleUrl app.runtime.AppURL = tailscaleUrl
// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl)
if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}
app.runtime.CookieDomain = cookieDomain
} }
} }
@@ -60,6 +60,7 @@ type ACRUI struct {
type ACRApp struct { type ACRApp struct {
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
SubdomainsEnabled bool `json:"subdomainsEnabled"`
} }
type AppContextResponse struct { type AppContextResponse struct {
@@ -161,6 +162,7 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
App: ACRApp{ App: ACRApp{
AppURL: controller.runtime.AppURL, AppURL: controller.runtime.AppURL,
CookieDomain: controller.runtime.CookieDomain, CookieDomain: controller.runtime.CookieDomain,
SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled,
}, },
}) })
} }
+20 -34
View File
@@ -1,7 +1,7 @@
package utils package utils
import ( import (
"errors" "fmt"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -10,58 +10,44 @@ import (
) )
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) // Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
func GetCookieDomain(u string) (string, error) { func GetCookieDomain(appUrl string) (string, error) {
parsed, err := url.Parse(u) u, err := url.Parse(appUrl)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("invalid app url: %w", err)
} }
host := parsed.Hostname() hostname := strings.ToLower(u.Hostname())
if netIP := net.ParseIP(host); netIP != nil { if netIP := net.ParseIP(hostname); netIP != nil {
return "", errors.New("ip addresses not allowed") return "", fmt.Errorf("ip addresses not allowed")
} }
parts := strings.Split(host, ".") parts := strings.Split(hostname, ".")
if len(parts) < 2 {
return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld")
}
if len(parts) == 2 { if len(parts) == 2 {
return host, nil return strings.ToLower(u.Host), nil
} }
if len(parts) < 3 { // parts > 3
return "", errors.New("invalid app url, must be at least second level domain")
}
domain := strings.Join(parts[1:], ".") domain := strings.Join(parts[1:], ".")
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil) _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
if err != nil { if err != nil {
return "", errors.New("domain in public suffix list, cannot set cookies") return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err)
} }
return domain, nil // now that we validated the domain, return with the port
} parts = strings.Split(strings.ToLower(u.Host), ":")
domainWithPort := strings.Join(parts[1:], ":")
func GetStandaloneCookieDomain(u string) (string, error) { return domainWithPort, nil
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
host := parsed.Hostname()
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("ip addresses not allowed")
}
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "", errors.New("invalid app url")
}
return host, nil
} }
func ParseFileToLine(content string) string { func ParseFileToLine(content string) string {
-45
View File
@@ -125,48 +125,3 @@ func TestFilter(t *testing.T) {
resultStr := utils.Filter(sliceStr, testFuncStr) resultStr := utils.Filter(sliceStr, testFuncStr)
assert.Equal(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestGetStandaloneCookieDomain(t *testing.T) {
// Normal case
domain := "http://tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with subdomain (full hostname is returned, no subdomain stripping)
domain = "http://sub.tinyauth.app"
expected = "sub.tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port (port should be stripped)
domain = "http://tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with path
domain = "https://tinyauth.app/some/path"
expected = "tinyauth.app"
result, err = utils.GetStandaloneCookieDomain(domain)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// IP address
domain = "http://10.10.10.10"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "ip addresses not allowed")
// Invalid domain (only TLD)
domain = "com"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "invalid app url")
// Invalid URL
domain = "http://[::1]:namedport"
_, err = utils.GetStandaloneCookieDomain(domain)
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
}