mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-17 09:40:14 +00:00
Merge branch 'main' into feat/deny-by-default-acls
This commit is contained in:
@@ -208,7 +208,12 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
name = user.Name
|
||||
} else {
|
||||
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
|
||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
||||
parts := strings.SplitN(user.Email, "@", 2)
|
||||
if len(parts) == 2 {
|
||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(parts[0]), parts[1])
|
||||
} else {
|
||||
name = utils.Capitalize(user.Email)
|
||||
}
|
||||
}
|
||||
|
||||
var username string
|
||||
|
||||
@@ -146,7 +146,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||
|
||||
if !ok {
|
||||
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
|
||||
controller.authorizeError(c, fmt.Errorf("client not found: %s", req.ClientID), "Client not found", "The client ID is invalid", "", "", "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -288,7 +288,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||
if err != nil {
|
||||
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to delete code")
|
||||
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
|
||||
}
|
||||
if errors.Is(err, service.ErrCodeNotFound) {
|
||||
controller.log.App.Warn().Msg("Code not found")
|
||||
|
||||
@@ -138,9 +138,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
|
||||
if !controller.useBrowserResponse(proxyCtx) {
|
||||
c.Header("x-tinyauth-location", redirectURL)
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
c.JSON(403, gin.H{
|
||||
"status": 403,
|
||||
"message": "Forbidden",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||
if controller.config.Resources.Path == "" {
|
||||
c.JSON(404, gin.H{
|
||||
"status": 404,
|
||||
"message": "Resources not found",
|
||||
"message": "Resource not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -606,46 +606,49 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
|
||||
auth.oauthMutex.Lock()
|
||||
defer auth.oauthMutex.Unlock()
|
||||
|
||||
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
|
||||
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
|
||||
return
|
||||
}
|
||||
|
||||
cleanupIds := make([]string, 0, OAuthCleanupCount)
|
||||
type entry struct {
|
||||
id string
|
||||
expiresAt int64
|
||||
}
|
||||
|
||||
for range OAuthCleanupCount {
|
||||
oldestId := ""
|
||||
oldestTime := int64(0)
|
||||
entries := make([]entry, 0, len(auth.oauthPendingSessions))
|
||||
for id, session := range auth.oauthPendingSessions {
|
||||
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
|
||||
}
|
||||
|
||||
for id, session := range auth.oauthPendingSessions {
|
||||
if oldestTime == 0 {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
continue
|
||||
}
|
||||
if slices.Contains(cleanupIds, id) {
|
||||
continue
|
||||
}
|
||||
if session.ExpiresAt.Unix() < oldestTime {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
cleanupIds = append(cleanupIds, oldestId)
|
||||
slices.SortFunc(entries, func(a, b entry) int {
|
||||
if a.expiresAt < b.expiresAt {
|
||||
return -1
|
||||
}
|
||||
|
||||
for _, id := range cleanupIds {
|
||||
delete(auth.oauthPendingSessions, id)
|
||||
if a.expiresAt > b.expiresAt {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
for _, e := range entries[:OAuthCleanupCount] {
|
||||
delete(auth.oauthPendingSessions, e.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) lockdownMode() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
auth.lockdownCtx = ctx
|
||||
auth.lockdownCancelFunc = cancel
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
if auth.lockdown != nil && auth.lockdown.Active {
|
||||
auth.loginMutex.Unlock()
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
auth.lockdownCtx = ctx
|
||||
auth.lockdownCancelFunc = cancel
|
||||
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown = &Lockdown{
|
||||
@@ -658,10 +661,12 @@ func (auth *AuthService) lockdownMode() {
|
||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||
|
||||
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
||||
defer timer.Stop()
|
||||
|
||||
auth.loginMutex.Unlock()
|
||||
|
||||
defer cancel()
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
// Timer expired, end lockdown
|
||||
|
||||
@@ -26,6 +26,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: config.Insecure,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ type OIDCService struct {
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey crypto.PublicKey
|
||||
publicKey *rsa.PublicKey
|
||||
issuer string
|
||||
}
|
||||
|
||||
@@ -239,6 +239,16 @@ func NewOIDCService(
|
||||
}
|
||||
}
|
||||
|
||||
rPublicKey, ok := publicKey.(*rsa.PublicKey)
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("public key is not an rsa public key")
|
||||
}
|
||||
|
||||
if rPublicKey.N.Cmp(privateKey.N) != 0 || rPublicKey.E != privateKey.E {
|
||||
return nil, fmt.Errorf("public key does not pair with private key")
|
||||
}
|
||||
|
||||
// We will reorganize the client into a map with the client ID as the key
|
||||
clients := make(map[string]model.OIDCClientConfig)
|
||||
|
||||
@@ -271,7 +281,7 @@ func NewOIDCService(
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
publicKey: publicKey,
|
||||
publicKey: rPublicKey,
|
||||
issuer: issuer,
|
||||
}
|
||||
|
||||
@@ -297,6 +307,11 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
||||
return errors.New("access_denied")
|
||||
}
|
||||
|
||||
// Redirect URI to verify that it's trusted
|
||||
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
|
||||
return errors.New("invalid_request_uri")
|
||||
}
|
||||
|
||||
// Scopes
|
||||
scopes := strings.Split(req.Scope, " ")
|
||||
|
||||
@@ -318,11 +333,6 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
||||
return errors.New("unsupported_response_type")
|
||||
}
|
||||
|
||||
// Redirect URI
|
||||
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
|
||||
return errors.New("invalid_request_uri")
|
||||
}
|
||||
|
||||
// PKCE code challenge method if set
|
||||
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
|
||||
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
|
||||
@@ -455,7 +465,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
|
||||
hasher := sha256.New()
|
||||
|
||||
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
|
||||
der := x509.MarshalPKCS1PublicKey(service.publicKey)
|
||||
|
||||
if der == nil {
|
||||
return "", errors.New("failed to marshal public key")
|
||||
@@ -813,7 +823,7 @@ func (service *OIDCService) cleanupRoutine() {
|
||||
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||
hasher := sha256.New()
|
||||
|
||||
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
|
||||
der := x509.MarshalPKCS1PublicKey(service.publicKey)
|
||||
|
||||
if der == nil {
|
||||
return nil, errors.New("failed to marshal public key")
|
||||
@@ -822,13 +832,13 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||
hasher.Write(der)
|
||||
|
||||
jwk := jose.JSONWebKey{
|
||||
Key: service.privateKey,
|
||||
Key: service.publicKey,
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
KeyID: base64.URLEncoding.EncodeToString(hasher.Sum(nil)),
|
||||
}
|
||||
|
||||
return jwk.Public().MarshalJSON()
|
||||
return jwk.MarshalJSON()
|
||||
}
|
||||
|
||||
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
|
||||
|
||||
Reference in New Issue
Block a user