refactor(db): use new store interface

This commit is contained in:
Scott McKendry
2026-04-30 18:16:50 +12:00
parent a6351790c3
commit ad6751df2a
39 changed files with 174 additions and 139 deletions
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations
//
//go:embed migrations/*.sql
//go:embed migrations/sqlite/*.sql
var Migrations embed.FS
+4 -5
View File
@@ -43,7 +43,7 @@ type BootstrapApp struct {
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries *repository.Queries
queries repository.Store
router *gin.Engine
db *sql.DB
wg sync.WaitGroup
@@ -162,7 +162,7 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database
err = app.SetupDatabase()
store, err := app.SetupStore()
if err != nil {
return fmt.Errorf("failed to setup database: %w", err)
@@ -176,9 +176,8 @@ func (app *BootstrapApp) Setup() error {
app.db.Close()
}()
// queries
queries := repository.New(app.db)
app.queries = queries
// store
app.queries = store
// services
err = app.setupServices()
+24 -14
View File
@@ -7,6 +7,8 @@ import (
"path/filepath"
"github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
@@ -14,17 +16,28 @@ import (
_ "modernc.org/sqlite"
)
func (app *BootstrapApp) SetupDatabase() error {
dir := filepath.Dir(app.config.Database.Path)
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
return app.setupSQLite(app.config.Database.Path)
}
// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store.
// Useful for testing or when constructing a store outside of a BootstrapApp.
func NewSQLiteStore(databasePath string) (repository.Store, error) {
app := &BootstrapApp{}
return app.setupSQLite(databasePath)
}
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil {
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
}
db, err := sql.Open("sqlite", app.config.Database.Path)
db, err := sql.Open("sqlite", databasePath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Close the database if there is an error during migration
@@ -38,32 +51,29 @@ func (app *BootstrapApp) SetupDatabase() error {
// if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1)
migrations, err := iofs.New(assets.Migrations, "migrations")
migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
if err != nil {
return fmt.Errorf("failed to create migrations: %w", err)
return nil, fmt.Errorf("failed to create migrations: %w", err)
}
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
return fmt.Errorf("failed to create sqlite3 instance: %w", err)
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
}
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil {
return fmt.Errorf("failed to create migrator: %w", err)
return nil, fmt.Errorf("failed to create migrator: %w", err)
}
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to migrate database: %w", err)
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
app.db = db
return nil
}
func (app *BootstrapApp) GetDB() *sql.DB {
return app.db
return sqlite.New(db), nil
}
+2 -11
View File
@@ -18,7 +18,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -839,16 +838,12 @@ func TestOIDCController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
require.NoError(t, err)
queries := repository.New(app.GetDB())
wg := &sync.WaitGroup{}
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
require.NoError(t, err)
for _, test := range tests {
@@ -869,8 +864,4 @@ func TestOIDCController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
app.GetDB().Close()
})
}
+2 -11
View File
@@ -12,7 +12,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -379,18 +378,14 @@ func TestProxyController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
require.NoError(t, err)
queries := repository.New(app.GetDB())
wg := &sync.WaitGroup{}
ctx := context.TODO()
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
aclsService := service.NewAccessControlsService(log, nil, acls)
for _, test := range tests {
@@ -411,8 +406,4 @@ func TestProxyController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
app.GetDB().Close()
})
}
+5 -12
View File
@@ -73,13 +73,9 @@ func TestUserController(t *testing.T) {
})
}
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
require.NoError(t, err)
queries := repository.New(app.GetDB())
type testCase struct {
description string
middlewares []gin.HandlerFunc
@@ -254,7 +250,7 @@ func TestUserController(t *testing.T) {
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-uuid",
Username: "test",
Email: "test@example.com",
@@ -378,7 +374,7 @@ func TestUserController(t *testing.T) {
totpAttrCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-attributes-uuid",
Username: "test",
Email: "test@example.com",
@@ -420,7 +416,7 @@ func TestUserController(t *testing.T) {
wg := &sync.WaitGroup{}
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
beforeEach := func() {
// Clear failed login attempts before each test
@@ -446,8 +442,5 @@ func TestUserController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
app.GetDB().Close()
})
}
}
@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -92,14 +91,10 @@ func TestWellKnownController(t *testing.T) {
ctx := context.TODO()
wg := &sync.WaitGroup{}
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
store, err := bootstrap.NewSQLiteStore(cfg.Database.Path)
require.NoError(t, err)
queries := repository.New(app.GetDB())
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
require.NoError(t, err)
for _, test := range tests {
@@ -114,8 +109,4 @@ func TestWellKnownController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
app.GetDB().Close()
})
}
+1 -1
View File
@@ -83,7 +83,7 @@ type Config struct {
}
type DatabaseConfig struct {
Path string `description:"The path to the database, including file name." yaml:"path"`
Path string `description:"The path to the SQLite database, including file name." yaml:"path"`
}
type AnalyticsConfig struct {
+14 -59
View File
@@ -1,64 +1,19 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
// This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest
// of the codebase can import them from a single location without needing to know about the underlying database implementation.
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session = sqlite.Session
type OidcCode = sqlite.OidcCode
type OidcToken = sqlite.OidcToken
type OidcUserinfo = sqlite.OidcUserinfo
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
type CreateSessionParams = sqlite.CreateSessionParams
type UpdateSessionParams = sqlite.UpdateSessionParams
type CreateOidcCodeParams = sqlite.CreateOidcCodeParams
type CreateOidcTokenParams = sqlite.CreateOidcTokenParams
type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams
type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams
type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams
@@ -1,8 +1,8 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.0
package repository
package sqlite
import (
"context"
+64
View File
@@ -0,0 +1,64 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.0
package sqlite
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.0
// source: oidc_queries.sql
package repository
package sqlite
import (
"context"
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.0
// source: session_queries.sql
package repository
package sqlite
import (
"context"
+41
View File
@@ -0,0 +1,41 @@
package repository
import "context"
// Store is the interface that all storage drivers must implement.
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
// Future drivers (postgres, etc.) must return the shared types defined in this package.
type Store interface {
// Sessions
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
GetSession(ctx context.Context, uuid string) (Session, error)
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC codes
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
DeleteOidcCode(ctx context.Context, codeHash string) error
DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
}
+2 -2
View File
@@ -79,7 +79,7 @@ type AuthService struct {
context context.Context
ldap *LdapService
queries *repository.Queries
queries repository.Store
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt
@@ -100,7 +100,7 @@ func NewAuthService(
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
queries repository.Store,
oauthBroker *OAuthBrokerService,
) *AuthService {
service := &AuthService{
+2 -2
View File
@@ -116,7 +116,7 @@ type OIDCService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
queries *repository.Queries
queries repository.Store
context context.Context
clients map[string]model.OIDCClientConfig
@@ -129,7 +129,7 @@ func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
queries repository.Store,
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
+4 -4
View File
@@ -1,12 +1,12 @@
version: "2"
sql:
- engine: "sqlite"
queries: "sql/*_queries.sql"
schema: "sql/*_schemas.sql"
queries: "sql/sqlite/*_queries.sql"
schema: "sql/sqlite/*_schemas.sql"
gen:
go:
package: "repository"
out: "internal/repository"
package: "sqlite"
out: "internal/repository/sqlite"
rename:
uuid: "UUID"
oauth_groups: "OAuthGroups"