mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 14:28:12 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 374b87964f | |||
| db911a41c3 | |||
| ef8bbd8c9f | |||
| 9566105245 | |||
| 7e9d21462d |
@@ -26,6 +26,12 @@ jobs:
|
||||
- name: Go dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Check codegen is up to date
|
||||
run: |
|
||||
go generate ./internal/repository/...
|
||||
git diff --exit-code -- internal/repository/
|
||||
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
|
||||
|
||||
- name: Install frontend dependencies
|
||||
run: |
|
||||
cd frontend
|
||||
@@ -49,11 +55,6 @@ jobs:
|
||||
run: |
|
||||
cp -r frontend/dist internal/assets/dist
|
||||
|
||||
- name: Lint backend
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.12
|
||||
|
||||
- name: Run tests
|
||||
run: go test -coverprofile=coverage.txt -v ./...
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
version: "2"
|
||||
linters:
|
||||
settings:
|
||||
errcheck:
|
||||
exclude-functions:
|
||||
- (http.ResponseWriter).Write
|
||||
- (http.ResponseWriter).WriteString
|
||||
- (github.com/gin-gonic/gin.ResponseWriter).Write
|
||||
- (github.com/gin-gonic/gin.ResponseWriter).WriteString
|
||||
exclusions:
|
||||
rules:
|
||||
- linters:
|
||||
- errcheck
|
||||
text: "//nolint:errcheck"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "//nolint:staticcheck"
|
||||
@@ -84,4 +84,4 @@ sql:
|
||||
|
||||
# Go gen
|
||||
generate:
|
||||
go run ./gen
|
||||
go generate ./internal/repository/...
|
||||
|
||||
@@ -0,0 +1,527 @@
|
||||
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
|
||||
// internal/repository/<driver>/. Run via:
|
||||
//
|
||||
// go generate ./internal/repository/...
|
||||
//
|
||||
// The generator introspects *Queries methods and the model/params types in the
|
||||
// driver package, then emits a store.go that wraps *Queries so it satisfies
|
||||
// repository.Store using the canonical shared types in the parent package.
|
||||
// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should
|
||||
// implement repository.Store directly by hand.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/types"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
//go:embed store.tmpl
|
||||
var storeSrc string
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
driverPkg := flag.String("pkg", "", "import path of the driver package")
|
||||
out := flag.String("out", "store.go", "output filename relative to driver package directory")
|
||||
flag.Parse()
|
||||
|
||||
if *driverPkg == "" {
|
||||
return fmt.Errorf("-pkg is required")
|
||||
}
|
||||
|
||||
// Resolve the driver package directory so we can overlay the output file
|
||||
// with a valid stub. This prevents a stale store.go from poisoning the
|
||||
// type-checker and producing cryptic "undefined" errors.
|
||||
driverDir, err := pkgDir(*driverPkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve driver dir: %w", err)
|
||||
}
|
||||
|
||||
outPath := filepath.Join(driverDir, *out)
|
||||
if filepath.IsAbs(*out) {
|
||||
outPath = *out
|
||||
}
|
||||
|
||||
// Stub replaces the output file during load so stale generated code is ignored.
|
||||
stub := []byte("package " + filepath.Base(driverDir) + "\n")
|
||||
cfg := &packages.Config{
|
||||
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
|
||||
Overlay: map[string][]byte{outPath: stub},
|
||||
}
|
||||
|
||||
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load driver package: %w", err)
|
||||
}
|
||||
|
||||
repoPkgPath := parentPkg(*driverPkg)
|
||||
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load repo package: %w", err)
|
||||
}
|
||||
|
||||
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
||||
return fmt.Errorf("struct shape mismatch: %w", err)
|
||||
}
|
||||
if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
methods, err := collectMethods(driverTypePkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
models, _ := collectTypes(driverTypePkg)
|
||||
|
||||
src, err := render(tmplData{
|
||||
PkgName: driverTypePkg.Name(),
|
||||
RepoPkg: repoPkgPath,
|
||||
ModelTypes: models,
|
||||
Methods: renderMethods(methods),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("render: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outPath, src, 0644); err != nil {
|
||||
return fmt.Errorf("write %s: %w", outPath, err)
|
||||
}
|
||||
fmt.Printf("wrote %s\n", outPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
||||
// or an error if the package fails to load or has type errors.
|
||||
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
||||
pkgs, err := packages.Load(cfg, importPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
||||
}
|
||||
if len(pkgs) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
||||
}
|
||||
pkg := pkgs[0]
|
||||
if len(pkg.Errors) > 0 {
|
||||
msgs := make([]string, len(pkg.Errors))
|
||||
for i, e := range pkg.Errors {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
||||
}
|
||||
return pkg.Types, nil
|
||||
}
|
||||
|
||||
// parentPkg returns the parent import path (everything before the last /).
|
||||
// Panics if imp contains no slash — callers are expected to pass driver sub-packages.
|
||||
func parentPkg(imp string) string {
|
||||
i := strings.LastIndex(imp, "/")
|
||||
if i < 0 {
|
||||
panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp))
|
||||
}
|
||||
return imp[:i]
|
||||
}
|
||||
|
||||
// pkgDir returns the on-disk directory for an import path using `go list`.
|
||||
func pkgDir(importPath string) (string, error) {
|
||||
out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("go list %s: %w", importPath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
// scopeStructs returns all named struct types in pkg, excluding the internal
|
||||
// sqlc types Queries, DBTX, and Store. Names are returned in sorted order.
|
||||
func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) {
|
||||
byName = make(map[string]*types.Struct)
|
||||
for _, name := range pkg.Scope().Names() { // Names() is already sorted
|
||||
switch name {
|
||||
case "Queries", "DBTX", "Store":
|
||||
continue
|
||||
}
|
||||
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
named, ok := obj.Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s, ok := named.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
byName[name] = s
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// validateStoreCoverage checks that every method declared in repository.Store
|
||||
// exists on *Queries in the driver package. Missing methods are reported by
|
||||
// name so the developer knows exactly which SQL queries need to be added.
|
||||
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
|
||||
queriesObj := driverPkg.Scope().Lookup("Queries")
|
||||
if queriesObj == nil {
|
||||
return fmt.Errorf("queries type not found in driver package")
|
||||
}
|
||||
queriesNamed := queriesObj.Type().(*types.Named)
|
||||
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
|
||||
queriesMethods := make(map[string]bool)
|
||||
for m := range queriesMS.Methods() {
|
||||
queriesMethods[m.Obj().Name()] = true
|
||||
}
|
||||
|
||||
storeObj := repoPkg.Scope().Lookup("Store")
|
||||
if storeObj == nil {
|
||||
return fmt.Errorf("store type not found in repository package")
|
||||
}
|
||||
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
|
||||
if !ok {
|
||||
return fmt.Errorf("repository.Store is not an interface")
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for method := range storeIface.Methods() {
|
||||
if name := method.Name(); !queriesMethods[name] {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
sort.Strings(missing)
|
||||
return fmt.Errorf(
|
||||
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries",
|
||||
len(missing), strings.Join(missing, "\n - "),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStructShapes checks that every model/params struct in the driver
|
||||
// package has fields that exactly match the corresponding type in the repo
|
||||
// (parent) package. This catches drift between sqlc-generated types and the
|
||||
// canonical repository types before a broken cast reaches the compiler.
|
||||
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
|
||||
_, repoStructs := scopeStructs(repoPkg)
|
||||
driverNames, driverStructs := scopeStructs(driverPkg)
|
||||
|
||||
var errs []string
|
||||
for _, name := range driverNames {
|
||||
repoStruct, ok := repoStructs[name]
|
||||
if !ok {
|
||||
// Driver has a type not in repo — fine (e.g. internal helpers).
|
||||
continue
|
||||
}
|
||||
if err := compareStructs(name, driverStructs[name], repoStruct); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
sort.Strings(errs)
|
||||
return fmt.Errorf("%s", strings.Join(errs, "\n "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareStructs(name string, driver, repo *types.Struct) error {
|
||||
if driver.NumFields() != repo.NumFields() {
|
||||
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
|
||||
name, driver.NumFields(), repo.NumFields())
|
||||
}
|
||||
for i := range driver.NumFields() {
|
||||
df := driver.Field(i)
|
||||
rf := repo.Field(i)
|
||||
if df.Name() != rf.Name() {
|
||||
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
|
||||
name, i, df.Name(), rf.Name())
|
||||
}
|
||||
if !types.Identical(df.Type(), rf.Type()) {
|
||||
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
|
||||
name, df.Name(), df.Type(), rf.Type())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectTypes returns model and params struct names from the driver package.
|
||||
func collectTypes(pkg *types.Package) (models []string, params []string) {
|
||||
names, _ := scopeStructs(pkg)
|
||||
for _, name := range names {
|
||||
if strings.HasSuffix(name, "Params") {
|
||||
params = append(params, name)
|
||||
} else {
|
||||
models = append(models, name)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type methodInfo struct {
|
||||
Name string
|
||||
Params []paramInfo
|
||||
Results []resultInfo
|
||||
}
|
||||
|
||||
type paramInfo struct {
|
||||
Name string
|
||||
TypeStr string // local (unqualified) type name
|
||||
RepoType string // "repository.X" if this is a driver model/params type; else ""
|
||||
}
|
||||
|
||||
type resultInfo struct {
|
||||
TypeStr string
|
||||
IsSlice bool
|
||||
RepoType string // "repository.X" if driver type; else ""
|
||||
}
|
||||
|
||||
func collectMethods(pkg *types.Package) ([]methodInfo, error) {
|
||||
obj := pkg.Scope().Lookup("Queries")
|
||||
if obj == nil {
|
||||
return nil, fmt.Errorf("queries type not found in %s", pkg.Path())
|
||||
}
|
||||
named, ok := obj.Type().(*types.Named)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("queries is not a named type")
|
||||
}
|
||||
ms := types.NewMethodSet(types.NewPointer(named))
|
||||
|
||||
var out []methodInfo
|
||||
for method := range ms.Methods() {
|
||||
fn, ok := method.Obj().(*types.Func)
|
||||
if !ok || fn.Name() == "WithTx" {
|
||||
continue
|
||||
}
|
||||
sig := fn.Type().(*types.Signature)
|
||||
mi := methodInfo{Name: fn.Name()}
|
||||
|
||||
// params: skip receiver + first (context.Context)
|
||||
for i := 1; i < sig.Params().Len(); i++ {
|
||||
p := sig.Params().At(i)
|
||||
mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path()))
|
||||
}
|
||||
// results: skip error
|
||||
for r := range sig.Results().Variables() {
|
||||
if r.Type().String() == "error" {
|
||||
continue
|
||||
}
|
||||
mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path()))
|
||||
}
|
||||
out = append(out, mi)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func makeParam(name string, t types.Type, driverPath string) paramInfo {
|
||||
return paramInfo{
|
||||
Name: name,
|
||||
TypeStr: localName(t, driverPath),
|
||||
RepoType: repoName(t, driverPath),
|
||||
}
|
||||
}
|
||||
|
||||
func makeResult(t types.Type, driverPath string) resultInfo {
|
||||
ri := resultInfo{}
|
||||
if sl, ok := t.(*types.Slice); ok {
|
||||
ri.IsSlice = true
|
||||
t = sl.Elem()
|
||||
}
|
||||
ri.TypeStr = localName(t, driverPath)
|
||||
ri.RepoType = repoName(t, driverPath)
|
||||
return ri
|
||||
}
|
||||
|
||||
func localName(t types.Type, driverPath string) string {
|
||||
named, ok := t.(*types.Named)
|
||||
if !ok {
|
||||
return types.TypeString(t, nil)
|
||||
}
|
||||
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
|
||||
return named.Obj().Name()
|
||||
}
|
||||
return types.TypeString(t, func(p *types.Package) string { return p.Name() })
|
||||
}
|
||||
|
||||
func repoName(t types.Type, driverPath string) string {
|
||||
named, ok := t.(*types.Named)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
|
||||
return "repository." + named.Obj().Name()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo".
|
||||
func converterFn(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(s[:1]) + s[1:] + "ToRepo"
|
||||
}
|
||||
|
||||
// renderedMethod holds pre-built signature and body strings passed to the template.
|
||||
type renderedMethod struct {
|
||||
Signature string
|
||||
Body string
|
||||
}
|
||||
|
||||
func renderMethods(methods []methodInfo) []renderedMethod {
|
||||
out := make([]renderedMethod, len(methods))
|
||||
for i, m := range methods {
|
||||
out[i] = renderedMethod{
|
||||
Signature: buildSig(m),
|
||||
Body: buildBody(m),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildSig(m methodInfo) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("func (s *Store) ")
|
||||
sb.WriteString(m.Name)
|
||||
sb.WriteString("(ctx context.Context")
|
||||
for _, p := range m.Params {
|
||||
sb.WriteString(", ")
|
||||
sb.WriteString(p.Name)
|
||||
sb.WriteString(" ")
|
||||
if p.RepoType != "" {
|
||||
sb.WriteString(p.RepoType)
|
||||
} else {
|
||||
sb.WriteString(p.TypeStr)
|
||||
}
|
||||
}
|
||||
sb.WriteString(") (")
|
||||
for _, r := range m.Results {
|
||||
if r.IsSlice {
|
||||
sb.WriteString("[]")
|
||||
}
|
||||
if r.RepoType != "" {
|
||||
sb.WriteString(r.RepoType)
|
||||
} else {
|
||||
sb.WriteString(r.TypeStr)
|
||||
}
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString("error)")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func callArgs(m methodInfo) string {
|
||||
args := make([]string, 0, len(m.Params))
|
||||
for _, p := range m.Params {
|
||||
if p.RepoType != "" {
|
||||
// convert repo type → driver type: DriverType(arg)
|
||||
args = append(args, p.TypeStr+"("+p.Name+")")
|
||||
} else {
|
||||
args = append(args, p.Name)
|
||||
}
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return "ctx"
|
||||
}
|
||||
return "ctx, " + strings.Join(args, ", ")
|
||||
}
|
||||
|
||||
// bodyTemplates holds the per-shape method body templates, parsed once at init.
|
||||
var bodyTemplates = template.Must(
|
||||
template.New("bodies").Parse(`
|
||||
{{define "void"}} return mapErr({{.Call}})
|
||||
{{end}}
|
||||
|
||||
{{define "scalar"}} r, err := {{.Call}}
|
||||
if err != nil {
|
||||
return {{.RepoType}}{}, mapErr(err)
|
||||
}
|
||||
return {{.Converter}}(r), nil
|
||||
{{end}}
|
||||
|
||||
{{define "slice"}} rows, err := {{.Call}}
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]{{.RepoType}}, len(rows))
|
||||
for i, row := range rows {
|
||||
out[i] = {{.Converter}}(row)
|
||||
}
|
||||
return out, nil
|
||||
{{end}}`),
|
||||
)
|
||||
|
||||
type bodyData struct {
|
||||
Call string
|
||||
RepoType string
|
||||
Converter string
|
||||
}
|
||||
|
||||
func buildBody(m methodInfo) string {
|
||||
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
|
||||
|
||||
var (
|
||||
name string
|
||||
data bodyData
|
||||
)
|
||||
|
||||
switch {
|
||||
case len(m.Results) == 0 || m.Results[0].RepoType == "":
|
||||
name = "void"
|
||||
data = bodyData{Call: call}
|
||||
case m.Results[0].IsSlice:
|
||||
name = "slice"
|
||||
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
|
||||
default:
|
||||
name = "scalar"
|
||||
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil {
|
||||
panic(fmt.Sprintf("buildBody %s: %v", name, err))
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
type tmplData struct {
|
||||
PkgName string
|
||||
RepoPkg string
|
||||
ModelTypes []string
|
||||
Methods []renderedMethod
|
||||
}
|
||||
|
||||
func render(data tmplData) ([]byte, error) {
|
||||
t, err := template.New("store").Funcs(template.FuncMap{
|
||||
"converterFn": converterFn,
|
||||
}).Parse(storeSrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse template: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, data); err != nil {
|
||||
return nil, fmt.Errorf("execute template: %w", err)
|
||||
}
|
||||
|
||||
formatted, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String())
|
||||
}
|
||||
return formatted, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||
package {{.PkgName}}
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"{{.RepoPkg}}"
|
||||
)
|
||||
|
||||
// Store wraps *Queries and implements repository.Store.
|
||||
type Store struct {
|
||||
q *Queries
|
||||
}
|
||||
|
||||
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||
func NewStore(q *Queries) repository.Store {
|
||||
return &Store{q: q}
|
||||
}
|
||||
|
||||
var errMap = []struct {
|
||||
from error
|
||||
to error
|
||||
}{
|
||||
{sql.ErrNoRows, repository.ErrNotFound},
|
||||
}
|
||||
|
||||
func mapErr(err error) error {
|
||||
for _, e := range errMap {
|
||||
if errors.Is(err, e.from) {
|
||||
return e.to
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
{{range .ModelTypes -}}
|
||||
func {{converterFn .}}(v {{.}}) repository.{{.}} {
|
||||
return repository.{{.}}(v)
|
||||
}
|
||||
{{end -}}
|
||||
{{range .Methods}}{{.Signature}} {
|
||||
{{.Body}}}
|
||||
|
||||
{{end}}
|
||||
@@ -68,7 +68,10 @@ func generateTotpCmd() *cli.Command {
|
||||
return fmt.Errorf("failed to parse user: %w", err)
|
||||
}
|
||||
|
||||
docker := strings.Contains(tCfg.User, "$$")
|
||||
docker := false
|
||||
if strings.Contains(tCfg.User, "$$") {
|
||||
docker = true
|
||||
}
|
||||
|
||||
if user.TOTPSecret != "" {
|
||||
return fmt.Errorf("user already has a TOTP secret")
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/tinyauthapp/paerser/cli"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/tinyauthapp/paerser/cli"
|
||||
)
|
||||
|
||||
type healthzResponse struct {
|
||||
@@ -45,7 +45,7 @@ func healthcheckCmd() *cli.Command {
|
||||
}
|
||||
|
||||
if appUrl == "" {
|
||||
return errors.New("could not determine app url")
|
||||
return errors.New("Could not determine app URL")
|
||||
}
|
||||
|
||||
tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
|
||||
@@ -70,7 +70,7 @@ func healthcheckCmd() *cli.Command {
|
||||
return fmt.Errorf("service is not healthy, got: %s", resp.Status)
|
||||
}
|
||||
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
defer resp.Body.Close()
|
||||
|
||||
var healthResp healthzResponse
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ require (
|
||||
github.com/weppos/publicsuffix-go v0.50.3
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/tools v0.43.0
|
||||
k8s.io/apimachinery v0.36.0
|
||||
k8s.io/client-go v0.36.0
|
||||
modernc.org/sqlite v1.50.0
|
||||
@@ -121,6 +122,7 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
|
||||
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
|
||||
|
||||
// Migrations
|
||||
//
|
||||
//go:embed migrations/*.sql
|
||||
//go:embed migrations/sqlite/*.sql
|
||||
var Migrations embed.FS
|
||||
|
||||
@@ -144,17 +144,14 @@ func (app *BootstrapApp) Setup() error {
|
||||
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
|
||||
|
||||
// Database
|
||||
db, err := app.SetupDatabase(app.config.Database.Path)
|
||||
store, err := app.SetupStore()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup database: %w", err)
|
||||
}
|
||||
|
||||
// Queries
|
||||
queries := repository.New(db)
|
||||
|
||||
// Services
|
||||
services, err := app.initServices(queries)
|
||||
services, err := app.initServices(store)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize services: %w", err)
|
||||
@@ -210,7 +207,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
// Start db cleanup routine
|
||||
tlog.App.Debug().Msg("Starting database cleanup routine")
|
||||
go app.dbCleanupRoutine(queries)
|
||||
go app.dbCleanupRoutine(store)
|
||||
|
||||
// If analytics are not disabled, start heartbeat
|
||||
if app.config.Analytics.Enabled {
|
||||
@@ -292,7 +289,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
||||
continue
|
||||
}
|
||||
|
||||
res.Body.Close() //nolint:errcheck
|
||||
res.Body.Close()
|
||||
|
||||
if res.StatusCode != 200 && res.StatusCode != 201 {
|
||||
tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
||||
@@ -300,7 +297,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
|
||||
func (app *BootstrapApp) dbCleanupRoutine(queries repository.Store) {
|
||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
@@ -14,7 +17,18 @@ import (
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
|
||||
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||
switch app.config.Database.Driver {
|
||||
case "memory":
|
||||
return memory.New(), nil
|
||||
case "sqlite", "":
|
||||
return app.setupSQLite(app.config.Database.Path)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
|
||||
dir := filepath.Dir(databasePath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
@@ -31,7 +45,7 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
|
||||
// if the sqlite connection starts being a bottleneck
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
migrations, err := iofs.New(assets.Migrations, "migrations")
|
||||
migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||
@@ -53,5 +67,5 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
return sqlite.NewStore(sqlite.New(db)), nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ type Services struct {
|
||||
oidcService *service.OIDCService
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
|
||||
func (app *BootstrapApp) initServices(queries repository.Store) (Services, error) {
|
||||
services := Services{}
|
||||
|
||||
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||
@@ -36,7 +36,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
|
||||
ldapService.Unconfigure() //nolint:errcheck
|
||||
ldapService.Unconfigure()
|
||||
}
|
||||
|
||||
services.ldapService = ldapService
|
||||
|
||||
@@ -166,12 +166,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if user.Email == "" {
|
||||
tlog.App.Error().Msg("OAuth provider did not return an email")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
|
||||
@@ -14,10 +14,9 @@ import (
|
||||
"github.com/google/go-querystring/query"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
@@ -852,14 +851,10 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
store := memory.New()
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||
err = oidcService.Init()
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, store)
|
||||
err := oidcService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -881,9 +876,4 @@ func TestOIDCController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,23 +2,20 @@ package controller_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
func TestProxyController(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
@@ -400,15 +397,10 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
store := memory.New()
|
||||
|
||||
docker := service.NewDockerService()
|
||||
err = docker.Init()
|
||||
err := docker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
@@ -419,7 +411,7 @@ func TestProxyController(t *testing.T) {
|
||||
err = broker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
|
||||
err = authService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -444,9 +436,4 @@ func TestProxyController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -14,17 +13,16 @@ import (
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
func TestUserController(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
@@ -111,21 +109,14 @@ func TestUserController(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
middlewares []gin.HandlerFunc
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
store := memory.New()
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Should be able to login with valid credentials",
|
||||
@@ -179,7 +170,7 @@ func TestUserController(t *testing.T) {
|
||||
{
|
||||
description: "Should rate limit on 3 invalid attempts",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "wrongpassword",
|
||||
@@ -201,7 +192,7 @@ func TestUserController(t *testing.T) {
|
||||
}
|
||||
|
||||
// 4th attempt should be rate limited
|
||||
recorder = httptest.NewRecorder() //nolint:staticcheck
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
@@ -293,8 +284,8 @@ func TestUserController(t *testing.T) {
|
||||
middlewares: []gin.HandlerFunc{
|
||||
totpCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck
|
||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
UUID: "test-totp-login-uuid",
|
||||
Username: "test",
|
||||
Email: "test@example.com",
|
||||
@@ -316,7 +307,7 @@ func TestUserController(t *testing.T) {
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder() //nolint:staticcheck
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{
|
||||
@@ -345,7 +336,7 @@ func TestUserController(t *testing.T) {
|
||||
middlewares: []gin.HandlerFunc{
|
||||
totpCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
for range 3 {
|
||||
totpReq := controller.TotpRequest{
|
||||
Code: "000000", // invalid code
|
||||
@@ -354,7 +345,7 @@ func TestUserController(t *testing.T) {
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder() //nolint:staticcheck
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
@@ -418,7 +409,7 @@ func TestUserController(t *testing.T) {
|
||||
totpAttrCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
UUID: "test-totp-login-attributes-uuid",
|
||||
Username: "test",
|
||||
Email: "test@example.com",
|
||||
@@ -456,8 +447,10 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
docker := service.NewDockerService()
|
||||
err = docker.Init()
|
||||
err := docker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
@@ -468,7 +461,7 @@ func TestUserController(t *testing.T) {
|
||||
err = broker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
|
||||
err = authService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -497,9 +490,4 @@ func TestUserController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,10 +10,9 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
@@ -101,15 +100,10 @@ func TestWellKnownController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
store := memory.New()
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||
err = oidcService.Init()
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, store)
|
||||
err := oidcService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -125,9 +119,4 @@ func TestWellKnownController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -171,7 +171,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
||||
}
|
||||
|
||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
|
||||
m.auth.DeleteSession(ctx, uuid) //nolint:errcheck
|
||||
m.auth.DeleteSession(ctx, uuid)
|
||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,24 +5,22 @@ import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
func TestContextMiddleware(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
@@ -52,7 +50,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
||||
seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) {
|
||||
t.Helper()
|
||||
_, err := queries.CreateSession(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
@@ -60,7 +58,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
type runArgs struct {
|
||||
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
@@ -272,22 +270,17 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
store := memory.New()
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
err = ldap.Init()
|
||||
err := ldap.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||
err = broker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
|
||||
err = authService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -317,12 +310,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
return captured, recorder
|
||||
}
|
||||
|
||||
test.run(t, runArgs{do: do, queries: queries})
|
||||
test.run(t, runArgs{do: do, queries: store})
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package model
|
||||
func NewDefaultConfiguration() *Config {
|
||||
return &Config{
|
||||
Database: DatabaseConfig{
|
||||
Driver: "sqlite",
|
||||
Path: "./tinyauth.db",
|
||||
},
|
||||
Analytics: AnalyticsConfig{
|
||||
@@ -82,7 +83,8 @@ type Config struct {
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `description:"The path to the database, including file name." yaml:"path"`
|
||||
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
|
||||
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
|
||||
}
|
||||
|
||||
type AnalyticsConfig struct {
|
||||
|
||||
@@ -0,0 +1,427 @@
|
||||
package memory_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func TestCreateAndGetSession(t *testing.T) {
|
||||
s := memory.New()
|
||||
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{
|
||||
UUID: "uuid-1",
|
||||
Username: "alice",
|
||||
Expiry: 9999,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "uuid-1", sess.UUID)
|
||||
assert.Equal(t, "alice", sess.Username)
|
||||
|
||||
got, err := s.GetSession(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, sess, got)
|
||||
}
|
||||
|
||||
func TestGetSession_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetSession(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestUpdateSession(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||
UUID: "uuid-1",
|
||||
Username: "bob",
|
||||
Email: "bob@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bob", updated.Username)
|
||||
assert.Equal(t, "bob@example.com", updated.Email)
|
||||
|
||||
got, err := s.GetSession(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updated, got)
|
||||
}
|
||||
|
||||
func TestUpdateSession_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"})
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteSession(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteSession(ctx, "uuid-1"))
|
||||
|
||||
_, err = s.GetSession(ctx, "uuid-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteExpiredSessions(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10})
|
||||
require.NoError(t, err)
|
||||
_, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteExpiredSessions(ctx, 50))
|
||||
|
||||
_, err = s.GetSession(ctx, "expired")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
|
||||
_, err = s.GetSession(ctx, "valid")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCreateAndGetOidcCode(t *testing.T) {
|
||||
s := memory.New()
|
||||
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
||||
Sub: "sub-1",
|
||||
CodeHash: "hash-1",
|
||||
Scope: "openid",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", code.Sub)
|
||||
|
||||
// destructive read removes the record
|
||||
got, err := s.GetOidcCode(ctx, "hash-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, code, got)
|
||||
|
||||
_, err = s.GetOidcCode(ctx, "hash-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetOidcCode_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcCode(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetOidcCodeBySub(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
|
||||
// destructive — gone after read
|
||||
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetOidcCodeBySub_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetOidcCodeUnsafe(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
|
||||
// non-destructive — still present
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetOidcCodeUnsafe_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetOidcCodeBySubUnsafe(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hash-1", got.CodeHash)
|
||||
|
||||
// non-destructive — still present
|
||||
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
||||
}
|
||||
|
||||
func TestDeleteOidcCode(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
||||
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteOidcCodeBySub(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
||||
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteExpiredOidcCodes(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
||||
require.NoError(t, err)
|
||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, deleted, 1)
|
||||
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
||||
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCreateAndGetOidcToken(t *testing.T) {
|
||||
s := memory.New()
|
||||
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-hash-1",
|
||||
CodeHash: "code-hash-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", tok.Sub)
|
||||
|
||||
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tok, got)
|
||||
}
|
||||
|
||||
func TestGetOidcToken_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcToken(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
||||
}
|
||||
|
||||
func TestGetOidcTokenByRefreshToken(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
RefreshTokenHash: "rt-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
}
|
||||
|
||||
func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestGetOidcTokenBySub(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "at-1", got.AccessTokenHash)
|
||||
}
|
||||
|
||||
func TestGetOidcTokenBySub_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestUpdateOidcTokenByRefreshToken(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
RefreshTokenHash: "rt-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||
RefreshTokenHash_2: "rt-1",
|
||||
AccessTokenHash: "at-2",
|
||||
RefreshTokenHash: "rt-2",
|
||||
TokenExpiresAt: 200,
|
||||
RefreshTokenExpiresAt: 400,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||
|
||||
// old key gone, new key present
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
|
||||
got, err := s.GetOidcToken(ctx, "at-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
}
|
||||
|
||||
func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||
RefreshTokenHash_2: "missing",
|
||||
})
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteOidcToken(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteOidcTokenBySub(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteOidcTokenByCodeHash(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
CodeHash: "code-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteExpiredOidcTokens(t *testing.T) {
|
||||
s := memory.New()
|
||||
// expired by TokenExpiresAt
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1", AccessTokenHash: "at-1",
|
||||
TokenExpiresAt: 10, RefreshTokenExpiresAt: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// expired by RefreshTokenExpiresAt
|
||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-2", AccessTokenHash: "at-2",
|
||||
TokenExpiresAt: 100, RefreshTokenExpiresAt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// valid
|
||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-3", AccessTokenHash: "at-3",
|
||||
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||
TokenExpiresAt: 50,
|
||||
RefreshTokenExpiresAt: 50,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, deleted, 2)
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-3")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCreateAndGetOidcUserInfo(t *testing.T) {
|
||||
s := memory.New()
|
||||
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
||||
Sub: "sub-1",
|
||||
Name: "Alice",
|
||||
Email: "alice@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", u.Sub)
|
||||
|
||||
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, u, got)
|
||||
}
|
||||
|
||||
func TestGetOidcUserInfo_NotFound(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.GetOidcUserInfo(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestDeleteOidcUserInfo(t *testing.T) {
|
||||
s := memory.New()
|
||||
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
||||
|
||||
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Enforce sub UNIQUE constraint
|
||||
for _, c := range s.oidcCodes {
|
||||
if c.Sub == arg.Sub {
|
||||
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
||||
}
|
||||
}
|
||||
code := repository.OidcCode(arg)
|
||||
s.oidcCodes[arg.CodeHash] = code
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
c, ok := s.oidcCodes[codeHash]
|
||||
if !ok {
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
delete(s.oidcCodes, codeHash)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
delete(s.oidcCodes, k)
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
c, ok := s.oidcCodes[codeHash]
|
||||
if !ok {
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcCodes, codeHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
delete(s.oidcCodes, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []repository.OidcCode
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.ExpiresAt < expiresAt {
|
||||
deleted = append(deleted, c)
|
||||
delete(s.oidcCodes, k)
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Enforce sub UNIQUE constraint
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.Sub == arg.Sub {
|
||||
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
||||
}
|
||||
}
|
||||
tok := repository.OidcToken{
|
||||
Sub: arg.Sub,
|
||||
AccessTokenHash: arg.AccessTokenHash,
|
||||
RefreshTokenHash: arg.RefreshTokenHash,
|
||||
CodeHash: arg.CodeHash,
|
||||
Scope: arg.Scope,
|
||||
ClientID: arg.ClientID,
|
||||
TokenExpiresAt: arg.TokenExpiresAt,
|
||||
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
||||
Nonce: arg.Nonce,
|
||||
}
|
||||
s.oidcTokens[arg.AccessTokenHash] = tok
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
t, ok := s.oidcTokens[accessTokenHash]
|
||||
if !ok {
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.RefreshTokenHash == refreshTokenHash {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.Sub == sub {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
||||
delete(s.oidcTokens, k)
|
||||
t.AccessTokenHash = arg.AccessTokenHash
|
||||
t.RefreshTokenHash = arg.RefreshTokenHash
|
||||
t.TokenExpiresAt = arg.TokenExpiresAt
|
||||
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||
s.oidcTokens[arg.AccessTokenHash] = t
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcTokens, accessTokenHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.Sub == sub {
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.CodeHash == codeHash {
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []repository.OidcToken
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||
deleted = append(deleted, t)
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u := repository.OidcUserinfo(arg)
|
||||
s.oidcUsers[arg.Sub] = u
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
u, ok := s.oidcUsers[sub]
|
||||
if !ok {
|
||||
return repository.OidcUserinfo{}, repository.ErrNotFound
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcUsers, sub)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess := repository.Session(arg)
|
||||
s.sessions[arg.UUID] = sess
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
sess, ok := s.sessions[uuid]
|
||||
if !ok {
|
||||
return repository.Session{}, repository.ErrNotFound
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess, ok := s.sessions[arg.UUID]
|
||||
if !ok {
|
||||
return repository.Session{}, repository.ErrNotFound
|
||||
}
|
||||
sess.Username = arg.Username
|
||||
sess.Email = arg.Email
|
||||
sess.Name = arg.Name
|
||||
sess.Provider = arg.Provider
|
||||
sess.TotpPending = arg.TotpPending
|
||||
sess.OAuthGroups = arg.OAuthGroups
|
||||
sess.Expiry = arg.Expiry
|
||||
sess.OAuthName = arg.OAuthName
|
||||
sess.OAuthSub = arg.OAuthSub
|
||||
s.sessions[arg.UUID] = sess
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, uuid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, v := range s.sessions {
|
||||
if v.Expiry < expiry {
|
||||
delete(s.sessions, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
||||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
// Store is a thread-safe in-memory implementation of repository.Store.
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]repository.Session
|
||||
oidcCodes map[string]repository.OidcCode
|
||||
oidcTokens map[string]repository.OidcToken
|
||||
oidcUsers map[string]repository.OidcUserinfo
|
||||
}
|
||||
|
||||
// New returns a new empty in-memory Store.
|
||||
func New() repository.Store {
|
||||
return &Store{
|
||||
sessions: make(map[string]repository.Session),
|
||||
oidcCodes: make(map[string]repository.OidcCode),
|
||||
oidcTokens: make(map[string]repository.OidcToken),
|
||||
oidcUsers: make(map[string]repository.OidcUserinfo),
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,22 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package repository
|
||||
|
||||
// Shared model and parameter types for all storage drivers.
|
||||
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
CreatedAt int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
}
|
||||
|
||||
type OidcCode struct {
|
||||
Sub string
|
||||
CodeHash string
|
||||
@@ -49,7 +62,7 @@ type OidcUserinfo struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
type CreateSessionParams struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
@@ -62,3 +75,74 @@ type Session struct {
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
}
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
UUID string
|
||||
}
|
||||
|
||||
type CreateOidcCodeParams struct {
|
||||
Sub string
|
||||
CodeHash string
|
||||
Scope string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ExpiresAt int64
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type CreateOidcTokenParams struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
RefreshTokenHash string
|
||||
Scope string
|
||||
ClientID string
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
CodeHash string
|
||||
Nonce string
|
||||
}
|
||||
|
||||
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||
AccessTokenHash string
|
||||
RefreshTokenHash string
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
RefreshTokenHash_2 string
|
||||
}
|
||||
|
||||
type DeleteExpiredOidcTokensParams struct {
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
}
|
||||
|
||||
type CreateOidcUserInfoParams struct {
|
||||
Sub string
|
||||
Name string
|
||||
PreferredUsername string
|
||||
Email string
|
||||
Groups string
|
||||
UpdatedAt int64
|
||||
GivenName string
|
||||
FamilyName string
|
||||
MiddleName string
|
||||
Nickname string
|
||||
Profile string
|
||||
Picture string
|
||||
Website string
|
||||
Gender string
|
||||
Birthdate string
|
||||
Zoneinfo string
|
||||
Locale string
|
||||
PhoneNumber string
|
||||
Address string
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
|
||||
package repository
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -0,0 +1,3 @@
|
||||
package sqlite
|
||||
|
||||
//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
|
||||
@@ -0,0 +1,64 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.31.1
|
||||
|
||||
package sqlite
|
||||
|
||||
type OidcCode struct {
|
||||
Sub string
|
||||
CodeHash string
|
||||
Scope string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ExpiresAt int64
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type OidcToken struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
RefreshTokenHash string
|
||||
CodeHash string
|
||||
Scope string
|
||||
ClientID string
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
Nonce string
|
||||
}
|
||||
|
||||
type OidcUserinfo struct {
|
||||
Sub string
|
||||
Name string
|
||||
PreferredUsername string
|
||||
Email string
|
||||
Groups string
|
||||
UpdatedAt int64
|
||||
GivenName string
|
||||
FamilyName string
|
||||
MiddleName string
|
||||
Nickname string
|
||||
Profile string
|
||||
Picture string
|
||||
Website string
|
||||
Gender string
|
||||
Birthdate string
|
||||
Zoneinfo string
|
||||
Locale string
|
||||
PhoneNumber string
|
||||
Address string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
CreatedAt int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
}
|
||||
+2
-2
@@ -1,9 +1,9 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: oidc_queries.sql
|
||||
|
||||
package repository
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
+2
-2
@@ -1,9 +1,9 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: session_queries.sql
|
||||
|
||||
package repository
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -0,0 +1,224 @@
|
||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
// Store wraps *Queries and implements repository.Store.
|
||||
type Store struct {
|
||||
q *Queries
|
||||
}
|
||||
|
||||
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||
func NewStore(q *Queries) repository.Store {
|
||||
return &Store{q: q}
|
||||
}
|
||||
|
||||
var errMap = []struct {
|
||||
from error
|
||||
to error
|
||||
}{
|
||||
{sql.ErrNoRows, repository.ErrNotFound},
|
||||
}
|
||||
|
||||
func mapErr(err error) error {
|
||||
for _, e := range errMap {
|
||||
if errors.Is(err, e.from) {
|
||||
return e.to
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
|
||||
return repository.OidcCode(v)
|
||||
}
|
||||
func oidcTokenToRepo(v OidcToken) repository.OidcToken {
|
||||
return repository.OidcToken(v)
|
||||
}
|
||||
func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo {
|
||||
return repository.OidcUserinfo(v)
|
||||
}
|
||||
func sessionToRepo(v Session) repository.Session {
|
||||
return repository.Session(v)
|
||||
}
|
||||
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcUserinfo{}, mapErr(err)
|
||||
}
|
||||
return oidcUserinfoToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
|
||||
if err != nil {
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return sessionToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]repository.OidcCode, len(rows))
|
||||
for i, row := range rows {
|
||||
out[i] = oidcCodeToRepo(row)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]repository.OidcToken, len(rows))
|
||||
for i, row := range rows {
|
||||
out[i] = oidcTokenToRepo(row)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcUserinfo{}, mapErr(err)
|
||||
}
|
||||
return oidcUserinfoToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||
r, err := s.q.GetSession(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return sessionToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
|
||||
if err != nil {
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return sessionToRepo(r), nil
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned by Store methods when the requested record does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// Store is the interface that all storage drivers must implement.
|
||||
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
|
||||
// Future drivers (postgres, etc.) must return the shared types defined in this package.
|
||||
type Store interface {
|
||||
// Sessions
|
||||
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
|
||||
GetSession(ctx context.Context, uuid string) (Session, error)
|
||||
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
|
||||
DeleteSession(ctx context.Context, uuid string) error
|
||||
DeleteExpiredSessions(ctx context.Context, expiry int64) error
|
||||
|
||||
// OIDC codes
|
||||
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
|
||||
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
|
||||
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
|
||||
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
|
||||
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
|
||||
DeleteOidcCode(ctx context.Context, codeHash string) error
|
||||
DeleteOidcCodeBySub(ctx context.Context, sub string) error
|
||||
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
|
||||
|
||||
// OIDC tokens
|
||||
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
|
||||
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
|
||||
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
|
||||
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
|
||||
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
|
||||
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
|
||||
DeleteOidcTokenBySub(ctx context.Context, sub string) error
|
||||
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
|
||||
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
|
||||
|
||||
// OIDC userinfo
|
||||
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
|
||||
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
|
||||
DeleteOidcUserInfo(ctx context.Context, sub string) error
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -96,14 +95,14 @@ type AuthService struct {
|
||||
loginMutex sync.RWMutex
|
||||
ldapGroupsMutex sync.RWMutex
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
oauthBroker *OAuthBrokerService
|
||||
lockdown *Lockdown
|
||||
lockdownCtx context.Context
|
||||
lockdownCancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries repository.Store, oauthBroker *OAuthBrokerService) *AuthService {
|
||||
return &AuthService{
|
||||
config: config,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
@@ -421,7 +420,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
return nil, err
|
||||
@@ -845,3 +844,10 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
||||
}
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if auth.config.SubdomainsEnabled {
|
||||
return "." + auth.config.CookieDomain
|
||||
}
|
||||
return auth.config.CookieDomain
|
||||
}
|
||||
|
||||
@@ -269,7 +269,7 @@ func (ldap *LdapService) reconnect() error {
|
||||
exp.Reset()
|
||||
|
||||
operation := func() (*ldapgo.Conn, error) {
|
||||
ldap.conn.Close() //nolint:errcheck
|
||||
ldap.conn.Close()
|
||||
conn, err := ldap.connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -92,7 +92,7 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close() //nolint:errcheck
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
@@ -121,7 +120,7 @@ type OIDCServiceConfig struct {
|
||||
|
||||
type OIDCService struct {
|
||||
config OIDCServiceConfig
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
clients map[string]model.OIDCClientConfig
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey crypto.PublicKey
|
||||
@@ -129,7 +128,7 @@ type OIDCService struct {
|
||||
isConfigured bool
|
||||
}
|
||||
|
||||
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
|
||||
func NewOIDCService(config OIDCServiceConfig, queries repository.Store) *OIDCService {
|
||||
return &OIDCService{
|
||||
config: config,
|
||||
queries: queries,
|
||||
@@ -422,7 +421,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
||||
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcCode{}, ErrCodeNotFound
|
||||
}
|
||||
return repository.OidcCode{}, err
|
||||
@@ -566,7 +565,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
||||
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return TokenResponse{}, ErrTokenNotFound
|
||||
}
|
||||
return TokenResponse{}, err
|
||||
@@ -645,7 +644,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
|
||||
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcToken{}, ErrTokenNotFound
|
||||
}
|
||||
return repository.OidcToken{}, err
|
||||
@@ -733,15 +732,15 @@ func (service *OIDCService) Hash(token string) string {
|
||||
|
||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -786,7 +785,7 @@ func (service *OIDCService) Cleanup() {
|
||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
continue
|
||||
}
|
||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestReadFile(t *testing.T) {
|
||||
|
||||
err = file.Close()
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_file") //nolint:errcheck
|
||||
defer os.Remove("/tmp/tinyauth_test_file")
|
||||
|
||||
// Normal case
|
||||
content, err := ReadFile("/tmp/tinyauth_test_file")
|
||||
|
||||
@@ -53,7 +53,7 @@ func FilterIP(filter string, ip string) (bool, error) {
|
||||
return false, errors.New("invalid IP address")
|
||||
}
|
||||
|
||||
filter = strings.ReplaceAll(filter, "-", "/")
|
||||
filter = strings.Replace(filter, "-", "/", -1)
|
||||
|
||||
if strings.Contains(filter, "/") {
|
||||
_, cidr, err := net.ParseCIDR(filter)
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestGetSecret(t *testing.T) {
|
||||
|
||||
err = file.Close()
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_secret") //nolint:errcheck
|
||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||
|
||||
// Get from config
|
||||
assert.Equal(t, "mysecret", utils.GetSecret("mysecret", ""))
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestGetStringList(t *testing.T) {
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_list_test_file") //nolint:errcheck
|
||||
defer os.Remove("/tmp/tinyauth_list_test_file")
|
||||
|
||||
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -24,7 +24,7 @@ func TestGetUsers(t *testing.T) {
|
||||
|
||||
err = file.Close()
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpDir + "/tinyauth_users_test.txt") //nolint:errcheck
|
||||
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
|
||||
|
||||
noAttrs := map[string]model.UserAttributes{}
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
version: "2"
|
||||
sql:
|
||||
- engine: "sqlite"
|
||||
queries: "sql/*_queries.sql"
|
||||
schema: "sql/*_schemas.sql"
|
||||
queries: "sql/sqlite/*_queries.sql"
|
||||
schema: "sql/sqlite/*_schemas.sql"
|
||||
gen:
|
||||
go:
|
||||
package: "repository"
|
||||
out: "internal/repository"
|
||||
package: "sqlite"
|
||||
out: "internal/repository/sqlite"
|
||||
rename:
|
||||
uuid: "UUID"
|
||||
oauth_groups: "OAuthGroups"
|
||||
|
||||
Reference in New Issue
Block a user