feat: add CSRF cookie protection

This commit is contained in:
Stavros
2025-04-14 20:00:58 +03:00
parent eb36b2211b
commit 02faabf688
3 changed files with 46 additions and 10 deletions

View File

@@ -2,7 +2,6 @@ package auth
import ( import (
"fmt" "fmt"
"net/http"
"regexp" "regexp"
"slices" "slices"
"strings" "strings"
@@ -42,7 +41,6 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
MaxAge: auth.Config.SessionExpiry, MaxAge: auth.Config.SessionExpiry,
Secure: auth.Config.CookieSecure, Secure: auth.Config.CookieSecure,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteDefaultMode,
Domain: fmt.Sprintf(".%s", auth.Config.Domain), Domain: fmt.Sprintf(".%s", auth.Config.Domain),
} }

View File

@@ -515,11 +515,17 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
log.Debug().Str("provider", request.Provider).Msg("Got provider") log.Debug().Str("provider", request.Provider).Msg("Got provider")
// Create state
state := provider.GenerateState()
// Get auth URL // Get auth URL
authURL := provider.GetAuthURL() authURL := provider.GetAuthURL(state)
log.Debug().Msg("Got auth URL") log.Debug().Msg("Got auth URL")
// Set CSRF cookie
c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
// Get redirect URI // Get redirect URI
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
@@ -553,16 +559,33 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
// Get code // Get state
code := c.Query("code") state := c.Query("state")
// Code empty so redirect to error // Get CSRF cookie
if code == "" { csrfCookie, err := c.Cookie("tinyauth-csrf")
log.Error().Msg("No code provided")
if err != nil {
log.Debug().Msg("No CSRF cookie")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return return
} }
log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie")
// Check if CSRF cookie is valid
if csrfCookie != state {
log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
// Clean up CSRF cookie
c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true)
// Get code
code := c.Query("code")
log.Debug().Msg("Got code") log.Debug().Msg("Got code")
// Get provider // Get provider

View File

@@ -2,6 +2,8 @@ package oauth
import ( import (
"context" "context"
"crypto/rand"
"encoding/base64"
"net/http" "net/http"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -26,9 +28,9 @@ func (oauth *OAuth) Init() {
oauth.Verifier = oauth2.GenerateVerifier() oauth.Verifier = oauth2.GenerateVerifier()
} }
func (oauth *OAuth) GetAuthURL() string { func (oauth *OAuth) GetAuthURL(state string) string {
// Return the auth url // Return the auth url
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
} }
func (oauth *OAuth) ExchangeToken(code string) (string, error) { func (oauth *OAuth) ExchangeToken(code string) (string, error) {
@@ -51,3 +53,16 @@ func (oauth *OAuth) GetClient() *http.Client {
// Return the http client with the token set // Return the http client with the token set
return oauth.Config.Client(oauth.Context, oauth.Token) return oauth.Config.Client(oauth.Context, oauth.Token)
} }
func (oauth *OAuth) GenerateState() string {
// Generate a random state string
b := make([]byte, 128)
// Fill the byte slice with random data
rand.Read(b)
// Encode the byte slice to a base64 string
state := base64.URLEncoding.EncodeToString(b)
return state
}