mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-25 04:40:15 +00:00
refactor: remove concurrent listeners and rework cookie logic (#950)
This commit is contained in:
@@ -46,7 +46,7 @@ type OAuthPendingSession struct {
|
||||
State string
|
||||
Verifier string
|
||||
Token *oauth2.Token
|
||||
Service *OAuthServiceImpl
|
||||
Service IOAuthService
|
||||
ExpiresAt time.Time
|
||||
CallbackParams OAuthCallbackParams
|
||||
}
|
||||
@@ -380,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||
}
|
||||
|
||||
if data.Provider == "tailscale" {
|
||||
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
|
||||
|
||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", tsCookieDomain),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -459,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -480,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: time.Now(),
|
||||
MaxAge: -1,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -549,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac
|
||||
session := OAuthPendingSession{
|
||||
State: state,
|
||||
Verifier: verifier,
|
||||
Service: &service,
|
||||
Service: service,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
CallbackParams: params,
|
||||
}
|
||||
@@ -566,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
||||
return session.Service.GetAuthURL(session.State, session.Verifier), nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||
@@ -576,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||
token, err := session.Service.GetToken(code, session.Verifier)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
@@ -605,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||
}
|
||||
|
||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||
userinfo, err := session.Service.GetUserinfo(session.Token)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
@@ -614,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return userinfo, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
|
||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return *session.Service, nil
|
||||
return session.Service, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
@@ -726,3 +704,10 @@ func (auth *AuthService) calculateLockdownLimit() int {
|
||||
|
||||
return limit
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if !auth.config.Auth.SubdomainsEnabled {
|
||||
return ""
|
||||
}
|
||||
return auth.runtime.CookieDomain
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user