refactor(db): use new store interface

This commit is contained in:
Scott McKendry
2026-04-30 18:16:50 +12:00
parent e6b291d21c
commit 06071e1f54
39 changed files with 174 additions and 139 deletions
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations // Migrations
// //
//go:embed migrations/*.sql //go:embed migrations/sqlite/*.sql
var Migrations embed.FS var Migrations embed.FS
+4 -5
View File
@@ -43,7 +43,7 @@ type BootstrapApp struct {
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
queries *repository.Queries queries repository.Store
router *gin.Engine router *gin.Engine
db *sql.DB db *sql.DB
wg sync.WaitGroup wg sync.WaitGroup
@@ -162,7 +162,7 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database // database
err = app.SetupDatabase() store, err := app.SetupStore()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
@@ -176,9 +176,8 @@ func (app *BootstrapApp) Setup() error {
app.db.Close() app.db.Close()
}() }()
// queries // store
queries := repository.New(app.db) app.queries = store
app.queries = queries
// services // services
err = app.setupServices() err = app.setupServices()
+24 -14
View File
@@ -7,6 +7,8 @@ import (
"path/filepath" "path/filepath"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/database/sqlite3"
@@ -14,17 +16,28 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func (app *BootstrapApp) SetupDatabase() error { func (app *BootstrapApp) SetupStore() (repository.Store, error) {
dir := filepath.Dir(app.config.Database.Path) return app.setupSQLite(app.config.Database.Path)
}
// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store.
// Useful for testing or when constructing a store outside of a BootstrapApp.
func NewSQLiteStore(databasePath string) (repository.Store, error) {
app := &BootstrapApp{}
return app.setupSQLite(databasePath)
}
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil { if err := os.MkdirAll(dir, 0750); err != nil {
return fmt.Errorf("failed to create database directory %s: %w", dir, err) return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
} }
db, err := sql.Open("sqlite", app.config.Database.Path) db, err := sql.Open("sqlite", databasePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
// Close the database if there is an error during migration // Close the database if there is an error during migration
@@ -38,32 +51,29 @@ func (app *BootstrapApp) SetupDatabase() error {
// if the sqlite connection starts being a bottleneck // if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
migrations, err := iofs.New(assets.Migrations, "migrations") migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
if err != nil { if err != nil {
return fmt.Errorf("failed to create migrations: %w", err) return nil, fmt.Errorf("failed to create migrations: %w", err)
} }
target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
return fmt.Errorf("failed to create sqlite3 instance: %w", err) return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
} }
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil { if err != nil {
return fmt.Errorf("failed to create migrator: %w", err) return nil, fmt.Errorf("failed to create migrator: %w", err)
} }
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to migrate database: %w", err) return nil, fmt.Errorf("failed to migrate database: %w", err)
} }
app.db = db app.db = db
return nil
}
func (app *BootstrapApp) GetDB() *sql.DB { return sqlite.New(db), nil
return app.db
} }
+2 -11
View File
@@ -18,7 +18,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -839,16 +838,12 @@ func TestOIDCController(t *testing.T) {
}, },
} }
app := bootstrap.NewBootstrapApp(cfg) store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
err := app.SetupDatabase()
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB())
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -869,8 +864,4 @@ func TestOIDCController(t *testing.T) {
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
app.GetDB().Close()
})
} }
+2 -11
View File
@@ -12,7 +12,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -379,18 +378,14 @@ func TestProxyController(t *testing.T) {
}, },
} }
app := bootstrap.NewBootstrapApp(cfg) store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
err := app.SetupDatabase()
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB())
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
ctx := context.TODO() ctx := context.TODO()
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
aclsService := service.NewAccessControlsService(log, nil, acls) aclsService := service.NewAccessControlsService(log, nil, acls)
for _, test := range tests { for _, test := range tests {
@@ -411,8 +406,4 @@ func TestProxyController(t *testing.T) {
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
app.GetDB().Close()
})
} }
+5 -12
View File
@@ -73,13 +73,9 @@ func TestUserController(t *testing.T) {
}) })
} }
app := bootstrap.NewBootstrapApp(cfg) store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
err := app.SetupDatabase()
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB())
type testCase struct { type testCase struct {
description string description string
middlewares []gin.HandlerFunc middlewares []gin.HandlerFunc
@@ -254,7 +250,7 @@ func TestUserController(t *testing.T) {
totpCtx, totpCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-uuid", UUID: "test-totp-login-uuid",
Username: "test", Username: "test",
Email: "test@example.com", Email: "test@example.com",
@@ -378,7 +374,7 @@ func TestUserController(t *testing.T) {
totpAttrCtx, totpAttrCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-attributes-uuid", UUID: "test-totp-login-attributes-uuid",
Username: "test", Username: "test",
Email: "test@example.com", Email: "test@example.com",
@@ -420,7 +416,7 @@ func TestUserController(t *testing.T) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -446,8 +442,5 @@ func TestUserController(t *testing.T) {
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
}
t.Cleanup(func() {
app.GetDB().Close()
})
} }
@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -92,14 +91,10 @@ func TestWellKnownController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
app := bootstrap.NewBootstrapApp(cfg) store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
err := app.SetupDatabase()
require.NoError(t, err) require.NoError(t, err)
queries := repository.New(app.GetDB()) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -114,8 +109,4 @@ func TestWellKnownController(t *testing.T) {
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
app.GetDB().Close()
})
} }
+1 -1
View File
@@ -83,7 +83,7 @@ type Config struct {
} }
type DatabaseConfig struct { type DatabaseConfig struct {
Path string `description:"The path to the database, including file name." yaml:"path"` Path string `description:"The path to the SQLite database, including file name." yaml:"path"`
} }
type AnalyticsConfig struct { type AnalyticsConfig struct {
+14 -59
View File
@@ -1,64 +1,19 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository package repository
type OidcCode struct { // This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest
Sub string // of the codebase can import them from a single location without needing to know about the underlying database implementation.
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct { import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct { type Session = sqlite.Session
Sub string type OidcCode = sqlite.OidcCode
Name string type OidcToken = sqlite.OidcToken
PreferredUsername string type OidcUserinfo = sqlite.OidcUserinfo
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct { type CreateSessionParams = sqlite.CreateSessionParams
UUID string type UpdateSessionParams = sqlite.UpdateSessionParams
Username string type CreateOidcCodeParams = sqlite.CreateOidcCodeParams
Email string type CreateOidcTokenParams = sqlite.CreateOidcTokenParams
Name string type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams
Provider string type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams
TotpPending bool type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -1,8 +1,8 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.0
package repository package sqlite
import ( import (
"context" "context"
+64
View File
@@ -0,0 +1,64 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.0
package sqlite
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.0
// source: oidc_queries.sql // source: oidc_queries.sql
package repository package sqlite
import ( import (
"context" "context"
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.0
// source: session_queries.sql // source: session_queries.sql
package repository package sqlite
import ( import (
"context" "context"
+41
View File
@@ -0,0 +1,41 @@
package repository
import "context"
// Store is the interface that all storage drivers must implement.
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
// Future drivers (postgres, etc.) must return the shared types defined in this package.
type Store interface {
// Sessions
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
GetSession(ctx context.Context, uuid string) (Session, error)
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC codes
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
DeleteOidcCode(ctx context.Context, codeHash string) error
DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
}
+2 -2
View File
@@ -79,7 +79,7 @@ type AuthService struct {
context context.Context context context.Context
ldap *LdapService ldap *LdapService
queries *repository.Queries queries repository.Store
oauthBroker *OAuthBrokerService oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
@@ -100,7 +100,7 @@ func NewAuthService(
ctx context.Context, ctx context.Context,
wg *sync.WaitGroup, wg *sync.WaitGroup,
ldap *LdapService, ldap *LdapService,
queries *repository.Queries, queries repository.Store,
oauthBroker *OAuthBrokerService, oauthBroker *OAuthBrokerService,
) *AuthService { ) *AuthService {
service := &AuthService{ service := &AuthService{
+2 -2
View File
@@ -116,7 +116,7 @@ type OIDCService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
queries *repository.Queries queries repository.Store
context context.Context context context.Context
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
@@ -129,7 +129,7 @@ func NewOIDCService(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
runtime model.RuntimeConfig, runtime model.RuntimeConfig,
queries *repository.Queries, queries repository.Store,
ctx context.Context, ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) { wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init // If not configured, skip init
+4 -4
View File
@@ -1,12 +1,12 @@
version: "2" version: "2"
sql: sql:
- engine: "sqlite" - engine: "sqlite"
queries: "sql/*_queries.sql" queries: "sql/sqlite/*_queries.sql"
schema: "sql/*_schemas.sql" schema: "sql/sqlite/*_schemas.sql"
gen: gen:
go: go:
package: "repository" package: "sqlite"
out: "internal/repository" out: "internal/repository/sqlite"
rename: rename:
uuid: "UUID" uuid: "UUID"
oauth_groups: "OAuthGroups" oauth_groups: "OAuthGroups"