diff --git a/cmd/root.go b/cmd/root.go index 776c3cc..b5d76b5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -95,6 +95,8 @@ var rootCmd = &cobra.Command{ DisableContinue: config.DisableContinue, Title: config.Title, GenericName: config.GenericName, + CookieSecure: config.CookieSecure, + Domain: domain, } // Create api config diff --git a/internal/auth/auth.go b/internal/auth/auth.go index d4a73ac..594feb9 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -178,7 +178,6 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) session.Values["provider"] = data.Provider session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() session.Values["totpPending"] = data.TotpPending - session.Values["redirectURI"] = data.RedirectURI // Save session err = session.Save(c.Request, c.Writer) @@ -230,11 +229,10 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) // Get data from session username, usernameOk := session.Values["username"].(string) provider, providerOK := session.Values["provider"].(string) - redirectURI, redirectOK := session.Values["redirectURI"].(string) expiry, expiryOk := session.Values["expiry"].(int64) totpPending, totpPendingOk := session.Values["totpPending"].(bool) - if !usernameOk || !providerOK || !expiryOk || !redirectOK || !totpPendingOk { + if !usernameOk || !providerOK || !expiryOk || !totpPendingOk { log.Warn().Msg("Session cookie is missing data") return types.SessionCookie{}, nil } @@ -257,7 +255,6 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) Username: username, Provider: provider, TotpPending: totpPending, - RedirectURI: redirectURI, }, nil } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index f839f64..d4b0308 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "strings" + "time" "tinyauth/internal/auth" "tinyauth/internal/docker" "tinyauth/internal/hooks" @@ -525,9 +526,7 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) { // Set redirect cookie if redirect URI is provided if redirectURI != "" { log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - RedirectURI: redirectURI, - }) + c.SetCookie("tinyauth-redirect", redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) } // Return auth URL @@ -623,25 +622,26 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { log.Debug().Msg("Email whitelisted") - // Get redirect URI - cookie, err := h.Auth.GetSessionCookie(c) - // Create session cookie (also cleans up redirect cookie) h.Auth.CreateSessionCookie(c, &types.SessionCookie{ Username: email, Provider: providerName.Provider, }) - // If it is empty it means that no redirect_uri was provided to the login screen so we just log in + // Check if we have a redirect URI + redirectCookie, err := c.Cookie("tinyauth-redirect") + if err != nil { + log.Debug().Msg("No redirect cookie") c.Redirect(http.StatusPermanentRedirect, h.Config.AppURL) + return } - log.Debug().Str("redirectURI", cookie.RedirectURI).Msg("Got redirect URI") + log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") // Build query queries, err := query.Values(types.LoginQuery{ - RedirectURI: cookie.RedirectURI, + RedirectURI: redirectCookie, }) log.Debug().Msg("Got redirect query") diff --git a/internal/types/config.go b/internal/types/config.go index bbab8cc..13730b4 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -37,6 +37,8 @@ type Config struct { // Server configuration type HandlersConfig struct { AppURL string + Domain string + CookieSecure bool DisableContinue bool GenericName string Title string diff --git a/internal/types/types.go b/internal/types/types.go index 652779a..19d877d 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -27,7 +27,6 @@ type SessionCookie struct { Username string Provider string TotpPending bool - RedirectURI string } // TinyauthLabels is the labels for the tinyauth container