From 02faabf688e78671a7bf1fc9601bd7b81787f13f Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 14 Apr 2025 20:00:58 +0300 Subject: [PATCH] feat: add CSRF cookie protection --- internal/auth/auth.go | 2 -- internal/handlers/handlers.go | 35 +++++++++++++++++++++++++++++------ internal/oauth/oauth.go | 19 +++++++++++++++++-- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 594feb9..521c3c8 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -2,7 +2,6 @@ package auth import ( "fmt" - "net/http" "regexp" "slices" "strings" @@ -42,7 +41,6 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { MaxAge: auth.Config.SessionExpiry, Secure: auth.Config.CookieSecure, HttpOnly: true, - SameSite: http.SameSiteDefaultMode, Domain: fmt.Sprintf(".%s", auth.Config.Domain), } diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 7e15ca3..99739f9 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -515,11 +515,17 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) { log.Debug().Str("provider", request.Provider).Msg("Got provider") + // Create state + state := provider.GenerateState() + // Get auth URL - authURL := provider.GetAuthURL() + authURL := provider.GetAuthURL(state) log.Debug().Msg("Got auth URL") + // Set CSRF cookie + c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) + // Get redirect URI redirectURI := c.Query("redirect_uri") @@ -553,16 +559,33 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") - // Get code - code := c.Query("code") + // Get state + state := c.Query("state") - // Code empty so redirect to error - if code == "" { - log.Error().Msg("No code provided") + // Get CSRF cookie + csrfCookie, err := c.Cookie("tinyauth-csrf") + + if err != nil { + log.Debug().Msg("No CSRF cookie") c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) return } + log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie") + + // Check if CSRF cookie is valid + if csrfCookie != state { + log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Clean up CSRF cookie + c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true) + + // Get code + code := c.Query("code") + log.Debug().Msg("Got code") // Get provider diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 86ca010..e37371f 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -2,6 +2,8 @@ package oauth import ( "context" + "crypto/rand" + "encoding/base64" "net/http" "golang.org/x/oauth2" @@ -26,9 +28,9 @@ func (oauth *OAuth) Init() { oauth.Verifier = oauth2.GenerateVerifier() } -func (oauth *OAuth) GetAuthURL() string { +func (oauth *OAuth) GetAuthURL(state string) string { // Return the auth url - return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) + return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) } func (oauth *OAuth) ExchangeToken(code string) (string, error) { @@ -51,3 +53,16 @@ func (oauth *OAuth) GetClient() *http.Client { // Return the http client with the token set return oauth.Config.Client(oauth.Context, oauth.Token) } + +func (oauth *OAuth) GenerateState() string { + // Generate a random state string + b := make([]byte, 128) + + // Fill the byte slice with random data + rand.Read(b) + + // Encode the byte slice to a base64 string + state := base64.URLEncoding.EncodeToString(b) + + return state +}