From ef8bbd8c9fad4903b47c5510582b41f57cd795db Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Mon, 4 May 2026 05:02:27 +1200 Subject: [PATCH] feat(db): add `memory` storage driver removes the sqlite dependency for tests, also brings back the option for users to run zero persistence instances of tinyauth. adds new mapErr fn for sqlc wrapper gen to prevent sql errors from leaking out of the store implementation. --- cmd/gen/sqlc-wrapper/main.go | 24 +- internal/bootstrap/db_bootstrap.go | 17 +- internal/controller/oidc_controller_test.go | 7 +- internal/controller/proxy_controller_test.go | 9 +- internal/controller/user_controller_test.go | 12 +- .../controller/well_known_controller_test.go | 7 +- .../middleware/context_middleware_test.go | 26 +- internal/model/config.go | 6 +- internal/repository/memory/oidc_queries.go | 241 ++++++++++++++++++ internal/repository/memory/session_queries.go | 63 +++++ internal/repository/memory/store.go | 27 ++ internal/repository/sqlite/store.go | 68 +++-- internal/repository/store.go | 8 +- internal/service/auth_service.go | 3 +- internal/service/oidc_service.go | 15 +- 15 files changed, 443 insertions(+), 90 deletions(-) create mode 100644 internal/repository/memory/oidc_queries.go create mode 100644 internal/repository/memory/session_queries.go create mode 100644 internal/repository/memory/store.go diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go index e66ae8ee..d6cb6318 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/main.go @@ -449,18 +449,18 @@ func buildBody(m methodInfo) string { // no repo-typed result → direct return if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn " + call + "\n" + return "\treturn mapErr(" + call + ")\n" } r := m.Results[0] if r.IsSlice { return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } @@ -477,6 +477,8 @@ package {{.PkgName}} import ( "context" + "database/sql" + "errors" "{{.RepoPkg}}" ) @@ -491,6 +493,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + {{range .ModelTypes -}} func {{converterFn .}}(v {{.}}) repository.{{.}} { return repository.{{.}}(v) diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 2279cb23..4f09372a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -8,6 +8,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" @@ -17,14 +18,14 @@ import ( ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { - return app.setupSQLite(app.config.Database.Path) -} - -// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. -// Useful for testing or when constructing a store outside of a BootstrapApp. -func NewSQLiteStore(databasePath string) (repository.Store, error) { - app := &BootstrapApp{} - return app.setupSQLite(databasePath) + switch app.config.Database.Driver { + case "memory": + return memory.New(), nil + case "sqlite", "": + return app.setupSQLite(app.config.Database.Path) + default: + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + } } func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 88690410..4f131ac7 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -14,9 +14,9 @@ import ( "github.com/google/go-querystring/query" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -851,11 +851,10 @@ func TestOIDCController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index f84d791b..c7876713 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,22 +2,20 @@ package controller_test import ( "net/http/httptest" - "path" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestProxyController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -399,11 +397,10 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4184274d..1d2f02e9 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "path" "strings" "testing" "time" @@ -14,17 +13,16 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestUserController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -117,6 +115,8 @@ func TestUserController(t *testing.T) { run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } + store := memory.New() + tests := []testCase{ { description: "Should be able to login with valid credentials", @@ -449,12 +449,8 @@ func TestUserController(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 2bb8cfe1..582e4842 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -10,9 +10,9 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -100,11 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 5dfde3b4..d1dfa99f 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -5,24 +5,22 @@ import ( "encoding/base64" "net/http" "net/http/httptest" - "path" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestContextMiddleware(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -52,7 +50,7 @@ func TestContextMiddleware(t *testing.T) { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) } - seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { + seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) { t.Helper() _, err := queries.CreateSession(context.Background(), params) require.NoError(t, err) @@ -60,7 +58,7 @@ func TestContextMiddleware(t *testing.T) { type runArgs struct { do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) - queries *repository.Queries + queries repository.Store } type testCase struct { @@ -272,22 +270,17 @@ func TestContextMiddleware(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) + store := memory.New() ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() + err := ldap.Init() require.NoError(t, err) broker := service.NewOAuthBrokerService(oauthBrokerCfgs) err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -317,12 +310,7 @@ func TestContextMiddleware(t *testing.T) { return captured, recorder } - test.run(t, runArgs{do: do, queries: queries}) + test.run(t, runArgs{do: do, queries: store}) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/model/config.go b/internal/model/config.go index c2125f92..cffe8734 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -4,7 +4,8 @@ package model func NewDefaultConfiguration() *Config { return &Config{ Database: DatabaseConfig{ - Path: "./tinyauth.db", + Driver: "sqlite", + Path: "./tinyauth.db", }, Analytics: AnalyticsConfig{ Enabled: true, @@ -82,7 +83,8 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the SQLite database, including file name." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go new file mode 100644 index 00000000..80305fc0 --- /dev/null +++ b/internal/repository/memory/oidc_queries.go @@ -0,0 +1,241 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, c := range s.oidcCodes { + if c.Sub == arg.Sub { + return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + } + } + code := repository.OidcCode(arg) + s.oidcCodes[arg.CodeHash] = code + return code, nil +} + +// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + delete(s.oidcCodes, codeHash) + return c, nil +} + +// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + return c, nil +} + +// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, c := range s.oidcCodes { + if c.Sub == sub { + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcCodes, codeHash) + return nil +} + +func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcCode + for k, c := range s.oidcCodes { + if c.ExpiresAt < expiresAt { + deleted = append(deleted, c) + delete(s.oidcCodes, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, t := range s.oidcTokens { + if t.Sub == arg.Sub { + return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") + } + } + tok := repository.OidcToken{ + Sub: arg.Sub, + AccessTokenHash: arg.AccessTokenHash, + RefreshTokenHash: arg.RefreshTokenHash, + CodeHash: arg.CodeHash, + Scope: arg.Scope, + ClientID: arg.ClientID, + TokenExpiresAt: arg.TokenExpiresAt, + RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, + Nonce: arg.Nonce, + } + s.oidcTokens[arg.AccessTokenHash] = tok + return tok, nil +} + +func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + t, ok := s.oidcTokens[accessTokenHash] + if !ok { + return repository.OidcToken{}, repository.ErrNotFound + } + return t, nil +} + +func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.RefreshTokenHash == refreshTokenHash { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.Sub == sub { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.RefreshTokenHash == arg.RefreshTokenHash_2 { + delete(s.oidcTokens, k) + t.AccessTokenHash = arg.AccessTokenHash + t.RefreshTokenHash = arg.RefreshTokenHash + t.TokenExpiresAt = arg.TokenExpiresAt + t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + s.oidcTokens[arg.AccessTokenHash] = t + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcTokens, accessTokenHash) + return nil +} + +func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.Sub == sub { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.CodeHash == codeHash { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcToken + for k, t := range s.oidcTokens { + if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + deleted = append(deleted, t) + delete(s.oidcTokens, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + u := repository.OidcUserinfo(arg) + s.oidcUsers[arg.Sub] = u + return u, nil +} + +func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + u, ok := s.oidcUsers[sub] + if !ok { + return repository.OidcUserinfo{}, repository.ErrNotFound + } + return u, nil +} + +func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcUsers, sub) + return nil +} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go new file mode 100644 index 00000000..2edde6b1 --- /dev/null +++ b/internal/repository/memory/session_queries.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess := repository.Session(arg) + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[uuid] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + return sess, nil +} + +func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[arg.UUID] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + sess.Username = arg.Username + sess.Email = arg.Email + sess.Name = arg.Name + sess.Provider = arg.Provider + sess.TotpPending = arg.TotpPending + sess.OAuthGroups = arg.OAuthGroups + sess.Expiry = arg.Expiry + sess.OAuthName = arg.OAuthName + sess.OAuthSub = arg.OAuthSub + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) DeleteSession(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, uuid) + return nil +} + +func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.sessions { + if v.Expiry < expiry { + delete(s.sessions, k) + } + } + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go new file mode 100644 index 00000000..969cba66 --- /dev/null +++ b/internal/repository/memory/store.go @@ -0,0 +1,27 @@ +// Package memory provides an in-memory implementation of repository.Store for use in tests. +package memory + +import ( + "sync" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store is a thread-safe in-memory implementation of repository.Store. +type Store struct { + mu sync.RWMutex + sessions map[string]repository.Session + oidcCodes map[string]repository.OidcCode + oidcTokens map[string]repository.OidcToken + oidcUsers map[string]repository.OidcUserinfo +} + +// New returns a new empty in-memory Store. +func New() repository.Store { + return &Store{ + sessions: make(map[string]repository.Session), + oidcCodes: make(map[string]repository.OidcCode), + oidcTokens: make(map[string]repository.OidcToken), + oidcUsers: make(map[string]repository.OidcUserinfo), + } +} diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index 65b4e190..f316efa4 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -3,6 +3,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "github.com/tinyauthapp/tinyauth/internal/repository" ) @@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + func oidcCodeToRepo(v OidcCode) repository.OidcCode { return repository.OidcCode(v) } @@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcCode, len(rows)) for i, row := range rows { @@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([] func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcToken, len(rows)) for i, row := range rows { @@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { - return s.q.DeleteExpiredSessions(ctx, expiry) + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcCode(ctx, codeHash) + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) } func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcCodeBySub(ctx, sub) + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) } func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return s.q.DeleteOidcToken(ctx, accessTokenHash) + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) } func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) } func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcTokenBySub(ctx, sub) + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) } func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return s.q.DeleteOidcUserInfo(ctx, sub) + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { - return s.q.DeleteSession(ctx, uuid) + return mapErr(s.q.DeleteSession(ctx, uuid)) } func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCode(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcToken(ctx, accessTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenBySub(ctx, sub) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { r, err := s.q.GetOidcUserInfo(ctx, sub) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { r, err := s.q.GetSession(ctx, uuid) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } diff --git a/internal/repository/store.go b/internal/repository/store.go index 765df6a5..302f2f10 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -1,6 +1,12 @@ package repository -import "context" +import ( + "context" + "errors" +) + +// ErrNotFound is returned by Store methods when the requested record does not exist. +var ErrNotFound = errors.New("not found") // Store is the interface that all storage drivers must implement. // The sqlc-generated *Queries struct satisfies this interface for SQLite. diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 29f491f1..c6a12944 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "errors" "fmt" "net/http" @@ -421,7 +420,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito session, err := auth.queries.GetSession(ctx, uuid) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return nil, errors.New("session not found") } return nil, err diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index d6b11628..65b5ccda 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -7,7 +7,6 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" - "database/sql" "encoding/base64" "encoding/json" "encoding/pem" @@ -422,7 +421,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err @@ -566,7 +565,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -645,7 +644,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -733,15 +732,15 @@ func (service *OIDCService) Hash(token string) string { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil @@ -786,7 +785,7 @@ func (service *OIDCService) Cleanup() { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")