refactor memeory_test.go

This commit is contained in:
Scott McKendry
2026-05-17 18:27:18 +12:00
parent ac1ff0a07f
commit 225041126e
+461 -410
View File
@@ -13,415 +13,466 @@ import (
var ctx = context.Background() var ctx = context.Background()
func TestCreateAndGetSession(t *testing.T) { func TestMemoryStore(t *testing.T) {
s := memory.New() type testCase struct {
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{ description string
UUID: "uuid-1", run func(t *testing.T, s repository.Store)
Username: "alice", }
Expiry: 9999,
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", sess.UUID)
assert.Equal(t, "alice", sess.Username)
got, err := s.GetSession(ctx, "uuid-1") tests := []testCase{
require.NoError(t, err) {
assert.Equal(t, sess, got) description: "Create and get session",
} run: func(t *testing.T, s repository.Store) {
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{
func TestGetSession_NotFound(t *testing.T) { UUID: "uuid-1",
s := memory.New() Username: "alice",
_, err := s.GetSession(ctx, "missing") Expiry: 9999,
assert.ErrorIs(t, err, repository.ErrNotFound) })
} require.NoError(t, err)
assert.Equal(t, "uuid-1", sess.UUID)
func TestUpdateSession(t *testing.T) { assert.Equal(t, "alice", sess.Username)
s := memory.New()
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"}) got, err := s.GetSession(ctx, "uuid-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, sess, got)
updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{ },
UUID: "uuid-1", },
Username: "bob", {
Email: "bob@example.com", description: "Get session not found",
}) run: func(t *testing.T, s repository.Store) {
require.NoError(t, err) _, err := s.GetSession(ctx, "missing")
assert.Equal(t, "bob", updated.Username) assert.ErrorIs(t, err, repository.ErrNotFound)
assert.Equal(t, "bob@example.com", updated.Email) },
},
got, err := s.GetSession(ctx, "uuid-1") {
require.NoError(t, err) description: "Update session",
assert.Equal(t, updated, got) run: func(t *testing.T, s repository.Store) {
} _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"})
require.NoError(t, err)
func TestUpdateSession_NotFound(t *testing.T) {
s := memory.New() updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{
_, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"}) UUID: "uuid-1",
assert.ErrorIs(t, err, repository.ErrNotFound) Username: "bob",
} Email: "bob@example.com",
})
func TestDeleteSession(t *testing.T) { require.NoError(t, err)
s := memory.New() assert.Equal(t, "bob", updated.Username)
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"}) assert.Equal(t, "bob@example.com", updated.Email)
require.NoError(t, err)
got, err := s.GetSession(ctx, "uuid-1")
require.NoError(t, s.DeleteSession(ctx, "uuid-1")) require.NoError(t, err)
assert.Equal(t, updated, got)
_, err = s.GetSession(ctx, "uuid-1") },
assert.ErrorIs(t, err, repository.ErrNotFound) },
} {
description: "Update session not found",
func TestDeleteExpiredSessions(t *testing.T) { run: func(t *testing.T, s repository.Store) {
s := memory.New() _, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"})
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10}) assert.ErrorIs(t, err, repository.ErrNotFound)
require.NoError(t, err) },
_, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100}) },
require.NoError(t, err) {
description: "Delete session",
require.NoError(t, s.DeleteExpiredSessions(ctx, 50)) run: func(t *testing.T, s repository.Store) {
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"})
_, err = s.GetSession(ctx, "expired") require.NoError(t, err)
assert.ErrorIs(t, err, repository.ErrNotFound)
require.NoError(t, s.DeleteSession(ctx, "uuid-1"))
_, err = s.GetSession(ctx, "valid")
assert.NoError(t, err) _, err = s.GetSession(ctx, "uuid-1")
} assert.ErrorIs(t, err, repository.ErrNotFound)
},
func TestCreateAndGetOidcCode(t *testing.T) { },
s := memory.New() {
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ description: "Delete expired sessions",
Sub: "sub-1", run: func(t *testing.T, s repository.Store) {
CodeHash: "hash-1", _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10})
Scope: "openid", require.NoError(t, err)
}) _, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100})
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", code.Sub)
require.NoError(t, s.DeleteExpiredSessions(ctx, 50))
// destructive read removes the record
got, err := s.GetOidcCode(ctx, "hash-1") _, err = s.GetSession(ctx, "expired")
require.NoError(t, err) assert.ErrorIs(t, err, repository.ErrNotFound)
assert.Equal(t, code, got)
_, err = s.GetSession(ctx, "valid")
_, err = s.GetOidcCode(ctx, "hash-1") assert.NoError(t, err)
assert.ErrorIs(t, err, repository.ErrNotFound) },
} },
{
func TestGetOidcCode_NotFound(t *testing.T) { description: "Create and get OIDC code",
s := memory.New() run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCode(ctx, "missing") code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
assert.ErrorIs(t, err, repository.ErrNotFound) Sub: "sub-1",
} CodeHash: "hash-1",
Scope: "openid",
func TestGetOidcCodeBySub(t *testing.T) { })
s := memory.New() require.NoError(t, err)
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) assert.Equal(t, "sub-1", code.Sub)
require.NoError(t, err)
// destructive read removes the record
got, err := s.GetOidcCodeBySub(ctx, "sub-1") got, err := s.GetOidcCode(ctx, "hash-1")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub) assert.Equal(t, code, got)
// destructive — gone after read _, err = s.GetOidcCode(ctx, "hash-1")
_, err = s.GetOidcCodeBySub(ctx, "sub-1") assert.ErrorIs(t, err, repository.ErrNotFound)
assert.ErrorIs(t, err, repository.ErrNotFound) },
} },
{
func TestGetOidcCodeBySub_NotFound(t *testing.T) { description: "Get OIDC code not found",
s := memory.New() run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySub(ctx, "missing") _, err := s.GetOidcCode(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
} },
},
func TestGetOidcCodeUnsafe(t *testing.T) { {
s := memory.New() description: "Get OIDC code by sub",
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) run: func(t *testing.T, s repository.Store) {
require.NoError(t, err) _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
require.NoError(t, err) got, err := s.GetOidcCodeBySub(ctx, "sub-1")
assert.Equal(t, "sub-1", got.Sub) require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// non-destructive — still present
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1") // destructive — gone after read
assert.NoError(t, err) _, err = s.GetOidcCodeBySub(ctx, "sub-1")
} assert.ErrorIs(t, err, repository.ErrNotFound)
},
func TestGetOidcCodeUnsafe_NotFound(t *testing.T) { },
s := memory.New() {
_, err := s.GetOidcCodeUnsafe(ctx, "missing") description: "Get OIDC code by sub not found",
assert.ErrorIs(t, err, repository.ErrNotFound) run: func(t *testing.T, s repository.Store) {
} _, err := s.GetOidcCodeBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
func TestGetOidcCodeBySubUnsafe(t *testing.T) { },
s := memory.New() },
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) {
require.NoError(t, err) description: "Get OIDC code unsafe",
run: func(t *testing.T, s repository.Store) {
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "hash-1", got.CodeHash)
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
// non-destructive — still present require.NoError(t, err)
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") assert.Equal(t, "sub-1", got.Sub)
assert.NoError(t, err)
} // non-destructive — still present
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) { assert.NoError(t, err)
s := memory.New() },
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") },
assert.ErrorIs(t, err, repository.ErrNotFound) {
} description: "Get OIDC code unsafe not found",
run: func(t *testing.T, s repository.Store) {
func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) { _, err := s.GetOidcCodeUnsafe(ctx, "missing")
s := memory.New() assert.ErrorIs(t, err, repository.ErrNotFound)
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) },
require.NoError(t, err) },
{
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) description: "Get OIDC code by sub unsafe",
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") run: func(t *testing.T, s repository.Store) {
} _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
func TestDeleteOidcCode(t *testing.T) {
s := memory.New() got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) require.NoError(t, err)
require.NoError(t, err) assert.Equal(t, "hash-1", got.CodeHash)
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) // non-destructive — still present
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1") assert.NoError(t, err)
assert.ErrorIs(t, err, repository.ErrNotFound) },
} },
{
func TestDeleteOidcCodeBySub(t *testing.T) { description: "Get OIDC code by sub unsafe not found",
s := memory.New() run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
require.NoError(t, err) assert.ErrorIs(t, err, repository.ErrNotFound)
},
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) },
{
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1") description: "Create OIDC code unique sub constraint",
assert.ErrorIs(t, err, repository.ErrNotFound) run: func(t *testing.T, s repository.Store) {
} _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
func TestDeleteExpiredOidcCodes(t *testing.T) {
s := memory.New() _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
require.NoError(t, err) },
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) },
require.NoError(t, err) {
description: "Delete OIDC code",
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) run: func(t *testing.T, s repository.Store) {
require.NoError(t, err) _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.Len(t, deleted, 1) require.NoError(t, err)
assert.Equal(t, "hash-1", deleted[0].CodeHash)
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
assert.NoError(t, err) _, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
} assert.ErrorIs(t, err, repository.ErrNotFound)
},
func TestCreateAndGetOidcToken(t *testing.T) { },
s := memory.New() {
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ description: "Delete OIDC code by sub",
Sub: "sub-1", run: func(t *testing.T, s repository.Store) {
AccessTokenHash: "at-hash-1", _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
CodeHash: "code-hash-1", require.NoError(t, err)
})
require.NoError(t, err) require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
assert.Equal(t, "sub-1", tok.Sub)
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
got, err := s.GetOidcToken(ctx, "at-hash-1") assert.ErrorIs(t, err, repository.ErrNotFound)
require.NoError(t, err) },
assert.Equal(t, tok, got) },
} {
description: "Delete expired OIDC codes",
func TestGetOidcToken_NotFound(t *testing.T) { run: func(t *testing.T, s repository.Store) {
s := memory.New() _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
_, err := s.GetOidcToken(ctx, "missing") require.NoError(t, err)
assert.ErrorIs(t, err, repository.ErrNotFound) _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
} require.NoError(t, err)
func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) { deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
s := memory.New() require.NoError(t, err)
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) require.Len(t, deleted, 1)
require.NoError(t, err) assert.Equal(t, "hash-1", deleted[0].CodeHash)
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) _, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") assert.NoError(t, err)
} },
},
func TestGetOidcTokenByRefreshToken(t *testing.T) { {
s := memory.New() description: "Create and get OIDC token",
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ run: func(t *testing.T, s repository.Store) {
Sub: "sub-1", tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
AccessTokenHash: "at-1", Sub: "sub-1",
RefreshTokenHash: "rt-1", AccessTokenHash: "at-hash-1",
}) CodeHash: "code-hash-1",
require.NoError(t, err) })
require.NoError(t, err)
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") assert.Equal(t, "sub-1", tok.Sub)
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub) got, err := s.GetOidcToken(ctx, "at-hash-1")
} require.NoError(t, err)
assert.Equal(t, tok, got)
func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) { },
s := memory.New() },
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing") {
assert.ErrorIs(t, err, repository.ErrNotFound) description: "Get OIDC token not found",
} run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcToken(ctx, "missing")
func TestGetOidcTokenBySub(t *testing.T) { assert.ErrorIs(t, err, repository.ErrNotFound)
s := memory.New() },
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ },
Sub: "sub-1", {
AccessTokenHash: "at-1", description: "Create OIDC token unique sub constraint",
}) run: func(t *testing.T, s repository.Store) {
require.NoError(t, err) _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
require.NoError(t, err) _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
assert.Equal(t, "at-1", got.AccessTokenHash) assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
} },
},
func TestGetOidcTokenBySub_NotFound(t *testing.T) { {
s := memory.New() description: "Get OIDC token by refresh token",
_, err := s.GetOidcTokenBySub(ctx, "missing") run: func(t *testing.T, s repository.Store) {
assert.ErrorIs(t, err, repository.ErrNotFound) _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
} Sub: "sub-1",
AccessTokenHash: "at-1",
func TestUpdateOidcTokenByRefreshToken(t *testing.T) { RefreshTokenHash: "rt-1",
s := memory.New() })
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ require.NoError(t, err)
Sub: "sub-1",
AccessTokenHash: "at-1", got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
RefreshTokenHash: "rt-1", require.NoError(t, err)
}) assert.Equal(t, "sub-1", got.Sub)
require.NoError(t, err) },
},
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ {
RefreshTokenHash_2: "rt-1", description: "Get OIDC token by refresh token not found",
AccessTokenHash: "at-2", run: func(t *testing.T, s repository.Store) {
RefreshTokenHash: "rt-2", _, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
TokenExpiresAt: 200, assert.ErrorIs(t, err, repository.ErrNotFound)
RefreshTokenExpiresAt: 400, },
}) },
require.NoError(t, err) {
assert.Equal(t, "at-2", updated.AccessTokenHash) description: "Get OIDC token by sub",
assert.Equal(t, "rt-2", updated.RefreshTokenHash) run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
// old key gone, new key present Sub: "sub-1",
_, err = s.GetOidcToken(ctx, "at-1") AccessTokenHash: "at-1",
assert.ErrorIs(t, err, repository.ErrNotFound) })
require.NoError(t, err)
got, err := s.GetOidcToken(ctx, "at-2")
require.NoError(t, err) got, err := s.GetOidcTokenBySub(ctx, "sub-1")
assert.Equal(t, "sub-1", got.Sub) require.NoError(t, err)
} assert.Equal(t, "at-1", got.AccessTokenHash)
},
func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) { },
s := memory.New() {
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ description: "Get OIDC token by sub not found",
RefreshTokenHash_2: "missing", run: func(t *testing.T, s repository.Store) {
}) _, err := s.GetOidcTokenBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound) assert.ErrorIs(t, err, repository.ErrNotFound)
} },
},
func TestDeleteOidcToken(t *testing.T) { {
s := memory.New() description: "Update OIDC token by refresh token",
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) run: func(t *testing.T, s repository.Store) {
require.NoError(t, err) _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
_, err = s.GetOidcToken(ctx, "at-1") })
assert.ErrorIs(t, err, repository.ErrNotFound) require.NoError(t, err)
}
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
func TestDeleteOidcTokenBySub(t *testing.T) { RefreshTokenHash_2: "rt-1",
s := memory.New() AccessTokenHash: "at-2",
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) RefreshTokenHash: "rt-2",
require.NoError(t, err) TokenExpiresAt: 200,
RefreshTokenExpiresAt: 400,
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) })
require.NoError(t, err)
_, err = s.GetOidcToken(ctx, "at-1") assert.Equal(t, "at-2", updated.AccessTokenHash)
assert.ErrorIs(t, err, repository.ErrNotFound) assert.Equal(t, "rt-2", updated.RefreshTokenHash)
}
// old key gone, new key present
func TestDeleteOidcTokenByCodeHash(t *testing.T) { _, err = s.GetOidcToken(ctx, "at-1")
s := memory.New() assert.ErrorIs(t, err, repository.ErrNotFound)
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", got, err := s.GetOidcToken(ctx, "at-2")
AccessTokenHash: "at-1", require.NoError(t, err)
CodeHash: "code-1", assert.Equal(t, "sub-1", got.Sub)
}) },
require.NoError(t, err) },
{
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) description: "Update OIDC token by refresh token not found",
run: func(t *testing.T, s repository.Store) {
_, err = s.GetOidcToken(ctx, "at-1") _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
assert.ErrorIs(t, err, repository.ErrNotFound) RefreshTokenHash_2: "missing",
} })
assert.ErrorIs(t, err, repository.ErrNotFound)
func TestDeleteExpiredOidcTokens(t *testing.T) { },
s := memory.New() },
// expired by TokenExpiresAt {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ description: "Delete OIDC token",
Sub: "sub-1", AccessTokenHash: "at-1", run: func(t *testing.T, s repository.Store) {
TokenExpiresAt: 10, RefreshTokenExpiresAt: 100, _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
}) require.NoError(t, err)
require.NoError(t, err)
// expired by RefreshTokenExpiresAt require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-2", AccessTokenHash: "at-2", _, err = s.GetOidcToken(ctx, "at-1")
TokenExpiresAt: 100, RefreshTokenExpiresAt: 10, assert.ErrorIs(t, err, repository.ErrNotFound)
}) },
require.NoError(t, err) },
// valid {
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ description: "Delete OIDC token by sub",
Sub: "sub-3", AccessTokenHash: "at-3", run: func(t *testing.T, s repository.Store) {
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
}) require.NoError(t, err)
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: 50, _, err = s.GetOidcToken(ctx, "at-1")
RefreshTokenExpiresAt: 50, assert.ErrorIs(t, err, repository.ErrNotFound)
}) },
require.NoError(t, err) },
assert.Len(t, deleted, 2) {
description: "Delete OIDC token by code hash",
_, err = s.GetOidcToken(ctx, "at-3") run: func(t *testing.T, s repository.Store) {
assert.NoError(t, err) _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
} Sub: "sub-1",
AccessTokenHash: "at-1",
func TestCreateAndGetOidcUserInfo(t *testing.T) { CodeHash: "code-1",
s := memory.New() })
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ require.NoError(t, err)
Sub: "sub-1",
Name: "Alice", require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
Email: "alice@example.com",
}) _, err = s.GetOidcToken(ctx, "at-1")
require.NoError(t, err) assert.ErrorIs(t, err, repository.ErrNotFound)
assert.Equal(t, "sub-1", u.Sub) },
},
got, err := s.GetOidcUserInfo(ctx, "sub-1") {
require.NoError(t, err) description: "Delete expired OIDC tokens",
assert.Equal(t, u, got) run: func(t *testing.T, s repository.Store) {
} // expired by TokenExpiresAt
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
func TestGetOidcUserInfo_NotFound(t *testing.T) { Sub: "sub-1", AccessTokenHash: "at-1",
s := memory.New() TokenExpiresAt: 10, RefreshTokenExpiresAt: 100,
_, err := s.GetOidcUserInfo(ctx, "missing") })
assert.ErrorIs(t, err, repository.ErrNotFound) require.NoError(t, err)
} // expired by RefreshTokenExpiresAt
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
func TestDeleteOidcUserInfo(t *testing.T) { Sub: "sub-2", AccessTokenHash: "at-2",
s := memory.New() TokenExpiresAt: 100, RefreshTokenExpiresAt: 10,
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) })
require.NoError(t, err) require.NoError(t, err)
// valid
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-3", AccessTokenHash: "at-3",
_, err = s.GetOidcUserInfo(ctx, "sub-1") TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
assert.ErrorIs(t, err, repository.ErrNotFound) })
require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: 50,
RefreshTokenExpiresAt: 50,
})
require.NoError(t, err)
assert.Len(t, deleted, 2)
_, err = s.GetOidcToken(ctx, "at-3")
assert.NoError(t, err)
},
},
{
description: "Create and get OIDC user info",
run: func(t *testing.T, s repository.Store) {
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
Sub: "sub-1",
Name: "Alice",
Email: "alice@example.com",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", u.Sub)
got, err := s.GetOidcUserInfo(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, u, got)
},
},
{
description: "Get OIDC user info not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcUserInfo(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC user info",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
_, err = s.GetOidcUserInfo(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
s := memory.New()
test.run(t, s)
})
}
} }