mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 04:35:40 +00:00
feat: add CSRF cookie protection
This commit is contained in:
@@ -2,7 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -42,7 +41,6 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
|
||||
MaxAge: auth.Config.SessionExpiry,
|
||||
Secure: auth.Config.CookieSecure,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteDefaultMode,
|
||||
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
|
||||
}
|
||||
|
||||
|
||||
@@ -515,11 +515,17 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
|
||||
|
||||
log.Debug().Str("provider", request.Provider).Msg("Got provider")
|
||||
|
||||
// Create state
|
||||
state := provider.GenerateState()
|
||||
|
||||
// Get auth URL
|
||||
authURL := provider.GetAuthURL()
|
||||
authURL := provider.GetAuthURL(state)
|
||||
|
||||
log.Debug().Msg("Got auth URL")
|
||||
|
||||
// Set CSRF cookie
|
||||
c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
|
||||
|
||||
// Get redirect URI
|
||||
redirectURI := c.Query("redirect_uri")
|
||||
|
||||
@@ -553,16 +559,33 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
|
||||
|
||||
// Get code
|
||||
code := c.Query("code")
|
||||
// Get state
|
||||
state := c.Query("state")
|
||||
|
||||
// Code empty so redirect to error
|
||||
if code == "" {
|
||||
log.Error().Msg("No code provided")
|
||||
// Get CSRF cookie
|
||||
csrfCookie, err := c.Cookie("tinyauth-csrf")
|
||||
|
||||
if err != nil {
|
||||
log.Debug().Msg("No CSRF cookie")
|
||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie")
|
||||
|
||||
// Check if CSRF cookie is valid
|
||||
if csrfCookie != state {
|
||||
log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state")
|
||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up CSRF cookie
|
||||
c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true)
|
||||
|
||||
// Get code
|
||||
code := c.Query("code")
|
||||
|
||||
log.Debug().Msg("Got code")
|
||||
|
||||
// Get provider
|
||||
|
||||
@@ -2,6 +2,8 @@ package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
@@ -26,9 +28,9 @@ func (oauth *OAuth) Init() {
|
||||
oauth.Verifier = oauth2.GenerateVerifier()
|
||||
}
|
||||
|
||||
func (oauth *OAuth) GetAuthURL() string {
|
||||
func (oauth *OAuth) GetAuthURL(state string) string {
|
||||
// Return the auth url
|
||||
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
||||
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
||||
}
|
||||
|
||||
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
||||
@@ -51,3 +53,16 @@ func (oauth *OAuth) GetClient() *http.Client {
|
||||
// Return the http client with the token set
|
||||
return oauth.Config.Client(oauth.Context, oauth.Token)
|
||||
}
|
||||
|
||||
func (oauth *OAuth) GenerateState() string {
|
||||
// Generate a random state string
|
||||
b := make([]byte, 128)
|
||||
|
||||
// Fill the byte slice with random data
|
||||
rand.Read(b)
|
||||
|
||||
// Encode the byte slice to a base64 string
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user