more wrapper cleanup

This commit is contained in:
Scott McKendry
2026-05-15 19:48:41 +12:00
parent 4149084329
commit 5f5b188511
7 changed files with 101 additions and 153 deletions
+1
View File
@@ -84,4 +84,5 @@ sql:
# Go gen # Go gen
generate: generate:
go run ./gen
go generate ./internal/repository/... go generate ./internal/repository/...
-46
View File
@@ -1,46 +0,0 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package {{.PkgName}}
import (
"context"
"database/sql"
"errors"
"{{.RepoPkg}}"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
{{range .ModelTypes -}}
func {{converterFn .}}(v {{.}}) repository.{{.}} {
return repository.{{.}}(v)
}
{{end -}}
{{range .Methods}}{{.Signature}} {
{{.Body}}}
{{end}}
+4 -1
View File
@@ -1,9 +1,10 @@
services: services:
traefik: traefik:
image: traefik:v3.6 image: traefik:v3.6
command: --api.insecure=true --providers.docker command: --api.insecure=true --providers.docker --entrypoints.web.address=:80 --entrypoints.websecure.address=:443
ports: ports:
- 80:80 - 80:80
- 443:443
volumes: volumes:
- /var/run/docker.sock:/var/run/docker.sock - /var/run/docker.sock:/var/run/docker.sock
@@ -25,6 +26,8 @@ services:
labels: labels:
traefik.enable: true traefik.enable: true
traefik.http.routers.tinyauth.rule: Host(`tinyauth.127.0.0.1.sslip.io`) traefik.http.routers.tinyauth.rule: Host(`tinyauth.127.0.0.1.sslip.io`)
traefik.http.routers.tinyauth.entrypoints: websecure
traefik.http.routers.tinyauth.tls: true
tinyauth-backend: tinyauth-backend:
build: build:
@@ -32,6 +32,7 @@ import (
var storeSrc string var storeSrc string
func main() { func main() {
fmt.Println("sqlc-wrapper: generating store.go files for sqlc driver packages...")
if err := run(); err != nil { if err := run(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@@ -88,13 +89,11 @@ func run() error {
if err != nil { if err != nil {
return err return err
} }
models, _ := collectTypes(driverTypePkg)
src, err := render(tmplData{ src, err := render(tmplData{
PkgName: driverTypePkg.Name(), PkgName: driverTypePkg.Name(),
RepoPkg: repoPkgPath, RepoPkg: repoPkgPath,
ModelTypes: models, Methods: renderMethods(methods),
Methods: renderMethods(methods),
}) })
if err != nil { if err != nil {
return fmt.Errorf("render: %w", err) return fmt.Errorf("render: %w", err)
@@ -260,19 +259,6 @@ func compareStructs(name string, driver, repo *types.Struct) error {
return nil return nil
} }
// collectTypes returns model and params struct names from the driver package.
func collectTypes(pkg *types.Package) (models []string, params []string) {
names, _ := scopeStructs(pkg)
for _, name := range names {
if strings.HasSuffix(name, "Params") {
params = append(params, name)
} else {
models = append(models, name)
}
}
return
}
type methodInfo struct { type methodInfo struct {
Name string Name string
Params []paramInfo Params []paramInfo
@@ -369,14 +355,6 @@ func repoName(t types.Type, driverPath string) string {
return "" return ""
} }
// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo".
func converterFn(s string) string {
if s == "" {
return ""
}
return strings.ToLower(s[:1]) + s[1:] + "ToRepo"
}
// renderedMethod holds pre-built signature and body strings passed to the template. // renderedMethod holds pre-built signature and body strings passed to the template.
type renderedMethod struct { type renderedMethod struct {
Signature string Signature string
@@ -441,35 +419,11 @@ func callArgs(m methodInfo) string {
return "ctx, " + strings.Join(args, ", ") return "ctx, " + strings.Join(args, ", ")
} }
// bodyTemplates holds the per-shape method body templates, parsed once at init. var bodyTmpl = template.Must(template.New("store").Parse(storeSrc))
var bodyTemplates = template.Must(
template.New("bodies").Parse(`
{{define "void"}} return mapErr({{.Call}})
{{end}}
{{define "scalar"}} r, err := {{.Call}}
if err != nil {
return {{.RepoType}}{}, mapErr(err)
}
return {{.Converter}}(r), nil
{{end}}
{{define "slice"}} rows, err := {{.Call}}
if err != nil {
return nil, mapErr(err)
}
out := make([]{{.RepoType}}, len(rows))
for i, row := range rows {
out[i] = {{.Converter}}(row)
}
return out, nil
{{end}}`),
)
type bodyData struct { type bodyData struct {
Call string Call string
RepoType string RepoType string
Converter string
} }
func buildBody(m methodInfo) string { func buildBody(m methodInfo) string {
@@ -486,36 +440,28 @@ func buildBody(m methodInfo) string {
data = bodyData{Call: call} data = bodyData{Call: call}
case m.Results[0].IsSlice: case m.Results[0].IsSlice:
name = "slice" name = "slice"
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} data = bodyData{Call: call, RepoType: m.Results[0].RepoType}
default: default:
name = "scalar" name = "scalar"
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} data = bodyData{Call: call, RepoType: m.Results[0].RepoType}
} }
var buf bytes.Buffer var buf bytes.Buffer
if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil { if err := bodyTmpl.ExecuteTemplate(&buf, name, data); err != nil {
panic(fmt.Sprintf("buildBody %s: %v", name, err)) panic(fmt.Sprintf("buildBody %s: %v", name, err))
} }
return buf.String() return buf.String()
} }
type tmplData struct { type tmplData struct {
PkgName string PkgName string
RepoPkg string RepoPkg string
ModelTypes []string Methods []renderedMethod
Methods []renderedMethod
} }
func render(data tmplData) ([]byte, error) { func render(data tmplData) ([]byte, error) {
t, err := template.New("store").Funcs(template.FuncMap{
"converterFn": converterFn,
}).Parse(storeSrc)
if err != nil {
return nil, fmt.Errorf("parse template: %w", err)
}
var buf bytes.Buffer var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil { if err := bodyTmpl.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute template: %w", err) return nil, fmt.Errorf("execute template: %w", err)
} }
+59
View File
@@ -0,0 +1,59 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package {{.PkgName}}
import (
"context"
"database/sql"
"errors"
"{{.RepoPkg}}"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errorMap = map[error]error{
sql.ErrNoRows: repository.ErrNotFound,
}
func mapErr(err error) error {
for from, to := range errorMap {
if errors.Is(err, from) {
return to
}
}
return err
}
{{range .Methods}}{{.Signature}} {
{{.Body}}}
{{end}}
{{- define "void"}} return mapErr({{.Call}})
{{end}}
{{- define "scalar"}} r, err := {{.Call}}
if err != nil {
return {{.RepoType}}{}, mapErr(err)
}
return {{.RepoType}}(r), nil
{{end}}
{{- define "slice"}} rows, err := {{.Call}}
if err != nil {
return nil, mapErr(err)
}
out := make([]{{.RepoType}}, len(rows))
for i, row := range rows {
out[i] = {{.RepoType}}(row)
}
return out, nil
{{end}}
+1 -1
View File
@@ -1,3 +1,3 @@
package sqlite package sqlite
//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite //go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
+22 -37
View File
@@ -19,40 +19,25 @@ func NewStore(q *Queries) repository.Store {
return &Store{q: q} return &Store{q: q}
} }
var errMap = []struct { var errorMap = map[error]error{
from error sql.ErrNoRows: repository.ErrNotFound,
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
} }
func mapErr(err error) error { func mapErr(err error) error {
for _, e := range errMap { for from, to := range errorMap {
if errors.Is(err, e.from) { if errors.Is(err, from) {
return e.to return to
} }
} }
return err return err
} }
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
return repository.OidcCode(v)
}
func oidcTokenToRepo(v OidcToken) repository.OidcToken {
return repository.OidcToken(v)
}
func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo {
return repository.OidcUserinfo(v)
}
func sessionToRepo(v Session) repository.Session {
return repository.Session(v)
}
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil { if err != nil {
return repository.OidcCode{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return oidcCodeToRepo(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
@@ -60,7 +45,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo
if err != nil { if err != nil {
return repository.OidcToken{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return oidcTokenToRepo(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
@@ -68,7 +53,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid
if err != nil { if err != nil {
return repository.OidcUserinfo{}, mapErr(err) return repository.OidcUserinfo{}, mapErr(err)
} }
return oidcUserinfoToRepo(r), nil return repository.OidcUserinfo(r), nil
} }
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
@@ -76,7 +61,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
if err != nil { if err != nil {
return repository.Session{}, mapErr(err) return repository.Session{}, mapErr(err)
} }
return sessionToRepo(r), nil return repository.Session(r), nil
} }
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
@@ -86,7 +71,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]
} }
out := make([]repository.OidcCode, len(rows)) out := make([]repository.OidcCode, len(rows))
for i, row := range rows { for i, row := range rows {
out[i] = oidcCodeToRepo(row) out[i] = repository.OidcCode(row)
} }
return out, nil return out, nil
} }
@@ -98,7 +83,7 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele
} }
out := make([]repository.OidcToken, len(rows)) out := make([]repository.OidcToken, len(rows))
for i, row := range rows { for i, row := range rows {
out[i] = oidcTokenToRepo(row) out[i] = repository.OidcToken(row)
} }
return out, nil return out, nil
} }
@@ -140,7 +125,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi
if err != nil { if err != nil {
return repository.OidcCode{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return oidcCodeToRepo(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
@@ -148,7 +133,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi
if err != nil { if err != nil {
return repository.OidcCode{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return oidcCodeToRepo(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
@@ -156,7 +141,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit
if err != nil { if err != nil {
return repository.OidcCode{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return oidcCodeToRepo(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
@@ -164,7 +149,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit
if err != nil { if err != nil {
return repository.OidcCode{}, mapErr(err) return repository.OidcCode{}, mapErr(err)
} }
return oidcCodeToRepo(r), nil return repository.OidcCode(r), nil
} }
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
@@ -172,7 +157,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos
if err != nil { if err != nil {
return repository.OidcToken{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return oidcTokenToRepo(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
@@ -180,7 +165,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash
if err != nil { if err != nil {
return repository.OidcToken{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return oidcTokenToRepo(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
@@ -188,7 +173,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O
if err != nil { if err != nil {
return repository.OidcToken{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return oidcTokenToRepo(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
@@ -196,7 +181,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid
if err != nil { if err != nil {
return repository.OidcUserinfo{}, mapErr(err) return repository.OidcUserinfo{}, mapErr(err)
} }
return oidcUserinfoToRepo(r), nil return repository.OidcUserinfo(r), nil
} }
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
@@ -204,7 +189,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
if err != nil { if err != nil {
return repository.Session{}, mapErr(err) return repository.Session{}, mapErr(err)
} }
return sessionToRepo(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
@@ -212,7 +197,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor
if err != nil { if err != nil {
return repository.OidcToken{}, mapErr(err) return repository.OidcToken{}, mapErr(err)
} }
return oidcTokenToRepo(r), nil return repository.OidcToken(r), nil
} }
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
@@ -220,5 +205,5 @@ func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionP
if err != nil { if err != nil {
return repository.Session{}, mapErr(err) return repository.Session{}, mapErr(err)
} }
return sessionToRepo(r), nil return repository.Session(r), nil
} }