mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-27 09:42:30 +00:00
fix: review comments
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"redirect_uri" TEXT NOT NULL,
|
"redirect_uri" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
@@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"expires_at" INTEGER NOT NULL
|
"expires_at" INTEGER NOT NULL
|
||||||
|
|||||||
@@ -134,6 +134,13 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
sub := utils.GenerateUUID(userContext.Username)
|
sub := utils.GenerateUUID(userContext.Username)
|
||||||
code := rand.Text()
|
code := rand.Text()
|
||||||
|
|
||||||
|
// Before storing the code, clean up old sessions
|
||||||
|
err = controller.oidc.CleanupOldSessions(c, sub)
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, err, "Failed to clean up old sessions", "Failed to clean up old sessions", req.RedirectURI, "server_error", req.State)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err = controller.oidc.StoreCode(c, sub, code, req)
|
err = controller.oidc.StoreCode(c, sub, code, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -215,7 +222,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := controller.oidc.GetCodeEntry(c, req.Code)
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrCodeExpired) {
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
||||||
@@ -256,7 +263,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = controller.oidc.DeleteCodeEntry(c, entry.Code)
|
err = controller.oidc.DeleteCodeEntry(c, entry.CodeHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
||||||
@@ -290,7 +297,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entry, err := controller.oidc.GetAccessToken(c, token)
|
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrTokenNotFound {
|
if err == service.ErrTokenNotFound {
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// There is no point in trying to get credentials if it's an OIDC endpoint
|
// There is no point in trying to get credentials if it's an OIDC endpoint
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
if slices.Contains(OIDCIgnorePaths, path) {
|
if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) {
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ package repository
|
|||||||
|
|
||||||
type OidcCode struct {
|
type OidcCode struct {
|
||||||
Sub string
|
Sub string
|
||||||
Code string
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
ClientID string
|
ClientID string
|
||||||
@@ -14,11 +14,11 @@ type OidcCode struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OidcToken struct {
|
type OidcToken struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessToken string
|
AccessTokenHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
ExpiresAt int64
|
ExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcUserinfo struct {
|
type OidcUserinfo struct {
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
const createOidcCode = `-- name: CreateOidcCode :one
|
const createOidcCode = `-- name: CreateOidcCode :one
|
||||||
INSERT INTO "oidc_codes" (
|
INSERT INTO "oidc_codes" (
|
||||||
"sub",
|
"sub",
|
||||||
"code",
|
"code_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"redirect_uri",
|
"redirect_uri",
|
||||||
"client_id",
|
"client_id",
|
||||||
@@ -20,12 +20,12 @@ INSERT INTO "oidc_codes" (
|
|||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING sub, code, scope, redirect_uri, client_id, expires_at
|
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOidcCodeParams struct {
|
type CreateOidcCodeParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
Code string
|
CodeHash string
|
||||||
Scope string
|
Scope string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
ClientID string
|
ClientID string
|
||||||
@@ -35,7 +35,7 @@ type CreateOidcCodeParams struct {
|
|||||||
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOidcCode,
|
row := q.db.QueryRowContext(ctx, createOidcCode,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.Code,
|
arg.CodeHash,
|
||||||
arg.Scope,
|
arg.Scope,
|
||||||
arg.RedirectURI,
|
arg.RedirectURI,
|
||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
@@ -44,7 +44,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
|
|||||||
var i OidcCode
|
var i OidcCode
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.Code,
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.RedirectURI,
|
&i.RedirectURI,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
@@ -56,28 +56,28 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
|
|||||||
const createOidcToken = `-- name: CreateOidcToken :one
|
const createOidcToken = `-- name: CreateOidcToken :one
|
||||||
INSERT INTO "oidc_tokens" (
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token",
|
"access_token_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"client_id",
|
"client_id",
|
||||||
"expires_at"
|
"expires_at"
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?
|
?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING sub, access_token, scope, client_id, expires_at
|
RETURNING sub, access_token_hash, scope, client_id, expires_at
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateOidcTokenParams struct {
|
type CreateOidcTokenParams struct {
|
||||||
Sub string
|
Sub string
|
||||||
AccessToken string
|
AccessTokenHash string
|
||||||
Scope string
|
Scope string
|
||||||
ClientID string
|
ClientID string
|
||||||
ExpiresAt int64
|
ExpiresAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, createOidcToken,
|
row := q.db.QueryRowContext(ctx, createOidcToken,
|
||||||
arg.Sub,
|
arg.Sub,
|
||||||
arg.AccessToken,
|
arg.AccessTokenHash,
|
||||||
arg.Scope,
|
arg.Scope,
|
||||||
arg.ClientID,
|
arg.ClientID,
|
||||||
arg.ExpiresAt,
|
arg.ExpiresAt,
|
||||||
@@ -85,7 +85,7 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams
|
|||||||
var i OidcToken
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessToken,
|
&i.AccessTokenHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.ExpiresAt,
|
&i.ExpiresAt,
|
||||||
@@ -139,21 +139,41 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
|
|||||||
|
|
||||||
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
||||||
DELETE FROM "oidc_codes"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "code" = ?
|
WHERE "code_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error {
|
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOidcCode, code)
|
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
const deleteOidcToken = `-- name: DeleteOidcToken :exec
|
||||||
DELETE FROM "oidc_tokens"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?
|
WHERE "access_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error {
|
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken)
|
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,16 +188,16 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getOidcCode = `-- name: GetOidcCode :one
|
const getOidcCode = `-- name: GetOidcCode :one
|
||||||
SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||||
WHERE "code" = ?
|
WHERE "code_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) {
|
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOidcCode, code)
|
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
|
||||||
var i OidcCode
|
var i OidcCode
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.Code,
|
&i.CodeHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.RedirectURI,
|
&i.RedirectURI,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
@@ -187,16 +207,16 @@ func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getOidcToken = `-- name: GetOidcToken :one
|
const getOidcToken = `-- name: GetOidcToken :one
|
||||||
SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens"
|
SELECT sub, access_token_hash, scope, client_id, expires_at FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?
|
WHERE "access_token_hash" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) {
|
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getOidcToken, accessToken)
|
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
|
||||||
var i OidcToken
|
var i OidcToken
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Sub,
|
&i.Sub,
|
||||||
&i.AccessToken,
|
&i.AccessTokenHash,
|
||||||
&i.Scope,
|
&i.Scope,
|
||||||
&i.ClientID,
|
&i.ClientID,
|
||||||
&i.ExpiresAt,
|
&i.ExpiresAt,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
@@ -245,8 +246,8 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
|
|
||||||
// Insert the code into the database
|
// Insert the code into the database
|
||||||
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
Code: code,
|
CodeHash: service.Hash(code),
|
||||||
// Here it's safe to split and trust the output since, we validated the scopes before
|
// Here it's safe to split and trust the output since, we validated the scopes before
|
||||||
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
||||||
RedirectURI: req.RedirectURI,
|
RedirectURI: req.RedirectURI,
|
||||||
@@ -288,8 +289,8 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
|
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
oidcCode, err := service.queries.GetOidcCode(c, code)
|
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
@@ -299,7 +300,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repositor
|
|||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().Unix() > oidcCode.ExpiresAt {
|
if time.Now().Unix() > oidcCode.ExpiresAt {
|
||||||
err = service.queries.DeleteOidcCode(c, code)
|
err = service.queries.DeleteOidcCode(c, codeHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcCode{}, err
|
return repository.OidcCode{}, err
|
||||||
}
|
}
|
||||||
@@ -360,10 +361,10 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
AccessToken: accessToken,
|
AccessTokenHash: service.Hash(accessToken),
|
||||||
Scope: scope,
|
Scope: scope,
|
||||||
ExpiresAt: expiresAt,
|
ExpiresAt: expiresAt,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -373,20 +374,20 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
|
|||||||
return tokenResponse, nil
|
return tokenResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
|
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
|
||||||
return service.queries.DeleteOidcCode(c, code)
|
return service.queries.DeleteOidcCode(c, codeHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
||||||
return service.queries.DeleteOidcUserInfo(c, sub)
|
return service.queries.DeleteOidcUserInfo(c, sub)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
|
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
|
||||||
return service.queries.DeleteOidcToken(c, token)
|
return service.queries.DeleteOidcToken(c, tokenHash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
|
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
|
||||||
entry, err := service.queries.GetOidcToken(c, token)
|
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
@@ -396,7 +397,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (reposi
|
|||||||
}
|
}
|
||||||
|
|
||||||
if entry.ExpiresAt < time.Now().Unix() {
|
if entry.ExpiresAt < time.Now().Unix() {
|
||||||
err := service.DeleteToken(c, token)
|
err := service.DeleteToken(c, tokenHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.OidcToken{}, err
|
return repository.OidcToken{}, err
|
||||||
}
|
}
|
||||||
@@ -436,3 +437,25 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
|
|||||||
|
|
||||||
return userInfo
|
return userInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) Hash(token string) string {
|
||||||
|
hasher := sha256.New()
|
||||||
|
hasher.Write([]byte(token))
|
||||||
|
return fmt.Sprintf("%x", hasher.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) CleanupOldSessions(c *gin.Context, sub string) error {
|
||||||
|
err := service.queries.DeleteOidcCodeBySub(c, sub)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = service.queries.DeleteOidcTokenBySub(c, sub)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = service.queries.DeleteOidcUserInfo(c, sub)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -108,28 +105,3 @@ func GenerateUUID(str string) string {
|
|||||||
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
|
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
|
||||||
return uuid.String()
|
return uuid.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// These could definitely be improved A LOT but at least they are cryptographically secure
|
|
||||||
func GetRandomString(length int) (string, error) {
|
|
||||||
if length < 1 {
|
|
||||||
return "", errors.New("length must be greater than 0")
|
|
||||||
}
|
|
||||||
b := make([]byte, length)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
state := base64.RawURLEncoding.EncodeToString(b)
|
|
||||||
return state[:length], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetRandomInt(length int) (int64, error) {
|
|
||||||
if length < 1 {
|
|
||||||
return 0, errors.New("length must be greater than 0")
|
|
||||||
}
|
|
||||||
a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return a.Int64(), nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package utils_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
@@ -148,25 +147,3 @@ func TestGenerateUUID(t *testing.T) {
|
|||||||
id3 := utils.GenerateUUID("differentstring")
|
id3 := utils.GenerateUUID("differentstring")
|
||||||
assert.Assert(t, id1 != id3)
|
assert.Assert(t, id1 != id3)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRandomString(t *testing.T) {
|
|
||||||
// Test with normal length
|
|
||||||
state, err := utils.GetRandomString(16)
|
|
||||||
assert.NilError(t, err)
|
|
||||||
assert.Equal(t, 16, len(state))
|
|
||||||
|
|
||||||
// Test with zero length
|
|
||||||
state, err = utils.GetRandomString(0)
|
|
||||||
assert.Error(t, err, "length must be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetRandomInt(t *testing.T) {
|
|
||||||
// Test with normal length
|
|
||||||
state, err := utils.GetRandomInt(16)
|
|
||||||
assert.NilError(t, err)
|
|
||||||
assert.Equal(t, 16, len(strconv.Itoa(int(state))))
|
|
||||||
|
|
||||||
// Test with zero length
|
|
||||||
state, err = utils.GetRandomInt(0)
|
|
||||||
assert.Error(t, err, "length must be greater than 0")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
-- name: CreateOidcCode :one
|
-- name: CreateOidcCode :one
|
||||||
INSERT INTO "oidc_codes" (
|
INSERT INTO "oidc_codes" (
|
||||||
"sub",
|
"sub",
|
||||||
"code",
|
"code_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"redirect_uri",
|
"redirect_uri",
|
||||||
"client_id",
|
"client_id",
|
||||||
@@ -13,16 +13,20 @@ RETURNING *;
|
|||||||
|
|
||||||
-- name: DeleteOidcCode :exec
|
-- name: DeleteOidcCode :exec
|
||||||
DELETE FROM "oidc_codes"
|
DELETE FROM "oidc_codes"
|
||||||
WHERE "code" = ?;
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcCodeBySub :exec
|
||||||
|
DELETE FROM "oidc_codes"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: GetOidcCode :one
|
-- name: GetOidcCode :one
|
||||||
SELECT * FROM "oidc_codes"
|
SELECT * FROM "oidc_codes"
|
||||||
WHERE "code" = ?;
|
WHERE "code_hash" = ?;
|
||||||
|
|
||||||
-- name: CreateOidcToken :one
|
-- name: CreateOidcToken :one
|
||||||
INSERT INTO "oidc_tokens" (
|
INSERT INTO "oidc_tokens" (
|
||||||
"sub",
|
"sub",
|
||||||
"access_token",
|
"access_token_hash",
|
||||||
"scope",
|
"scope",
|
||||||
"client_id",
|
"client_id",
|
||||||
"expires_at"
|
"expires_at"
|
||||||
@@ -33,11 +37,15 @@ RETURNING *;
|
|||||||
|
|
||||||
-- name: DeleteOidcToken :exec
|
-- name: DeleteOidcToken :exec
|
||||||
DELETE FROM "oidc_tokens"
|
DELETE FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?;
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
|
-- name: DeleteOidcTokenBySub :exec
|
||||||
|
DELETE FROM "oidc_tokens"
|
||||||
|
WHERE "sub" = ?;
|
||||||
|
|
||||||
-- name: GetOidcToken :one
|
-- name: GetOidcToken :one
|
||||||
SELECT * FROM "oidc_tokens"
|
SELECT * FROM "oidc_tokens"
|
||||||
WHERE "access_token" = ?;
|
WHERE "access_token_hash" = ?;
|
||||||
|
|
||||||
-- name: CreateOidcUserInfo :one
|
-- name: CreateOidcUserInfo :one
|
||||||
INSERT INTO "oidc_userinfo" (
|
INSERT INTO "oidc_userinfo" (
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"redirect_uri" TEXT NOT NULL,
|
"redirect_uri" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
@@ -9,7 +9,7 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
|
|||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
|
||||||
"sub" TEXT NOT NULL UNIQUE,
|
"sub" TEXT NOT NULL UNIQUE,
|
||||||
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||||
"scope" TEXT NOT NULL,
|
"scope" TEXT NOT NULL,
|
||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"expires_at" INTEGER NOT NULL
|
"expires_at" INTEGER NOT NULL
|
||||||
|
|||||||
Reference in New Issue
Block a user