mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-02 17:40:14 +00:00
Compare commits
7 Commits
faa3156672
...
b3c152fa1c
| Author | SHA1 | Date | |
|---|---|---|---|
| b3c152fa1c | |||
| 5caee887de | |||
| b5770ef305 | |||
| 1c4ca8f436 | |||
| a72300484b | |||
| 4fe5de241b | |||
| 83ed9ece57 |
@@ -15,14 +15,15 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
|
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||||
switch app.config.Database.Driver {
|
switch app.config.Database.Driver {
|
||||||
// case "memory":
|
case "memory":
|
||||||
// return memory.New(), nil
|
return memory.New(), nil
|
||||||
case "sqlite", "":
|
case "sqlite", "":
|
||||||
return app.setupSQLite(app.config.Database.Path)
|
return app.setupSQLite(app.config.Database.Path)
|
||||||
case "postgres":
|
case "postgres":
|
||||||
|
|||||||
@@ -327,6 +327,21 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
|
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
// ensure no code reuse
|
||||||
|
usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code))
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
controller.log.App.Warn().Msg("Code reuse detected")
|
||||||
|
err := controller.oidc.DeleteSessionBySub(c, usedCodeSub)
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code")
|
||||||
|
}
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "invalid_grant",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
controller.log.App.Warn().Msg("Code not found")
|
controller.log.App.Warn().Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
@@ -334,6 +349,9 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mark code as used to prevent reuse
|
||||||
|
controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub)
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
controller.log.App.Warn().Msg("Redirect URI does not match")
|
controller.log.App.Warn().Msg("Redirect URI does not match")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
package memory_test
|
package memory_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -105,366 +101,182 @@ func TestMemoryStore(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Create and get OIDC code",
|
description: "Create and get OIDC session",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
CodeHash: "hash-1",
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
Scope: "openid",
|
Scope: "openid",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", code.Sub)
|
assert.Equal(t, "sub-1", sess.Sub)
|
||||||
|
|
||||||
// destructive read removes the record
|
got, err := s.GetOIDCSessionBySub(ctx, "sub-1")
|
||||||
got, err := s.GetOidcCode(ctx, "hash-1")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, code, got)
|
assert.Equal(t, sess, got)
|
||||||
|
},
|
||||||
_, err = s.GetOidcCode(ctx, "hash-1")
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC session by sub not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOIDCSessionBySub(ctx, "missing")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC code not found",
|
description: "Get OIDC session by access token hash",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.GetOidcCode(ctx, "missing")
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
|
||||||
|
|
||||||
// destructive — gone after read
|
|
||||||
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code unsafe",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
|
||||||
|
|
||||||
// non-destructive — still present
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code unsafe not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub unsafe",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "hash-1", got.CodeHash)
|
|
||||||
|
|
||||||
// non-destructive — still present
|
|
||||||
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC code by sub unsafe not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC code unique sub constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC code",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC code by sub",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete expired OIDC codes",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, deleted, 1)
|
|
||||||
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
|
||||||
|
|
||||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create and get OIDC token",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-hash-1",
|
|
||||||
CodeHash: "code-hash-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", tok.Sub)
|
|
||||||
|
|
||||||
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tok, got)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC token not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcToken(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Create OIDC token unique sub constraint",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
|
||||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC token by refresh token",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
AccessTokenHash: "at-1",
|
||||||
RefreshTokenHash: "rt-1",
|
RefreshTokenHash: "rt-1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC token by refresh token not found",
|
description: "Get OIDC session by access token hash not found",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
_, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Get OIDC token by sub",
|
description: "Get OIDC session by refresh token hash",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "at-1", got.AccessTokenHash)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC token by sub not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Update OIDC token by refresh token",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-1",
|
AccessTokenHash: "at-1",
|
||||||
RefreshTokenHash: "rt-1",
|
RefreshTokenHash: "rt-1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1")
|
||||||
RefreshTokenHash_2: "rt-1",
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC session by refresh token hash not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC session unique sub constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC session unique access token hash constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC session unique refresh token hash constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC session",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
|
||||||
|
Sub: "sub-1",
|
||||||
AccessTokenHash: "at-2",
|
AccessTokenHash: "at-2",
|
||||||
RefreshTokenHash: "rt-2",
|
RefreshTokenHash: "rt-2",
|
||||||
|
Scope: "openid profile",
|
||||||
TokenExpiresAt: 200,
|
TokenExpiresAt: 200,
|
||||||
RefreshTokenExpiresAt: 400,
|
RefreshTokenExpiresAt: 400,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||||
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||||
|
assert.Equal(t, "openid profile", updated.Scope)
|
||||||
|
|
||||||
// old key gone, new key present
|
// updated token hashes are now queryable, old ones are gone
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
|
|
||||||
got, err := s.GetOidcToken(ctx, "at-2")
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "sub-1", got.Sub)
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
},
|
|
||||||
},
|
_, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
|
||||||
{
|
|
||||||
description: "Update OIDC token by refresh token not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
|
||||||
RefreshTokenHash_2: "missing",
|
|
||||||
})
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Delete OIDC token",
|
description: "Update OIDC session not found",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
_, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC session by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Delete OIDC token by sub",
|
description: "Delete expired OIDC sessions",
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC token by code hash",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
AccessTokenHash: "at-1",
|
|
||||||
CodeHash: "code-1",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete expired OIDC tokens",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
run: func(t *testing.T, s repository.Store) {
|
||||||
// both expiries past
|
// both expiries past
|
||||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-1", AccessTokenHash: "at-1",
|
Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1",
|
||||||
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// valid
|
// valid
|
||||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||||
Sub: "sub-3", AccessTokenHash: "at-3",
|
Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2",
|
||||||
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
|
||||||
TokenExpiresAt: 50,
|
TokenExpiresAt: 50,
|
||||||
RefreshTokenExpiresAt: 50,
|
RefreshTokenExpiresAt: 50,
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, deleted, 1)
|
|
||||||
|
|
||||||
_, err = s.GetOidcToken(ctx, "at-3")
|
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
_, err = s.GetOIDCSessionBySub(ctx, "sub-2")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Create and get OIDC user info",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
|
||||||
Sub: "sub-1",
|
|
||||||
Name: "Alice",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "sub-1", u.Sub)
|
|
||||||
|
|
||||||
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, u, got)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Get OIDC user info not found",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.GetOidcUserInfo(ctx, "missing")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Delete OIDC user info",
|
|
||||||
run: func(t *testing.T, s repository.Store) {
|
|
||||||
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
|
||||||
|
|
||||||
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
|
||||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -11,235 +7,90 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
// Enforce sub UNIQUE constraint
|
// Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique).
|
||||||
for _, c := range s.oidcCodes {
|
for _, sess := range s.oidcSessions {
|
||||||
if c.Sub == arg.Sub {
|
switch {
|
||||||
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
case sess.Sub == arg.Sub:
|
||||||
|
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub")
|
||||||
|
case sess.AccessTokenHash == arg.AccessTokenHash:
|
||||||
|
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash")
|
||||||
|
case sess.RefreshTokenHash == arg.RefreshTokenHash:
|
||||||
|
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
code := repository.OidcCode(arg)
|
sess := repository.OidcSession(arg)
|
||||||
s.oidcCodes[arg.CodeHash] = code
|
s.oidcSessions[arg.Sub] = sess
|
||||||
return code, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) {
|
||||||
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
s.mu.RLock()
|
||||||
s.mu.Lock()
|
defer s.mu.RUnlock()
|
||||||
defer s.mu.Unlock()
|
sess, ok := s.oidcSessions[sub]
|
||||||
c, ok := s.oidcCodes[codeHash]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
delete(s.oidcCodes, codeHash)
|
return sess, nil
|
||||||
return c, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) {
|
||||||
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, c := range s.oidcCodes {
|
|
||||||
if c.Sub == sub {
|
|
||||||
delete(s.oidcCodes, k)
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
|
||||||
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
c, ok := s.oidcCodes[codeHash]
|
for _, sess := range s.oidcSessions {
|
||||||
|
if sess.AccessTokenHash == accessTokenHash {
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, sess := range s.oidcSessions {
|
||||||
|
if sess.RefreshTokenHash == refreshTokenHash {
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
sess, ok := s.oidcSessions[arg.Sub]
|
||||||
if !ok {
|
if !ok {
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
return repository.OidcSession{}, repository.ErrNotFound
|
||||||
}
|
}
|
||||||
return c, nil
|
sess.AccessTokenHash = arg.AccessTokenHash
|
||||||
|
sess.RefreshTokenHash = arg.RefreshTokenHash
|
||||||
|
sess.Scope = arg.Scope
|
||||||
|
sess.ClientID = arg.ClientID
|
||||||
|
sess.TokenExpiresAt = arg.TokenExpiresAt
|
||||||
|
sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||||
|
sess.Nonce = arg.Nonce
|
||||||
|
sess.UserinfoJson = arg.UserinfoJson
|
||||||
|
s.oidcSessions[arg.Sub] = sess
|
||||||
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error {
|
||||||
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, c := range s.oidcCodes {
|
|
||||||
if c.Sub == sub {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcCode{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
delete(s.oidcCodes, codeHash)
|
delete(s.oidcSessions, sub)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
for k, c := range s.oidcCodes {
|
for k, sess := range s.oidcSessions {
|
||||||
if c.Sub == sub {
|
if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||||
delete(s.oidcCodes, k)
|
delete(s.oidcSessions, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
var deleted []repository.OidcCode
|
|
||||||
for k, c := range s.oidcCodes {
|
|
||||||
if c.ExpiresAt < expiresAt {
|
|
||||||
deleted = append(deleted, c)
|
|
||||||
delete(s.oidcCodes, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return deleted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
// Enforce sub UNIQUE constraint
|
|
||||||
for _, t := range s.oidcTokens {
|
|
||||||
if t.Sub == arg.Sub {
|
|
||||||
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tok := repository.OidcToken{
|
|
||||||
Sub: arg.Sub,
|
|
||||||
AccessTokenHash: arg.AccessTokenHash,
|
|
||||||
RefreshTokenHash: arg.RefreshTokenHash,
|
|
||||||
CodeHash: arg.CodeHash,
|
|
||||||
Scope: arg.Scope,
|
|
||||||
ClientID: arg.ClientID,
|
|
||||||
TokenExpiresAt: arg.TokenExpiresAt,
|
|
||||||
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
|
||||||
Nonce: arg.Nonce,
|
|
||||||
}
|
|
||||||
s.oidcTokens[arg.AccessTokenHash] = tok
|
|
||||||
return tok, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
t, ok := s.oidcTokens[accessTokenHash]
|
|
||||||
if !ok {
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, t := range s.oidcTokens {
|
|
||||||
if t.RefreshTokenHash == refreshTokenHash {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
for _, t := range s.oidcTokens {
|
|
||||||
if t.Sub == sub {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
t.AccessTokenHash = arg.AccessTokenHash
|
|
||||||
t.RefreshTokenHash = arg.RefreshTokenHash
|
|
||||||
t.TokenExpiresAt = arg.TokenExpiresAt
|
|
||||||
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
|
||||||
s.oidcTokens[arg.AccessTokenHash] = t
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return repository.OidcToken{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
delete(s.oidcTokens, accessTokenHash)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.Sub == sub {
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.CodeHash == codeHash {
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
var deleted []repository.OidcToken
|
|
||||||
for k, t := range s.oidcTokens {
|
|
||||||
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
|
||||||
deleted = append(deleted, t)
|
|
||||||
delete(s.oidcTokens, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return deleted, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
u := repository.OidcUserinfo(arg)
|
|
||||||
s.oidcUsers[arg.Sub] = u
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
u, ok := s.oidcUsers[sub]
|
|
||||||
if !ok {
|
|
||||||
return repository.OidcUserinfo{}, repository.ErrNotFound
|
|
||||||
}
|
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
delete(s.oidcUsers, sub)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
//go:build exclude
|
|
||||||
|
|
||||||
// temporary
|
|
||||||
|
|
||||||
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: oidc_queries.sql
|
// source: oidc_queries.sql
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: session_queries.sql
|
// source: session_queries.sql
|
||||||
|
|
||||||
package postgres
|
package postgres
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: oidc_queries.sql
|
// source: oidc_queries.sql
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: session_queries.sql
|
// source: session_queries.sql
|
||||||
|
|
||||||
package sqlite
|
package sqlite
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type GithubEmailResponse []struct {
|
|||||||
Verified bool `json:"verified"`
|
Verified bool `json:"verified"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GithubUserInfoResponse struct {
|
type GithubUserinfoResponse struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
@@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
|
|||||||
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
||||||
var user model.Claims
|
var user model.Claims
|
||||||
|
|
||||||
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
||||||
|
|
||||||
type OAuthService struct {
|
type OAuthService struct {
|
||||||
serviceCfg model.OAuthServiceConfig
|
serviceCfg model.OAuthServiceConfig
|
||||||
config *oauth2.Config
|
config *oauth2.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
userinfoExtractor UserinfoExtractor
|
userinfoExtractor OAuthUserinfoExtractor
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
|
func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService {
|
||||||
s.userinfoExtractor = extractor
|
s.userinfoExtractor = extractor
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -126,6 +126,10 @@ type AuthorizeCodeEntry struct {
|
|||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UsedCodeEntry struct {
|
||||||
|
Sub string
|
||||||
|
}
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
@@ -139,6 +143,7 @@ type OIDCService struct {
|
|||||||
|
|
||||||
caches struct {
|
caches struct {
|
||||||
code *CacheStore[AuthorizeCodeEntry]
|
code *CacheStore[AuthorizeCodeEntry]
|
||||||
|
usedCode *CacheStore[UsedCodeEntry]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,11 +306,13 @@ func NewOIDCService(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
// dg.Go(service.cleanupRoutine, ding.RingMinor)
|
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||||
|
|
||||||
// Create caches
|
// Create caches
|
||||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||||
|
usedCode := NewCacheStore[UsedCodeEntry](256)
|
||||||
service.caches.code = codeCash
|
service.caches.code = codeCash
|
||||||
|
service.caches.usedCode = usedCode
|
||||||
|
|
||||||
// Start cache cleanup routine
|
// Start cache cleanup routine
|
||||||
dg.Go(func(ctx context.Context) {
|
dg.Go(func(ctx context.Context) {
|
||||||
@@ -316,6 +323,7 @@ func NewOIDCService(
|
|||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
service.caches.code.Sweep()
|
service.caches.code.Sweep()
|
||||||
|
service.caches.usedCode.Sweep()
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -406,7 +414,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store the code in the cache
|
// Store the code in the cache
|
||||||
service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute)
|
service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute)
|
||||||
|
|
||||||
return code
|
return code
|
||||||
}
|
}
|
||||||
@@ -676,7 +684,7 @@ func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash str
|
|||||||
// since there is no way for the client to access anything anymore
|
// since there is no way for the client to access anything anymore
|
||||||
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||||
// Deletes by sub
|
// Deletes by sub
|
||||||
err := service.queries.DeleteSession(ctx, entry.Sub)
|
err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -747,68 +755,35 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// // Cleanup routine - Resource heavy due to the linked tables
|
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||||
// func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||||
// service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
// ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
defer ticker.Stop()
|
||||||
// defer ticker.Stop()
|
|
||||||
|
|
||||||
// for {
|
for {
|
||||||
// select {
|
select {
|
||||||
// case <-ticker.C:
|
case <-ticker.C:
|
||||||
// service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
||||||
|
|
||||||
// currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
// // For the OIDC tokens, if they are expired we delete the userinfo and codes
|
// Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry
|
||||||
// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
|
||||||
// TokenExpiresAt: currentTime,
|
TokenExpiresAt: currentTime,
|
||||||
// RefreshTokenExpiresAt: currentTime,
|
RefreshTokenExpiresAt: currentTime,
|
||||||
// })
|
})
|
||||||
|
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions")
|
||||||
// }
|
}
|
||||||
|
|
||||||
// for _, expiredToken := range expiredTokens {
|
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||||
// err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
case <-ctx.Done():
|
||||||
// if err != nil {
|
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
return
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
|
}
|
||||||
// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
|
||||||
// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
|
||||||
// }
|
|
||||||
|
|
||||||
// for _, expiredCode := range expiredCodes {
|
|
||||||
// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// if !errors.Is(err, repository.ErrNotFound) {
|
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
|
||||||
// }
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
|
||||||
// err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
|
||||||
// if err != nil {
|
|
||||||
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
|
||||||
// case <-ctx.Done():
|
|
||||||
// service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (service *OIDCService) GetJWK() ([]byte, error) {
|
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
@@ -850,3 +825,24 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
|
|||||||
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
|
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
|
||||||
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
|
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) {
|
||||||
|
entry, ok := service.caches.usedCode.Get(codeHash)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.Sub, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
|
||||||
|
entry := UsedCodeEntry{
|
||||||
|
Sub: sub,
|
||||||
|
}
|
||||||
|
service.caches.usedCode.Set(codeHash, entry, 2*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
|
||||||
|
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package service_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
@@ -10,28 +9,17 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() repository.OidcUserinfo {
|
func newTestUser() service.UserinfoResponse {
|
||||||
addr := model.AddressClaim{
|
return service.UserinfoResponse{
|
||||||
Formatted: "123 Main St",
|
|
||||||
StreetAddress: "123 Main St",
|
|
||||||
Locality: "Springfield",
|
|
||||||
Region: "IL",
|
|
||||||
PostalCode: "62701",
|
|
||||||
Country: "US",
|
|
||||||
}
|
|
||||||
addrJSON, _ := json.Marshal(addr)
|
|
||||||
|
|
||||||
return repository.OidcUserinfo{
|
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
Groups: "admins,users",
|
Groups: []string{"admins", "users"},
|
||||||
UpdatedAt: 1234567890,
|
UpdatedAt: 1234567890,
|
||||||
GivenName: "Test",
|
GivenName: "Test",
|
||||||
FamilyName: "User",
|
FamilyName: "User",
|
||||||
@@ -45,7 +33,14 @@ func newTestUser() repository.OidcUserinfo {
|
|||||||
Zoneinfo: "America/Chicago",
|
Zoneinfo: "America/Chicago",
|
||||||
Locale: "en-US",
|
Locale: "en-US",
|
||||||
PhoneNumber: "+15555550100",
|
PhoneNumber: "+15555550100",
|
||||||
Address: string(addrJSON),
|
Address: &model.AddressClaim{
|
||||||
|
Formatted: "123 Main St",
|
||||||
|
StreetAddress: "123 Main St",
|
||||||
|
Locality: "Springfield",
|
||||||
|
Region: "IL",
|
||||||
|
PostalCode: "62701",
|
||||||
|
Country: "US",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +72,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *repository.OidcUserinfo)
|
mutate func(u *service.UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info service.UserinfoResponse)
|
||||||
}
|
}
|
||||||
@@ -98,7 +93,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid,profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
@@ -118,7 +113,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid,email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
@@ -127,8 +122,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid,email",
|
scope: "openid email",
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.Email = "" },
|
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
@@ -136,7 +131,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid,phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
@@ -145,8 +140,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid,phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" },
|
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
@@ -154,7 +149,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid,address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
@@ -165,32 +160,16 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
assert.Equal(t, "US", info.Address.Country)
|
assert.Equal(t, "US", info.Address.Country)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "address scope with invalid JSON omits address",
|
|
||||||
scope: "openid,address",
|
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" },
|
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
|
||||||
assert.Nil(t, info.Address)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid,groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "groups scope returns empty slice when no groups",
|
|
||||||
scope: "openid,groups",
|
|
||||||
mutate: func(u *repository.OidcUserinfo) { u.Groups = "" },
|
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
|
||||||
assert.Equal(t, []string{}, info.Groups)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid,profile,email,phone,address,groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
|
|||||||
@@ -6,6 +6,6 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
|
|||||||
"client_id" TEXT NOT NULL,
|
"client_id" TEXT NOT NULL,
|
||||||
"token_expires_at" INTEGER NOT NULL,
|
"token_expires_at" INTEGER NOT NULL,
|
||||||
"refresh_token_expires_at" INTEGER NOT NULL,
|
"refresh_token_expires_at" INTEGER NOT NULL,
|
||||||
"nonce" TEXT DEFAULT "",
|
"nonce" TEXT NOT NULL DEFAULT "",
|
||||||
"userinfo_json" TEXT NOT NULL
|
"userinfo_json" TEXT NOT NULL
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user