From 3eeabd36238ea762d1a46e598b2da012f9313d83 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Thu, 7 May 2026 18:55:51 +1200 Subject: [PATCH] refactor(db): cleanup sqlc-wrapper gen --- Makefile | 2 +- .../sqlc-wrapper/{main.go => sqlc_wrapper.go} | 445 +++++++++--------- cmd/gen/sqlc-wrapper/store.tmpl | 46 ++ internal/repository/sqlite/db.go | 2 +- internal/repository/sqlite/models.go | 2 +- .../repository/sqlite/oidc_queries.sql.go | 2 +- .../repository/sqlite/session_queries.sql.go | 2 +- 7 files changed, 267 insertions(+), 234 deletions(-) rename cmd/gen/sqlc-wrapper/{main.go => sqlc_wrapper.go} (65%) create mode 100644 cmd/gen/sqlc-wrapper/store.tmpl diff --git a/Makefile b/Makefile index 7f4e393e..7782830d 100644 --- a/Makefile +++ b/Makefile @@ -84,4 +84,4 @@ sql: # Go gen generate: - go run ./gen + go generate ./internal/repository/... diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/sqlc_wrapper.go similarity index 65% rename from cmd/gen/sqlc-wrapper/main.go rename to cmd/gen/sqlc-wrapper/sqlc_wrapper.go index d6cb6318..0592d20c 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/sqlc_wrapper.go @@ -12,6 +12,7 @@ package main import ( "bytes" + _ "embed" "flag" "fmt" "go/format" @@ -27,13 +28,22 @@ import ( "golang.org/x/tools/go/packages" ) +//go:embed store.tmpl +var storeSrc string + func main() { + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { driverPkg := flag.String("pkg", "", "import path of the driver package") out := flag.String("out", "store.go", "output filename relative to driver package directory") flag.Parse() if *driverPkg == "" { - log.Fatal("-pkg is required") + return fmt.Errorf("-pkg is required") } // Resolve the driver package directory so we can overlay the output file @@ -41,8 +51,9 @@ func main() { // type-checker and producing cryptic "undefined" errors. driverDir, err := pkgDir(*driverPkg) if err != nil { - log.Fatalf("resolve driver dir: %v", err) + return fmt.Errorf("resolve driver dir: %w", err) } + outPath := filepath.Join(driverDir, *out) if filepath.IsAbs(*out) { outPath = *out @@ -50,73 +61,81 @@ func main() { // Stub replaces the output file during load so stale generated code is ignored. stub := []byte("package " + filepath.Base(driverDir) + "\n") - cfg := &packages.Config{ Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports, Overlay: map[string][]byte{outPath: stub}, } - pkgs, err := packages.Load(cfg, *driverPkg) + + driverTypePkg, err := loadOnePkg(cfg, *driverPkg) if err != nil { - log.Fatalf("load %s: %v", *driverPkg, err) - } - if len(pkgs) != 1 { - log.Fatalf("expected 1 package, got %d", len(pkgs)) - } - pkg := pkgs[0] - if len(pkg.Errors) > 0 { - for _, e := range pkg.Errors { - log.Printf("package error: %v", e) - } - log.Fatal("package has errors") + return fmt.Errorf("load driver package: %w", err) } - repoPkg := parentPkg(*driverPkg) - - // Load the parent (repository) package so we can validate struct shapes. - repoPkgs, err := packages.Load(cfg, repoPkg) + repoPkgPath := parentPkg(*driverPkg) + repoTypePkg, err := loadOnePkg(cfg, repoPkgPath) if err != nil { - log.Fatalf("load repo pkg %s: %v", repoPkg, err) - } - if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 { - log.Fatalf("could not load repo package %s cleanly", repoPkg) - } - if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil { - log.Fatalf("struct shape mismatch: %v", err) + return fmt.Errorf("load repo package: %w", err) } - // Check *Queries covers every method in repository.Store before generating. - if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil { - log.Fatalf("%v", err) + if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { + return fmt.Errorf("struct shape mismatch: %w", err) + } + if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil { + return err } - methods, err := collectMethods(pkg.Types) + methods, err := collectMethods(driverTypePkg) if err != nil { - log.Fatal(err) + return err } + models, _ := collectTypes(driverTypePkg) - models, _ := collectTypes(pkg.Types) - - data := tmplData{ - PkgName: pkg.Name, - RepoPkg: repoPkg, + src, err := render(tmplData{ + PkgName: driverTypePkg.Name(), + RepoPkg: repoPkgPath, ModelTypes: models, Methods: renderMethods(methods), - } - - src, err := render(data) + }) if err != nil { - log.Fatalf("render: %v", err) + return fmt.Errorf("render: %w", err) } if err := os.WriteFile(outPath, src, 0644); err != nil { - log.Fatalf("write %s: %v", outPath, err) + return fmt.Errorf("write %s: %w", outPath, err) } fmt.Printf("wrote %s\n", outPath) + return nil } +// loadOnePkg loads a single package via cfg and returns its *types.Package, +// or an error if the package fails to load or has type errors. +func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { + pkgs, err := packages.Load(cfg, importPath) + if err != nil { + return nil, fmt.Errorf("load %s: %w", importPath, err) + } + if len(pkgs) != 1 { + return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + msgs := make([]string, len(pkg.Errors)) + for i, e := range pkg.Errors { + msgs[i] = e.Error() + } + return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) + } + return pkg.Types, nil +} + +// parentPkg returns the parent import path (everything before the last /). +// Panics if imp contains no slash — callers are expected to pass driver sub-packages. func parentPkg(imp string) string { - parts := strings.Split(imp, "/") - return strings.Join(parts[:len(parts)-1], "/") + i := strings.LastIndex(imp, "/") + if i < 0 { + panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp)) + } + return imp[:i] } // pkgDir returns the on-disk directory for an import path using `go list`. @@ -128,14 +147,40 @@ func pkgDir(importPath string) (string, error) { return strings.TrimSpace(string(out)), nil } +// scopeStructs returns all named struct types in pkg, excluding the internal +// sqlc types Queries, DBTX, and Store. Names are returned in sorted order. +func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) { + byName = make(map[string]*types.Struct) + for _, name := range pkg.Scope().Names() { // Names() is already sorted + switch name { + case "Queries", "DBTX", "Store": + continue + } + obj, ok := pkg.Scope().Lookup(name).(*types.TypeName) + if !ok { + continue + } + named, ok := obj.Type().(*types.Named) + if !ok { + continue + } + s, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + names = append(names, name) + byName[name] = s + } + return +} + // validateStoreCoverage checks that every method declared in repository.Store // exists on *Queries in the driver package. Missing methods are reported by // name so the developer knows exactly which SQL queries need to be added. func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { - // Collect *Queries method names. queriesObj := driverPkg.Scope().Lookup("Queries") if queriesObj == nil { - return fmt.Errorf("Queries type not found in driver package") + return fmt.Errorf("queries type not found in driver package") } queriesNamed := queriesObj.Type().(*types.Named) queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed)) @@ -144,10 +189,9 @@ func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { queriesMethods[m.Obj().Name()] = true } - // Collect repository.Store interface methods. storeObj := repoPkg.Scope().Lookup("Store") if storeObj == nil { - return fmt.Errorf("Store type not found in repository package") + return fmt.Errorf("store type not found in repository package") } storeIface, ok := storeObj.Type().Underlying().(*types.Interface) if !ok { @@ -155,22 +199,80 @@ func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { } var missing []string - for i := range storeIface.NumMethods() { - name := storeIface.Method(i).Name() - if !queriesMethods[name] { + for method := range storeIface.Methods() { + if name := method.Name(); !queriesMethods[name] { missing = append(missing, name) } } if len(missing) > 0 { sort.Strings(missing) return fmt.Errorf( - "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.", + "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries", len(missing), strings.Join(missing, "\n - "), ) } return nil } +// validateStructShapes checks that every model/params struct in the driver +// package has fields that exactly match the corresponding type in the repo +// (parent) package. This catches drift between sqlc-generated types and the +// canonical repository types before a broken cast reaches the compiler. +func validateStructShapes(driverPkg, repoPkg *types.Package) error { + _, repoStructs := scopeStructs(repoPkg) + driverNames, driverStructs := scopeStructs(driverPkg) + + var errs []string + for _, name := range driverNames { + repoStruct, ok := repoStructs[name] + if !ok { + // Driver has a type not in repo — fine (e.g. internal helpers). + continue + } + if err := compareStructs(name, driverStructs[name], repoStruct); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + sort.Strings(errs) + return fmt.Errorf("%s", strings.Join(errs, "\n ")) + } + return nil +} + +func compareStructs(name string, driver, repo *types.Struct) error { + if driver.NumFields() != repo.NumFields() { + return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", + name, driver.NumFields(), repo.NumFields()) + } + for i := range driver.NumFields() { + df := driver.Field(i) + rf := repo.Field(i) + if df.Name() != rf.Name() { + return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", + name, i, df.Name(), rf.Name()) + } + if !types.Identical(df.Type(), rf.Type()) { + return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", + name, df.Name(), df.Type(), rf.Type()) + } + } + return nil +} + +// collectTypes returns model and params struct names from the driver package. +func collectTypes(pkg *types.Package) (models []string, params []string) { + names, _ := scopeStructs(pkg) + for _, name := range names { + if strings.HasSuffix(name, "Params") { + params = append(params, name) + } else { + models = append(models, name) + } + } + return +} + type methodInfo struct { Name string Params []paramInfo @@ -227,10 +329,11 @@ func collectMethods(pkg *types.Package) ([]methodInfo, error) { } func makeParam(name string, t types.Type, driverPath string) paramInfo { - pi := paramInfo{Name: name} - pi.TypeStr = localName(t, driverPath) - pi.RepoType = repoName(t, driverPath) - return pi + return paramInfo{ + Name: name, + TypeStr: localName(t, driverPath), + RepoType: repoName(t, driverPath), + } } func makeResult(t types.Type, driverPath string) resultInfo { @@ -266,133 +369,27 @@ func repoName(t types.Type, driverPath string) string { return "" } -func collectTypes(pkg *types.Package) (models []string, params []string) { - for _, name := range pkg.Scope().Names() { - obj := pkg.Scope().Lookup(name) - if obj == nil { - continue - } - tn, ok := obj.(*types.TypeName) - if !ok { - continue - } - named, ok := tn.Type().(*types.Named) - if !ok { - continue - } - if _, ok := named.Underlying().(*types.Struct); !ok { - continue - } - switch name { - case "Queries", "DBTX", "Store": - continue - } - if strings.HasSuffix(name, "Params") { - params = append(params, name) - } else { - models = append(models, name) - } - } - return -} - -// validateStructShapes checks that every model/params struct in the driver -// package has fields that exactly match the corresponding type in the repo -// (parent) package. This catches drift between sqlc-generated types and the -// canonical repository types before a broken cast reaches the compiler. -func validateStructShapes(driverPkg, repoPkg *types.Package) error { - var errs []string - for _, name := range driverPkg.Scope().Names() { - obj := driverPkg.Scope().Lookup(name) - if obj == nil { - continue - } - tn, ok := obj.(*types.TypeName) - if !ok { - continue - } - named, ok := tn.Type().(*types.Named) - if !ok { - continue - } - driverStruct, ok := named.Underlying().(*types.Struct) - if !ok { - continue - } - switch name { - case "Queries", "DBTX", "Store": - continue - } - - repoObj := repoPkg.Scope().Lookup(name) - if repoObj == nil { - // Driver has a type not in repo — that's fine (e.g. internal helpers). - continue - } - repoNamed, ok := repoObj.Type().(*types.Named) - if !ok { - continue - } - repoStruct, ok := repoNamed.Underlying().(*types.Struct) - if !ok { - errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name)) - continue - } - - if err := compareStructs(name, driverStruct, repoStruct); err != nil { - errs = append(errs, err.Error()) - } - } - if len(errs) > 0 { - return fmt.Errorf("%s", strings.Join(errs, "\n ")) - } - return nil -} - -func compareStructs(name string, driver, repo *types.Struct) error { - if driver.NumFields() != repo.NumFields() { - return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", - name, driver.NumFields(), repo.NumFields()) - } - for i := range driver.NumFields() { - df := driver.Field(i) - rf := repo.Field(i) - if df.Name() != rf.Name() { - return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", - name, i, df.Name(), rf.Name()) - } - if !types.Identical(df.Type(), rf.Type()) { - return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", - name, df.Name(), df.Type(), rf.Type()) - } - } - return nil -} - -// converterFn: "Session" -> "sessionToRepo" +// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo". func converterFn(s string) string { if s == "" { return "" } - r := []rune(s) - r[0] = []rune(strings.ToLower(string(r[0])))[0] - return string(r) + "ToRepo" + return strings.ToLower(s[:1]) + s[1:] + "ToRepo" } -// renderedMethod is the pre-built method body passed to the template. +// renderedMethod holds pre-built signature and body strings passed to the template. type renderedMethod struct { Signature string Body string } -// renderMethods converts []methodInfo into fully pre-rendered signature+body strings. func renderMethods(methods []methodInfo) []renderedMethod { - var out []renderedMethod - for _, m := range methods { - out = append(out, renderedMethod{ + out := make([]renderedMethod, len(methods)) + for i, m := range methods { + out[i] = renderedMethod{ Signature: buildSig(m), Body: buildBody(m), - }) + } } return out } @@ -429,7 +426,7 @@ func buildSig(m methodInfo) string { } func callArgs(m methodInfo) string { - var args []string + args := make([]string, 0, len(m.Params)) for _, p := range m.Params { if p.RepoType != "" { // convert repo type → driver type: DriverType(arg) @@ -444,25 +441,62 @@ func callArgs(m methodInfo) string { return "ctx, " + strings.Join(args, ", ") } +// bodyTemplates holds the per-shape method body templates, parsed once at init. +var bodyTemplates = template.Must( + template.New("bodies").Parse(` +{{define "void"}} return mapErr({{.Call}}) +{{end}} + +{{define "scalar"}} r, err := {{.Call}} + if err != nil { + return {{.RepoType}}{}, mapErr(err) + } + return {{.Converter}}(r), nil +{{end}} + +{{define "slice"}} rows, err := {{.Call}} + if err != nil { + return nil, mapErr(err) + } + out := make([]{{.RepoType}}, len(rows)) + for i, row := range rows { + out[i] = {{.Converter}}(row) + } + return out, nil +{{end}}`), +) + +type bodyData struct { + Call string + RepoType string + Converter string +} + func buildBody(m methodInfo) string { call := "s.q." + m.Name + "(" + callArgs(m) + ")" - // no repo-typed result → direct return - if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn mapErr(" + call + ")\n" + var ( + name string + data bodyData + ) + + switch { + case len(m.Results) == 0 || m.Results[0].RepoType == "": + name = "void" + data = bodyData{Call: call} + case m.Results[0].IsSlice: + name = "slice" + data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} + default: + name = "scalar" + data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} } - r := m.Results[0] - if r.IsSlice { - return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", - call, r.RepoType, converterFn(r.TypeStr), - ) + var buf bytes.Buffer + if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil { + panic(fmt.Sprintf("buildBody %s: %v", name, err)) } - return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", - call, r.RepoType, converterFn(r.TypeStr), - ) + return buf.String() } type tmplData struct { @@ -472,53 +506,6 @@ type tmplData struct { Methods []renderedMethod } -const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. -package {{.PkgName}} - -import ( - "context" - "database/sql" - "errors" - - "{{.RepoPkg}}" -) - -// Store wraps *Queries and implements repository.Store. -type Store struct { - q *Queries -} - -// NewStore wraps a *Queries to satisfy repository.Store. -func NewStore(q *Queries) repository.Store { - return &Store{q: q} -} - -var errMap = []struct { - from error - to error -}{ - {sql.ErrNoRows, repository.ErrNotFound}, -} - -func mapErr(err error) error { - for _, e := range errMap { - if errors.Is(err, e.from) { - return e.to - } - } - return err -} - -{{range .ModelTypes -}} -func {{converterFn .}}(v {{.}}) repository.{{.}} { - return repository.{{.}}(v) -} -{{end -}} -{{range .Methods}}{{.Signature}} { -{{.Body}}} - -{{end}}` - func render(data tmplData) ([]byte, error) { t, err := template.New("store").Funcs(template.FuncMap{ "converterFn": converterFn, diff --git a/cmd/gen/sqlc-wrapper/store.tmpl b/cmd/gen/sqlc-wrapper/store.tmpl new file mode 100644 index 00000000..02bb6fb1 --- /dev/null +++ b/cmd/gen/sqlc-wrapper/store.tmpl @@ -0,0 +1,46 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + "database/sql" + "errors" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + +{{range .ModelTypes -}} +func {{converterFn .}}(v {{.}}) repository.{{.}} { + return repository.{{.}}(v) +} +{{end -}} +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}} diff --git a/internal/repository/sqlite/db.go b/internal/repository/sqlite/db.go index ee310fc2..51a4906a 100644 --- a/internal/repository/sqlite/db.go +++ b/internal/repository/sqlite/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 package sqlite diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index caf37f4c..fd6f78da 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 package sqlite diff --git a/internal/repository/sqlite/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go index 027ac421..e5d08bc2 100644 --- a/internal/repository/sqlite/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 // source: oidc_queries.sql package sqlite diff --git a/internal/repository/sqlite/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go index 4271b727..7792fc4b 100644 --- a/internal/repository/sqlite/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 // source: session_queries.sql package sqlite