diff --git a/Makefile b/Makefile index 80e7a629..88022407 100644 --- a/Makefile +++ b/Makefile @@ -84,4 +84,5 @@ sql: # Go gen generate: + go run ./gen go generate ./internal/repository/... diff --git a/cmd/gen/sqlc-wrapper/store.tmpl b/cmd/gen/sqlc-wrapper/store.tmpl deleted file mode 100644 index 02bb6fb1..00000000 --- a/cmd/gen/sqlc-wrapper/store.tmpl +++ /dev/null @@ -1,46 +0,0 @@ -// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. -package {{.PkgName}} - -import ( - "context" - "database/sql" - "errors" - - "{{.RepoPkg}}" -) - -// Store wraps *Queries and implements repository.Store. -type Store struct { - q *Queries -} - -// NewStore wraps a *Queries to satisfy repository.Store. -func NewStore(q *Queries) repository.Store { - return &Store{q: q} -} - -var errMap = []struct { - from error - to error -}{ - {sql.ErrNoRows, repository.ErrNotFound}, -} - -func mapErr(err error) error { - for _, e := range errMap { - if errors.Is(err, e.from) { - return e.to - } - } - return err -} - -{{range .ModelTypes -}} -func {{converterFn .}}(v {{.}}) repository.{{.}} { - return repository.{{.}}(v) -} -{{end -}} -{{range .Methods}}{{.Signature}} { -{{.Body}}} - -{{end}} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 8dd5fc1d..eb4f7ce8 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -1,9 +1,10 @@ services: traefik: image: traefik:v3.6 - command: --api.insecure=true --providers.docker + command: --api.insecure=true --providers.docker --entrypoints.web.address=:80 --entrypoints.websecure.address=:443 ports: - 80:80 + - 443:443 volumes: - /var/run/docker.sock:/var/run/docker.sock @@ -25,6 +26,8 @@ services: labels: traefik.enable: true traefik.http.routers.tinyauth.rule: Host(`tinyauth.127.0.0.1.sslip.io`) + traefik.http.routers.tinyauth.entrypoints: websecure + traefik.http.routers.tinyauth.tls: true tinyauth-backend: build: diff --git a/cmd/gen/sqlc-wrapper/sqlc_wrapper.go b/gen/sqlc-wrapper/sqlc_wrapper.go similarity index 86% rename from cmd/gen/sqlc-wrapper/sqlc_wrapper.go rename to gen/sqlc-wrapper/sqlc_wrapper.go index 0592d20c..a7a75eb4 100644 --- a/cmd/gen/sqlc-wrapper/sqlc_wrapper.go +++ b/gen/sqlc-wrapper/sqlc_wrapper.go @@ -32,6 +32,7 @@ import ( var storeSrc string func main() { + fmt.Println("sqlc-wrapper: generating store.go files for sqlc driver packages...") if err := run(); err != nil { log.Fatal(err) } @@ -88,13 +89,11 @@ func run() error { if err != nil { return err } - models, _ := collectTypes(driverTypePkg) src, err := render(tmplData{ - PkgName: driverTypePkg.Name(), - RepoPkg: repoPkgPath, - ModelTypes: models, - Methods: renderMethods(methods), + PkgName: driverTypePkg.Name(), + RepoPkg: repoPkgPath, + Methods: renderMethods(methods), }) if err != nil { return fmt.Errorf("render: %w", err) @@ -260,19 +259,6 @@ func compareStructs(name string, driver, repo *types.Struct) error { return nil } -// collectTypes returns model and params struct names from the driver package. -func collectTypes(pkg *types.Package) (models []string, params []string) { - names, _ := scopeStructs(pkg) - for _, name := range names { - if strings.HasSuffix(name, "Params") { - params = append(params, name) - } else { - models = append(models, name) - } - } - return -} - type methodInfo struct { Name string Params []paramInfo @@ -369,14 +355,6 @@ func repoName(t types.Type, driverPath string) string { return "" } -// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo". -func converterFn(s string) string { - if s == "" { - return "" - } - return strings.ToLower(s[:1]) + s[1:] + "ToRepo" -} - // renderedMethod holds pre-built signature and body strings passed to the template. type renderedMethod struct { Signature string @@ -441,35 +419,11 @@ func callArgs(m methodInfo) string { return "ctx, " + strings.Join(args, ", ") } -// bodyTemplates holds the per-shape method body templates, parsed once at init. -var bodyTemplates = template.Must( - template.New("bodies").Parse(` -{{define "void"}} return mapErr({{.Call}}) -{{end}} - -{{define "scalar"}} r, err := {{.Call}} - if err != nil { - return {{.RepoType}}{}, mapErr(err) - } - return {{.Converter}}(r), nil -{{end}} - -{{define "slice"}} rows, err := {{.Call}} - if err != nil { - return nil, mapErr(err) - } - out := make([]{{.RepoType}}, len(rows)) - for i, row := range rows { - out[i] = {{.Converter}}(row) - } - return out, nil -{{end}}`), -) +var bodyTmpl = template.Must(template.New("store").Parse(storeSrc)) type bodyData struct { - Call string - RepoType string - Converter string + Call string + RepoType string } func buildBody(m methodInfo) string { @@ -486,36 +440,28 @@ func buildBody(m methodInfo) string { data = bodyData{Call: call} case m.Results[0].IsSlice: name = "slice" - data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} + data = bodyData{Call: call, RepoType: m.Results[0].RepoType} default: name = "scalar" - data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} + data = bodyData{Call: call, RepoType: m.Results[0].RepoType} } var buf bytes.Buffer - if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil { + if err := bodyTmpl.ExecuteTemplate(&buf, name, data); err != nil { panic(fmt.Sprintf("buildBody %s: %v", name, err)) } return buf.String() } type tmplData struct { - PkgName string - RepoPkg string - ModelTypes []string - Methods []renderedMethod + PkgName string + RepoPkg string + Methods []renderedMethod } func render(data tmplData) ([]byte, error) { - t, err := template.New("store").Funcs(template.FuncMap{ - "converterFn": converterFn, - }).Parse(storeSrc) - if err != nil { - return nil, fmt.Errorf("parse template: %w", err) - } - var buf bytes.Buffer - if err := t.Execute(&buf, data); err != nil { + if err := bodyTmpl.Execute(&buf, data); err != nil { return nil, fmt.Errorf("execute template: %w", err) } diff --git a/gen/sqlc-wrapper/store.tmpl b/gen/sqlc-wrapper/store.tmpl new file mode 100644 index 00000000..fa4acf01 --- /dev/null +++ b/gen/sqlc-wrapper/store.tmpl @@ -0,0 +1,59 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + "database/sql" + "errors" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +var errorMap = map[error]error{ + sql.ErrNoRows: repository.ErrNotFound, +} + +func mapErr(err error) error { + for from, to := range errorMap { + if errors.Is(err, from) { + return to + } + } + return err +} + +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}} + +{{- define "void"}} return mapErr({{.Call}}) +{{end}} + +{{- define "scalar"}} r, err := {{.Call}} + if err != nil { + return {{.RepoType}}{}, mapErr(err) + } + return {{.RepoType}}(r), nil +{{end}} + +{{- define "slice"}} rows, err := {{.Call}} + if err != nil { + return nil, mapErr(err) + } + out := make([]{{.RepoType}}, len(rows)) + for i, row := range rows { + out[i] = {{.RepoType}}(row) + } + return out, nil +{{end}} diff --git a/internal/repository/sqlite/generate.go b/internal/repository/sqlite/generate.go index 5f6011f1..ed695567 100644 --- a/internal/repository/sqlite/generate.go +++ b/internal/repository/sqlite/generate.go @@ -1,3 +1,3 @@ package sqlite -//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite +//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index f316efa4..e7ce1792 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -19,40 +19,25 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } -var errMap = []struct { - from error - to error -}{ - {sql.ErrNoRows, repository.ErrNotFound}, +var errorMap = map[error]error{ + sql.ErrNoRows: repository.ErrNotFound, } func mapErr(err error) error { - for _, e := range errMap { - if errors.Is(err, e.from) { - return e.to + for from, to := range errorMap { + if errors.Is(err, from) { + return to } } return err } -func oidcCodeToRepo(v OidcCode) repository.OidcCode { - return repository.OidcCode(v) -} -func oidcTokenToRepo(v OidcToken) repository.OidcToken { - return repository.OidcToken(v) -} -func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo { - return repository.OidcUserinfo(v) -} -func sessionToRepo(v Session) repository.Session { - return repository.Session(v) -} func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) if err != nil { return repository.OidcCode{}, mapErr(err) } - return oidcCodeToRepo(r), nil + return repository.OidcCode(r), nil } func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { @@ -60,7 +45,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo if err != nil { return repository.OidcToken{}, mapErr(err) } - return oidcTokenToRepo(r), nil + return repository.OidcToken(r), nil } func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { @@ -68,7 +53,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid if err != nil { return repository.OidcUserinfo{}, mapErr(err) } - return oidcUserinfoToRepo(r), nil + return repository.OidcUserinfo(r), nil } func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { @@ -76,7 +61,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP if err != nil { return repository.Session{}, mapErr(err) } - return sessionToRepo(r), nil + return repository.Session(r), nil } func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { @@ -86,7 +71,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([] } out := make([]repository.OidcCode, len(rows)) for i, row := range rows { - out[i] = oidcCodeToRepo(row) + out[i] = repository.OidcCode(row) } return out, nil } @@ -98,7 +83,7 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele } out := make([]repository.OidcToken, len(rows)) for i, row := range rows { - out[i] = oidcTokenToRepo(row) + out[i] = repository.OidcToken(row) } return out, nil } @@ -140,7 +125,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi if err != nil { return repository.OidcCode{}, mapErr(err) } - return oidcCodeToRepo(r), nil + return repository.OidcCode(r), nil } func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { @@ -148,7 +133,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi if err != nil { return repository.OidcCode{}, mapErr(err) } - return oidcCodeToRepo(r), nil + return repository.OidcCode(r), nil } func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { @@ -156,7 +141,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit if err != nil { return repository.OidcCode{}, mapErr(err) } - return oidcCodeToRepo(r), nil + return repository.OidcCode(r), nil } func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { @@ -164,7 +149,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit if err != nil { return repository.OidcCode{}, mapErr(err) } - return oidcCodeToRepo(r), nil + return repository.OidcCode(r), nil } func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { @@ -172,7 +157,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos if err != nil { return repository.OidcToken{}, mapErr(err) } - return oidcTokenToRepo(r), nil + return repository.OidcToken(r), nil } func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { @@ -180,7 +165,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash if err != nil { return repository.OidcToken{}, mapErr(err) } - return oidcTokenToRepo(r), nil + return repository.OidcToken(r), nil } func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { @@ -188,7 +173,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O if err != nil { return repository.OidcToken{}, mapErr(err) } - return oidcTokenToRepo(r), nil + return repository.OidcToken(r), nil } func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { @@ -196,7 +181,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid if err != nil { return repository.OidcUserinfo{}, mapErr(err) } - return oidcUserinfoToRepo(r), nil + return repository.OidcUserinfo(r), nil } func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { @@ -204,7 +189,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session if err != nil { return repository.Session{}, mapErr(err) } - return sessionToRepo(r), nil + return repository.Session(r), nil } func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { @@ -212,7 +197,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor if err != nil { return repository.OidcToken{}, mapErr(err) } - return oidcTokenToRepo(r), nil + return repository.OidcToken(r), nil } func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { @@ -220,5 +205,5 @@ func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionP if err != nil { return repository.Session{}, mapErr(err) } - return sessionToRepo(r), nil + return repository.Session(r), nil }