From c8e86d853675449b2ced36dd426ad9baa9c071c4 Mon Sep 17 00:00:00 2001 From: Stavros Date: Wed, 24 Dec 2025 20:17:41 +0200 Subject: [PATCH] refactor: replace gorm with vanilla sql and sqlc --- CONTRIBUTING.md | 2 +- docker-compose.dev.yml | 1 - go.mod | 7 +- go.sum | 10 -- internal/bootstrap/app_bootstrap.go | 25 ++-- internal/bootstrap/db_bootstrap.go | 52 ++++++++ internal/bootstrap/service_bootstrap.go | 20 +-- internal/model/session_model.go | 13 -- internal/repository/db.go | 31 +++++ internal/repository/models.go | 17 +++ internal/repository/query.sql.go | 161 ++++++++++++++++++++++++ internal/service/auth_service.go | 26 ++-- internal/service/database_service.go | 91 -------------- query.sql | 40 ++++++ schema.sql | 11 ++ sqlc.yml | 18 +++ 16 files changed, 366 insertions(+), 159 deletions(-) create mode 100644 internal/bootstrap/db_bootstrap.go delete mode 100644 internal/model/session_model.go create mode 100644 internal/repository/db.go create mode 100644 internal/repository/models.go create mode 100644 internal/repository/query.sql.go delete mode 100644 internal/service/database_service.go create mode 100644 query.sql create mode 100644 schema.sql create mode 100644 sqlc.yml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 079d181..1ace4da 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,7 +23,7 @@ cd tinyauth Although you will not need the requirements in your machine since the development will happen in docker, I still recommend to install them because this way you will not have import errors. To install the go requirements run: ```sh -go mod tidy +go mod download ``` You also need to download the frontend dependencies, this can be done like so: diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index cc454f6..8ffd6fd 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -42,7 +42,6 @@ services: volumes: - ./internal:/tinyauth/internal - ./cmd:/tinyauth/cmd - - ./main.go:/tinyauth/main.go - /var/run/docker.sock:/var/run/docker.sock - ./data:/data ports: diff --git a/go.mod b/go.mod index 092a6b6..94c7ac8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ toolchain go1.24.3 require ( github.com/cenkalti/backoff/v5 v5.0.3 github.com/gin-gonic/gin v1.11.0 - github.com/glebarez/sqlite v1.11.0 github.com/golang-migrate/migrate/v4 v4.19.1 github.com/google/go-querystring v1.1.0 github.com/google/uuid v1.6.0 @@ -17,8 +16,8 @@ require ( github.com/weppos/publicsuffix-go v0.50.1 golang.org/x/crypto v0.46.0 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b - gorm.io/gorm v1.31.1 gotest.tools/v3 v3.5.2 + modernc.org/sqlite v1.38.2 ) require ( @@ -32,15 +31,12 @@ require ( github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/log v0.1.0 // indirect - github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/go-playground/validator/v10 v10.28.0 // indirect github.com/goccy/go-yaml v1.18.0 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/imdario/mergo v0.3.11 // indirect - github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect @@ -59,7 +55,6 @@ require ( modernc.org/libc v1.66.3 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.38.2 // indirect rsc.io/qr v0.2.0 // indirect ) diff --git a/go.sum b/go.sum index 4a44d52..520b5f4 100644 --- a/go.sum +++ b/go.sum @@ -101,10 +101,6 @@ github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= -github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= -github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= -github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= -github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= @@ -161,10 +157,6 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6 github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= -github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= -github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= -github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= @@ -378,8 +370,6 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= -gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM= diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index a4d8024..bfbd0c7 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -13,11 +13,10 @@ import ( "time" "tinyauth/internal/config" "tinyauth/internal/controller" - "tinyauth/internal/model" + "tinyauth/internal/repository" "tinyauth/internal/utils" "github.com/rs/zerolog/log" - "gorm.io/gorm" ) type BootstrapApp struct { @@ -107,8 +106,18 @@ func (app *BootstrapApp) Setup() error { log.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name") log.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") + // Database + db, err := app.setupDatabase(app.config.DatabasePath) + + if err != nil { + return fmt.Errorf("failed to setup database: %w", err) + } + + // Queries + queries := repository.New(db) + // Services - services, err := app.initServices() + services, err := app.initServices(queries) if err != nil { return fmt.Errorf("failed to initialize services: %w", err) @@ -154,9 +163,9 @@ func (app *BootstrapApp) Setup() error { return fmt.Errorf("failed to setup routes: %w", err) } - // Start DB cleanup routine + // Start db cleanup routine log.Debug().Msg("Starting database cleanup routine") - go app.dbCleanup(services.databaseService.GetDatabase()) + go app.dbCleanup(queries) // If analytics are not disabled, start heartbeat if !app.config.DisableAnalytics { @@ -246,16 +255,16 @@ func (app *BootstrapApp) heartbeat() { } } -func (app *BootstrapApp) dbCleanup(db *gorm.DB) { +func (app *BootstrapApp) dbCleanup(queries *repository.Queries) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() for ; true; <-ticker.C { log.Debug().Msg("Cleaning up old database sessions") - _, err := gorm.G[model.Session](db).Where("expiry < ?", time.Now().Unix()).Delete(ctx) + err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) if err != nil { - log.Error().Err(err).Msg("Failed to cleanup old sessions") + log.Error().Err(err).Msg("Failed to clean up old database sessions") } } } diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go new file mode 100644 index 0000000..9969361 --- /dev/null +++ b/internal/bootstrap/db_bootstrap.go @@ -0,0 +1,52 @@ +package bootstrap + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "tinyauth/internal/assets" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "modernc.org/sqlite" +) + +func (app *BootstrapApp) setupDatabase(databasePath string) (*sql.DB, error) { + dir := filepath.Dir(databasePath) + + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) + } + + db, err := sql.Open("sqlite", databasePath) + + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + migrations, err := iofs.New(assets.Migrations, "migrations") + + if err != nil { + return nil, fmt.Errorf("failed to create migrations: %w", err) + } + + target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) + + if err != nil { + return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err) + } + + migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) + + if err != nil { + return nil, fmt.Errorf("failed to create migrator: %w", err) + } + + if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { + return nil, fmt.Errorf("failed to migrate database: %w", err) + } + + return db, nil +} diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index e18d832..5731616 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -1,6 +1,7 @@ package bootstrap import ( + "tinyauth/internal/repository" "tinyauth/internal/service" "github.com/rs/zerolog/log" @@ -9,27 +10,14 @@ import ( type Services struct { accessControlService *service.AccessControlsService authService *service.AuthService - databaseService *service.DatabaseService dockerService *service.DockerService ldapService *service.LdapService oauthBrokerService *service.OAuthBrokerService } -func (app *BootstrapApp) initServices() (Services, error) { +func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { services := Services{} - databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{ - DatabasePath: app.config.DatabasePath, - }) - - err := databaseService.Init() - - if err != nil { - return Services{}, err - } - - services.databaseService = databaseService - ldapService := service.NewLdapService(service.LdapServiceConfig{ Address: app.config.Ldap.Address, BindDN: app.config.Ldap.BindDN, @@ -39,7 +27,7 @@ func (app *BootstrapApp) initServices() (Services, error) { SearchFilter: app.config.Ldap.SearchFilter, }) - err = ldapService.Init() + err := ldapService.Init() if err == nil { services.ldapService = ldapService @@ -76,7 +64,7 @@ func (app *BootstrapApp) initServices() (Services, error) { LoginTimeout: app.config.Auth.LoginTimeout, LoginMaxRetries: app.config.Auth.LoginMaxRetries, SessionCookieName: app.context.sessionCookieName, - }, dockerService, ldapService, databaseService.GetDatabase()) + }, dockerService, ldapService, queries) err = authService.Init() diff --git a/internal/model/session_model.go b/internal/model/session_model.go deleted file mode 100644 index 0fdb6c3..0000000 --- a/internal/model/session_model.go +++ /dev/null @@ -1,13 +0,0 @@ -package model - -type Session struct { - UUID string `gorm:"column:uuid;primaryKey"` - Username string `gorm:"column:username"` - Email string `gorm:"column:email"` - Name string `gorm:"column:name"` - Provider string `gorm:"column:provider"` - TOTPPending bool `gorm:"column:totp_pending"` - OAuthGroups string `gorm:"column:oauth_groups"` - Expiry int64 `gorm:"column:expiry"` - OAuthName string `gorm:"column:oauth_name"` -} diff --git a/internal/repository/db.go b/internal/repository/db.go new file mode 100644 index 0000000..998bfd3 --- /dev/null +++ b/internal/repository/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package repository + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/repository/models.go b/internal/repository/models.go new file mode 100644 index 0000000..5283d3f --- /dev/null +++ b/internal/repository/models.go @@ -0,0 +1,17 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package repository + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string +} diff --git a/internal/repository/query.sql.go b/internal/repository/query.sql.go new file mode 100644 index 0000000..ba47872 --- /dev/null +++ b/internal/repository/query.sql.go @@ -0,0 +1,161 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package repository + +import ( + "context" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO sessions ( + "uuid", + "username", + "email", + "name", + "provider", + "totp_pending", + "oauth_groups", + "expiry", + "oauth_name" +) VALUES ( + ?, ?, ?, ?, ?, ?, ?, ?, ? +) +RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name +` + +type CreateSessionParams struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, createSession, + arg.UUID, + arg.Username, + arg.Email, + arg.Name, + arg.Provider, + arg.TotpPending, + arg.OAuthGroups, + arg.Expiry, + arg.OAuthName, + ) + var i Session + err := row.Scan( + &i.UUID, + &i.Username, + &i.Email, + &i.Name, + &i.Provider, + &i.TotpPending, + &i.OAuthGroups, + &i.Expiry, + &i.OAuthName, + ) + return i, err +} + +const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec +DELETE FROM "sessions" +WHERE "expiry" < ? +` + +func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + _, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry) + return err +} + +const deleteSession = `-- name: DeleteSession :exec +DELETE FROM "sessions" +WHERE "uuid" = ? +` + +func (q *Queries) DeleteSession(ctx context.Context, uuid string) error { + _, err := q.db.ExecContext(ctx, deleteSession, uuid) + return err +} + +const getSession = `-- name: GetSession :one +SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name FROM "sessions" +WHERE "uuid" = ? +` + +func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, uuid) + var i Session + err := row.Scan( + &i.UUID, + &i.Username, + &i.Email, + &i.Name, + &i.Provider, + &i.TotpPending, + &i.OAuthGroups, + &i.Expiry, + &i.OAuthName, + ) + return i, err +} + +const updateSession = `-- name: UpdateSession :one +UPDATE "sessions" SET + "username" = ?, + "email" = ?, + "name" = ?, + "provider" = ?, + "totp_pending" = ?, + "oauth_groups" = ?, + "expiry" = ?, + "oauth_name" = ? +WHERE "uuid" = ? +RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name +` + +type UpdateSessionParams struct { + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string + UUID string +} + +func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, updateSession, + arg.Username, + arg.Email, + arg.Name, + arg.Provider, + arg.TotpPending, + arg.OAuthGroups, + arg.Expiry, + arg.OAuthName, + arg.UUID, + ) + var i Session + err := row.Scan( + &i.UUID, + &i.Username, + &i.Email, + &i.Name, + &i.Provider, + &i.TotpPending, + &i.OAuthGroups, + &i.Expiry, + &i.OAuthName, + ) + return i, err +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index bcba481..c09e04d 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "database/sql" "errors" "fmt" "regexp" @@ -9,14 +10,13 @@ import ( "sync" "time" "tinyauth/internal/config" - "tinyauth/internal/model" + "tinyauth/internal/repository" "tinyauth/internal/utils" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" ) type LoginAttempt struct { @@ -42,17 +42,17 @@ type AuthService struct { loginAttempts map[string]*LoginAttempt loginMutex sync.RWMutex ldap *LdapService - database *gorm.DB + queries *repository.Queries ctx context.Context } -func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, database *gorm.DB) *AuthService { +func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService { return &AuthService{ config: config, docker: docker, loginAttempts: make(map[string]*LoginAttempt), ldap: ldap, - database: database, + queries: queries, } } @@ -205,19 +205,19 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio expiry = auth.config.SessionExpiry } - session := model.Session{ + session := repository.CreateSessionParams{ UUID: uuid.String(), Username: data.Username, Email: data.Email, Name: data.Name, Provider: data.Provider, - TOTPPending: data.TotpPending, + TotpPending: data.TotpPending, OAuthGroups: data.OAuthGroups, Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), OAuthName: data.OAuthName, } - err = gorm.G[model.Session](auth.database).Create(auth.ctx, &session) + _, err = auth.queries.CreateSession(c, session) if err != nil { return err @@ -235,7 +235,7 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { return err } - _, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Delete(auth.ctx) + err = auth.queries.DeleteSession(auth.ctx, cookie) if err != nil { return err @@ -253,20 +253,20 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, return config.SessionCookie{}, err } - session, err := gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).First(auth.ctx) + session, err := auth.queries.GetSession(auth.ctx, cookie) if err != nil { return config.SessionCookie{}, err } - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, sql.ErrNoRows) { return config.SessionCookie{}, fmt.Errorf("session not found") } currentTime := time.Now().Unix() if currentTime > session.Expiry { - _, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Delete(auth.ctx) + err = auth.queries.DeleteSession(auth.ctx, cookie) if err != nil { log.Error().Err(err).Msg("Failed to delete expired session") } @@ -279,7 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, Email: session.Email, Name: session.Name, Provider: session.Provider, - TotpPending: session.TOTPPending, + TotpPending: session.TotpPending, OAuthGroups: session.OAuthGroups, OAuthName: session.OAuthName, }, nil diff --git a/internal/service/database_service.go b/internal/service/database_service.go deleted file mode 100644 index 30d8803..0000000 --- a/internal/service/database_service.go +++ /dev/null @@ -1,91 +0,0 @@ -package service - -import ( - "database/sql" - "fmt" - "os" - "path/filepath" - "tinyauth/internal/assets" - - "github.com/glebarez/sqlite" - "github.com/golang-migrate/migrate/v4" - sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3" - "github.com/golang-migrate/migrate/v4/source/iofs" - "gorm.io/gorm" -) - -type DatabaseServiceConfig struct { - DatabasePath string -} - -type DatabaseService struct { - config DatabaseServiceConfig - database *gorm.DB -} - -func NewDatabaseService(config DatabaseServiceConfig) *DatabaseService { - return &DatabaseService{ - config: config, - } -} - -func (ds *DatabaseService) Init() error { - dbPath := ds.config.DatabasePath - if dbPath == "" { - dbPath = "/data/tinyauth.db" - } - - dir := filepath.Dir(dbPath) - if err := os.MkdirAll(dir, 0755); err != nil { - return fmt.Errorf("failed to create database directory %s: %w", dir, err) - } - - gormDB, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{}) - - if err != nil { - return err - } - - sqlDB, err := gormDB.DB() - - if err != nil { - return err - } - - sqlDB.SetMaxOpenConns(1) - - err = ds.migrateDatabase(sqlDB) - - if err != nil && err != migrate.ErrNoChange { - return err - } - - ds.database = gormDB - return nil -} - -func (ds *DatabaseService) migrateDatabase(sqlDB *sql.DB) error { - data, err := iofs.New(assets.Migrations, "migrations") - - if err != nil { - return err - } - - target, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{}) - - if err != nil { - return err - } - - migrator, err := migrate.NewWithInstance("iofs", data, "tinyauth", target) - - if err != nil { - return err - } - - return migrator.Up() -} - -func (ds *DatabaseService) GetDatabase() *gorm.DB { - return ds.database -} diff --git a/query.sql b/query.sql new file mode 100644 index 0000000..8737f48 --- /dev/null +++ b/query.sql @@ -0,0 +1,40 @@ +-- name: CreateSession :one +INSERT INTO sessions ( + "uuid", + "username", + "email", + "name", + "provider", + "totp_pending", + "oauth_groups", + "expiry", + "oauth_name" +) VALUES ( + ?, ?, ?, ?, ?, ?, ?, ?, ? +) +RETURNING *; + +-- name: GetSession :one +SELECT * FROM "sessions" +WHERE "uuid" = ?; + +-- name: DeleteSession :exec +DELETE FROM "sessions" +WHERE "uuid" = ?; + +-- name: UpdateSession :one +UPDATE "sessions" SET + "username" = ?, + "email" = ?, + "name" = ?, + "provider" = ?, + "totp_pending" = ?, + "oauth_groups" = ?, + "expiry" = ?, + "oauth_name" = ? +WHERE "uuid" = ? +RETURNING *; + +-- name: DeleteExpiredSessions :exec +DELETE FROM "sessions" +WHERE "expiry" < ?; \ No newline at end of file diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..d26cfd0 --- /dev/null +++ b/schema.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS "sessions" ( + "uuid" TEXT NOT NULL PRIMARY KEY UNIQUE, + "username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "name" TEXT NOT NULL, + "provider" TEXT NOT NULL, + "totp_pending" BOOLEAN NOT NULL, + "oauth_groups" TEXT NULL, + "expiry" INTEGER NOT NULL, + "oauth_name" TEXT NULL +); diff --git a/sqlc.yml b/sqlc.yml new file mode 100644 index 0000000..0ab33c0 --- /dev/null +++ b/sqlc.yml @@ -0,0 +1,18 @@ +version: "2" +sql: + - engine: "sqlite" + queries: "query.sql" + schema: "schema.sql" + gen: + go: + package: "repository" + out: "internal/repository" + rename: + uuid: "UUID" + oauth_groups: "OAuthGroups" + oauth_name: "OAuthName" + overrides: + - column: "sessions.oauth_groups" + go_type: "string" + - column: "sessions.oauth_name" + go_type: "string"