From 4fe5de241bc3aab8120839d2aa0f57c7153e3f6b Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 1 Jun 2026 11:55:47 +0300 Subject: [PATCH] chore: fix memory store --- internal/repository/memory/memory_test.go | 396 +++++------------- internal/repository/memory/oidc_queries.go | 269 +++--------- internal/repository/memory/session_queries.go | 4 - internal/repository/memory/store.go | 4 - 4 files changed, 164 insertions(+), 509 deletions(-) diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go index 07fee88d..558ed234 100644 --- a/internal/repository/memory/memory_test.go +++ b/internal/repository/memory/memory_test.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - package memory_test import ( @@ -105,366 +101,182 @@ func TestMemoryStore(t *testing.T) { }, }, { - description: "Create and get OIDC code", + description: "Create and get OIDC session", run: func(t *testing.T, s repository.Store) { - code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ - Sub: "sub-1", - CodeHash: "hash-1", - Scope: "openid", + sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + Scope: "openid", }) require.NoError(t, err) - assert.Equal(t, "sub-1", code.Sub) + assert.Equal(t, "sub-1", sess.Sub) - // destructive read removes the record - got, err := s.GetOidcCode(ctx, "hash-1") + got, err := s.GetOIDCSessionBySub(ctx, "sub-1") require.NoError(t, err) - assert.Equal(t, code, got) - - _, err = s.GetOidcCode(ctx, "hash-1") + assert.Equal(t, sess, got) + }, + }, + { + description: "Get OIDC session by sub not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOIDCSessionBySub(ctx, "missing") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Get OIDC code not found", + description: "Get OIDC session by access token hash", run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCode(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code by sub", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeBySub(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) - - // destructive — gone after read - _, err = s.GetOidcCodeBySub(ctx, "sub-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code by sub not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCodeBySub(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code unsafe", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") - require.NoError(t, err) - assert.Equal(t, "sub-1", got.Sub) - - // non-destructive — still present - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.NoError(t, err) - }, - }, - { - description: "Get OIDC code unsafe not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCodeUnsafe(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Get OIDC code by sub unsafe", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "hash-1", got.CodeHash) - - // non-destructive — still present - _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") - assert.NoError(t, err) - }, - }, - { - description: "Get OIDC code by sub unsafe not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Create OIDC code unique sub constraint", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) - assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") - }, - }, - { - description: "Delete OIDC code", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete OIDC code by sub", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete expired OIDC codes", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) - require.NoError(t, err) - _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) - require.NoError(t, err) - - deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) - require.NoError(t, err) - require.Len(t, deleted, 1) - assert.Equal(t, "hash-1", deleted[0].CodeHash) - - _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") - assert.NoError(t, err) - }, - }, - { - description: "Create and get OIDC token", - run: func(t *testing.T, s repository.Store) { - tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-hash-1", - CodeHash: "code-hash-1", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", tok.Sub) - - got, err := s.GetOidcToken(ctx, "at-hash-1") - require.NoError(t, err) - assert.Equal(t, tok, got) - }, - }, - { - description: "Get OIDC token not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcToken(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Create OIDC token unique sub constraint", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) - assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") - }, - }, - { - description: "Get OIDC token by refresh token", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", }) require.NoError(t, err) - got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") + got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1") require.NoError(t, err) assert.Equal(t, "sub-1", got.Sub) }, }, { - description: "Get OIDC token by refresh token not found", + description: "Get OIDC session by access token hash not found", run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") + _, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Get OIDC token by sub", + description: "Get OIDC session by refresh token hash", run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - }) - require.NoError(t, err) - - got, err := s.GetOidcTokenBySub(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, "at-1", got.AccessTokenHash) - }, - }, - { - description: "Get OIDC token by sub not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcTokenBySub(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Update OIDC token by refresh token", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", }) require.NoError(t, err) - updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ - RefreshTokenHash_2: "rt-1", + got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + }, + }, + { + description: "Get OIDC session by refresh token hash not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Create OIDC session unique sub constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) + require.NoError(t, err) + + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub") + }, + }, + { + description: "Create OIDC session unique access token hash constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) + require.NoError(t, err) + + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash") + }, + }, + { + description: "Create OIDC session unique refresh token hash constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) + require.NoError(t, err) + + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash") + }, + }, + { + description: "Update OIDC session", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{ + Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2", + Scope: "openid profile", TokenExpiresAt: 200, RefreshTokenExpiresAt: 400, }) require.NoError(t, err) assert.Equal(t, "at-2", updated.AccessTokenHash) assert.Equal(t, "rt-2", updated.RefreshTokenHash) + assert.Equal(t, "openid profile", updated.Scope) - // old key gone, new key present - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - - got, err := s.GetOidcToken(ctx, "at-2") + // updated token hashes are now queryable, old ones are gone + got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2") require.NoError(t, err) assert.Equal(t, "sub-1", got.Sub) - }, - }, - { - description: "Update OIDC token by refresh token not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ - RefreshTokenHash_2: "missing", - }) + + _, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Delete OIDC token", + description: "Update OIDC session not found", run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + _, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC session by sub", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"}) require.NoError(t, err) - require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) + require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1")) - _, err = s.GetOidcToken(ctx, "at-1") + _, err = s.GetOIDCSessionBySub(ctx, "sub-1") assert.ErrorIs(t, err, repository.ErrNotFound) }, }, { - description: "Delete OIDC token by sub", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete OIDC token by code hash", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", - AccessTokenHash: "at-1", - CodeHash: "code-1", - }) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) - - _, err = s.GetOidcToken(ctx, "at-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete expired OIDC tokens", + description: "Delete expired OIDC sessions", run: func(t *testing.T, s repository.Store) { // both expiries past - _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-1", AccessTokenHash: "at-1", + _, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1", TokenExpiresAt: 10, RefreshTokenExpiresAt: 10, }) require.NoError(t, err) // valid - _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ - Sub: "sub-3", AccessTokenHash: "at-3", + _, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2", TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, }) require.NoError(t, err) - deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{ TokenExpiresAt: 50, RefreshTokenExpiresAt: 50, - }) - require.NoError(t, err) - assert.Len(t, deleted, 1) + })) - _, err = s.GetOidcToken(ctx, "at-3") + _, err = s.GetOIDCSessionBySub(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + + _, err = s.GetOIDCSessionBySub(ctx, "sub-2") assert.NoError(t, err) }, }, - { - description: "Create and get OIDC user info", - run: func(t *testing.T, s repository.Store) { - u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ - Sub: "sub-1", - Name: "Alice", - Email: "alice@example.com", - }) - require.NoError(t, err) - assert.Equal(t, "sub-1", u.Sub) - - got, err := s.GetOidcUserInfo(ctx, "sub-1") - require.NoError(t, err) - assert.Equal(t, u, got) - }, - }, - { - description: "Get OIDC user info not found", - run: func(t *testing.T, s repository.Store) { - _, err := s.GetOidcUserInfo(ctx, "missing") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, - { - description: "Delete OIDC user info", - run: func(t *testing.T, s repository.Store) { - _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) - require.NoError(t, err) - - require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) - - _, err = s.GetOidcUserInfo(ctx, "sub-1") - assert.ErrorIs(t, err, repository.ErrNotFound) - }, - }, } for _, test := range tests { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go index 0b4d758f..1ee81c8b 100644 --- a/internal/repository/memory/oidc_queries.go +++ b/internal/repository/memory/oidc_queries.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - package memory import ( @@ -11,235 +7,90 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" ) -func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { +func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { s.mu.Lock() defer s.mu.Unlock() - // Enforce sub UNIQUE constraint - for _, c := range s.oidcCodes { - if c.Sub == arg.Sub { - return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + // Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique). + for _, sess := range s.oidcSessions { + switch { + case sess.Sub == arg.Sub: + return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub") + case sess.AccessTokenHash == arg.AccessTokenHash: + return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash") + case sess.RefreshTokenHash == arg.RefreshTokenHash: + return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash") } } - code := repository.OidcCode(arg) - s.oidcCodes[arg.CodeHash] = code - return code, nil + sess := repository.OidcSession(arg) + s.oidcSessions[arg.Sub] = sess + return sess, nil } -// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). -func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { - s.mu.Lock() - defer s.mu.Unlock() - c, ok := s.oidcCodes[codeHash] +func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.oidcSessions[sub] if !ok { - return repository.OidcCode{}, repository.ErrNotFound + return repository.OidcSession{}, repository.ErrNotFound } - delete(s.oidcCodes, codeHash) - return c, nil + return sess, nil } -// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). -func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { - s.mu.Lock() - defer s.mu.Unlock() - for k, c := range s.oidcCodes { - if c.Sub == sub { - delete(s.oidcCodes, k) - return c, nil - } - } - return repository.OidcCode{}, repository.ErrNotFound -} - -// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). -func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { +func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) { s.mu.RLock() defer s.mu.RUnlock() - c, ok := s.oidcCodes[codeHash] + for _, sess := range s.oidcSessions { + if sess.AccessTokenHash == accessTokenHash { + return sess, nil + } + } + return repository.OidcSession{}, repository.ErrNotFound +} + +func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, sess := range s.oidcSessions { + if sess.RefreshTokenHash == refreshTokenHash { + return sess, nil + } + } + return repository.OidcSession{}, repository.ErrNotFound +} + +func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.oidcSessions[arg.Sub] if !ok { - return repository.OidcCode{}, repository.ErrNotFound + return repository.OidcSession{}, repository.ErrNotFound } - return c, nil + sess.AccessTokenHash = arg.AccessTokenHash + sess.RefreshTokenHash = arg.RefreshTokenHash + sess.Scope = arg.Scope + sess.ClientID = arg.ClientID + sess.TokenExpiresAt = arg.TokenExpiresAt + sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + sess.Nonce = arg.Nonce + sess.UserinfoJson = arg.UserinfoJson + s.oidcSessions[arg.Sub] = sess + return sess, nil } -// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). -func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { - s.mu.RLock() - defer s.mu.RUnlock() - for _, c := range s.oidcCodes { - if c.Sub == sub { - return c, nil - } - } - return repository.OidcCode{}, repository.ErrNotFound -} - -func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { +func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error { s.mu.Lock() defer s.mu.Unlock() - delete(s.oidcCodes, codeHash) + delete(s.oidcSessions, sub) return nil } -func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { +func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { s.mu.Lock() defer s.mu.Unlock() - for k, c := range s.oidcCodes { - if c.Sub == sub { - delete(s.oidcCodes, k) + for k, sess := range s.oidcSessions { + if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + delete(s.oidcSessions, k) } } return nil } - -func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { - s.mu.Lock() - defer s.mu.Unlock() - var deleted []repository.OidcCode - for k, c := range s.oidcCodes { - if c.ExpiresAt < expiresAt { - deleted = append(deleted, c) - delete(s.oidcCodes, k) - } - } - return deleted, nil -} - -func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { - s.mu.Lock() - defer s.mu.Unlock() - // Enforce sub UNIQUE constraint - for _, t := range s.oidcTokens { - if t.Sub == arg.Sub { - return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") - } - } - tok := repository.OidcToken{ - Sub: arg.Sub, - AccessTokenHash: arg.AccessTokenHash, - RefreshTokenHash: arg.RefreshTokenHash, - CodeHash: arg.CodeHash, - Scope: arg.Scope, - ClientID: arg.ClientID, - TokenExpiresAt: arg.TokenExpiresAt, - RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, - Nonce: arg.Nonce, - } - s.oidcTokens[arg.AccessTokenHash] = tok - return tok, nil -} - -func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { - s.mu.RLock() - defer s.mu.RUnlock() - t, ok := s.oidcTokens[accessTokenHash] - if !ok { - return repository.OidcToken{}, repository.ErrNotFound - } - return t, nil -} - -func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { - s.mu.RLock() - defer s.mu.RUnlock() - for _, t := range s.oidcTokens { - if t.RefreshTokenHash == refreshTokenHash { - return t, nil - } - } - return repository.OidcToken{}, repository.ErrNotFound -} - -func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { - s.mu.RLock() - defer s.mu.RUnlock() - for _, t := range s.oidcTokens { - if t.Sub == sub { - return t, nil - } - } - return repository.OidcToken{}, repository.ErrNotFound -} - -func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { - s.mu.Lock() - defer s.mu.Unlock() - for k, t := range s.oidcTokens { - if t.RefreshTokenHash == arg.RefreshTokenHash_2 { - delete(s.oidcTokens, k) - t.AccessTokenHash = arg.AccessTokenHash - t.RefreshTokenHash = arg.RefreshTokenHash - t.TokenExpiresAt = arg.TokenExpiresAt - t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt - s.oidcTokens[arg.AccessTokenHash] = t - return t, nil - } - } - return repository.OidcToken{}, repository.ErrNotFound -} - -func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.oidcTokens, accessTokenHash) - return nil -} - -func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { - s.mu.Lock() - defer s.mu.Unlock() - for k, t := range s.oidcTokens { - if t.Sub == sub { - delete(s.oidcTokens, k) - } - } - return nil -} - -func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { - s.mu.Lock() - defer s.mu.Unlock() - for k, t := range s.oidcTokens { - if t.CodeHash == codeHash { - delete(s.oidcTokens, k) - } - } - return nil -} - -func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { - s.mu.Lock() - defer s.mu.Unlock() - var deleted []repository.OidcToken - for k, t := range s.oidcTokens { - if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { - deleted = append(deleted, t) - delete(s.oidcTokens, k) - } - } - return deleted, nil -} - -func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { - s.mu.Lock() - defer s.mu.Unlock() - u := repository.OidcUserinfo(arg) - s.oidcUsers[arg.Sub] = u - return u, nil -} - -func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { - s.mu.RLock() - defer s.mu.RUnlock() - u, ok := s.oidcUsers[sub] - if !ok { - return repository.OidcUserinfo{}, repository.ErrNotFound - } - return u, nil -} - -func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { - s.mu.Lock() - defer s.mu.Unlock() - delete(s.oidcUsers, sub) - return nil -} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go index fbbb43cf..2edde6b1 100644 --- a/internal/repository/memory/session_queries.go +++ b/internal/repository/memory/session_queries.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - package memory import ( diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go index a2a56ad3..684ddeb3 100644 --- a/internal/repository/memory/store.go +++ b/internal/repository/memory/store.go @@ -1,7 +1,3 @@ -//go:build exclude - -// temporary - // Package memory provides an in-memory implementation of repository.Store for use in tests. package memory