feat(db): add memory storage driver

removes the sqlite dependency for tests, also brings back the option for
users to run zero persistence instances of tinyauth.

adds new mapErr fn for sqlc wrapper gen to prevent sql errors from
leaking out of the store implementation.
This commit is contained in:
Scott McKendry
2026-05-04 05:02:27 +12:00
parent 0244f39387
commit 04b8e9884b
14 changed files with 435 additions and 69 deletions
+43 -25
View File
@@ -3,6 +3,8 @@ package sqlite
import (
"context"
"database/sql"
"errors"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
@@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
return repository.OidcCode(v)
}
@@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session {
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil {
return repository.OidcCode{}, err
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
@@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, err
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
@@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, err
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
@@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
if err != nil {
return repository.Session{}, err
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
@@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, err
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
@@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, err
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
@@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return s.q.DeleteExpiredSessions(ctx, expiry)
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return s.q.DeleteOidcCode(ctx, codeHash)
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return s.q.DeleteOidcCodeBySub(ctx, sub)
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return s.q.DeleteOidcToken(ctx, accessTokenHash)
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return s.q.DeleteOidcTokenBySub(ctx, sub)
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return s.q.DeleteOidcUserInfo(ctx, sub)
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return s.q.DeleteSession(ctx, uuid)
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, err
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
@@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil {
return repository.OidcCode{}, err
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
@@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil {
return repository.OidcCode{}, err
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
@@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, err
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
@@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, err
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
@@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, err
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
@@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, err
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
@@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, err
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
@@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
r, err := s.q.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, err
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
@@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil {
return repository.OidcToken{}, err
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
@@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
if err != nil {
return repository.Session{}, err
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}