fix: review comments

This commit is contained in:
Stavros
2026-01-24 16:16:26 +02:00
parent 71bc3966bc
commit cf1a613229
10 changed files with 124 additions and 117 deletions

View File

@@ -4,6 +4,7 @@ import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"database/sql"
"encoding/pem"
@@ -245,8 +246,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
// Insert the code into the database
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Sub: sub,
Code: code,
Sub: sub,
CodeHash: service.Hash(code),
// Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI,
@@ -288,8 +289,8 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
return nil
}
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, code)
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@@ -299,7 +300,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
}
if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, code)
err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil {
return repository.OidcCode{}, err
}
@@ -360,10 +361,10 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
}
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub,
AccessToken: accessToken,
Scope: scope,
ExpiresAt: expiresAt,
Sub: sub,
AccessTokenHash: service.Hash(accessToken),
Scope: scope,
ExpiresAt: expiresAt,
})
if err != nil {
@@ -373,20 +374,20 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
return tokenResponse, nil
}
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
return service.queries.DeleteOidcCode(c, code)
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcCode(c, codeHash)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
return service.queries.DeleteOidcToken(c, token)
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, tokenHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, token)
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if err == sql.ErrNoRows {
@@ -396,7 +397,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (reposi
}
if entry.ExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, token)
err := service.DeleteToken(c, tokenHash)
if err != nil {
return repository.OidcToken{}, err
}
@@ -436,3 +437,25 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
return userInfo
}
func (service *OIDCService) Hash(token string) string {
hasher := sha256.New()
hasher.Write([]byte(token))
return fmt.Sprintf("%x", hasher.Sum(nil))
}
func (service *OIDCService) CleanupOldSessions(c *gin.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(c, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcTokenBySub(c, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcUserInfo(c, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
return nil
}