From 04b8e9884bcd76ae502fd686f6e8ab5c8e703e60 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Mon, 4 May 2026 05:02:27 +1200 Subject: [PATCH] feat(db): add `memory` storage driver removes the sqlite dependency for tests, also brings back the option for users to run zero persistence instances of tinyauth. adds new mapErr fn for sqlc wrapper gen to prevent sql errors from leaking out of the store implementation. --- cmd/gen/sqlc-wrapper/main.go | 24 +- internal/bootstrap/db_bootstrap.go | 17 +- internal/config/config.go | 6 +- internal/controller/oidc_controller_test.go | 7 +- internal/controller/proxy_controller_test.go | 9 +- internal/controller/user_controller_test.go | 9 +- .../controller/well_known_controller_test.go | 7 +- internal/repository/memory/oidc_queries.go | 241 ++++++++++++++++++ internal/repository/memory/session_queries.go | 63 +++++ internal/repository/memory/store.go | 27 ++ internal/repository/sqlite/store.go | 68 +++-- internal/repository/store.go | 8 +- internal/service/auth_service.go | 3 +- internal/service/oidc_service.go | 15 +- 14 files changed, 435 insertions(+), 69 deletions(-) create mode 100644 internal/repository/memory/oidc_queries.go create mode 100644 internal/repository/memory/session_queries.go create mode 100644 internal/repository/memory/store.go diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go index e66ae8ee..d6cb6318 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/main.go @@ -449,18 +449,18 @@ func buildBody(m methodInfo) string { // no repo-typed result → direct return if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn " + call + "\n" + return "\treturn mapErr(" + call + ")\n" } r := m.Results[0] if r.IsSlice { return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } @@ -477,6 +477,8 @@ package {{.PkgName}} import ( "context" + "database/sql" + "errors" "{{.RepoPkg}}" ) @@ -491,6 +493,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + {{range .ModelTypes -}} func {{converterFn .}}(v {{.}}) repository.{{.}} { return repository.{{.}}(v) diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 2279cb23..4f09372a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -8,6 +8,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" @@ -17,14 +18,14 @@ import ( ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { - return app.setupSQLite(app.config.Database.Path) -} - -// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. -// Useful for testing or when constructing a store outside of a BootstrapApp. -func NewSQLiteStore(databasePath string) (repository.Store, error) { - app := &BootstrapApp{} - return app.setupSQLite(databasePath) + switch app.config.Database.Driver { + case "memory": + return memory.New(), nil + case "sqlite", "": + return app.setupSQLite(app.config.Database.Path) + default: + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + } } func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { diff --git a/internal/config/config.go b/internal/config/config.go index 5b14e27e..9d2a8663 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,7 +4,8 @@ package config func NewDefaultConfiguration() *Config { return &Config{ Database: DatabaseConfig{ - Path: "./tinyauth.db", + Driver: "sqlite", + Path: "./tinyauth.db", }, Analytics: AnalyticsConfig{ Enabled: true, @@ -95,7 +96,8 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the SQLite database, including file name." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 991f6759..b83094c1 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -12,9 +12,9 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -847,11 +847,10 @@ func TestOIDCController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index adfc7fb1..74bfdead 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,13 +2,12 @@ package controller_test import ( "net/http/httptest" - "path" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -17,7 +16,6 @@ import ( func TestProxyController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ @@ -392,11 +390,10 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index b67c70fa..1d6e11b2 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -3,16 +3,15 @@ package controller_test import ( "encoding/json" "net/http/httptest" - "path" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -21,7 +20,6 @@ import ( func TestUserController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ @@ -350,11 +348,10 @@ func TestUserController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index eba449b0..25c8e5a8 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -8,9 +8,9 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -100,11 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go new file mode 100644 index 00000000..80305fc0 --- /dev/null +++ b/internal/repository/memory/oidc_queries.go @@ -0,0 +1,241 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, c := range s.oidcCodes { + if c.Sub == arg.Sub { + return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + } + } + code := repository.OidcCode(arg) + s.oidcCodes[arg.CodeHash] = code + return code, nil +} + +// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + delete(s.oidcCodes, codeHash) + return c, nil +} + +// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + return c, nil +} + +// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, c := range s.oidcCodes { + if c.Sub == sub { + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcCodes, codeHash) + return nil +} + +func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcCode + for k, c := range s.oidcCodes { + if c.ExpiresAt < expiresAt { + deleted = append(deleted, c) + delete(s.oidcCodes, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, t := range s.oidcTokens { + if t.Sub == arg.Sub { + return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") + } + } + tok := repository.OidcToken{ + Sub: arg.Sub, + AccessTokenHash: arg.AccessTokenHash, + RefreshTokenHash: arg.RefreshTokenHash, + CodeHash: arg.CodeHash, + Scope: arg.Scope, + ClientID: arg.ClientID, + TokenExpiresAt: arg.TokenExpiresAt, + RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, + Nonce: arg.Nonce, + } + s.oidcTokens[arg.AccessTokenHash] = tok + return tok, nil +} + +func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + t, ok := s.oidcTokens[accessTokenHash] + if !ok { + return repository.OidcToken{}, repository.ErrNotFound + } + return t, nil +} + +func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.RefreshTokenHash == refreshTokenHash { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.Sub == sub { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.RefreshTokenHash == arg.RefreshTokenHash_2 { + delete(s.oidcTokens, k) + t.AccessTokenHash = arg.AccessTokenHash + t.RefreshTokenHash = arg.RefreshTokenHash + t.TokenExpiresAt = arg.TokenExpiresAt + t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + s.oidcTokens[arg.AccessTokenHash] = t + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcTokens, accessTokenHash) + return nil +} + +func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.Sub == sub { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.CodeHash == codeHash { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcToken + for k, t := range s.oidcTokens { + if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + deleted = append(deleted, t) + delete(s.oidcTokens, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + u := repository.OidcUserinfo(arg) + s.oidcUsers[arg.Sub] = u + return u, nil +} + +func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + u, ok := s.oidcUsers[sub] + if !ok { + return repository.OidcUserinfo{}, repository.ErrNotFound + } + return u, nil +} + +func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcUsers, sub) + return nil +} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go new file mode 100644 index 00000000..2edde6b1 --- /dev/null +++ b/internal/repository/memory/session_queries.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess := repository.Session(arg) + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[uuid] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + return sess, nil +} + +func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[arg.UUID] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + sess.Username = arg.Username + sess.Email = arg.Email + sess.Name = arg.Name + sess.Provider = arg.Provider + sess.TotpPending = arg.TotpPending + sess.OAuthGroups = arg.OAuthGroups + sess.Expiry = arg.Expiry + sess.OAuthName = arg.OAuthName + sess.OAuthSub = arg.OAuthSub + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) DeleteSession(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, uuid) + return nil +} + +func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.sessions { + if v.Expiry < expiry { + delete(s.sessions, k) + } + } + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go new file mode 100644 index 00000000..969cba66 --- /dev/null +++ b/internal/repository/memory/store.go @@ -0,0 +1,27 @@ +// Package memory provides an in-memory implementation of repository.Store for use in tests. +package memory + +import ( + "sync" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store is a thread-safe in-memory implementation of repository.Store. +type Store struct { + mu sync.RWMutex + sessions map[string]repository.Session + oidcCodes map[string]repository.OidcCode + oidcTokens map[string]repository.OidcToken + oidcUsers map[string]repository.OidcUserinfo +} + +// New returns a new empty in-memory Store. +func New() repository.Store { + return &Store{ + sessions: make(map[string]repository.Session), + oidcCodes: make(map[string]repository.OidcCode), + oidcTokens: make(map[string]repository.OidcToken), + oidcUsers: make(map[string]repository.OidcUserinfo), + } +} diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index 65b4e190..f316efa4 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -3,6 +3,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "github.com/tinyauthapp/tinyauth/internal/repository" ) @@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + func oidcCodeToRepo(v OidcCode) repository.OidcCode { return repository.OidcCode(v) } @@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcCode, len(rows)) for i, row := range rows { @@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([] func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcToken, len(rows)) for i, row := range rows { @@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { - return s.q.DeleteExpiredSessions(ctx, expiry) + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcCode(ctx, codeHash) + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) } func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcCodeBySub(ctx, sub) + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) } func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return s.q.DeleteOidcToken(ctx, accessTokenHash) + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) } func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) } func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcTokenBySub(ctx, sub) + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) } func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return s.q.DeleteOidcUserInfo(ctx, sub) + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { - return s.q.DeleteSession(ctx, uuid) + return mapErr(s.q.DeleteSession(ctx, uuid)) } func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCode(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcToken(ctx, accessTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenBySub(ctx, sub) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { r, err := s.q.GetOidcUserInfo(ctx, sub) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { r, err := s.q.GetSession(ctx, uuid) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } diff --git a/internal/repository/store.go b/internal/repository/store.go index 765df6a5..302f2f10 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -1,6 +1,12 @@ package repository -import "context" +import ( + "context" + "errors" +) + +// ErrNotFound is returned by Store methods when the requested record does not exist. +var ErrNotFound = errors.New("not found") // Store is the interface that all storage drivers must implement. // The sqlc-generated *Queries struct satisfies this interface for SQLite. diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ab343396..5d2ead2f 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "errors" "fmt" "regexp" @@ -411,7 +410,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e session, err := auth.queries.GetSession(c, cookie) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.Session{}, fmt.Errorf("session not found") } return repository.Session{}, err diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index e5f7ea76..14d94f61 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -7,7 +7,6 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" - "database/sql" "encoding/base64" "encoding/json" "encoding/pem" @@ -420,7 +419,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err @@ -564,7 +563,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -643,7 +642,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -731,15 +730,15 @@ func (service *OIDCService) Hash(token string) string { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil @@ -784,7 +783,7 @@ func (service *OIDCService) Cleanup() { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")