|
|
|
@@ -12,6 +12,7 @@ package main
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
_ "embed"
|
|
|
|
|
"flag"
|
|
|
|
|
"fmt"
|
|
|
|
|
"go/format"
|
|
|
|
@@ -27,13 +28,22 @@ import (
|
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
//go:embed store.tmpl
|
|
|
|
|
var storeSrc string
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
|
if err := run(); err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func run() error {
|
|
|
|
|
driverPkg := flag.String("pkg", "", "import path of the driver package")
|
|
|
|
|
out := flag.String("out", "store.go", "output filename relative to driver package directory")
|
|
|
|
|
flag.Parse()
|
|
|
|
|
|
|
|
|
|
if *driverPkg == "" {
|
|
|
|
|
log.Fatal("-pkg is required")
|
|
|
|
|
return fmt.Errorf("-pkg is required")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Resolve the driver package directory so we can overlay the output file
|
|
|
|
@@ -41,8 +51,9 @@ func main() {
|
|
|
|
|
// type-checker and producing cryptic "undefined" errors.
|
|
|
|
|
driverDir, err := pkgDir(*driverPkg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("resolve driver dir: %v", err)
|
|
|
|
|
return fmt.Errorf("resolve driver dir: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
outPath := filepath.Join(driverDir, *out)
|
|
|
|
|
if filepath.IsAbs(*out) {
|
|
|
|
|
outPath = *out
|
|
|
|
@@ -50,73 +61,81 @@ func main() {
|
|
|
|
|
|
|
|
|
|
// Stub replaces the output file during load so stale generated code is ignored.
|
|
|
|
|
stub := []byte("package " + filepath.Base(driverDir) + "\n")
|
|
|
|
|
|
|
|
|
|
cfg := &packages.Config{
|
|
|
|
|
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
|
|
|
|
|
Overlay: map[string][]byte{outPath: stub},
|
|
|
|
|
}
|
|
|
|
|
pkgs, err := packages.Load(cfg, *driverPkg)
|
|
|
|
|
|
|
|
|
|
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("load %s: %v", *driverPkg, err)
|
|
|
|
|
}
|
|
|
|
|
if len(pkgs) != 1 {
|
|
|
|
|
log.Fatalf("expected 1 package, got %d", len(pkgs))
|
|
|
|
|
}
|
|
|
|
|
pkg := pkgs[0]
|
|
|
|
|
if len(pkg.Errors) > 0 {
|
|
|
|
|
for _, e := range pkg.Errors {
|
|
|
|
|
log.Printf("package error: %v", e)
|
|
|
|
|
}
|
|
|
|
|
log.Fatal("package has errors")
|
|
|
|
|
return fmt.Errorf("load driver package: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
repoPkg := parentPkg(*driverPkg)
|
|
|
|
|
|
|
|
|
|
// Load the parent (repository) package so we can validate struct shapes.
|
|
|
|
|
repoPkgs, err := packages.Load(cfg, repoPkg)
|
|
|
|
|
repoPkgPath := parentPkg(*driverPkg)
|
|
|
|
|
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("load repo pkg %s: %v", repoPkg, err)
|
|
|
|
|
}
|
|
|
|
|
if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 {
|
|
|
|
|
log.Fatalf("could not load repo package %s cleanly", repoPkg)
|
|
|
|
|
}
|
|
|
|
|
if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil {
|
|
|
|
|
log.Fatalf("struct shape mismatch: %v", err)
|
|
|
|
|
return fmt.Errorf("load repo package: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check *Queries covers every method in repository.Store before generating.
|
|
|
|
|
if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil {
|
|
|
|
|
log.Fatalf("%v", err)
|
|
|
|
|
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
|
|
|
|
return fmt.Errorf("struct shape mismatch: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
methods, err := collectMethods(pkg.Types)
|
|
|
|
|
methods, err := collectMethods(driverTypePkg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
models, _ := collectTypes(driverTypePkg)
|
|
|
|
|
|
|
|
|
|
models, _ := collectTypes(pkg.Types)
|
|
|
|
|
|
|
|
|
|
data := tmplData{
|
|
|
|
|
PkgName: pkg.Name,
|
|
|
|
|
RepoPkg: repoPkg,
|
|
|
|
|
src, err := render(tmplData{
|
|
|
|
|
PkgName: driverTypePkg.Name(),
|
|
|
|
|
RepoPkg: repoPkgPath,
|
|
|
|
|
ModelTypes: models,
|
|
|
|
|
Methods: renderMethods(methods),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
src, err := render(data)
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatalf("render: %v", err)
|
|
|
|
|
return fmt.Errorf("render: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := os.WriteFile(outPath, src, 0644); err != nil {
|
|
|
|
|
log.Fatalf("write %s: %v", outPath, err)
|
|
|
|
|
return fmt.Errorf("write %s: %w", outPath, err)
|
|
|
|
|
}
|
|
|
|
|
fmt.Printf("wrote %s\n", outPath)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
|
|
|
|
// or an error if the package fails to load or has type errors.
|
|
|
|
|
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
|
|
|
|
pkgs, err := packages.Load(cfg, importPath)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
|
|
|
|
}
|
|
|
|
|
if len(pkgs) != 1 {
|
|
|
|
|
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
|
|
|
|
}
|
|
|
|
|
pkg := pkgs[0]
|
|
|
|
|
if len(pkg.Errors) > 0 {
|
|
|
|
|
msgs := make([]string, len(pkg.Errors))
|
|
|
|
|
for i, e := range pkg.Errors {
|
|
|
|
|
msgs[i] = e.Error()
|
|
|
|
|
}
|
|
|
|
|
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
|
|
|
|
}
|
|
|
|
|
return pkg.Types, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// parentPkg returns the parent import path (everything before the last /).
|
|
|
|
|
// Panics if imp contains no slash — callers are expected to pass driver sub-packages.
|
|
|
|
|
func parentPkg(imp string) string {
|
|
|
|
|
parts := strings.Split(imp, "/")
|
|
|
|
|
return strings.Join(parts[:len(parts)-1], "/")
|
|
|
|
|
i := strings.LastIndex(imp, "/")
|
|
|
|
|
if i < 0 {
|
|
|
|
|
panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp))
|
|
|
|
|
}
|
|
|
|
|
return imp[:i]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// pkgDir returns the on-disk directory for an import path using `go list`.
|
|
|
|
@@ -128,14 +147,40 @@ func pkgDir(importPath string) (string, error) {
|
|
|
|
|
return strings.TrimSpace(string(out)), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// scopeStructs returns all named struct types in pkg, excluding the internal
|
|
|
|
|
// sqlc types Queries, DBTX, and Store. Names are returned in sorted order.
|
|
|
|
|
func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) {
|
|
|
|
|
byName = make(map[string]*types.Struct)
|
|
|
|
|
for _, name := range pkg.Scope().Names() { // Names() is already sorted
|
|
|
|
|
switch name {
|
|
|
|
|
case "Queries", "DBTX", "Store":
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
named, ok := obj.Type().(*types.Named)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
s, ok := named.Underlying().(*types.Struct)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
names = append(names, name)
|
|
|
|
|
byName[name] = s
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// validateStoreCoverage checks that every method declared in repository.Store
|
|
|
|
|
// exists on *Queries in the driver package. Missing methods are reported by
|
|
|
|
|
// name so the developer knows exactly which SQL queries need to be added.
|
|
|
|
|
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
|
|
|
|
|
// Collect *Queries method names.
|
|
|
|
|
queriesObj := driverPkg.Scope().Lookup("Queries")
|
|
|
|
|
if queriesObj == nil {
|
|
|
|
|
return fmt.Errorf("Queries type not found in driver package")
|
|
|
|
|
return fmt.Errorf("queries type not found in driver package")
|
|
|
|
|
}
|
|
|
|
|
queriesNamed := queriesObj.Type().(*types.Named)
|
|
|
|
|
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
|
|
|
|
@@ -144,10 +189,9 @@ func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
|
|
|
|
|
queriesMethods[m.Obj().Name()] = true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Collect repository.Store interface methods.
|
|
|
|
|
storeObj := repoPkg.Scope().Lookup("Store")
|
|
|
|
|
if storeObj == nil {
|
|
|
|
|
return fmt.Errorf("Store type not found in repository package")
|
|
|
|
|
return fmt.Errorf("store type not found in repository package")
|
|
|
|
|
}
|
|
|
|
|
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
|
|
|
|
|
if !ok {
|
|
|
|
@@ -155,22 +199,80 @@ func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var missing []string
|
|
|
|
|
for i := range storeIface.NumMethods() {
|
|
|
|
|
name := storeIface.Method(i).Name()
|
|
|
|
|
if !queriesMethods[name] {
|
|
|
|
|
for method := range storeIface.Methods() {
|
|
|
|
|
if name := method.Name(); !queriesMethods[name] {
|
|
|
|
|
missing = append(missing, name)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if len(missing) > 0 {
|
|
|
|
|
sort.Strings(missing)
|
|
|
|
|
return fmt.Errorf(
|
|
|
|
|
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.",
|
|
|
|
|
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries",
|
|
|
|
|
len(missing), strings.Join(missing, "\n - "),
|
|
|
|
|
)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// validateStructShapes checks that every model/params struct in the driver
|
|
|
|
|
// package has fields that exactly match the corresponding type in the repo
|
|
|
|
|
// (parent) package. This catches drift between sqlc-generated types and the
|
|
|
|
|
// canonical repository types before a broken cast reaches the compiler.
|
|
|
|
|
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
|
|
|
|
|
_, repoStructs := scopeStructs(repoPkg)
|
|
|
|
|
driverNames, driverStructs := scopeStructs(driverPkg)
|
|
|
|
|
|
|
|
|
|
var errs []string
|
|
|
|
|
for _, name := range driverNames {
|
|
|
|
|
repoStruct, ok := repoStructs[name]
|
|
|
|
|
if !ok {
|
|
|
|
|
// Driver has a type not in repo — fine (e.g. internal helpers).
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if err := compareStructs(name, driverStructs[name], repoStruct); err != nil {
|
|
|
|
|
errs = append(errs, err.Error())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if len(errs) > 0 {
|
|
|
|
|
sort.Strings(errs)
|
|
|
|
|
return fmt.Errorf("%s", strings.Join(errs, "\n "))
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func compareStructs(name string, driver, repo *types.Struct) error {
|
|
|
|
|
if driver.NumFields() != repo.NumFields() {
|
|
|
|
|
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
|
|
|
|
|
name, driver.NumFields(), repo.NumFields())
|
|
|
|
|
}
|
|
|
|
|
for i := range driver.NumFields() {
|
|
|
|
|
df := driver.Field(i)
|
|
|
|
|
rf := repo.Field(i)
|
|
|
|
|
if df.Name() != rf.Name() {
|
|
|
|
|
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
|
|
|
|
|
name, i, df.Name(), rf.Name())
|
|
|
|
|
}
|
|
|
|
|
if !types.Identical(df.Type(), rf.Type()) {
|
|
|
|
|
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
|
|
|
|
|
name, df.Name(), df.Type(), rf.Type())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// collectTypes returns model and params struct names from the driver package.
|
|
|
|
|
func collectTypes(pkg *types.Package) (models []string, params []string) {
|
|
|
|
|
names, _ := scopeStructs(pkg)
|
|
|
|
|
for _, name := range names {
|
|
|
|
|
if strings.HasSuffix(name, "Params") {
|
|
|
|
|
params = append(params, name)
|
|
|
|
|
} else {
|
|
|
|
|
models = append(models, name)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type methodInfo struct {
|
|
|
|
|
Name string
|
|
|
|
|
Params []paramInfo
|
|
|
|
@@ -227,10 +329,11 @@ func collectMethods(pkg *types.Package) ([]methodInfo, error) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func makeParam(name string, t types.Type, driverPath string) paramInfo {
|
|
|
|
|
pi := paramInfo{Name: name}
|
|
|
|
|
pi.TypeStr = localName(t, driverPath)
|
|
|
|
|
pi.RepoType = repoName(t, driverPath)
|
|
|
|
|
return pi
|
|
|
|
|
return paramInfo{
|
|
|
|
|
Name: name,
|
|
|
|
|
TypeStr: localName(t, driverPath),
|
|
|
|
|
RepoType: repoName(t, driverPath),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func makeResult(t types.Type, driverPath string) resultInfo {
|
|
|
|
@@ -266,133 +369,27 @@ func repoName(t types.Type, driverPath string) string {
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func collectTypes(pkg *types.Package) (models []string, params []string) {
|
|
|
|
|
for _, name := range pkg.Scope().Names() {
|
|
|
|
|
obj := pkg.Scope().Lookup(name)
|
|
|
|
|
if obj == nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
tn, ok := obj.(*types.TypeName)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
named, ok := tn.Type().(*types.Named)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if _, ok := named.Underlying().(*types.Struct); !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
switch name {
|
|
|
|
|
case "Queries", "DBTX", "Store":
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if strings.HasSuffix(name, "Params") {
|
|
|
|
|
params = append(params, name)
|
|
|
|
|
} else {
|
|
|
|
|
models = append(models, name)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// validateStructShapes checks that every model/params struct in the driver
|
|
|
|
|
// package has fields that exactly match the corresponding type in the repo
|
|
|
|
|
// (parent) package. This catches drift between sqlc-generated types and the
|
|
|
|
|
// canonical repository types before a broken cast reaches the compiler.
|
|
|
|
|
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
|
|
|
|
|
var errs []string
|
|
|
|
|
for _, name := range driverPkg.Scope().Names() {
|
|
|
|
|
obj := driverPkg.Scope().Lookup(name)
|
|
|
|
|
if obj == nil {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
tn, ok := obj.(*types.TypeName)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
named, ok := tn.Type().(*types.Named)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
driverStruct, ok := named.Underlying().(*types.Struct)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
switch name {
|
|
|
|
|
case "Queries", "DBTX", "Store":
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
repoObj := repoPkg.Scope().Lookup(name)
|
|
|
|
|
if repoObj == nil {
|
|
|
|
|
// Driver has a type not in repo — that's fine (e.g. internal helpers).
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
repoNamed, ok := repoObj.Type().(*types.Named)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
repoStruct, ok := repoNamed.Underlying().(*types.Struct)
|
|
|
|
|
if !ok {
|
|
|
|
|
errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name))
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := compareStructs(name, driverStruct, repoStruct); err != nil {
|
|
|
|
|
errs = append(errs, err.Error())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if len(errs) > 0 {
|
|
|
|
|
return fmt.Errorf("%s", strings.Join(errs, "\n "))
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func compareStructs(name string, driver, repo *types.Struct) error {
|
|
|
|
|
if driver.NumFields() != repo.NumFields() {
|
|
|
|
|
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
|
|
|
|
|
name, driver.NumFields(), repo.NumFields())
|
|
|
|
|
}
|
|
|
|
|
for i := range driver.NumFields() {
|
|
|
|
|
df := driver.Field(i)
|
|
|
|
|
rf := repo.Field(i)
|
|
|
|
|
if df.Name() != rf.Name() {
|
|
|
|
|
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
|
|
|
|
|
name, i, df.Name(), rf.Name())
|
|
|
|
|
}
|
|
|
|
|
if !types.Identical(df.Type(), rf.Type()) {
|
|
|
|
|
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
|
|
|
|
|
name, df.Name(), df.Type(), rf.Type())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// converterFn: "Session" -> "sessionToRepo"
|
|
|
|
|
// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo".
|
|
|
|
|
func converterFn(s string) string {
|
|
|
|
|
if s == "" {
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
r := []rune(s)
|
|
|
|
|
r[0] = []rune(strings.ToLower(string(r[0])))[0]
|
|
|
|
|
return string(r) + "ToRepo"
|
|
|
|
|
return strings.ToLower(s[:1]) + s[1:] + "ToRepo"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// renderedMethod is the pre-built method body passed to the template.
|
|
|
|
|
// renderedMethod holds pre-built signature and body strings passed to the template.
|
|
|
|
|
type renderedMethod struct {
|
|
|
|
|
Signature string
|
|
|
|
|
Body string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// renderMethods converts []methodInfo into fully pre-rendered signature+body strings.
|
|
|
|
|
func renderMethods(methods []methodInfo) []renderedMethod {
|
|
|
|
|
var out []renderedMethod
|
|
|
|
|
for _, m := range methods {
|
|
|
|
|
out = append(out, renderedMethod{
|
|
|
|
|
out := make([]renderedMethod, len(methods))
|
|
|
|
|
for i, m := range methods {
|
|
|
|
|
out[i] = renderedMethod{
|
|
|
|
|
Signature: buildSig(m),
|
|
|
|
|
Body: buildBody(m),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
@@ -429,7 +426,7 @@ func buildSig(m methodInfo) string {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func callArgs(m methodInfo) string {
|
|
|
|
|
var args []string
|
|
|
|
|
args := make([]string, 0, len(m.Params))
|
|
|
|
|
for _, p := range m.Params {
|
|
|
|
|
if p.RepoType != "" {
|
|
|
|
|
// convert repo type → driver type: DriverType(arg)
|
|
|
|
@@ -444,25 +441,62 @@ func callArgs(m methodInfo) string {
|
|
|
|
|
return "ctx, " + strings.Join(args, ", ")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// bodyTemplates holds the per-shape method body templates, parsed once at init.
|
|
|
|
|
var bodyTemplates = template.Must(
|
|
|
|
|
template.New("bodies").Parse(`
|
|
|
|
|
{{define "void"}} return mapErr({{.Call}})
|
|
|
|
|
{{end}}
|
|
|
|
|
|
|
|
|
|
{{define "scalar"}} r, err := {{.Call}}
|
|
|
|
|
if err != nil {
|
|
|
|
|
return {{.RepoType}}{}, mapErr(err)
|
|
|
|
|
}
|
|
|
|
|
return {{.Converter}}(r), nil
|
|
|
|
|
{{end}}
|
|
|
|
|
|
|
|
|
|
{{define "slice"}} rows, err := {{.Call}}
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, mapErr(err)
|
|
|
|
|
}
|
|
|
|
|
out := make([]{{.RepoType}}, len(rows))
|
|
|
|
|
for i, row := range rows {
|
|
|
|
|
out[i] = {{.Converter}}(row)
|
|
|
|
|
}
|
|
|
|
|
return out, nil
|
|
|
|
|
{{end}}`),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type bodyData struct {
|
|
|
|
|
Call string
|
|
|
|
|
RepoType string
|
|
|
|
|
Converter string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func buildBody(m methodInfo) string {
|
|
|
|
|
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
|
|
|
|
|
|
|
|
|
|
// no repo-typed result → direct return
|
|
|
|
|
if len(m.Results) == 0 || m.Results[0].RepoType == "" {
|
|
|
|
|
return "\treturn mapErr(" + call + ")\n"
|
|
|
|
|
var (
|
|
|
|
|
name string
|
|
|
|
|
data bodyData
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
switch {
|
|
|
|
|
case len(m.Results) == 0 || m.Results[0].RepoType == "":
|
|
|
|
|
name = "void"
|
|
|
|
|
data = bodyData{Call: call}
|
|
|
|
|
case m.Results[0].IsSlice:
|
|
|
|
|
name = "slice"
|
|
|
|
|
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
|
|
|
|
|
default:
|
|
|
|
|
name = "scalar"
|
|
|
|
|
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
r := m.Results[0]
|
|
|
|
|
if r.IsSlice {
|
|
|
|
|
return fmt.Sprintf(
|
|
|
|
|
"\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n",
|
|
|
|
|
call, r.RepoType, converterFn(r.TypeStr),
|
|
|
|
|
)
|
|
|
|
|
var buf bytes.Buffer
|
|
|
|
|
if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil {
|
|
|
|
|
panic(fmt.Sprintf("buildBody %s: %v", name, err))
|
|
|
|
|
}
|
|
|
|
|
return fmt.Sprintf(
|
|
|
|
|
"\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n",
|
|
|
|
|
call, r.RepoType, converterFn(r.TypeStr),
|
|
|
|
|
)
|
|
|
|
|
return buf.String()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type tmplData struct {
|
|
|
|
@@ -472,53 +506,6 @@ type tmplData struct {
|
|
|
|
|
Methods []renderedMethod
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
|
|
|
|
package {{.PkgName}}
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"database/sql"
|
|
|
|
|
"errors"
|
|
|
|
|
|
|
|
|
|
"{{.RepoPkg}}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Store wraps *Queries and implements repository.Store.
|
|
|
|
|
type Store struct {
|
|
|
|
|
q *Queries
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewStore wraps a *Queries to satisfy repository.Store.
|
|
|
|
|
func NewStore(q *Queries) repository.Store {
|
|
|
|
|
return &Store{q: q}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var errMap = []struct {
|
|
|
|
|
from error
|
|
|
|
|
to error
|
|
|
|
|
}{
|
|
|
|
|
{sql.ErrNoRows, repository.ErrNotFound},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func mapErr(err error) error {
|
|
|
|
|
for _, e := range errMap {
|
|
|
|
|
if errors.Is(err, e.from) {
|
|
|
|
|
return e.to
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{{range .ModelTypes -}}
|
|
|
|
|
func {{converterFn .}}(v {{.}}) repository.{{.}} {
|
|
|
|
|
return repository.{{.}}(v)
|
|
|
|
|
}
|
|
|
|
|
{{end -}}
|
|
|
|
|
{{range .Methods}}{{.Signature}} {
|
|
|
|
|
{{.Body}}}
|
|
|
|
|
|
|
|
|
|
{{end}}`
|
|
|
|
|
|
|
|
|
|
func render(data tmplData) ([]byte, error) {
|
|
|
|
|
t, err := template.New("store").Funcs(template.FuncMap{
|
|
|
|
|
"converterFn": converterFn,
|