refactor: replace gorm with vanilla sql and sqlc (#541)

* refactor: replace gorm with vanilla sql and sqlc

* chore: go mod tidy

* refactor: rebase for main

* tests: fix tests

* fix: review comments
This commit is contained in:
Stavros
2025-12-31 17:59:21 +02:00
committed by GitHub
parent 2dc047d9b7
commit 7e17a4ad86
16 changed files with 416 additions and 177 deletions

View File

@@ -14,11 +14,10 @@ import (
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/model"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
type BootstrapApp struct {
@@ -108,8 +107,18 @@ func (app *BootstrapApp) Setup() error {
log.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
log.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
// Database
db, err := app.SetupDatabase(app.config.DatabasePath)
if err != nil {
return fmt.Errorf("failed to setup database: %w", err)
}
// Queries
queries := repository.New(db)
// Services
services, err := app.initServices()
services, err := app.initServices(queries)
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
@@ -155,9 +164,9 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to setup routes: %w", err)
}
// Start DB cleanup routine
// Start db cleanup routine
log.Debug().Msg("Starting database cleanup routine")
go app.dbCleanup(services.databaseService.GetDatabase())
go app.dbCleanup(queries)
// If analytics are not disabled, start heartbeat
if !app.config.DisableAnalytics {
@@ -247,16 +256,16 @@ func (app *BootstrapApp) heartbeat() {
}
}
func (app *BootstrapApp) dbCleanup(db *gorm.DB) {
func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
ctx := context.Background()
for ; true; <-ticker.C {
log.Debug().Msg("Cleaning up old database sessions")
_, err := gorm.G[model.Session](db).Where("expiry < ?", time.Now().Unix()).Delete(ctx)
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil {
log.Error().Err(err).Msg("Failed to cleanup old sessions")
log.Error().Err(err).Msg("Failed to clean up old database sessions")
}
}
}

View File

@@ -0,0 +1,53 @@
package bootstrap
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"github.com/steveiliop56/tinyauth/internal/assets"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "modernc.org/sqlite"
)
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil {
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
}
db, err := sql.Open("sqlite", databasePath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err)
}
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
}
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
if err != nil {
return nil, fmt.Errorf("failed to create migrator: %w", err)
}
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
return db, nil
}

View File

@@ -1,6 +1,7 @@
package bootstrap
import (
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/rs/zerolog/log"
@@ -9,27 +10,14 @@ import (
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
databaseService *service.DatabaseService
dockerService *service.DockerService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
}
func (app *BootstrapApp) initServices() (Services, error) {
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
DatabasePath: app.config.DatabasePath,
})
err := databaseService.Init()
if err != nil {
return Services{}, err
}
services.databaseService = databaseService
ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.Ldap.Address,
BindDN: app.config.Ldap.BindDN,
@@ -39,7 +27,7 @@ func (app *BootstrapApp) initServices() (Services, error) {
SearchFilter: app.config.Ldap.SearchFilter,
})
err = ldapService.Init()
err := ldapService.Init()
if err == nil {
services.ldapService = ldapService
@@ -76,7 +64,7 @@ func (app *BootstrapApp) initServices() (Services, error) {
LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
}, dockerService, ldapService, databaseService.GetDatabase())
}, dockerService, ldapService, queries)
err = authService.Init()

View File

@@ -4,8 +4,10 @@ import (
"net/http/httptest"
"testing"
"github.com/steveiliop56/tinyauth/internal/bootstrap"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/gin-gonic/gin"
@@ -26,14 +28,16 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
group := router.Group("/api")
recorder := httptest.NewRecorder()
// Mock app
app := bootstrap.NewBootstrapApp(config.Config{})
// Database
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
DatabasePath: "/tmp/tinyauth_test.db",
})
db, err := app.SetupDatabase(":memory:")
assert.NilError(t, databaseService.Init())
assert.NilError(t, err)
database := databaseService.GetDatabase()
// Queries
queries := repository.New(db)
// Docker
dockerService := service.NewDockerService()
@@ -60,7 +64,7 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
LoginTimeout: 300,
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}, dockerService, nil, database)
}, dockerService, nil, queries)
// Controller
ctrl := controller.NewProxyController(controller.ProxyControllerConfig{

View File

@@ -8,8 +8,10 @@ import (
"testing"
"time"
"github.com/steveiliop56/tinyauth/internal/bootstrap"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/gin-gonic/gin"
@@ -34,14 +36,16 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng
group := router.Group("/api")
recorder := httptest.NewRecorder()
// Mock app
app := bootstrap.NewBootstrapApp(config.Config{})
// Database
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
DatabasePath: "/tmp/tinyauth_test.db",
})
db, err := app.SetupDatabase(":memory:")
assert.NilError(t, databaseService.Init())
assert.NilError(t, err)
database := databaseService.GetDatabase()
// Queries
queries := repository.New(db)
// Auth service
authService := service.NewAuthService(service.AuthServiceConfig{
@@ -63,7 +67,7 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300,
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}, nil, nil, database)
}, nil, nil, queries)
// Controller
ctrl := controller.NewUserController(controller.UserControllerConfig{

View File

@@ -1,14 +0,0 @@
package model
type Session struct {
UUID string `gorm:"column:uuid;primaryKey"`
Username string `gorm:"column:username"`
Email string `gorm:"column:email"`
Name string `gorm:"column:name"`
Provider string `gorm:"column:provider"`
TOTPPending bool `gorm:"column:totp_pending"`
OAuthGroups string `gorm:"column:oauth_groups"`
Expiry int64 `gorm:"column:expiry"`
OAuthName string `gorm:"column:oauth_name"`
OAuthSub string `gorm:"column:oauth_sub"`
}

31
internal/repository/db.go Normal file
View File

@@ -0,0 +1,31 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}

View File

@@ -0,0 +1,18 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
}

