diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go index e66ae8ee..d6cb6318 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/main.go @@ -449,18 +449,18 @@ func buildBody(m methodInfo) string { // no repo-typed result → direct return if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn " + call + "\n" + return "\treturn mapErr(" + call + ")\n" } r := m.Results[0] if r.IsSlice { return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } @@ -477,6 +477,8 @@ package {{.PkgName}} import ( "context" + "database/sql" + "errors" "{{.RepoPkg}}" ) @@ -491,6 +493,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + {{range .ModelTypes -}} func {{converterFn .}}(v {{.}}) repository.{{.}} { return repository.{{.}}(v) diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 2279cb23..4f09372a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -8,6 +8,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" @@ -17,14 +18,14 @@ import ( ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { - return app.setupSQLite(app.config.Database.Path) -} - -// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. -// Useful for testing or when constructing a store outside of a BootstrapApp. -func NewSQLiteStore(databasePath string) (repository.Store, error) { - app := &BootstrapApp{} - return app.setupSQLite(databasePath) + switch app.config.Database.Driver { + case "memory": + return memory.New(), nil + case "sqlite", "": + return app.setupSQLite(app.config.Database.Path) + default: + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + } } func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { diff --git a/internal/config/config.go b/internal/config/config.go index 5b14e27e..9d2a8663 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,7 +4,8 @@ package config func NewDefaultConfiguration() *Config { return &Config{ Database: DatabaseConfig{ - Path: "./tinyauth.db", + Driver: "sqlite", + Path: "./tinyauth.db", }, Analytics: AnalyticsConfig{ Enabled: true, @@ -95,7 +96,8 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the SQLite database, including file name." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 991f6759..b83094c1 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -12,9 +12,9 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -847,11 +847,10 @@ func TestOIDCController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index adfc7fb1..74bfdead 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,13 +2,12 @@ package controller_test import ( "net/http/httptest" - "path" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -17,7 +16,6 @@ import ( func TestProxyController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ @@ -392,11 +390,10 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index b67c70fa..1d6e11b2 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -3,16 +3,15 @@ package controller_test import ( "encoding/json" "net/http/httptest" - "path" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -21,7 +20,6 @@ import ( func TestUserController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ Users: []config.User{ @@ -350,11 +348,10 @@ func TestUserController(t *testing.T) { oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index eba449b0..25c8e5a8 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -8,9 +8,9 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" @@ -100,11 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go new file mode 100644 index 00000000..80305fc0 --- /dev/null +++ b/internal/repository/memory/oidc_queries.go @@ -0,0 +1,241 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, c := range s.oidcCodes { + if c.Sub == arg.Sub { + return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + } + } + code := repository.OidcCode(arg) + s.oidcCodes[arg.CodeHash] = code + return code, nil +} + +// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + delete(s.oidcCodes, codeHash) + return c, nil +} + +// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + return c, nil +} + +// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, c := range s.oidcCodes { + if c.Sub == sub { + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcCodes, codeHash) + return nil +} + +func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcCode + for k, c := range s.oidcCodes { + if c.ExpiresAt < expiresAt { + deleted = append(deleted, c) + delete(s.oidcCodes, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, t := range s.oidcTokens { + if t.Sub == arg.Sub { + return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") + } + } + tok := repository.OidcToken{ + Sub: arg.Sub, + AccessTokenHash: arg.AccessTokenHash, + RefreshTokenHash: arg.RefreshTokenHash, + CodeHash: arg.CodeHash, + Scope: arg.Scope, + ClientID: arg.ClientID, + TokenExpiresAt: arg.TokenExpiresAt, + RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, + Nonce: arg.Nonce, + } + s.oidcTokens[arg.AccessTokenHash] = tok + return tok, nil +} + +func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + t, ok := s.oidcTokens[accessTokenHash] + if !ok { + return repository.OidcToken{}, repository.ErrNotFound + } + return t, nil +} + +func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.RefreshTokenHash == refreshTokenHash { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.Sub == sub { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.RefreshTokenHash == arg.RefreshTokenHash_2 { + delete(s.oidcTokens, k) + t.AccessTokenHash = arg.AccessTokenHash + t.RefreshTokenHash = arg.RefreshTokenHash + t.TokenExpiresAt = arg.TokenExpiresAt + t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + s.oidcTokens[arg.AccessTokenHash] = t + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcTokens, accessTokenHash) + return nil +} + +func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.Sub == sub { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.CodeHash == codeHash { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcToken + for k, t := range s.oidcTokens { + if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + deleted = append(deleted, t) + delete(s.oidcTokens, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + u := repository.OidcUserinfo(arg) + s.oidcUsers[arg.Sub] = u + return u, nil +} + +func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + u, ok := s.oidcUsers[sub] + if !ok { + return repository.OidcUserinfo{}, repository.ErrNotFound + } + return u, nil +} + +func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcUsers, sub) + return nil +} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go new file mode 100644 index 00000000..2edde6b1 --- /dev/null +++ b/internal/repository/memory/session_queries.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess := repository.Session(arg) + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[uuid] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + return sess, nil +} + +func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[arg.UUID] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + sess.Username = arg.Username + sess.Email = arg.Email + sess.Name = arg.Name + sess.Provider = arg.Provider + sess.TotpPending = arg.TotpPending + sess.OAuthGroups = arg.OAuthGroups + sess.Expiry = arg.Expiry + sess.OAuthName = arg.OAuthName + sess.OAuthSub = arg.OAuthSub + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) DeleteSession(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, uuid) + return nil +} + +func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.sessions { + if v.Expiry < expiry { + delete(s.sessions, k) + } + } + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go new file mode 100644 index 00000000..969cba66 --- /dev/null +++ b/internal/repository/memory/store.go @@ -0,0 +1,27 @@ +// Package memory provides an in-memory implementation of repository.Store for use in tests. +package memory + +import ( + "sync" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store is a thread-safe in-memory implementation of repository.Store. +type Store struct { + mu sync.RWMutex + sessions map[string]repository.Session + oidcCodes map[string]repository.OidcCode + oidcTokens map[string]repository.OidcToken + oidcUsers map[string]repository.OidcUserinfo +} + +// New returns a new empty in-memory Store. +func New() repository.Store { + return &Store{ + sessions: make(map[string]repository.Session), + oidcCodes: make(map[string]repository.OidcCode), + oidcTokens: make(map[string]repository.OidcToken), + oidcUsers: make(map[string]repository.OidcUserinfo), + } +} diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index 65b4e190..f316efa4 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -3,6 +3,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "github.com/tinyauthapp/tinyauth/internal/repository" ) @@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + func oidcCodeToRepo(v OidcCode) repository.OidcCode { return repository.OidcCode(v) } @@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcCode, len(rows)) for i, row := range rows { @@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([] func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcToken, len(rows)) for i, row := range rows { @@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { - return s.q.DeleteExpiredSessions(ctx, expiry) + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcCode(ctx, codeHash) + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) } func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcCodeBySub(ctx, sub) + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) } func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return s.q.DeleteOidcToken(ctx, accessTokenHash) + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) } func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) } func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcTokenBySub(ctx, sub) + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) } func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return s.q.DeleteOidcUserInfo(ctx, sub) + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { - return s.q.DeleteSession(ctx, uuid) + return mapErr(s.q.DeleteSession(ctx, uuid)) } func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCode(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcToken(ctx, accessTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenBySub(ctx, sub) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { r, err := s.q.GetOidcUserInfo(ctx, sub) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { r, err := s.q.GetSession(ctx, uuid) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } diff --git a/internal/repository/store.go b/internal/repository/store.go index 765df6a5..302f2f10 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -1,6 +1,12 @@ package repository -import "context" +import ( + "context" + "errors" +) + +// ErrNotFound is returned by Store methods when the requested record does not exist. +var ErrNotFound = errors.New("not found") // Store is the interface that all storage drivers must implement. // The sqlc-generated *Queries struct satisfies this interface for SQLite. diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ab343396..5d2ead2f 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "errors" "fmt" "regexp" @@ -411,7 +410,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e session, err := auth.queries.GetSession(c, cookie) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.Session{}, fmt.Errorf("session not found") } return repository.Session{}, err diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index e5f7ea76..14d94f61 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -7,7 +7,6 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" - "database/sql" "encoding/base64" "encoding/json" "encoding/pem" @@ -420,7 +419,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err @@ -564,7 +563,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -643,7 +642,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -731,15 +730,15 @@ func (service *OIDCService) Hash(token string) string { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil @@ -784,7 +783,7 @@ func (service *OIDCService) Cleanup() { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")