mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 06:18:11 +00:00
test(db): add memory driver tests
This commit is contained in:
@@ -0,0 +1,427 @@
|
|||||||
|
package memory_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ctx = context.Background()
|
||||||
|
|
||||||
|
func TestCreateAndGetSession(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{
|
||||||
|
UUID: "uuid-1",
|
||||||
|
Username: "alice",
|
||||||
|
Expiry: 9999,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "uuid-1", sess.UUID)
|
||||||
|
assert.Equal(t, "alice", sess.Username)
|
||||||
|
|
||||||
|
got, err := s.GetSession(ctx, "uuid-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, sess, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSession_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetSession(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSession(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||||
|
UUID: "uuid-1",
|
||||||
|
Username: "bob",
|
||||||
|
Email: "bob@example.com",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "bob", updated.Username)
|
||||||
|
assert.Equal(t, "bob@example.com", updated.Email)
|
||||||
|
|
||||||
|
got, err := s.GetSession(ctx, "uuid-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, updated, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSession_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteSession(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteSession(ctx, "uuid-1"))
|
||||||
|
|
||||||
|
_, err = s.GetSession(ctx, "uuid-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteExpiredSessions(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteExpiredSessions(ctx, 50))
|
||||||
|
|
||||||
|
_, err = s.GetSession(ctx, "expired")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
_, err = s.GetSession(ctx, "valid")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndGetOidcCode(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
CodeHash: "hash-1",
|
||||||
|
Scope: "openid",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", code.Sub)
|
||||||
|
|
||||||
|
// destructive read removes the record
|
||||||
|
got, err := s.GetOidcCode(ctx, "hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, code, got)
|
||||||
|
|
||||||
|
_, err = s.GetOidcCode(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCode_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcCode(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCodeBySub(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
|
||||||
|
// destructive — gone after read
|
||||||
|
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCodeBySub_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCodeUnsafe(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
|
||||||
|
// non-destructive — still present
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCodeUnsafe_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCodeBySubUnsafe(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hash-1", got.CodeHash)
|
||||||
|
|
||||||
|
// non-destructive — still present
|
||||||
|
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteOidcCode(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteOidcCodeBySub(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteExpiredOidcCodes(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, deleted, 1)
|
||||||
|
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndGetOidcToken(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-hash-1",
|
||||||
|
CodeHash: "code-hash-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", tok.Sub)
|
||||||
|
|
||||||
|
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tok, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcToken_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcToken(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcTokenByRefreshToken(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcTokenBySub(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "at-1", got.AccessTokenHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcTokenBySub_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateOidcTokenByRefreshToken(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
RefreshTokenHash_2: "rt-1",
|
||||||
|
AccessTokenHash: "at-2",
|
||||||
|
RefreshTokenHash: "rt-2",
|
||||||
|
TokenExpiresAt: 200,
|
||||||
|
RefreshTokenExpiresAt: 400,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||||
|
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||||
|
|
||||||
|
// old key gone, new key present
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
got, err := s.GetOidcToken(ctx, "at-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
RefreshTokenHash_2: "missing",
|
||||||
|
})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteOidcToken(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteOidcTokenBySub(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteOidcTokenByCodeHash(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
CodeHash: "code-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteExpiredOidcTokens(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
// expired by TokenExpiresAt
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1", AccessTokenHash: "at-1",
|
||||||
|
TokenExpiresAt: 10, RefreshTokenExpiresAt: 100,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
// expired by RefreshTokenExpiresAt
|
||||||
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-2", AccessTokenHash: "at-2",
|
||||||
|
TokenExpiresAt: 100, RefreshTokenExpiresAt: 10,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
// valid
|
||||||
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-3", AccessTokenHash: "at-3",
|
||||||
|
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
|
TokenExpiresAt: 50,
|
||||||
|
RefreshTokenExpiresAt: 50,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, deleted, 2)
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateAndGetOidcUserInfo(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
Name: "Alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", u.Sub)
|
||||||
|
|
||||||
|
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, u, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOidcUserInfo_NotFound(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.GetOidcUserInfo(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteOidcUserInfo(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user