View File

@@ -0,0 +1,170 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: query.sql
package repository
import (
"context"
)
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (
"uuid",
"username",
"email",
"name",
"provider",
"totp_pending",
"oauth_groups",
"expiry",
"oauth_name",
"oauth_sub"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name, oauth_sub
`
type CreateSessionParams struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, createSession,
arg.UUID,
arg.Username,
arg.Email,
arg.Name,
arg.Provider,
arg.TotpPending,
arg.OAuthGroups,
arg.Expiry,
arg.OAuthName,
arg.OAuthSub,
)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
DELETE FROM "sessions"
WHERE "expiry" < ?
`
func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
_, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry)
return err
}
const deleteSession = `-- name: DeleteSession :exec
DELETE FROM "sessions"
WHERE "uuid" = ?
`
func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteSession, uuid)
return err
}
const getSession = `-- name: GetSession :one
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name, oauth_sub FROM "sessions"
WHERE "uuid" = ?
`
func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) {
row := q.db.QueryRowContext(ctx, getSession, uuid)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
const updateSession = `-- name: UpdateSession :one
UPDATE "sessions" SET
"username" = ?,
"email" = ?,
"name" = ?,
"provider" = ?,
"totp_pending" = ?,
"oauth_groups" = ?,
"expiry" = ?,
"oauth_name" = ?,
"oauth_sub" = ?
WHERE "uuid" = ?
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name, oauth_sub
`
type UpdateSessionParams struct {
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
UUID string
}
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, updateSession,
arg.Username,
arg.Email,
arg.Name,
arg.Provider,
arg.TotpPending,
arg.OAuthGroups,
arg.Expiry,
arg.OAuthName,
arg.OAuthSub,
arg.UUID,
)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}

View File

@@ -1,6 +1,7 @@
package service
import (
"database/sql"
"errors"
"fmt"
"regexp"
@@ -9,14 +10,13 @@ import (
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/model"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type LoginAttempt struct {
@@ -42,16 +42,16 @@ type AuthService struct {
loginAttempts map[string]*LoginAttempt
loginMutex sync.RWMutex
ldap *LdapService
database *gorm.DB
queries *repository.Queries
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, database *gorm.DB) *AuthService {
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
return &AuthService{
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldap: ldap,
database: database,
queries: queries,
}
}
@@ -203,20 +203,20 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio
expiry = auth.config.SessionExpiry
}
session := model.Session{
session := repository.CreateSessionParams{
UUID: uuid.String(),
Username: data.Username,
Email: data.Email,
Name: data.Name,
Provider: data.Provider,
TOTPPending: data.TotpPending,
TotpPending: data.TotpPending,
OAuthGroups: data.OAuthGroups,
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
}
err = gorm.G[model.Session](auth.database).Create(c, &session)
_, err = auth.queries.CreateSession(c, session)
if err != nil {
return err
@@ -234,7 +234,7 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
return err
}
session, err := gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).First(c)
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
return err
@@ -248,8 +248,17 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
newExpiry := currentTime + int64(time.Hour.Seconds())
_, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Updates(c, model.Session{
Expiry: newExpiry,
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
Username: session.Username,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
Expiry: newExpiry,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
UUID: session.UUID,
})
if err != nil {
@@ -268,7 +277,7 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return err
}
_, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Delete(c)
err = auth.queries.DeleteSession(c, cookie)
if err != nil {
return err
@@ -286,20 +295,19 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
return config.SessionCookie{}, err
}
session, err := gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).First(c)
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return config.SessionCookie{}, fmt.Errorf("session not found")
}
return config.SessionCookie{}, err
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return config.SessionCookie{}, fmt.Errorf("session not found")
}
currentTime := time.Now().Unix()
if currentTime > session.Expiry {
_, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Delete(c)
err = auth.queries.DeleteSession(c, cookie)
if err != nil {
log.Error().Err(err).Msg("Failed to delete expired session")
}
@@ -312,7 +320,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TOTPPending,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,

View File

@@ -1,92 +0,0 @@
package service
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"github.com/steveiliop56/tinyauth/internal/assets"
"github.com/glebarez/sqlite"
"github.com/golang-migrate/migrate/v4"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
"gorm.io/gorm"
)
type DatabaseServiceConfig struct {
DatabasePath string
}
type DatabaseService struct {
config DatabaseServiceConfig
database *gorm.DB
}
func NewDatabaseService(config DatabaseServiceConfig) *DatabaseService {
return &DatabaseService{
config: config,
}
}
func (ds *DatabaseService) Init() error {
dbPath := ds.config.DatabasePath
if dbPath == "" {
dbPath = "/data/tinyauth.db"
}
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
}
gormDB, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
if err != nil {
return err
}
sqlDB, err := gormDB.DB()
if err != nil {
return err
}
sqlDB.SetMaxOpenConns(1)
err = ds.migrateDatabase(sqlDB)
if err != nil && err != migrate.ErrNoChange {
return err
}
ds.database = gormDB
return nil
}
func (ds *DatabaseService) migrateDatabase(sqlDB *sql.DB) error {
data, err := iofs.New(assets.Migrations, "migrations")
if err != nil {
return err
}
target, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
if err != nil {
return err
}
migrator, err := migrate.NewWithInstance("iofs", data, "tinyauth", target)
if err != nil {
return err
}
return migrator.Up()
}
func (ds *DatabaseService) GetDatabase() *gorm.DB {
return ds.database
}