Compare commits

..

4 Commits

Author SHA1 Message Date
Scott McKendry 3eeabd3623 refactor(db): cleanup sqlc-wrapper gen 2026-05-07 19:12:00 +12:00
Scott McKendry 04b8e9884b feat(db): add memory storage driver
removes the sqlite dependency for tests, also brings back the option for
users to run zero persistence instances of tinyauth.

adds new mapErr fn for sqlc wrapper gen to prevent sql errors from
leaking out of the store implementation.
2026-05-04 05:02:27 +12:00
Scott McKendry 0244f39387 feat(db): add code gen to build sqlc-compatible wrappers 2026-05-03 14:11:54 +12:00
Scott McKendry 1d0a4627a9 refactor(db): use new store interface 2026-04-30 19:18:33 +12:00
51 changed files with 1406 additions and 108 deletions
+6
View File
@@ -26,6 +26,12 @@ jobs:
- name: Go dependencies
run: go mod download
- name: Check codegen is up to date
run: |
go generate ./internal/repository/...
git diff --exit-code -- internal/repository/
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
- name: Install frontend dependencies
run: |
cd frontend
+1 -1
View File
@@ -38,6 +38,6 @@ jobs:
retention-days: 5
- name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4
with:
sarif_file: results.sarif
+1 -1
View File
@@ -84,4 +84,4 @@ sql:
# Go gen
generate:
go run ./gen
go generate ./internal/repository/...
+527
View File
@@ -0,0 +1,527 @@
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
// internal/repository/<driver>/. Run via:
//
// go generate ./internal/repository/...
//
// The generator introspects *Queries methods and the model/params types in the
// driver package, then emits a store.go that wraps *Queries so it satisfies
// repository.Store using the canonical shared types in the parent package.
// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should
// implement repository.Store directly by hand.
package main
import (
"bytes"
_ "embed"
"flag"
"fmt"
"go/format"
"go/types"
"log"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"text/template"
"golang.org/x/tools/go/packages"
)
//go:embed store.tmpl
var storeSrc string
func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}
func run() error {
driverPkg := flag.String("pkg", "", "import path of the driver package")
out := flag.String("out", "store.go", "output filename relative to driver package directory")
flag.Parse()
if *driverPkg == "" {
return fmt.Errorf("-pkg is required")
}
// Resolve the driver package directory so we can overlay the output file
// with a valid stub. This prevents a stale store.go from poisoning the
// type-checker and producing cryptic "undefined" errors.
driverDir, err := pkgDir(*driverPkg)
if err != nil {
return fmt.Errorf("resolve driver dir: %w", err)
}
outPath := filepath.Join(driverDir, *out)
if filepath.IsAbs(*out) {
outPath = *out
}
// Stub replaces the output file during load so stale generated code is ignored.
stub := []byte("package " + filepath.Base(driverDir) + "\n")
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
Overlay: map[string][]byte{outPath: stub},
}
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
if err != nil {
return fmt.Errorf("load driver package: %w", err)
}
repoPkgPath := parentPkg(*driverPkg)
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if err != nil {
return fmt.Errorf("load repo package: %w", err)
}
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
return fmt.Errorf("struct shape mismatch: %w", err)
}
if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil {
return err
}
methods, err := collectMethods(driverTypePkg)
if err != nil {
return err
}
models, _ := collectTypes(driverTypePkg)
src, err := render(tmplData{
PkgName: driverTypePkg.Name(),
RepoPkg: repoPkgPath,
ModelTypes: models,
Methods: renderMethods(methods),
})
if err != nil {
return fmt.Errorf("render: %w", err)
}
if err := os.WriteFile(outPath, src, 0644); err != nil {
return fmt.Errorf("write %s: %w", outPath, err)
}
fmt.Printf("wrote %s\n", outPath)
return nil
}
// loadOnePkg loads a single package via cfg and returns its *types.Package,
// or an error if the package fails to load or has type errors.
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
pkgs, err := packages.Load(cfg, importPath)
if err != nil {
return nil, fmt.Errorf("load %s: %w", importPath, err)
}
if len(pkgs) != 1 {
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
}
pkg := pkgs[0]
if len(pkg.Errors) > 0 {
msgs := make([]string, len(pkg.Errors))
for i, e := range pkg.Errors {
msgs[i] = e.Error()
}
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
}
return pkg.Types, nil
}
// parentPkg returns the parent import path (everything before the last /).
// Panics if imp contains no slash — callers are expected to pass driver sub-packages.
func parentPkg(imp string) string {
i := strings.LastIndex(imp, "/")
if i < 0 {
panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp))
}
return imp[:i]
}
// pkgDir returns the on-disk directory for an import path using `go list`.
func pkgDir(importPath string) (string, error) {
out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output()
if err != nil {
return "", fmt.Errorf("go list %s: %w", importPath, err)
}
return strings.TrimSpace(string(out)), nil
}
// scopeStructs returns all named struct types in pkg, excluding the internal
// sqlc types Queries, DBTX, and Store. Names are returned in sorted order.
func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) {
byName = make(map[string]*types.Struct)
for _, name := range pkg.Scope().Names() { // Names() is already sorted
switch name {
case "Queries", "DBTX", "Store":
continue
}
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
if !ok {
continue
}
named, ok := obj.Type().(*types.Named)
if !ok {
continue
}
s, ok := named.Underlying().(*types.Struct)
if !ok {
continue
}
names = append(names, name)
byName[name] = s
}
return
}
// validateStoreCoverage checks that every method declared in repository.Store
// exists on *Queries in the driver package. Missing methods are reported by
// name so the developer knows exactly which SQL queries need to be added.
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
queriesObj := driverPkg.Scope().Lookup("Queries")
if queriesObj == nil {
return fmt.Errorf("queries type not found in driver package")
}
queriesNamed := queriesObj.Type().(*types.Named)
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
queriesMethods := make(map[string]bool)
for m := range queriesMS.Methods() {
queriesMethods[m.Obj().Name()] = true
}
storeObj := repoPkg.Scope().Lookup("Store")
if storeObj == nil {
return fmt.Errorf("store type not found in repository package")
}
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
if !ok {
return fmt.Errorf("repository.Store is not an interface")
}
var missing []string
for method := range storeIface.Methods() {
if name := method.Name(); !queriesMethods[name] {
missing = append(missing, name)
}
}
if len(missing) > 0 {
sort.Strings(missing)
return fmt.Errorf(
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries",
len(missing), strings.Join(missing, "\n - "),
)
}
return nil
}
// validateStructShapes checks that every model/params struct in the driver
// package has fields that exactly match the corresponding type in the repo
// (parent) package. This catches drift between sqlc-generated types and the
// canonical repository types before a broken cast reaches the compiler.
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
_, repoStructs := scopeStructs(repoPkg)
driverNames, driverStructs := scopeStructs(driverPkg)
var errs []string
for _, name := range driverNames {
repoStruct, ok := repoStructs[name]
if !ok {
// Driver has a type not in repo — fine (e.g. internal helpers).
continue
}
if err := compareStructs(name, driverStructs[name], repoStruct); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
sort.Strings(errs)
return fmt.Errorf("%s", strings.Join(errs, "\n "))
}
return nil
}
func compareStructs(name string, driver, repo *types.Struct) error {
if driver.NumFields() != repo.NumFields() {
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
name, driver.NumFields(), repo.NumFields())
}
for i := range driver.NumFields() {
df := driver.Field(i)
rf := repo.Field(i)
if df.Name() != rf.Name() {
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
name, i, df.Name(), rf.Name())
}
if !types.Identical(df.Type(), rf.Type()) {
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
name, df.Name(), df.Type(), rf.Type())
}
}
return nil
}
// collectTypes returns model and params struct names from the driver package.
func collectTypes(pkg *types.Package) (models []string, params []string) {
names, _ := scopeStructs(pkg)
for _, name := range names {
if strings.HasSuffix(name, "Params") {
params = append(params, name)
} else {
models = append(models, name)
}
}
return
}
type methodInfo struct {
Name string
Params []paramInfo
Results []resultInfo
}
type paramInfo struct {
Name string
TypeStr string // local (unqualified) type name
RepoType string // "repository.X" if this is a driver model/params type; else ""
}
type resultInfo struct {
TypeStr string
IsSlice bool
RepoType string // "repository.X" if driver type; else ""
}
func collectMethods(pkg *types.Package) ([]methodInfo, error) {
obj := pkg.Scope().Lookup("Queries")
if obj == nil {
return nil, fmt.Errorf("queries type not found in %s", pkg.Path())
}
named, ok := obj.Type().(*types.Named)
if !ok {
return nil, fmt.Errorf("queries is not a named type")
}
ms := types.NewMethodSet(types.NewPointer(named))
var out []methodInfo
for method := range ms.Methods() {
fn, ok := method.Obj().(*types.Func)
if !ok || fn.Name() == "WithTx" {
continue
}
sig := fn.Type().(*types.Signature)
mi := methodInfo{Name: fn.Name()}
// params: skip receiver + first (context.Context)
for i := 1; i < sig.Params().Len(); i++ {
p := sig.Params().At(i)
mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path()))
}
// results: skip error
for r := range sig.Results().Variables() {
if r.Type().String() == "error" {
continue
}
mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path()))
}
out = append(out, mi)
}
return out, nil
}
func makeParam(name string, t types.Type, driverPath string) paramInfo {
return paramInfo{
Name: name,
TypeStr: localName(t, driverPath),
RepoType: repoName(t, driverPath),
}
}
func makeResult(t types.Type, driverPath string) resultInfo {
ri := resultInfo{}
if sl, ok := t.(*types.Slice); ok {
ri.IsSlice = true
t = sl.Elem()
}
ri.TypeStr = localName(t, driverPath)
ri.RepoType = repoName(t, driverPath)
return ri
}
func localName(t types.Type, driverPath string) string {
named, ok := t.(*types.Named)
if !ok {
return types.TypeString(t, nil)
}
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
return named.Obj().Name()
}
return types.TypeString(t, func(p *types.Package) string { return p.Name() })
}
func repoName(t types.Type, driverPath string) string {
named, ok := t.(*types.Named)
if !ok {
return ""
}
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
return "repository." + named.Obj().Name()
}
return ""
}
// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo".
func converterFn(s string) string {
if s == "" {
return ""
}
return strings.ToLower(s[:1]) + s[1:] + "ToRepo"
}
// renderedMethod holds pre-built signature and body strings passed to the template.
type renderedMethod struct {
Signature string
Body string
}
func renderMethods(methods []methodInfo) []renderedMethod {
out := make([]renderedMethod, len(methods))
for i, m := range methods {
out[i] = renderedMethod{
Signature: buildSig(m),
Body: buildBody(m),
}
}
return out
}
func buildSig(m methodInfo) string {
var sb strings.Builder
sb.WriteString("func (s *Store) ")
sb.WriteString(m.Name)
sb.WriteString("(ctx context.Context")
for _, p := range m.Params {
sb.WriteString(", ")
sb.WriteString(p.Name)
sb.WriteString(" ")
if p.RepoType != "" {
sb.WriteString(p.RepoType)
} else {
sb.WriteString(p.TypeStr)
}
}
sb.WriteString(") (")
for _, r := range m.Results {
if r.IsSlice {
sb.WriteString("[]")
}
if r.RepoType != "" {
sb.WriteString(r.RepoType)
} else {
sb.WriteString(r.TypeStr)
}
sb.WriteString(", ")
}
sb.WriteString("error)")
return sb.String()
}
func callArgs(m methodInfo) string {
args := make([]string, 0, len(m.Params))
for _, p := range m.Params {
if p.RepoType != "" {
// convert repo type → driver type: DriverType(arg)
args = append(args, p.TypeStr+"("+p.Name+")")
} else {
args = append(args, p.Name)
}
}
if len(args) == 0 {
return "ctx"
}
return "ctx, " + strings.Join(args, ", ")
}
// bodyTemplates holds the per-shape method body templates, parsed once at init.
var bodyTemplates = template.Must(
template.New("bodies").Parse(`
{{define "void"}} return mapErr({{.Call}})
{{end}}
{{define "scalar"}} r, err := {{.Call}}
if err != nil {
return {{.RepoType}}{}, mapErr(err)
}
return {{.Converter}}(r), nil
{{end}}
{{define "slice"}} rows, err := {{.Call}}
if err != nil {
return nil, mapErr(err)
}
out := make([]{{.RepoType}}, len(rows))
for i, row := range rows {
out[i] = {{.Converter}}(row)
}
return out, nil
{{end}}`),
)
type bodyData struct {
Call string
RepoType string
Converter string
}
func buildBody(m methodInfo) string {
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
var (
name string
data bodyData
)
switch {
case len(m.Results) == 0 || m.Results[0].RepoType == "":
name = "void"
data = bodyData{Call: call}
case m.Results[0].IsSlice:
name = "slice"
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
default:
name = "scalar"
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
}
var buf bytes.Buffer
if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil {
panic(fmt.Sprintf("buildBody %s: %v", name, err))
}
return buf.String()
}
type tmplData struct {
PkgName string
RepoPkg string
ModelTypes []string
Methods []renderedMethod
}
func render(data tmplData) ([]byte, error) {
t, err := template.New("store").Funcs(template.FuncMap{
"converterFn": converterFn,
}).Parse(storeSrc)
if err != nil {
return nil, fmt.Errorf("parse template: %w", err)
}
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute template: %w", err)
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String())
}
return formatted, nil
}
+46
View File
@@ -0,0 +1,46 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package {{.PkgName}}
import (
"context"
"database/sql"
"errors"
"{{.RepoPkg}}"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
{{range .ModelTypes -}}
func {{converterFn .}}(v {{.}}) repository.{{.}} {
return repository.{{.}}(v)
}
{{end -}}
{{range .Methods}}{{.Signature}} {
{{.Body}}}
{{end}}
+2
View File
@@ -20,6 +20,7 @@ require (
github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.50.0
golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.43.0
gotest.tools/v3 v3.5.2
k8s.io/apimachinery v0.32.2
k8s.io/client-go v0.32.2
@@ -124,6 +125,7 @@ require (
go.opentelemetry.io/otel/trace v1.43.0 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations
//
//go:embed migrations/*.sql
//go:embed migrations/sqlite/*.sql
var Migrations embed.FS
+4 -7
View File
@@ -130,17 +130,14 @@ func (app *BootstrapApp) Setup() error {
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
// Database
db, err := app.SetupDatabase(app.config.Database.Path)
store, err := app.SetupStore()
if err != nil {
return fmt.Errorf("failed to setup database: %w", err)
}
// Queries
queries := repository.New(db)
// Services
services, err := app.initServices(queries)
services, err := app.initServices(store)
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
@@ -196,7 +193,7 @@ func (app *BootstrapApp) Setup() error {
// Start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(queries)
go app.dbCleanupRoutine(store)
// If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
@@ -286,7 +283,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
}
}
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
func (app *BootstrapApp) dbCleanupRoutine(queries repository.Store) {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
ctx := context.Background()
+17 -3
View File
@@ -7,6 +7,9 @@ import (
"path/filepath"
"github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
@@ -14,7 +17,18 @@ import (
_ "modernc.org/sqlite"
)
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
switch app.config.Database.Driver {
case "memory":
return memory.New(), nil
case "sqlite", "":
return app.setupSQLite(app.config.Database.Path)
default:
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
}
}
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil {
@@ -31,7 +45,7 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
// if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1)
migrations, err := iofs.New(assets.Migrations, "migrations")
migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err)
@@ -53,5 +67,5 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
return db, nil
return sqlite.NewStore(sqlite.New(db)), nil
}
+1 -1
View File
@@ -18,7 +18,7 @@ type Services struct {
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
func (app *BootstrapApp) initServices(queries repository.Store) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{
+3 -1
View File
@@ -4,6 +4,7 @@ package config
func NewDefaultConfiguration() *Config {
return &Config{
Database: DatabaseConfig{
Driver: "sqlite",
Path: "./tinyauth.db",
},
Analytics: AnalyticsConfig{
@@ -95,7 +96,8 @@ type Config struct {
}
type DatabaseConfig struct {
Path string `description:"The path to the database, including file name." yaml:"path"`
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
}
type AnalyticsConfig struct {
+4 -14
View File
@@ -12,10 +12,9 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
@@ -848,14 +847,10 @@ func TestOIDCController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(config.Config{})
store := memory.New()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
oidcService := service.NewOIDCService(oidcServiceCfg, store)
err := oidcService.Init()
require.NoError(t, err)
for _, test := range tests {
@@ -877,9 +872,4 @@ func TestOIDCController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
+4 -17
View File
@@ -2,14 +2,12 @@ package controller_test
import (
"net/http/httptest"
"path"
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
@@ -18,7 +16,6 @@ import (
func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
@@ -393,15 +390,10 @@ func TestProxyController(t *testing.T) {
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
store := memory.New()
docker := service.NewDockerService()
err = docker.Init()
err := docker.Init()
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{})
@@ -412,7 +404,7 @@ func TestProxyController(t *testing.T) {
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
err = authService.Init()
require.NoError(t, err)
@@ -437,9 +429,4 @@ func TestProxyController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
+4 -17
View File
@@ -3,17 +3,15 @@ package controller_test
import (
"encoding/json"
"net/http/httptest"
"path"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
@@ -22,7 +20,6 @@ import (
func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
@@ -351,15 +348,10 @@ func TestUserController(t *testing.T) {
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
store := memory.New()
docker := service.NewDockerService()
err = docker.Init()
err := docker.Init()
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{})
@@ -370,7 +362,7 @@ func TestUserController(t *testing.T) {
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
err = authService.Init()
require.NoError(t, err)
@@ -435,9 +427,4 @@ func TestUserController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
@@ -8,10 +8,9 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
@@ -101,15 +100,10 @@ func TestWellKnownController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(config.Config{})
store := memory.New()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
oidcService := service.NewOIDCService(oidcServiceCfg, store)
err := oidcService.Init()
require.NoError(t, err)
for _, test := range tests {
@@ -125,9 +119,4 @@ func TestWellKnownController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
+241
View File
@@ -0,0 +1,241 @@
package memory
import (
"context"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, c := range s.oidcCodes {
if c.Sub == arg.Sub {
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
}
}
code := repository.OidcCode(arg)
s.oidcCodes[arg.CodeHash] = code
return code, nil
}
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcCode{}, repository.ErrNotFound
}
delete(s.oidcCodes, codeHash)
return c, nil
}
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcCode{}, repository.ErrNotFound
}
return c, nil
}
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.oidcCodes {
if c.Sub == sub {
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcCodes, codeHash)
return nil
}
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcCode
for k, c := range s.oidcCodes {
if c.ExpiresAt < expiresAt {
deleted = append(deleted, c)
delete(s.oidcCodes, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, t := range s.oidcTokens {
if t.Sub == arg.Sub {
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
}
}
tok := repository.OidcToken{
Sub: arg.Sub,
AccessTokenHash: arg.AccessTokenHash,
RefreshTokenHash: arg.RefreshTokenHash,
CodeHash: arg.CodeHash,
Scope: arg.Scope,
ClientID: arg.ClientID,
TokenExpiresAt: arg.TokenExpiresAt,
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
Nonce: arg.Nonce,
}
s.oidcTokens[arg.AccessTokenHash] = tok
return tok, nil
}
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.oidcTokens[accessTokenHash]
if !ok {
return repository.OidcToken{}, repository.ErrNotFound
}
return t, nil
}
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.RefreshTokenHash == refreshTokenHash {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.Sub == sub {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
delete(s.oidcTokens, k)
t.AccessTokenHash = arg.AccessTokenHash
t.RefreshTokenHash = arg.RefreshTokenHash
t.TokenExpiresAt = arg.TokenExpiresAt
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
s.oidcTokens[arg.AccessTokenHash] = t
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcTokens, accessTokenHash)
return nil
}
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.Sub == sub {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.CodeHash == codeHash {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcToken
for k, t := range s.oidcTokens {
if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
deleted = append(deleted, t)
delete(s.oidcTokens, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
u := repository.OidcUserinfo(arg)
s.oidcUsers[arg.Sub] = u
return u, nil
}
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
u, ok := s.oidcUsers[sub]
if !ok {
return repository.OidcUserinfo{}, repository.ErrNotFound
}
return u, nil
}
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcUsers, sub)
return nil
}
@@ -0,0 +1,63 @@
package memory
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess := repository.Session(arg)
s.sessions[arg.UUID] = sess
return sess, nil
}
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.sessions[uuid]
if !ok {
return repository.Session{}, repository.ErrNotFound
}
return sess, nil
}
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess, ok := s.sessions[arg.UUID]
if !ok {
return repository.Session{}, repository.ErrNotFound
}
sess.Username = arg.Username
sess.Email = arg.Email
sess.Name = arg.Name
sess.Provider = arg.Provider
sess.TotpPending = arg.TotpPending
sess.OAuthGroups = arg.OAuthGroups
sess.Expiry = arg.Expiry
sess.OAuthName = arg.OAuthName
sess.OAuthSub = arg.OAuthSub
s.sessions[arg.UUID] = sess
return sess, nil
}
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, uuid)
return nil
}
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.sessions {
if v.Expiry < expiry {
delete(s.sessions, k)
}
}
return nil
}
+27
View File
@@ -0,0 +1,27 @@
// Package memory provides an in-memory implementation of repository.Store for use in tests.
package memory
import (
"sync"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store is a thread-safe in-memory implementation of repository.Store.
type Store struct {
mu sync.RWMutex
sessions map[string]repository.Session
oidcCodes map[string]repository.OidcCode
oidcTokens map[string]repository.OidcToken
oidcUsers map[string]repository.OidcUserinfo
}
// New returns a new empty in-memory Store.
func New() repository.Store {
return &Store{
sessions: make(map[string]repository.Session),
oidcCodes: make(map[string]repository.OidcCode),
oidcTokens: make(map[string]repository.OidcToken),
oidcUsers: make(map[string]repository.OidcUserinfo),
}
}
+89 -5
View File
@@ -1,9 +1,22 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository
// Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go.
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
type OidcCode struct {
Sub string
CodeHash string
@@ -49,7 +62,7 @@ type OidcUserinfo struct {
Address string
}
type Session struct {
type CreateSessionParams struct {
UUID string
Username string
Email string
@@ -62,3 +75,74 @@ type Session struct {
OAuthName string
OAuthSub string
}
type UpdateSessionParams struct {
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
UUID string
}
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
CodeHash string
Nonce string
}
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
@@ -1,8 +1,8 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
package repository
package sqlite
import (
"context"
+3
View File
@@ -0,0 +1,3 @@
package sqlite
//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
+64
View File
@@ -0,0 +1,64 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
package sqlite
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
// source: oidc_queries.sql
package repository
package sqlite
import (
"context"
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// sqlc v1.31.1
// source: session_queries.sql
package repository
package sqlite
import (
"context"
+224
View File
@@ -0,0 +1,224 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package sqlite
import (
"context"
"database/sql"
"errors"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
return repository.OidcCode(v)
}
func oidcTokenToRepo(v OidcToken) repository.OidcToken {
return repository.OidcToken(v)
}
func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo {
return repository.OidcUserinfo(v)
}
func sessionToRepo(v Session) repository.Session {
return repository.Session(v)
}
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = oidcCodeToRepo(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = oidcTokenToRepo(row)
}
return out, nil
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
r, err := s.q.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
+47
View File
@@ -0,0 +1,47 @@
package repository
import (
"context"
"errors"
)
// ErrNotFound is returned by Store methods when the requested record does not exist.
var ErrNotFound = errors.New("not found")
// Store is the interface that all storage drivers must implement.
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
// Future drivers (postgres, etc.) must return the shared types defined in this package.
type Store interface {
// Sessions
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
GetSession(ctx context.Context, uuid string) (Session, error)
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC codes
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
DeleteOidcCode(ctx context.Context, codeHash string) error
DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
}
+3 -4
View File
@@ -2,7 +2,6 @@ package service
import (
"context"
"database/sql"
"errors"
"fmt"
"regexp"
@@ -90,14 +89,14 @@ type AuthService struct {
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
queries repository.Store
oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries repository.Store, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{
config: config,
loginAttempts: make(map[string]*LoginAttempt),
@@ -411,7 +410,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return repository.Session{}, fmt.Errorf("session not found")
}
return repository.Session{}, err
+9 -10
View File
@@ -7,7 +7,6 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/json"
"encoding/pem"
@@ -121,7 +120,7 @@ type OIDCServiceConfig struct {
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
queries repository.Store
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
@@ -129,7 +128,7 @@ type OIDCService struct {
isConfigured bool
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
func NewOIDCService(config OIDCServiceConfig, queries repository.Store) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
@@ -420,7 +419,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
@@ -564,7 +563,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
@@ -643,7 +642,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
@@ -731,15 +730,15 @@ func (service *OIDCService) Hash(token string) string {
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
return nil
@@ -784,7 +783,7 @@ func (service *OIDCService) Cleanup() {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
+4 -4
View File
@@ -1,12 +1,12 @@
version: "2"
sql:
- engine: "sqlite"
queries: "sql/*_queries.sql"
schema: "sql/*_schemas.sql"
queries: "sql/sqlite/*_queries.sql"
schema: "sql/sqlite/*_schemas.sql"
gen:
go:
package: "repository"
out: "internal/repository"
package: "sqlite"
out: "internal/repository/sqlite"
rename:
uuid: "UUID"
oauth_groups: "OAuthGroups"