From 0244f39387be47a925af2eec15eba5b8a7820b30 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sun, 3 May 2026 13:49:24 +1200 Subject: [PATCH] feat(db): add code gen to build sqlc-compatible wrappers --- .github/workflows/ci.yml | 6 + cmd/gen/sqlc-wrapper/main.go | 522 +++++++++++++++++++++++++ go.mod | 2 + internal/bootstrap/db_bootstrap.go | 2 +- internal/repository/models.go | 157 +++++++- internal/repository/sqlite/generate.go | 3 + internal/repository/sqlite/store.go | 206 ++++++++++ 7 files changed, 883 insertions(+), 15 deletions(-) create mode 100644 cmd/gen/sqlc-wrapper/main.go create mode 100644 internal/repository/sqlite/generate.go create mode 100644 internal/repository/sqlite/store.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12db1641..fb8c9736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,12 @@ jobs: - name: Go dependencies run: go mod download + - name: Check codegen is up to date + run: | + go generate ./internal/repository/... + git diff --exit-code -- internal/repository/ + git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true + - name: Install frontend dependencies run: | cd frontend diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go new file mode 100644 index 00000000..e66ae8ee --- /dev/null +++ b/cmd/gen/sqlc-wrapper/main.go @@ -0,0 +1,522 @@ +// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under +// internal/repository//. Run via: +// +// go generate ./internal/repository/... +// +// The generator introspects *Queries methods and the model/params types in the +// driver package, then emits a store.go that wraps *Queries so it satisfies +// repository.Store using the canonical shared types in the parent package. +// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should +// implement repository.Store directly by hand. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "go/types" + "log" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "text/template" + + "golang.org/x/tools/go/packages" +) + +func main() { + driverPkg := flag.String("pkg", "", "import path of the driver package") + out := flag.String("out", "store.go", "output filename relative to driver package directory") + flag.Parse() + + if *driverPkg == "" { + log.Fatal("-pkg is required") + } + + // Resolve the driver package directory so we can overlay the output file + // with a valid stub. This prevents a stale store.go from poisoning the + // type-checker and producing cryptic "undefined" errors. + driverDir, err := pkgDir(*driverPkg) + if err != nil { + log.Fatalf("resolve driver dir: %v", err) + } + outPath := filepath.Join(driverDir, *out) + if filepath.IsAbs(*out) { + outPath = *out + } + + // Stub replaces the output file during load so stale generated code is ignored. + stub := []byte("package " + filepath.Base(driverDir) + "\n") + + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports, + Overlay: map[string][]byte{outPath: stub}, + } + pkgs, err := packages.Load(cfg, *driverPkg) + if err != nil { + log.Fatalf("load %s: %v", *driverPkg, err) + } + if len(pkgs) != 1 { + log.Fatalf("expected 1 package, got %d", len(pkgs)) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + for _, e := range pkg.Errors { + log.Printf("package error: %v", e) + } + log.Fatal("package has errors") + } + + repoPkg := parentPkg(*driverPkg) + + // Load the parent (repository) package so we can validate struct shapes. + repoPkgs, err := packages.Load(cfg, repoPkg) + if err != nil { + log.Fatalf("load repo pkg %s: %v", repoPkg, err) + } + if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 { + log.Fatalf("could not load repo package %s cleanly", repoPkg) + } + if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil { + log.Fatalf("struct shape mismatch: %v", err) + } + + // Check *Queries covers every method in repository.Store before generating. + if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil { + log.Fatalf("%v", err) + } + + methods, err := collectMethods(pkg.Types) + if err != nil { + log.Fatal(err) + } + + models, _ := collectTypes(pkg.Types) + + data := tmplData{ + PkgName: pkg.Name, + RepoPkg: repoPkg, + ModelTypes: models, + Methods: renderMethods(methods), + } + + src, err := render(data) + if err != nil { + log.Fatalf("render: %v", err) + } + + if err := os.WriteFile(outPath, src, 0644); err != nil { + log.Fatalf("write %s: %v", outPath, err) + } + fmt.Printf("wrote %s\n", outPath) +} + +func parentPkg(imp string) string { + parts := strings.Split(imp, "/") + return strings.Join(parts[:len(parts)-1], "/") +} + +// pkgDir returns the on-disk directory for an import path using `go list`. +func pkgDir(importPath string) (string, error) { + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output() + if err != nil { + return "", fmt.Errorf("go list %s: %w", importPath, err) + } + return strings.TrimSpace(string(out)), nil +} + +// validateStoreCoverage checks that every method declared in repository.Store +// exists on *Queries in the driver package. Missing methods are reported by +// name so the developer knows exactly which SQL queries need to be added. +func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { + // Collect *Queries method names. + queriesObj := driverPkg.Scope().Lookup("Queries") + if queriesObj == nil { + return fmt.Errorf("Queries type not found in driver package") + } + queriesNamed := queriesObj.Type().(*types.Named) + queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed)) + queriesMethods := make(map[string]bool) + for m := range queriesMS.Methods() { + queriesMethods[m.Obj().Name()] = true + } + + // Collect repository.Store interface methods. + storeObj := repoPkg.Scope().Lookup("Store") + if storeObj == nil { + return fmt.Errorf("Store type not found in repository package") + } + storeIface, ok := storeObj.Type().Underlying().(*types.Interface) + if !ok { + return fmt.Errorf("repository.Store is not an interface") + } + + var missing []string + for i := range storeIface.NumMethods() { + name := storeIface.Method(i).Name() + if !queriesMethods[name] { + missing = append(missing, name) + } + } + if len(missing) > 0 { + sort.Strings(missing) + return fmt.Errorf( + "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.", + len(missing), strings.Join(missing, "\n - "), + ) + } + return nil +} + +type methodInfo struct { + Name string + Params []paramInfo + Results []resultInfo +} + +type paramInfo struct { + Name string + TypeStr string // local (unqualified) type name + RepoType string // "repository.X" if this is a driver model/params type; else "" +} + +type resultInfo struct { + TypeStr string + IsSlice bool + RepoType string // "repository.X" if driver type; else "" +} + +func collectMethods(pkg *types.Package) ([]methodInfo, error) { + obj := pkg.Scope().Lookup("Queries") + if obj == nil { + return nil, fmt.Errorf("queries type not found in %s", pkg.Path()) + } + named, ok := obj.Type().(*types.Named) + if !ok { + return nil, fmt.Errorf("queries is not a named type") + } + ms := types.NewMethodSet(types.NewPointer(named)) + + var out []methodInfo + for method := range ms.Methods() { + fn, ok := method.Obj().(*types.Func) + if !ok || fn.Name() == "WithTx" { + continue + } + sig := fn.Type().(*types.Signature) + mi := methodInfo{Name: fn.Name()} + + // params: skip receiver + first (context.Context) + for i := 1; i < sig.Params().Len(); i++ { + p := sig.Params().At(i) + mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path())) + } + // results: skip error + for r := range sig.Results().Variables() { + if r.Type().String() == "error" { + continue + } + mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path())) + } + out = append(out, mi) + } + return out, nil +} + +func makeParam(name string, t types.Type, driverPath string) paramInfo { + pi := paramInfo{Name: name} + pi.TypeStr = localName(t, driverPath) + pi.RepoType = repoName(t, driverPath) + return pi +} + +func makeResult(t types.Type, driverPath string) resultInfo { + ri := resultInfo{} + if sl, ok := t.(*types.Slice); ok { + ri.IsSlice = true + t = sl.Elem() + } + ri.TypeStr = localName(t, driverPath) + ri.RepoType = repoName(t, driverPath) + return ri +} + +func localName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return types.TypeString(t, nil) + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return named.Obj().Name() + } + return types.TypeString(t, func(p *types.Package) string { return p.Name() }) +} + +func repoName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return "" + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return "repository." + named.Obj().Name() + } + return "" +} + +func collectTypes(pkg *types.Package) (models []string, params []string) { + for _, name := range pkg.Scope().Names() { + obj := pkg.Scope().Lookup(name) + if obj == nil { + continue + } + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + if _, ok := named.Underlying().(*types.Struct); !ok { + continue + } + switch name { + case "Queries", "DBTX", "Store": + continue + } + if strings.HasSuffix(name, "Params") { + params = append(params, name) + } else { + models = append(models, name) + } + } + return +} + +// validateStructShapes checks that every model/params struct in the driver +// package has fields that exactly match the corresponding type in the repo +// (parent) package. This catches drift between sqlc-generated types and the +// canonical repository types before a broken cast reaches the compiler. +func validateStructShapes(driverPkg, repoPkg *types.Package) error { + var errs []string + for _, name := range driverPkg.Scope().Names() { + obj := driverPkg.Scope().Lookup(name) + if obj == nil { + continue + } + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + driverStruct, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + switch name { + case "Queries", "DBTX", "Store": + continue + } + + repoObj := repoPkg.Scope().Lookup(name) + if repoObj == nil { + // Driver has a type not in repo — that's fine (e.g. internal helpers). + continue + } + repoNamed, ok := repoObj.Type().(*types.Named) + if !ok { + continue + } + repoStruct, ok := repoNamed.Underlying().(*types.Struct) + if !ok { + errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name)) + continue + } + + if err := compareStructs(name, driverStruct, repoStruct); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("%s", strings.Join(errs, "\n ")) + } + return nil +} + +func compareStructs(name string, driver, repo *types.Struct) error { + if driver.NumFields() != repo.NumFields() { + return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", + name, driver.NumFields(), repo.NumFields()) + } + for i := range driver.NumFields() { + df := driver.Field(i) + rf := repo.Field(i) + if df.Name() != rf.Name() { + return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", + name, i, df.Name(), rf.Name()) + } + if !types.Identical(df.Type(), rf.Type()) { + return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", + name, df.Name(), df.Type(), rf.Type()) + } + } + return nil +} + +// converterFn: "Session" -> "sessionToRepo" +func converterFn(s string) string { + if s == "" { + return "" + } + r := []rune(s) + r[0] = []rune(strings.ToLower(string(r[0])))[0] + return string(r) + "ToRepo" +} + +// renderedMethod is the pre-built method body passed to the template. +type renderedMethod struct { + Signature string + Body string +} + +// renderMethods converts []methodInfo into fully pre-rendered signature+body strings. +func renderMethods(methods []methodInfo) []renderedMethod { + var out []renderedMethod + for _, m := range methods { + out = append(out, renderedMethod{ + Signature: buildSig(m), + Body: buildBody(m), + }) + } + return out +} + +func buildSig(m methodInfo) string { + var sb strings.Builder + sb.WriteString("func (s *Store) ") + sb.WriteString(m.Name) + sb.WriteString("(ctx context.Context") + for _, p := range m.Params { + sb.WriteString(", ") + sb.WriteString(p.Name) + sb.WriteString(" ") + if p.RepoType != "" { + sb.WriteString(p.RepoType) + } else { + sb.WriteString(p.TypeStr) + } + } + sb.WriteString(") (") + for _, r := range m.Results { + if r.IsSlice { + sb.WriteString("[]") + } + if r.RepoType != "" { + sb.WriteString(r.RepoType) + } else { + sb.WriteString(r.TypeStr) + } + sb.WriteString(", ") + } + sb.WriteString("error)") + return sb.String() +} + +func callArgs(m methodInfo) string { + var args []string + for _, p := range m.Params { + if p.RepoType != "" { + // convert repo type → driver type: DriverType(arg) + args = append(args, p.TypeStr+"("+p.Name+")") + } else { + args = append(args, p.Name) + } + } + if len(args) == 0 { + return "ctx" + } + return "ctx, " + strings.Join(args, ", ") +} + +func buildBody(m methodInfo) string { + call := "s.q." + m.Name + "(" + callArgs(m) + ")" + + // no repo-typed result → direct return + if len(m.Results) == 0 || m.Results[0].RepoType == "" { + return "\treturn " + call + "\n" + } + + r := m.Results[0] + if r.IsSlice { + return fmt.Sprintf( + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + call, r.RepoType, converterFn(r.TypeStr), + ) + } + return fmt.Sprintf( + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + call, r.RepoType, converterFn(r.TypeStr), + ) +} + +type tmplData struct { + PkgName string + RepoPkg string + ModelTypes []string + Methods []renderedMethod +} + +const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +{{range .ModelTypes -}} +func {{converterFn .}}(v {{.}}) repository.{{.}} { + return repository.{{.}}(v) +} +{{end -}} +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}}` + +func render(data tmplData) ([]byte, error) { + t, err := template.New("store").Funcs(template.FuncMap{ + "converterFn": converterFn, + }).Parse(storeSrc) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + var buf bytes.Buffer + if err := t.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("execute template: %w", err) + } + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String()) + } + return formatted, nil +} diff --git a/go.mod b/go.mod index d0c5a515..fb4a459c 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 + golang.org/x/tools v0.43.0 gotest.tools/v3 v3.5.2 k8s.io/apimachinery v0.32.2 k8s.io/client-go v0.32.2 @@ -124,6 +125,7 @@ require ( go.opentelemetry.io/otel/trace v1.43.0 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index efc21311..2279cb23 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -66,5 +66,5 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err return nil, fmt.Errorf("failed to migrate database: %w", err) } - return sqlite.New(db), nil + return sqlite.NewStore(sqlite.New(db)), nil } diff --git a/internal/repository/models.go b/internal/repository/models.go index 0c33e038..3f58dd66 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,19 +1,148 @@ package repository -// This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest -// of the codebase can import them from a single location without needing to know about the underlying database implementation. +// Shared model and parameter types for all storage drivers. +// sqlc-generated driver packages use these via the conversion layer in their store.go. -import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} -type Session = sqlite.Session -type OidcCode = sqlite.OidcCode -type OidcToken = sqlite.OidcToken -type OidcUserinfo = sqlite.OidcUserinfo +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} -type CreateSessionParams = sqlite.CreateSessionParams -type UpdateSessionParams = sqlite.UpdateSessionParams -type CreateOidcCodeParams = sqlite.CreateOidcCodeParams -type CreateOidcTokenParams = sqlite.CreateOidcTokenParams -type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams -type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams -type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type CreateSessionParams struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + +type UpdateSessionParams struct { + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string + OAuthSub string + UUID string +} + +type CreateOidcCodeParams struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type CreateOidcTokenParams struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + CodeHash string + Nonce string +} + +type UpdateOidcTokenByRefreshTokenParams struct { + AccessTokenHash string + RefreshTokenHash string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + RefreshTokenHash_2 string +} + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} diff --git a/internal/repository/sqlite/generate.go b/internal/repository/sqlite/generate.go new file mode 100644 index 00000000..5f6011f1 --- /dev/null +++ b/internal/repository/sqlite/generate.go @@ -0,0 +1,3 @@ +package sqlite + +//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go new file mode 100644 index 00000000..65b4e190 --- /dev/null +++ b/internal/repository/sqlite/store.go @@ -0,0 +1,206 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package sqlite + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +func oidcCodeToRepo(v OidcCode) repository.OidcCode { + return repository.OidcCode(v) +} +func oidcTokenToRepo(v OidcToken) repository.OidcToken { + return repository.OidcToken(v) +} +func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo { + return repository.OidcUserinfo(v) +} +func sessionToRepo(v Session) repository.Session { + return repository.Session(v) +} +func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) + if err != nil { + return repository.OidcUserinfo{}, err + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} + +func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { + rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) + if err != nil { + return nil, err + } + out := make([]repository.OidcCode, len(rows)) + for i, row := range rows { + out[i] = oidcCodeToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) + if err != nil { + return nil, err + } + out := make([]repository.OidcToken, len(rows)) + for i, row := range rows { + out[i] = oidcTokenToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + return s.q.DeleteExpiredSessions(ctx, expiry) +} + +func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { + return s.q.DeleteOidcCode(ctx, codeHash) +} + +func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + return s.q.DeleteOidcCodeBySub(ctx, sub) +} + +func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + return s.q.DeleteOidcToken(ctx, accessTokenHash) +} + +func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) +} + +func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + return s.q.DeleteOidcTokenBySub(ctx, sub) +} + +func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { + return s.q.DeleteOidcUserInfo(ctx, sub) +} + +func (s *Store) DeleteSession(ctx context.Context, uuid string) error { + return s.q.DeleteSession(ctx, uuid) +} + +func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCode(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySub(ctx, sub) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcToken(ctx, accessTokenHash) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenBySub(ctx, sub) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { + r, err := s.q.GetOidcUserInfo(ctx, sub) + if err != nil { + return repository.OidcUserinfo{}, err + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { + r, err := s.q.GetSession(ctx, uuid) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} + +func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +}