feat: add CSRF cookie protection

This commit is contained in:
Stavros
2025-04-14 20:00:58 +03:00
parent eb36b2211b
commit 02faabf688
3 changed files with 46 additions and 10 deletions

View File

@@ -2,7 +2,6 @@ package auth
import (
"fmt"
"net/http"
"regexp"
"slices"
"strings"
@@ -42,7 +41,6 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
MaxAge: auth.Config.SessionExpiry,
Secure: auth.Config.CookieSecure,
HttpOnly: true,
SameSite: http.SameSiteDefaultMode,
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
}

View File

@@ -515,11 +515,17 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
log.Debug().Str("provider", request.Provider).Msg("Got provider")
// Create state
state := provider.GenerateState()
// Get auth URL
authURL := provider.GetAuthURL()
authURL := provider.GetAuthURL(state)
log.Debug().Msg("Got auth URL")
// Set CSRF cookie
c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
// Get redirect URI
redirectURI := c.Query("redirect_uri")
@@ -553,16 +559,33 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
// Get code
code := c.Query("code")
// Get state
state := c.Query("state")
// Code empty so redirect to error
if code == "" {
log.Error().Msg("No code provided")
// Get CSRF cookie
csrfCookie, err := c.Cookie("tinyauth-csrf")
if err != nil {
log.Debug().Msg("No CSRF cookie")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie")
// Check if CSRF cookie is valid
if csrfCookie != state {
log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
// Clean up CSRF cookie
c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true)
// Get code
code := c.Query("code")
log.Debug().Msg("Got code")
// Get provider

View File

@@ -2,6 +2,8 @@ package oauth
import (
"context"
"crypto/rand"
"encoding/base64"
"net/http"
"golang.org/x/oauth2"
@@ -26,9 +28,9 @@ func (oauth *OAuth) Init() {
oauth.Verifier = oauth2.GenerateVerifier()
}
func (oauth *OAuth) GetAuthURL() string {
func (oauth *OAuth) GetAuthURL(state string) string {
// Return the auth url
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
}
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
@@ -51,3 +53,16 @@ func (oauth *OAuth) GetClient() *http.Client {
// Return the http client with the token set
return oauth.Config.Client(oauth.Context, oauth.Token)
}
func (oauth *OAuth) GenerateState() string {
// Generate a random state string
b := make([]byte, 128)
// Fill the byte slice with random data
rand.Read(b)
// Encode the byte slice to a base64 string
state := base64.URLEncoding.EncodeToString(b)
return state
}