mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 04:35:40 +00:00
feat: add CSRF cookie protection
This commit is contained in:
@@ -2,7 +2,6 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -42,7 +41,6 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
|
|||||||
MaxAge: auth.Config.SessionExpiry,
|
MaxAge: auth.Config.SessionExpiry,
|
||||||
Secure: auth.Config.CookieSecure,
|
Secure: auth.Config.CookieSecure,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteDefaultMode,
|
|
||||||
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
|
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -515,11 +515,17 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
|
|||||||
|
|
||||||
log.Debug().Str("provider", request.Provider).Msg("Got provider")
|
log.Debug().Str("provider", request.Provider).Msg("Got provider")
|
||||||
|
|
||||||
|
// Create state
|
||||||
|
state := provider.GenerateState()
|
||||||
|
|
||||||
// Get auth URL
|
// Get auth URL
|
||||||
authURL := provider.GetAuthURL()
|
authURL := provider.GetAuthURL(state)
|
||||||
|
|
||||||
log.Debug().Msg("Got auth URL")
|
log.Debug().Msg("Got auth URL")
|
||||||
|
|
||||||
|
// Set CSRF cookie
|
||||||
|
c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
|
||||||
|
|
||||||
// Get redirect URI
|
// Get redirect URI
|
||||||
redirectURI := c.Query("redirect_uri")
|
redirectURI := c.Query("redirect_uri")
|
||||||
|
|
||||||
@@ -553,16 +559,33 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
|
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
|
||||||
|
|
||||||
// Get code
|
// Get state
|
||||||
code := c.Query("code")
|
state := c.Query("state")
|
||||||
|
|
||||||
// Code empty so redirect to error
|
// Get CSRF cookie
|
||||||
if code == "" {
|
csrfCookie, err := c.Cookie("tinyauth-csrf")
|
||||||
log.Error().Msg("No code provided")
|
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Msg("No CSRF cookie")
|
||||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie")
|
||||||
|
|
||||||
|
// Check if CSRF cookie is valid
|
||||||
|
if csrfCookie != state {
|
||||||
|
log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state")
|
||||||
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up CSRF cookie
|
||||||
|
c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true)
|
||||||
|
|
||||||
|
// Get code
|
||||||
|
code := c.Query("code")
|
||||||
|
|
||||||
log.Debug().Msg("Got code")
|
log.Debug().Msg("Got code")
|
||||||
|
|
||||||
// Get provider
|
// Get provider
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package oauth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -26,9 +28,9 @@ func (oauth *OAuth) Init() {
|
|||||||
oauth.Verifier = oauth2.GenerateVerifier()
|
oauth.Verifier = oauth2.GenerateVerifier()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oauth *OAuth) GetAuthURL() string {
|
func (oauth *OAuth) GetAuthURL(state string) string {
|
||||||
// Return the auth url
|
// Return the auth url
|
||||||
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
||||||
@@ -51,3 +53,16 @@ func (oauth *OAuth) GetClient() *http.Client {
|
|||||||
// Return the http client with the token set
|
// Return the http client with the token set
|
||||||
return oauth.Config.Client(oauth.Context, oauth.Token)
|
return oauth.Config.Client(oauth.Context, oauth.Token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (oauth *OAuth) GenerateState() string {
|
||||||
|
// Generate a random state string
|
||||||
|
b := make([]byte, 128)
|
||||||
|
|
||||||
|
// Fill the byte slice with random data
|
||||||
|
rand.Read(b)
|
||||||
|
|
||||||
|
// Encode the byte slice to a base64 string
|
||||||
|
state := base64.URLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user