mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-19 02:30:14 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 03d6bea4e1 | |||
| a56c349525 | |||
| 8b4ba23328 |
@@ -28,6 +28,18 @@ jobs:
|
||||
- name: Go dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Setup sqlc
|
||||
uses: sqlc-dev/setup-sqlc@v4
|
||||
with:
|
||||
sqlc-version: "1.31.1"
|
||||
|
||||
- name: Check codegen is up to date
|
||||
run: |
|
||||
sqlc generate
|
||||
go generate ./internal/repository/...
|
||||
git diff --exit-code -- internal/repository/
|
||||
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
|
||||
|
||||
- name: Install frontend dependencies
|
||||
working-directory: ./frontend
|
||||
run: pnpm ci
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
name: Close stale issues and PRs
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 10 * * *
|
||||
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10
|
||||
with:
|
||||
days-before-stale: 30
|
||||
stale-pr-message: This PR has been inactive for 30 days and will be marked as stale.
|
||||
stale-issue-message: This issue has been inactive for 30 days and will be marked as stale.
|
||||
close-issue-message: Closed for inactivity.
|
||||
close-pr-message: Closed for inactivity.
|
||||
stale-issue-label: stale
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: pinned
|
||||
exempt-pr-labels: pinned
|
||||
@@ -85,3 +85,4 @@ sql:
|
||||
# Go gen
|
||||
generate:
|
||||
go run ./gen
|
||||
go generate ./internal/repository/...
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
services:
|
||||
traefik:
|
||||
image: traefik:v3.6
|
||||
command: --api.insecure=true --providers.docker
|
||||
command: --api.insecure=true --providers.docker --entrypoints.web.address=:80 --entrypoints.websecure.address=:443
|
||||
ports:
|
||||
- 80:80
|
||||
- 443:443
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
|
||||
@@ -25,6 +26,8 @@ services:
|
||||
labels:
|
||||
traefik.enable: true
|
||||
traefik.http.routers.tinyauth.rule: Host(`tinyauth.127.0.0.1.sslip.io`)
|
||||
traefik.http.routers.tinyauth.entrypoints: websecure
|
||||
traefik.http.routers.tinyauth.tls: true
|
||||
|
||||
tinyauth-backend:
|
||||
build:
|
||||
|
||||
@@ -0,0 +1,473 @@
|
||||
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
|
||||
// internal/repository/<driver>/. Run via:
|
||||
//
|
||||
// go generate ./internal/repository/...
|
||||
//
|
||||
// The generator introspects *Queries methods and the model/params types in the
|
||||
// driver package, then emits a store.go that wraps *Queries so it satisfies
|
||||
// repository.Store using the canonical shared types in the parent package.
|
||||
// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should
|
||||
// implement repository.Store directly by hand.
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"go/types"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
//go:embed store.tmpl
|
||||
var storeSrc string
|
||||
|
||||
func main() {
|
||||
fmt.Println("sqlc-wrapper: generating store.go files for sqlc driver packages...")
|
||||
if err := run(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
driverPkg := flag.String("pkg", "", "import path of the driver package")
|
||||
out := flag.String("out", "store.go", "output filename relative to driver package directory")
|
||||
flag.Parse()
|
||||
|
||||
if *driverPkg == "" {
|
||||
return fmt.Errorf("-pkg is required")
|
||||
}
|
||||
|
||||
// Resolve the driver package directory so we can overlay the output file
|
||||
// with a valid stub. This prevents a stale store.go from poisoning the
|
||||
// type-checker and producing cryptic "undefined" errors.
|
||||
driverDir, err := pkgDir(*driverPkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve driver dir: %w", err)
|
||||
}
|
||||
|
||||
outPath := filepath.Join(driverDir, *out)
|
||||
if filepath.IsAbs(*out) {
|
||||
outPath = *out
|
||||
}
|
||||
|
||||
// Stub replaces the output file during load so stale generated code is ignored.
|
||||
stub := []byte("package " + filepath.Base(driverDir) + "\n")
|
||||
cfg := &packages.Config{
|
||||
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
|
||||
Overlay: map[string][]byte{outPath: stub},
|
||||
}
|
||||
|
||||
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load driver package: %w", err)
|
||||
}
|
||||
|
||||
repoPkgPath := parentPkg(*driverPkg)
|
||||
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load repo package: %w", err)
|
||||
}
|
||||
|
||||
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
||||
return fmt.Errorf("struct shape mismatch: %w", err)
|
||||
}
|
||||
if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
methods, err := collectMethods(driverTypePkg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
src, err := render(tmplData{
|
||||
PkgName: driverTypePkg.Name(),
|
||||
RepoPkg: repoPkgPath,
|
||||
Methods: renderMethods(methods),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("render: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(outPath, src, 0644); err != nil {
|
||||
return fmt.Errorf("write %s: %w", outPath, err)
|
||||
}
|
||||
fmt.Printf("wrote %s\n", outPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
||||
// or an error if the package fails to load or has type errors.
|
||||
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
||||
pkgs, err := packages.Load(cfg, importPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
||||
}
|
||||
if len(pkgs) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
||||
}
|
||||
pkg := pkgs[0]
|
||||
if len(pkg.Errors) > 0 {
|
||||
msgs := make([]string, len(pkg.Errors))
|
||||
for i, e := range pkg.Errors {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
||||
}
|
||||
return pkg.Types, nil
|
||||
}
|
||||
|
||||
// parentPkg returns the parent import path (everything before the last /).
|
||||
// Panics if imp contains no slash — callers are expected to pass driver sub-packages.
|
||||
func parentPkg(imp string) string {
|
||||
i := strings.LastIndex(imp, "/")
|
||||
if i < 0 {
|
||||
panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp))
|
||||
}
|
||||
return imp[:i]
|
||||
}
|
||||
|
||||
// pkgDir returns the on-disk directory for an import path using `go list`.
|
||||
func pkgDir(importPath string) (string, error) {
|
||||
out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("go list %s: %w", importPath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(out)), nil
|
||||
}
|
||||
|
||||
// scopeStructs returns all named struct types in pkg, excluding the internal
|
||||
// sqlc types Queries, DBTX, and Store. Names are returned in sorted order.
|
||||
func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) {
|
||||
byName = make(map[string]*types.Struct)
|
||||
for _, name := range pkg.Scope().Names() { // Names() is already sorted
|
||||
switch name {
|
||||
case "Queries", "DBTX", "Store":
|
||||
continue
|
||||
}
|
||||
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
named, ok := obj.Type().(*types.Named)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s, ok := named.Underlying().(*types.Struct)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
byName[name] = s
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// validateStoreCoverage checks that every method declared in repository.Store
|
||||
// exists on *Queries in the driver package. Missing methods are reported by
|
||||
// name so the developer knows exactly which SQL queries need to be added.
|
||||
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
|
||||
queriesObj := driverPkg.Scope().Lookup("Queries")
|
||||
if queriesObj == nil {
|
||||
return fmt.Errorf("queries type not found in driver package")
|
||||
}
|
||||
queriesNamed := queriesObj.Type().(*types.Named)
|
||||
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
|
||||
queriesMethods := make(map[string]bool)
|
||||
for m := range queriesMS.Methods() {
|
||||
queriesMethods[m.Obj().Name()] = true
|
||||
}
|
||||
|
||||
storeObj := repoPkg.Scope().Lookup("Store")
|
||||
if storeObj == nil {
|
||||
return fmt.Errorf("store type not found in repository package")
|
||||
}
|
||||
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
|
||||
if !ok {
|
||||
return fmt.Errorf("repository.Store is not an interface")
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for method := range storeIface.Methods() {
|
||||
if name := method.Name(); !queriesMethods[name] {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
sort.Strings(missing)
|
||||
return fmt.Errorf(
|
||||
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries",
|
||||
len(missing), strings.Join(missing, "\n - "),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStructShapes checks that every model/params struct in the driver
|
||||
// package has fields that exactly match the corresponding type in the repo
|
||||
// (parent) package. This catches drift between sqlc-generated types and the
|
||||
// canonical repository types before a broken cast reaches the compiler.
|
||||
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
|
||||
_, repoStructs := scopeStructs(repoPkg)
|
||||
driverNames, driverStructs := scopeStructs(driverPkg)
|
||||
|
||||
var errs []string
|
||||
for _, name := range driverNames {
|
||||
repoStruct, ok := repoStructs[name]
|
||||
if !ok {
|
||||
// Driver has a type not in repo — fine (e.g. internal helpers).
|
||||
continue
|
||||
}
|
||||
if err := compareStructs(name, driverStructs[name], repoStruct); err != nil {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
}
|
||||
if len(errs) > 0 {
|
||||
sort.Strings(errs)
|
||||
return fmt.Errorf("%s", strings.Join(errs, "\n "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareStructs(name string, driver, repo *types.Struct) error {
|
||||
if driver.NumFields() != repo.NumFields() {
|
||||
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
|
||||
name, driver.NumFields(), repo.NumFields())
|
||||
}
|
||||
for i := range driver.NumFields() {
|
||||
df := driver.Field(i)
|
||||
rf := repo.Field(i)
|
||||
if df.Name() != rf.Name() {
|
||||
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
|
||||
name, i, df.Name(), rf.Name())
|
||||
}
|
||||
if !types.Identical(df.Type(), rf.Type()) {
|
||||
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
|
||||
name, df.Name(), df.Type(), rf.Type())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type methodInfo struct {
|
||||
Name string
|
||||
Params []paramInfo
|
||||
Results []resultInfo
|
||||
}
|
||||
|
||||
type paramInfo struct {
|
||||
Name string
|
||||
TypeStr string // local (unqualified) type name
|
||||
RepoType string // "repository.X" if this is a driver model/params type; else ""
|
||||
}
|
||||
|
||||
type resultInfo struct {
|
||||
TypeStr string
|
||||
IsSlice bool
|
||||
RepoType string // "repository.X" if driver type; else ""
|
||||
}
|
||||
|
||||
func collectMethods(pkg *types.Package) ([]methodInfo, error) {
|
||||
obj := pkg.Scope().Lookup("Queries")
|
||||
if obj == nil {
|
||||
return nil, fmt.Errorf("queries type not found in %s", pkg.Path())
|
||||
}
|
||||
named, ok := obj.Type().(*types.Named)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("queries is not a named type")
|
||||
}
|
||||
ms := types.NewMethodSet(types.NewPointer(named))
|
||||
|
||||
var out []methodInfo
|
||||
for method := range ms.Methods() {
|
||||
fn, ok := method.Obj().(*types.Func)
|
||||
if !ok || fn.Name() == "WithTx" {
|
||||
continue
|
||||
}
|
||||
sig := fn.Type().(*types.Signature)
|
||||
mi := methodInfo{Name: fn.Name()}
|
||||
|
||||
// params: skip receiver + first (context.Context)
|
||||
for i := 1; i < sig.Params().Len(); i++ {
|
||||
p := sig.Params().At(i)
|
||||
mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path()))
|
||||
}
|
||||
// results: skip error
|
||||
for r := range sig.Results().Variables() {
|
||||
if r.Type().String() == "error" {
|
||||
continue
|
||||
}
|
||||
mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path()))
|
||||
}
|
||||
out = append(out, mi)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func makeParam(name string, t types.Type, driverPath string) paramInfo {
|
||||
return paramInfo{
|
||||
Name: name,
|
||||
TypeStr: localName(t, driverPath),
|
||||
RepoType: repoName(t, driverPath),
|
||||
}
|
||||
}
|
||||
|
||||
func makeResult(t types.Type, driverPath string) resultInfo {
|
||||
ri := resultInfo{}
|
||||
if sl, ok := t.(*types.Slice); ok {
|
||||
ri.IsSlice = true
|
||||
t = sl.Elem()
|
||||
}
|
||||
ri.TypeStr = localName(t, driverPath)
|
||||
ri.RepoType = repoName(t, driverPath)
|
||||
return ri
|
||||
}
|
||||
|
||||
func localName(t types.Type, driverPath string) string {
|
||||
named, ok := t.(*types.Named)
|
||||
if !ok {
|
||||
return types.TypeString(t, nil)
|
||||
}
|
||||
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
|
||||
return named.Obj().Name()
|
||||
}
|
||||
return types.TypeString(t, func(p *types.Package) string { return p.Name() })
|
||||
}
|
||||
|
||||
func repoName(t types.Type, driverPath string) string {
|
||||
named, ok := t.(*types.Named)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
|
||||
return "repository." + named.Obj().Name()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// renderedMethod holds pre-built signature and body strings passed to the template.
|
||||
type renderedMethod struct {
|
||||
Signature string
|
||||
Body string
|
||||
}
|
||||
|
||||
func renderMethods(methods []methodInfo) []renderedMethod {
|
||||
out := make([]renderedMethod, len(methods))
|
||||
for i, m := range methods {
|
||||
out[i] = renderedMethod{
|
||||
Signature: buildSig(m),
|
||||
Body: buildBody(m),
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildSig(m methodInfo) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("func (s *Store) ")
|
||||
sb.WriteString(m.Name)
|
||||
sb.WriteString("(ctx context.Context")
|
||||
for _, p := range m.Params {
|
||||
sb.WriteString(", ")
|
||||
sb.WriteString(p.Name)
|
||||
sb.WriteString(" ")
|
||||
if p.RepoType != "" {
|
||||
sb.WriteString(p.RepoType)
|
||||
} else {
|
||||
sb.WriteString(p.TypeStr)
|
||||
}
|
||||
}
|
||||
sb.WriteString(") (")
|
||||
for _, r := range m.Results {
|
||||
if r.IsSlice {
|
||||
sb.WriteString("[]")
|
||||
}
|
||||
if r.RepoType != "" {
|
||||
sb.WriteString(r.RepoType)
|
||||
} else {
|
||||
sb.WriteString(r.TypeStr)
|
||||
}
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString("error)")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func callArgs(m methodInfo) string {
|
||||
args := make([]string, 0, len(m.Params))
|
||||
for _, p := range m.Params {
|
||||
if p.RepoType != "" {
|
||||
// convert repo type → driver type: DriverType(arg)
|
||||
args = append(args, p.TypeStr+"("+p.Name+")")
|
||||
} else {
|
||||
args = append(args, p.Name)
|
||||
}
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return "ctx"
|
||||
}
|
||||
return "ctx, " + strings.Join(args, ", ")
|
||||
}
|
||||
|
||||
var bodyTmpl = template.Must(template.New("store").Parse(storeSrc))
|
||||
|
||||
type bodyData struct {
|
||||
Call string
|
||||
RepoType string
|
||||
}
|
||||
|
||||
func buildBody(m methodInfo) string {
|
||||
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
|
||||
|
||||
var (
|
||||
name string
|
||||
data bodyData
|
||||
)
|
||||
|
||||
switch {
|
||||
case len(m.Results) == 0 || m.Results[0].RepoType == "":
|
||||
name = "void"
|
||||
data = bodyData{Call: call}
|
||||
case m.Results[0].IsSlice:
|
||||
name = "slice"
|
||||
data = bodyData{Call: call, RepoType: m.Results[0].RepoType}
|
||||
default:
|
||||
name = "scalar"
|
||||
data = bodyData{Call: call, RepoType: m.Results[0].RepoType}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := bodyTmpl.ExecuteTemplate(&buf, name, data); err != nil {
|
||||
panic(fmt.Sprintf("buildBody %s: %v", name, err))
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
type tmplData struct {
|
||||
PkgName string
|
||||
RepoPkg string
|
||||
Methods []renderedMethod
|
||||
}
|
||||
|
||||
func render(data tmplData) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := bodyTmpl.Execute(&buf, data); err != nil {
|
||||
return nil, fmt.Errorf("execute template: %w", err)
|
||||
}
|
||||
|
||||
formatted, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String())
|
||||
}
|
||||
return formatted, nil
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||
package {{.PkgName}}
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"{{.RepoPkg}}"
|
||||
)
|
||||
|
||||
// Store wraps *Queries and implements repository.Store.
|
||||
type Store struct {
|
||||
q *Queries
|
||||
}
|
||||
|
||||
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||
func NewStore(q *Queries) repository.Store {
|
||||
return &Store{q: q}
|
||||
}
|
||||
|
||||
var errorMap = map[error]error{
|
||||
sql.ErrNoRows: repository.ErrNotFound,
|
||||
}
|
||||
|
||||
func mapErr(err error) error {
|
||||
for from, to := range errorMap {
|
||||
if errors.Is(err, from) {
|
||||
return to
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
{{range .Methods}}{{.Signature}} {
|
||||
{{.Body}}}
|
||||
|
||||
{{end}}
|
||||
|
||||
{{- define "void"}} return mapErr({{.Call}})
|
||||
{{end}}
|
||||
|
||||
{{- define "scalar"}} r, err := {{.Call}}
|
||||
if err != nil {
|
||||
return {{.RepoType}}{}, mapErr(err)
|
||||
}
|
||||
return {{.RepoType}}(r), nil
|
||||
{{end}}
|
||||
|
||||
{{- define "slice"}} rows, err := {{.Call}}
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]{{.RepoType}}, len(rows))
|
||||
for i, row := range rows {
|
||||
out[i] = {{.RepoType}}(row)
|
||||
}
|
||||
return out, nil
|
||||
{{end}}
|
||||
@@ -18,11 +18,12 @@ require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||
github.com/weppos/publicsuffix-go v0.50.3
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/crypto v0.51.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
k8s.io/apimachinery v0.36.0
|
||||
k8s.io/client-go v0.36.0
|
||||
modernc.org/sqlite v1.50.0
|
||||
golang.org/x/tools v0.44.0
|
||||
k8s.io/apimachinery v0.36.1
|
||||
k8s.io/client-go v0.36.1
|
||||
modernc.org/sqlite v1.50.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -121,11 +122,12 @@ require (
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/mod v0.35.0 // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/term v0.42.0 // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
golang.org/x/sys v0.44.0 // indirect
|
||||
golang.org/x/term v0.43.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
@@ -134,7 +136,7 @@ require (
|
||||
k8s.io/klog/v2 v2.140.0 // indirect
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
|
||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
|
||||
modernc.org/libc v1.72.0 // indirect
|
||||
modernc.org/libc v1.72.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
|
||||
@@ -317,29 +317,29 @@ go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
|
||||
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
|
||||
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
|
||||
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||
@@ -361,22 +361,22 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
k8s.io/api v0.36.0 h1:SgqDhZzHdOtMk40xVSvCXkP9ME0H05hPM3p9AB1kL80=
|
||||
k8s.io/api v0.36.0/go.mod h1:m1LVrGPNYax5NBHdO+QuAedXyuzTt4RryI/qnmNvs34=
|
||||
k8s.io/apimachinery v0.36.0 h1:jZyPzhd5Z+3h9vJLt0z9XdzW9VzNzWAUw+P1xZ9PXtQ=
|
||||
k8s.io/apimachinery v0.36.0/go.mod h1:FklypaRJt6n5wUIwWXIP6GJlIpUizTgfo1T/As+Tyxc=
|
||||
k8s.io/client-go v0.36.0 h1:pOYi7C4RHChYjMiHpZSpSbIM6ZxVbRXBy7CuiIwqA3c=
|
||||
k8s.io/client-go v0.36.0/go.mod h1:ZKKcpwF0aLYfkHFCjillCKaTK/yBkEDHTDXCFY6AS9Y=
|
||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
|
||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
|
||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
||||
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
|
||||
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
|
||||
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
|
||||
modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0=
|
||||
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
|
||||
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
||||
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
|
||||
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
|
||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
@@ -385,18 +385,18 @@ modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c=
|
||||
modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ=
|
||||
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
|
||||
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM=
|
||||
modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
|
||||
modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w=
|
||||
modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
|
||||
|
||||
// Migrations
|
||||
//
|
||||
//go:embed migrations/*.sql
|
||||
//go:embed migrations/sqlite/*.sql
|
||||
var Migrations embed.FS
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
@@ -34,7 +35,6 @@ type Services struct {
|
||||
ldapService *service.LdapService
|
||||
oauthBrokerService *service.OAuthBrokerService
|
||||
oidcService *service.OIDCService
|
||||
policyEngine *service.PolicyEngine
|
||||
}
|
||||
|
||||
type BootstrapApp struct {
|
||||
@@ -44,7 +44,7 @@ type BootstrapApp struct {
|
||||
log *logger.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
router *gin.Engine
|
||||
db *sql.DB
|
||||
wg sync.WaitGroup
|
||||
@@ -163,7 +163,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||
|
||||
// database
|
||||
err = app.SetupDatabase()
|
||||
store, err := app.SetupStore()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup database: %w", err)
|
||||
@@ -174,12 +174,13 @@ func (app *BootstrapApp) Setup() error {
|
||||
defer func() {
|
||||
app.cancel()
|
||||
app.wg.Wait()
|
||||
app.db.Close()
|
||||
if app.db != nil {
|
||||
app.db.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// queries
|
||||
queries := repository.New(app.db)
|
||||
app.queries = queries
|
||||
// store
|
||||
app.queries = store
|
||||
|
||||
// services
|
||||
err = app.setupServices()
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
@@ -14,17 +17,28 @@ import (
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) SetupDatabase() error {
|
||||
dir := filepath.Dir(app.config.Database.Path)
|
||||
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||
switch app.config.Database.Driver {
|
||||
case "memory":
|
||||
return memory.New(), nil
|
||||
case "sqlite", "":
|
||||
return app.setupSQLite(app.config.Database.Path)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
|
||||
dir := filepath.Dir(databasePath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", app.config.Database.Path)
|
||||
db, err := sql.Open("sqlite", databasePath)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Close the database if there is an error during migration
|
||||
@@ -38,32 +52,29 @@ func (app *BootstrapApp) SetupDatabase() error {
|
||||
// if the sqlite connection starts being a bottleneck
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
migrations, err := iofs.New(assets.Migrations, "migrations")
|
||||
migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrations: %w", err)
|
||||
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||
}
|
||||
|
||||
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
||||
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
||||
}
|
||||
|
||||
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
|
||||
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to migrate database: %w", err)
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
app.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) GetDB() *sql.DB {
|
||||
return app.db
|
||||
return sqlite.NewStore(sqlite.New(db)), nil
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func (app *BootstrapApp) setupRouter() error {
|
||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||
controller.NewHealthController(apiRouter)
|
||||
|
||||
@@ -16,21 +16,38 @@ func (app *BootstrapApp) setupServices() error {
|
||||
|
||||
app.services.ldapService = ldapService
|
||||
|
||||
labelProvider, err := app.getLabelProvider()
|
||||
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
||||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize label provider: %w", err)
|
||||
var labelProvider service.LabelProvider
|
||||
|
||||
if useKubernetes {
|
||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||
|
||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||
}
|
||||
|
||||
app.services.kubernetesService = kubernetesService
|
||||
labelProvider = kubernetesService
|
||||
} else {
|
||||
app.log.App.Debug().Msg("Using Docker label provider")
|
||||
|
||||
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize docker service: %w", err)
|
||||
}
|
||||
|
||||
app.services.dockerService = dockerService
|
||||
labelProvider = dockerService
|
||||
}
|
||||
|
||||
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
|
||||
accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
|
||||
app.services.accessControlService = accessControlsService
|
||||
|
||||
err = app.setupPolicyEngine()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||
}
|
||||
|
||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||
app.services.oauthBrokerService = oauthBrokerService
|
||||
|
||||
@@ -47,79 +64,3 @@ func (app *BootstrapApp) setupServices() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
||||
switch app.config.LabelProvider {
|
||||
case "none", "docker", "kubernetes", "auto":
|
||||
if app.config.LabelProvider == "none" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
||||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
||||
|
||||
if useKubernetes {
|
||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||
|
||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||
}
|
||||
|
||||
app.services.kubernetesService = kubernetesService
|
||||
return kubernetesService, nil
|
||||
}
|
||||
|
||||
app.log.App.Debug().Msg("Using Docker label provider")
|
||||
|
||||
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
||||
}
|
||||
|
||||
if dockerService == nil {
|
||||
if app.config.LabelProvider == "docker" {
|
||||
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
app.services.dockerService = dockerService
|
||||
return dockerService, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) setupPolicyEngine() error {
|
||||
policyEngine, err := service.NewPolicyEngine(app.config, app.log)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||
}
|
||||
|
||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||
Log: app.log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||
Log: app.log,
|
||||
Config: app.config,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||
Log: app.log,
|
||||
})
|
||||
|
||||
app.services.policyEngine = policyEngine
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,10 +15,9 @@ import (
|
||||
"github.com/google/go-querystring/query"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -839,16 +838,11 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(cfg)
|
||||
|
||||
err := app.SetupDatabase()
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(app.GetDB())
|
||||
store := memory.New()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -869,8 +863,4 @@ func TestOIDCController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
app.GetDB().Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
@@ -52,11 +51,10 @@ type ProxyContext struct {
|
||||
}
|
||||
|
||||
type ProxyController struct {
|
||||
log *logger.Logger
|
||||
runtime model.RuntimeConfig
|
||||
acls *service.AccessControlsService
|
||||
auth *service.AuthService
|
||||
policyEngine *service.PolicyEngine
|
||||
log *logger.Logger
|
||||
runtime model.RuntimeConfig
|
||||
acls *service.AccessControlsService
|
||||
auth *service.AuthService
|
||||
}
|
||||
|
||||
func NewProxyController(
|
||||
@@ -65,14 +63,12 @@ func NewProxyController(
|
||||
router *gin.RouterGroup,
|
||||
acls *service.AccessControlsService,
|
||||
auth *service.AuthService,
|
||||
policyEngine *service.PolicyEngine,
|
||||
) *ProxyController {
|
||||
controller := &ProxyController{
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
acls: acls,
|
||||
auth: auth,
|
||||
policyEngine: policyEngine,
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
acls: acls,
|
||||
auth: auth,
|
||||
}
|
||||
|
||||
proxyGroup := router.Group("/auth")
|
||||
@@ -105,13 +101,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
aclsCtx := &service.ACLContext{
|
||||
ACLs: acls,
|
||||
IP: net.ParseIP(clientIP),
|
||||
Path: proxyCtx.Path,
|
||||
}
|
||||
|
||||
if controller.policyEngine.Evaluate(service.RuleIPBypassed, aclsCtx) {
|
||||
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||
controller.setHeaders(c, acls)
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
@@ -120,7 +110,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if controller.policyEngine.Evaluate(service.RuleAuthEnabled, aclsCtx) {
|
||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
|
||||
controller.handleError(c, proxyCtx)
|
||||
return
|
||||
}
|
||||
|
||||
if !authEnabled {
|
||||
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
|
||||
controller.setHeaders(c, acls)
|
||||
c.JSON(200, gin.H{
|
||||
@@ -130,7 +128,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !controller.policyEngine.Evaluate(service.RuleIPAllowed, aclsCtx) {
|
||||
if !controller.auth.CheckIP(clientIP, acls) {
|
||||
queries, err := query.Values(UnauthorizedQuery{
|
||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||
IP: clientIP,
|
||||
@@ -166,10 +164,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
aclsCtx.UserContext = userContext
|
||||
|
||||
if userContext.Authenticated {
|
||||
if !controller.policyEngine.Evaluate(service.RuleUserAllowed, aclsCtx) {
|
||||
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
||||
|
||||
if !userAllowed {
|
||||
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
|
||||
|
||||
queries, err := query.Values(UnauthorizedQuery{
|
||||
@@ -207,9 +205,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
var groupOK bool
|
||||
|
||||
if userContext.IsOAuth() {
|
||||
groupOK = controller.policyEngine.Evaluate(service.RuleOAuthGroup, aclsCtx)
|
||||
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
|
||||
} else {
|
||||
groupOK = controller.policyEngine.Evaluate(service.RuleLDAPGroup, aclsCtx)
|
||||
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
|
||||
}
|
||||
|
||||
if !groupOK {
|
||||
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -24,6 +22,33 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
acls := map[string]model.App{
|
||||
"app_path_allow": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "path-allow.example.com",
|
||||
},
|
||||
Path: model.AppPath{
|
||||
Allow: "/allowed",
|
||||
},
|
||||
},
|
||||
"app_user_allow": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "user-allow.example.com",
|
||||
},
|
||||
Users: model.AppUsers{
|
||||
Allow: "testuser",
|
||||
},
|
||||
},
|
||||
"ip_bypass": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "ip-bypass.example.com",
|
||||
},
|
||||
IP: model.AppIP{
|
||||
Bypass: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const browserUserAgent = `
|
||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||
|
||||
@@ -352,41 +377,14 @@ func TestProxyController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(cfg)
|
||||
|
||||
err := app.SetupDatabase()
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(app.GetDB())
|
||||
store := memory.New()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
ctx := context.TODO()
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||
Log: log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||
Log: log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||
Log: log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||
Log: log,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||
Log: log,
|
||||
Config: cfg,
|
||||
})
|
||||
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||
Log: log,
|
||||
})
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
|
||||
aclsService := service.NewAccessControlsService(log, nil, acls)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
@@ -401,13 +399,9 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
|
||||
controller.NewProxyController(log, runtime, group, aclsService, authService)
|
||||
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
app.GetDB().Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -73,12 +73,7 @@ func TestUserController(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(cfg)
|
||||
|
||||
err := app.SetupDatabase()
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(app.GetDB())
|
||||
store := memory.New()
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
@@ -254,7 +249,7 @@ func TestUserController(t *testing.T) {
|
||||
totpCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
UUID: "test-totp-login-uuid",
|
||||
Username: "test",
|
||||
Email: "test@example.com",
|
||||
@@ -378,7 +373,7 @@ func TestUserController(t *testing.T) {
|
||||
totpAttrCtx,
|
||||
},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
UUID: "test-totp-login-attributes-uuid",
|
||||
Username: "test",
|
||||
Email: "test@example.com",
|
||||
@@ -420,7 +415,7 @@ func TestUserController(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
|
||||
|
||||
beforeEach := func() {
|
||||
// Clear failed login attempts before each test
|
||||
@@ -446,8 +441,4 @@ func TestUserController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
app.GetDB().Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,9 +11,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -92,14 +91,9 @@ func TestWellKnownController(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(cfg)
|
||||
store := memory.New()
|
||||
|
||||
err := app.SetupDatabase()
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(app.GetDB())
|
||||
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -114,8 +108,4 @@ func TestWellKnownController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
app.GetDB().Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
@@ -31,7 +31,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
||||
seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) {
|
||||
t.Helper()
|
||||
_, err := queries.CreateSession(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
@@ -39,7 +39,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
type runArgs struct {
|
||||
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
@@ -252,15 +252,10 @@ func TestContextMiddleware(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(cfg)
|
||||
|
||||
err := app.SetupDatabase()
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(app.GetDB())
|
||||
store := memory.New()
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
|
||||
|
||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
|
||||
|
||||
@@ -286,11 +281,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
return captured, recorder
|
||||
}
|
||||
|
||||
test.run(t, runArgs{do: do, queries: queries})
|
||||
test.run(t, runArgs{do: do, queries: store})
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
app.GetDB().Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ package model
|
||||
func NewDefaultConfiguration() *Config {
|
||||
return &Config{
|
||||
Database: DatabaseConfig{
|
||||
Path: "./tinyauth.db",
|
||||
Driver: "sqlite",
|
||||
Path: "./tinyauth.db",
|
||||
},
|
||||
Analytics: AnalyticsConfig{
|
||||
Enabled: true,
|
||||
@@ -24,9 +25,6 @@ func NewDefaultConfiguration() *Config {
|
||||
SessionMaxLifetime: 0, // disabled
|
||||
LoginTimeout: 300, // 5 minutes
|
||||
LoginMaxRetries: 3,
|
||||
ACLs: ACLsConfig{
|
||||
Policy: "allow",
|
||||
},
|
||||
},
|
||||
UI: UIConfig{
|
||||
Title: "Tinyauth",
|
||||
@@ -81,12 +79,13 @@ type Config struct {
|
||||
UI UIConfig `description:"UI customization." yaml:"ui"`
|
||||
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"`
|
||||
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
||||
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment." yaml:"labelProvider"`
|
||||
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
|
||||
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `description:"The path to the database, including file name." yaml:"path"`
|
||||
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
|
||||
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
|
||||
}
|
||||
|
||||
type AnalyticsConfig struct {
|
||||
@@ -117,7 +116,6 @@ type AuthConfig struct {
|
||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||
}
|
||||
|
||||
type UserAttributes struct {
|
||||
@@ -227,10 +225,6 @@ type OIDCClientConfig struct {
|
||||
Name string `description:"Client name in UI." yaml:"name"`
|
||||
}
|
||||
|
||||
type ACLsConfig struct {
|
||||
Policy string `description:"ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow." yaml:"policy"`
|
||||
}
|
||||
|
||||
// ACLs
|
||||
|
||||
type Apps struct {
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
package memory_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func TestMemoryStore(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(t *testing.T, s repository.Store)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Create and get session",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{
|
||||
UUID: "uuid-1",
|
||||
Username: "alice",
|
||||
Expiry: 9999,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "uuid-1", sess.UUID)
|
||||
assert.Equal(t, "alice", sess.Username)
|
||||
|
||||
got, err := s.GetSession(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, sess, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get session not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetSession(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update session",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||
UUID: "uuid-1",
|
||||
Username: "bob",
|
||||
Email: "bob@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bob", updated.Username)
|
||||
assert.Equal(t, "bob@example.com", updated.Email)
|
||||
|
||||
got, err := s.GetSession(ctx, "uuid-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, updated, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update session not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"})
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete session",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteSession(ctx, "uuid-1"))
|
||||
|
||||
_, err = s.GetSession(ctx, "uuid-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete expired sessions",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10})
|
||||
require.NoError(t, err)
|
||||
_, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteExpiredSessions(ctx, 50))
|
||||
|
||||
_, err = s.GetSession(ctx, "expired")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
|
||||
_, err = s.GetSession(ctx, "valid")
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create and get OIDC code",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
||||
Sub: "sub-1",
|
||||
CodeHash: "hash-1",
|
||||
Scope: "openid",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", code.Sub)
|
||||
|
||||
// destructive read removes the record
|
||||
got, err := s.GetOidcCode(ctx, "hash-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, code, got)
|
||||
|
||||
_, err = s.GetOidcCode(ctx, "hash-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcCode(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code by sub",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
|
||||
// destructive — gone after read
|
||||
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code by sub not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code unsafe",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
|
||||
// non-destructive — still present
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code unsafe not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code by sub unsafe",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hash-1", got.CodeHash)
|
||||
|
||||
// non-destructive — still present
|
||||
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC code by sub unsafe not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create OIDC code unique sub constraint",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC code",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
||||
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC code by sub",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
||||
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete expired OIDC codes",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
||||
require.NoError(t, err)
|
||||
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, deleted, 1)
|
||||
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
||||
|
||||
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create and get OIDC token",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-hash-1",
|
||||
CodeHash: "code-hash-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", tok.Sub)
|
||||
|
||||
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tok, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC token not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcToken(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create OIDC token unique sub constraint",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
||||
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC token by refresh token",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
RefreshTokenHash: "rt-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC token by refresh token not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC token by sub",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "at-1", got.AccessTokenHash)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC token by sub not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update OIDC token by refresh token",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
RefreshTokenHash: "rt-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||
RefreshTokenHash_2: "rt-1",
|
||||
AccessTokenHash: "at-2",
|
||||
RefreshTokenHash: "rt-2",
|
||||
TokenExpiresAt: 200,
|
||||
RefreshTokenExpiresAt: 400,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||
|
||||
// old key gone, new key present
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
|
||||
got, err := s.GetOidcToken(ctx, "at-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", got.Sub)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Update OIDC token by refresh token not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||
RefreshTokenHash_2: "missing",
|
||||
})
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC token",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC token by sub",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC token by code hash",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1",
|
||||
AccessTokenHash: "at-1",
|
||||
CodeHash: "code-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete expired OIDC tokens",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
// both expiries past
|
||||
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-1", AccessTokenHash: "at-1",
|
||||
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// valid
|
||||
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||
Sub: "sub-3", AccessTokenHash: "at-3",
|
||||
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||
TokenExpiresAt: 50,
|
||||
RefreshTokenExpiresAt: 50,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, deleted, 1)
|
||||
|
||||
_, err = s.GetOidcToken(ctx, "at-3")
|
||||
assert.NoError(t, err)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Create and get OIDC user info",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
||||
Sub: "sub-1",
|
||||
Name: "Alice",
|
||||
Email: "alice@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sub-1", u.Sub)
|
||||
|
||||
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, u, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Get OIDC user info not found",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.GetOidcUserInfo(ctx, "missing")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Delete OIDC user info",
|
||||
run: func(t *testing.T, s repository.Store) {
|
||||
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
||||
|
||||
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
||||
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
s := memory.New()
|
||||
test.run(t, s)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Enforce sub UNIQUE constraint
|
||||
for _, c := range s.oidcCodes {
|
||||
if c.Sub == arg.Sub {
|
||||
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
||||
}
|
||||
}
|
||||
code := repository.OidcCode(arg)
|
||||
s.oidcCodes[arg.CodeHash] = code
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
c, ok := s.oidcCodes[codeHash]
|
||||
if !ok {
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
delete(s.oidcCodes, codeHash)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
delete(s.oidcCodes, k)
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
c, ok := s.oidcCodes[codeHash]
|
||||
if !ok {
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcCodes, codeHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
delete(s.oidcCodes, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []repository.OidcCode
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.ExpiresAt < expiresAt {
|
||||
deleted = append(deleted, c)
|
||||
delete(s.oidcCodes, k)
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Enforce sub UNIQUE constraint
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.Sub == arg.Sub {
|
||||
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
||||
}
|
||||
}
|
||||
tok := repository.OidcToken{
|
||||
Sub: arg.Sub,
|
||||
AccessTokenHash: arg.AccessTokenHash,
|
||||
RefreshTokenHash: arg.RefreshTokenHash,
|
||||
CodeHash: arg.CodeHash,
|
||||
Scope: arg.Scope,
|
||||
ClientID: arg.ClientID,
|
||||
TokenExpiresAt: arg.TokenExpiresAt,
|
||||
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
||||
Nonce: arg.Nonce,
|
||||
}
|
||||
s.oidcTokens[arg.AccessTokenHash] = tok
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
t, ok := s.oidcTokens[accessTokenHash]
|
||||
if !ok {
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.RefreshTokenHash == refreshTokenHash {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.Sub == sub {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
||||
delete(s.oidcTokens, k)
|
||||
t.AccessTokenHash = arg.AccessTokenHash
|
||||
t.RefreshTokenHash = arg.RefreshTokenHash
|
||||
t.TokenExpiresAt = arg.TokenExpiresAt
|
||||
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||
s.oidcTokens[arg.AccessTokenHash] = t
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcTokens, accessTokenHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.Sub == sub {
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.CodeHash == codeHash {
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []repository.OidcToken
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||
deleted = append(deleted, t)
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u := repository.OidcUserinfo(arg)
|
||||
s.oidcUsers[arg.Sub] = u
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
u, ok := s.oidcUsers[sub]
|
||||
if !ok {
|
||||
return repository.OidcUserinfo{}, repository.ErrNotFound
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcUsers, sub)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess := repository.Session(arg)
|
||||
s.sessions[arg.UUID] = sess
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
sess, ok := s.sessions[uuid]
|
||||
if !ok {
|
||||
return repository.Session{}, repository.ErrNotFound
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess, ok := s.sessions[arg.UUID]
|
||||
if !ok {
|
||||
return repository.Session{}, repository.ErrNotFound
|
||||
}
|
||||
sess.Username = arg.Username
|
||||
sess.Email = arg.Email
|
||||
sess.Name = arg.Name
|
||||
sess.Provider = arg.Provider
|
||||
sess.TotpPending = arg.TotpPending
|
||||
sess.OAuthGroups = arg.OAuthGroups
|
||||
sess.Expiry = arg.Expiry
|
||||
sess.OAuthName = arg.OAuthName
|
||||
sess.OAuthSub = arg.OAuthSub
|
||||
s.sessions[arg.UUID] = sess
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, uuid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, v := range s.sessions {
|
||||
if v.Expiry < expiry {
|
||||
delete(s.sessions, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
||||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
// Store is a thread-safe in-memory implementation of repository.Store.
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]repository.Session
|
||||
oidcCodes map[string]repository.OidcCode
|
||||
oidcTokens map[string]repository.OidcToken
|
||||
oidcUsers map[string]repository.OidcUserinfo
|
||||
}
|
||||
|
||||
// New returns a new empty in-memory Store.
|
||||
func New() repository.Store {
|
||||
return &Store{
|
||||
sessions: make(map[string]repository.Session),
|
||||
oidcCodes: make(map[string]repository.OidcCode),
|
||||
oidcTokens: make(map[string]repository.OidcToken),
|
||||
oidcUsers: make(map[string]repository.OidcUserinfo),
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,22 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package repository
|
||||
|
||||
// Shared model and parameter types for all storage drivers.
|
||||
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
CreatedAt int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
}
|
||||
|
||||
type OidcCode struct {
|
||||
Sub string
|
||||
CodeHash string
|
||||
@@ -49,7 +62,7 @@ type OidcUserinfo struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
type CreateSessionParams struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
@@ -62,3 +75,74 @@ type Session struct {
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
}
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
UUID string
|
||||
}
|
||||
|
||||
type CreateOidcCodeParams struct {
|
||||
Sub string
|
||||
CodeHash string
|
||||
Scope string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ExpiresAt int64
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type CreateOidcTokenParams struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
RefreshTokenHash string
|
||||
Scope string
|
||||
ClientID string
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
CodeHash string
|
||||
Nonce string
|
||||
}
|
||||
|
||||
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||
AccessTokenHash string
|
||||
RefreshTokenHash string
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
RefreshTokenHash_2 string
|
||||
}
|
||||
|
||||
type DeleteExpiredOidcTokensParams struct {
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
}
|
||||
|
||||
type CreateOidcUserInfoParams struct {
|
||||
Sub string
|
||||
Name string
|
||||
PreferredUsername string
|
||||
Email string
|
||||
Groups string
|
||||
UpdatedAt int64
|
||||
GivenName string
|
||||
FamilyName string
|
||||
MiddleName string
|
||||
Nickname string
|
||||
Profile string
|
||||
Picture string
|
||||
Website string
|
||||
Gender string
|
||||
Birthdate string
|
||||
Zoneinfo string
|
||||
Locale string
|
||||
PhoneNumber string
|
||||
Address string
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
|
||||
package repository
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -0,0 +1,3 @@
|
||||
package sqlite
|
||||
|
||||
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
|
||||
@@ -0,0 +1,64 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.31.1
|
||||
|
||||
package sqlite
|
||||
|
||||
type OidcCode struct {
|
||||
Sub string
|
||||
CodeHash string
|
||||
Scope string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
ExpiresAt int64
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type OidcToken struct {
|
||||
Sub string
|
||||
AccessTokenHash string
|
||||
RefreshTokenHash string
|
||||
CodeHash string
|
||||
Scope string
|
||||
ClientID string
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
Nonce string
|
||||
}
|
||||
|
||||
type OidcUserinfo struct {
|
||||
Sub string
|
||||
Name string
|
||||
PreferredUsername string
|
||||
Email string
|
||||
Groups string
|
||||
UpdatedAt int64
|
||||
GivenName string
|
||||
FamilyName string
|
||||
MiddleName string
|
||||
Nickname string
|
||||
Profile string
|
||||
Picture string
|
||||
Website string
|
||||
Gender string
|
||||
Birthdate string
|
||||
Zoneinfo string
|
||||
Locale string
|
||||
PhoneNumber string
|
||||
Address string
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
CreatedAt int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
}
|
||||
+2
-2
@@ -1,9 +1,9 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: oidc_queries.sql
|
||||
|
||||
package repository
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
+2
-2
@@ -1,9 +1,9 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// sqlc v1.31.1
|
||||
// source: session_queries.sql
|
||||
|
||||
package repository
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -0,0 +1,209 @@
|
||||
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
// Store wraps *Queries and implements repository.Store.
|
||||
type Store struct {
|
||||
q *Queries
|
||||
}
|
||||
|
||||
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||
func NewStore(q *Queries) repository.Store {
|
||||
return &Store{q: q}
|
||||
}
|
||||
|
||||
var errorMap = map[error]error{
|
||||
sql.ErrNoRows: repository.ErrNotFound,
|
||||
}
|
||||
|
||||
func mapErr(err error) error {
|
||||
for from, to := range errorMap {
|
||||
if errors.Is(err, from) {
|
||||
return to
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcCode(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcToken(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcUserinfo{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcUserinfo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
|
||||
if err != nil {
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]repository.OidcCode, len(rows))
|
||||
for i, row := range rows {
|
||||
out[i] = repository.OidcCode(row)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||
if err != nil {
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]repository.OidcToken, len(rows))
|
||||
for i, row := range rows {
|
||||
out[i] = repository.OidcToken(row)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcCode(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcCode(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcCode(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcCode(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcToken(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcToken(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcToken(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcUserinfo{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcUserinfo(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||
r, err := s.q.GetSession(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return repository.OidcToken(r), nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
|
||||
if err != nil {
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return repository.Session(r), nil
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned by Store methods when the requested record does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// Store is the interface that all storage drivers must implement.
|
||||
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
|
||||
// Future drivers (postgres, etc.) must return the shared types defined in this package.
|
||||
type Store interface {
|
||||
// Sessions
|
||||
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
|
||||
GetSession(ctx context.Context, uuid string) (Session, error)
|
||||
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
|
||||
DeleteSession(ctx context.Context, uuid string) error
|
||||
DeleteExpiredSessions(ctx context.Context, expiry int64) error
|
||||
|
||||
// OIDC codes
|
||||
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
|
||||
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
|
||||
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
|
||||
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
|
||||
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
|
||||
DeleteOidcCode(ctx context.Context, codeHash string) error
|
||||
DeleteOidcCodeBySub(ctx context.Context, sub string) error
|
||||
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
|
||||
|
||||
// OIDC tokens
|
||||
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
|
||||
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
|
||||
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
|
||||
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
|
||||
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
|
||||
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
|
||||
DeleteOidcTokenBySub(ctx context.Context, sub string) error
|
||||
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
|
||||
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
|
||||
|
||||
// OIDC userinfo
|
||||
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
|
||||
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
|
||||
DeleteOidcUserInfo(ctx context.Context, sub string) error
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
type RuleName string
|
||||
|
||||
const (
|
||||
RuleUserAllowed RuleName = "rule-user-allowed"
|
||||
RuleOAuthGroup RuleName = "rule-oauth-group"
|
||||
RuleLDAPGroup RuleName = "rule-ldap-group"
|
||||
RuleAuthEnabled RuleName = "rule-auth-enabled"
|
||||
RuleIPAllowed RuleName = "rule-ip-allowed"
|
||||
RuleIPBypassed RuleName = "rule-ip-bypassed"
|
||||
)
|
||||
|
||||
type UserAllowedRule struct {
|
||||
Log *logger.Logger
|
||||
}
|
||||
|
||||
func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
|
||||
if ctx.ACLs == nil || ctx.UserContext == nil {
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
if ctx.UserContext.Provider == model.ProviderOAuth {
|
||||
rule.Log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Whitelist, ctx.UserContext.OAuth.Email)
|
||||
if err != nil {
|
||||
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.OAuth.Email).Msg("Invalid entry in OAuth whitelist")
|
||||
return EffectAbstain
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Debug().Str("email", ctx.UserContext.OAuth.Email).Msg("User is in OAuth whitelist, allowing access")
|
||||
return EffectAllow
|
||||
}
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
if ctx.ACLs.Users.Block != "" {
|
||||
rule.Log.App.Debug().Msg("Checking users block list")
|
||||
match, err := utils.CheckFilter(ctx.ACLs.Users.Block, ctx.UserContext.GetUsername())
|
||||
if err != nil {
|
||||
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users block list")
|
||||
return EffectAbstain
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users block list, denying access")
|
||||
return EffectDeny
|
||||
}
|
||||
return EffectAllow
|
||||
}
|
||||
|
||||
rule.Log.App.Debug().Msg("Checking users allow list")
|
||||
|
||||
match, err := utils.CheckFilter(ctx.ACLs.Users.Allow, ctx.UserContext.GetUsername())
|
||||
|
||||
if err != nil {
|
||||
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users allow list")
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
if match {
|
||||
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users allow list, allowing access")
|
||||
return EffectAllow
|
||||
}
|
||||
|
||||
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is not in users allow list, denying access")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
type OAuthGroupRule struct {
|
||||
Log *logger.Logger
|
||||
}
|
||||
|
||||
func (rule *OAuthGroupRule) Evaluate(ctx *ACLContext) Effect {
|
||||
if ctx.ACLs == nil || ctx.UserContext == nil {
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
if !ctx.UserContext.IsOAuth() {
|
||||
rule.Log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
if _, ok := model.OverrideProviders[ctx.UserContext.OAuth.ID]; ok {
|
||||
rule.Log.App.Debug().Str("provider", ctx.UserContext.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||
return EffectAllow
|
||||
}
|
||||
|
||||
for _, group := range ctx.UserContext.OAuth.Groups {
|
||||
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Groups, strings.TrimSpace(group))
|
||||
if err != nil {
|
||||
return EffectAbstain
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.OAuth.Groups).Msg("User group matched, allowing access")
|
||||
return EffectAllow
|
||||
}
|
||||
}
|
||||
|
||||
rule.Log.App.Debug().Msg("No groups matched")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
type LDAPGroupRule struct {
|
||||
Log *logger.Logger
|
||||
}
|
||||
|
||||
func (rule *LDAPGroupRule) Evaluate(ctx *ACLContext) Effect {
|
||||
if ctx == nil || ctx.UserContext == nil {
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
if !ctx.UserContext.IsLDAP() {
|
||||
rule.Log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
for _, group := range ctx.UserContext.LDAP.Groups {
|
||||
match, err := utils.CheckFilter(ctx.ACLs.LDAP.Groups, strings.TrimSpace(group))
|
||||
if err != nil {
|
||||
return EffectAbstain
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.LDAP.Groups).Msg("User group matched, allowing access")
|
||||
return EffectAllow
|
||||
}
|
||||
}
|
||||
|
||||
rule.Log.App.Debug().Msg("No groups matched")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
type AuthEnabledRule struct {
|
||||
Log *logger.Logger
|
||||
}
|
||||
|
||||
func (rule *AuthEnabledRule) Evaluate(ctx *ACLContext) Effect {
|
||||
if ctx.ACLs == nil {
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
if ctx.ACLs.Path.Block != "" {
|
||||
regex, err := regexp.Compile(ctx.ACLs.Path.Block)
|
||||
|
||||
if err != nil {
|
||||
rule.Log.App.Error().Err(err).Msg("Failed to compile block regex")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
if !regex.MatchString(ctx.Path) {
|
||||
return EffectAllow
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.ACLs.Path.Allow != "" {
|
||||
regex, err := regexp.Compile(ctx.ACLs.Path.Allow)
|
||||
|
||||
if err != nil {
|
||||
rule.Log.App.Error().Err(err).Msg("Failed to compile allow regex")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
if regex.MatchString(ctx.Path) {
|
||||
return EffectAllow
|
||||
}
|
||||
}
|
||||
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
type IPAllowedRule struct {
|
||||
Log *logger.Logger
|
||||
Config model.Config
|
||||
}
|
||||
|
||||
func (rule *IPAllowedRule) Evaluate(ctx *ACLContext) Effect {
|
||||
if ctx.ACLs == nil {
|
||||
return EffectAbstain
|
||||
}
|
||||
|
||||
// Merge the global and app IP filter
|
||||
blockedIps := append(ctx.ACLs.IP.Block, rule.Config.Auth.IP.Block...)
|
||||
allowedIPs := append(ctx.ACLs.IP.Allow, rule.Config.Auth.IP.Allow...)
|
||||
|
||||
for _, blocked := range blockedIps {
|
||||
match, err := utils.CheckIPFilter(blocked, ctx.IP.String())
|
||||
if err != nil {
|
||||
rule.Log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
continue
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||
return EffectDeny
|
||||
}
|
||||
}
|
||||
|
||||
for _, allowed := range allowedIPs {
|
||||
match, err := utils.CheckIPFilter(allowed, ctx.IP.String())
|
||||
if err != nil {
|
||||
rule.Log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
continue
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||
return EffectAllow
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedIPs) > 0 {
|
||||
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in allow list, denying access")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in block or allow list, allowing access")
|
||||
return EffectAllow
|
||||
}
|
||||
|
||||
type IPBypassedRule struct {
|
||||
Log *logger.Logger
|
||||
}
|
||||
|
||||
func (rule *IPBypassedRule) Evaluate(ctx *ACLContext) Effect {
|
||||
if ctx.ACLs == nil {
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
for _, bypassed := range ctx.ACLs.IP.Bypass {
|
||||
match, err := utils.CheckIPFilter(bypassed, ctx.IP.String())
|
||||
if err != nil {
|
||||
rule.Log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
continue
|
||||
}
|
||||
if match {
|
||||
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||
return EffectAllow
|
||||
}
|
||||
}
|
||||
|
||||
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in bypass list, proceeding with authentication")
|
||||
return EffectDeny
|
||||
}
|
||||
@@ -1,732 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func TestUserAllowedRule(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
rule := &UserAllowedRule{Log: log}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx *ACLContext
|
||||
expected Effect
|
||||
}{
|
||||
{
|
||||
name: "abstains when ACLs are nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: nil,
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "abstains when user context is nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Whitelist: "alice"},
|
||||
},
|
||||
UserContext: nil,
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "allows OAuth user when email matches whitelist",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Whitelist: "allowed@example.com"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "different-username",
|
||||
Email: "allowed@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies OAuth user when email does not match whitelist",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Whitelist: "allowed@example.com"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{Email: "denied@example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "abstains for OAuth user when whitelist filter is invalid",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Whitelist: "/[/"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
BaseContext: model.BaseContext{Email: "allowed@example.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "denies local user when username matches block list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Users: model.AppUsers{Block: "alice,bob"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows local user when username does not match block list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Users: model.AppUsers{Block: "alice,bob"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "charlie"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "abstains when block list filter is invalid",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Users: model.AppUsers{Block: "/[/"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "allows local user when username matches allow list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Users: model.AppUsers{Allow: "alice,bob"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies local user when username does not match allow list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Users: model.AppUsers{Allow: "alice,bob"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "charlie"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "abstains when allow list filter is invalid",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Users: model.AppUsers{Allow: "/[/"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthGroupRule(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
rule := &OAuthGroupRule{Log: log}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx *ACLContext
|
||||
expected Effect
|
||||
}{
|
||||
{
|
||||
name: "abstains when ACLs are nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: nil,
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
Groups: []string{"admins"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "abstains when user context is nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Whitelist: "alice"},
|
||||
},
|
||||
UserContext: nil,
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "abstains when user is not OAuth",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "allows when provider is an override provider regardless of groups",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
ID: "google",
|
||||
Groups: []string{"unrelated"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "allows OAuth user when a group matches",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Groups: "admins,users"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
ID: "custom",
|
||||
Groups: []string{"users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies OAuth user when no group matches",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
ID: "custom",
|
||||
Groups: []string{"users", "guests"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "denies OAuth user when user has no groups",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
ID: "custom",
|
||||
Groups: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "abstains when groups filter is invalid",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Groups: "/[/"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{
|
||||
ID: "custom",
|
||||
Groups: []string{"admins"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLDAPGroupRule(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
rule := &LDAPGroupRule{Log: log}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx *ACLContext
|
||||
expected Effect
|
||||
}{
|
||||
{
|
||||
name: "abstains when context is nil",
|
||||
ctx: nil,
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "abstains when user context is nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
OAuth: model.AppOAuth{Whitelist: "alice"},
|
||||
},
|
||||
UserContext: nil,
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "abstains when user is not LDAP",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
LDAP: model.AppLDAP{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{Username: "alice"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "allows LDAP user when a group matches",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
LDAP: model.AppLDAP{Groups: "admins,users"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
Groups: []string{"users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies LDAP user when no group matches",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
LDAP: model.AppLDAP{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
Groups: []string{"users", "guests"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "denies LDAP user when user has no groups",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
LDAP: model.AppLDAP{Groups: "admins"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
Groups: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "abstains when groups filter is invalid",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
LDAP: model.AppLDAP{Groups: "/[/"},
|
||||
},
|
||||
UserContext: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{
|
||||
Groups: []string{"admins"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthEnabledRule(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
rule := &AuthEnabledRule{Log: log}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx *ACLContext
|
||||
expected Effect
|
||||
}{
|
||||
{
|
||||
name: "deny when ACLs are nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: nil,
|
||||
Path: "/anything",
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows when path does not match block regex",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{Block: "^/admin"},
|
||||
},
|
||||
Path: "/public",
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies when path matches block regex and no allow regex",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{Block: "^/admin"},
|
||||
},
|
||||
Path: "/admin/users",
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows when path matches allow regex",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{Allow: "^/public"},
|
||||
},
|
||||
Path: "/public/index",
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies when path does not match allow regex",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{Allow: "^/public"},
|
||||
},
|
||||
Path: "/private",
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows when blocked path is also explicitly allowed",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{
|
||||
Block: "^/admin",
|
||||
Allow: "^/admin/public",
|
||||
},
|
||||
},
|
||||
Path: "/admin/public/page",
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies when block regex fails to compile",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{Block: "[invalid"},
|
||||
},
|
||||
Path: "/anything",
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "denies when allow regex fails to compile",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
Path: model.AppPath{Allow: "[invalid"},
|
||||
},
|
||||
Path: "/anything",
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "denies when no path rules are configured",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{},
|
||||
Path: "/anything",
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPAllowedRule(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config model.Config
|
||||
ctx *ACLContext
|
||||
expected Effect
|
||||
}{
|
||||
{
|
||||
name: "abstains when ACLs are nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: nil,
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectAbstain,
|
||||
},
|
||||
{
|
||||
name: "denies when IP matches app block list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{Block: []string{"10.0.0.1"}},
|
||||
},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "denies when IP matches global block list",
|
||||
config: model.Config{
|
||||
Auth: model.AuthConfig{
|
||||
IP: model.IPConfig{Block: []string{"10.0.0.0/24"}},
|
||||
},
|
||||
},
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{},
|
||||
IP: net.ParseIP("10.0.0.5"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows when IP matches app allow list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{Allow: []string{"192.168.1.0/24"}},
|
||||
},
|
||||
IP: net.ParseIP("192.168.1.10"),
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "allows when IP matches global allow list",
|
||||
config: model.Config{
|
||||
Auth: model.AuthConfig{
|
||||
IP: model.IPConfig{Allow: []string{"192.168.1.10"}},
|
||||
},
|
||||
},
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{},
|
||||
IP: net.ParseIP("192.168.1.10"),
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies when allow list is set and IP does not match",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{Allow: []string{"192.168.1.0/24"}},
|
||||
},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows when no block or allow lists are configured",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "block list takes precedence over allow list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{
|
||||
Block: []string{"10.0.0.1"},
|
||||
Allow: []string{"10.0.0.1"},
|
||||
},
|
||||
},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "skips invalid block entries and continues evaluation",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{
|
||||
Block: []string{"not-an-ip"},
|
||||
Allow: []string{"10.0.0.1"},
|
||||
},
|
||||
},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &IPAllowedRule{Log: log, Config: tt.config}
|
||||
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBypassedRule(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
rule := &IPBypassedRule{Log: log}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx *ACLContext
|
||||
expected Effect
|
||||
}{
|
||||
{
|
||||
name: "deny when ACLs are nil",
|
||||
ctx: &ACLContext{
|
||||
ACLs: nil,
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allows when IP matches bypass list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
|
||||
},
|
||||
IP: net.ParseIP("10.0.0.5"),
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "denies when IP does not match bypass list",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
|
||||
},
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "denies when bypass list is empty",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "skips invalid bypass entries and allows on later match",
|
||||
ctx: &ACLContext{
|
||||
ACLs: &model.App{
|
||||
IP: model.AppIP{Bypass: []string{"not-an-ip", "10.0.0.1"}},
|
||||
},
|
||||
IP: net.ParseIP("10.0.0.1"),
|
||||
},
|
||||
expected: EffectAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,52 +13,51 @@ type LabelProvider interface {
|
||||
|
||||
type AccessControlsService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
labelProvider *LabelProvider
|
||||
static map[string]model.App
|
||||
}
|
||||
|
||||
func NewAccessControlsService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
labelProvider *LabelProvider) *AccessControlsService {
|
||||
|
||||
labelProvider *LabelProvider,
|
||||
static map[string]model.App) *AccessControlsService {
|
||||
return &AccessControlsService{
|
||||
log: log,
|
||||
config: config,
|
||||
labelProvider: labelProvider,
|
||||
static: static,
|
||||
}
|
||||
}
|
||||
|
||||
func (service *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||
var nameMatch *model.App
|
||||
|
||||
// First try to find a matching app by domain, then fallback to matching by app name (subdomain)
|
||||
for app, config := range service.config.Apps {
|
||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||
var appAcls *model.App
|
||||
for app, config := range acls.static {
|
||||
if config.Config.Domain == domain {
|
||||
service.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||
return &config
|
||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||
appAcls = &config
|
||||
break // If we find a match by domain, we can stop searching
|
||||
}
|
||||
|
||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||
service.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||
nameMatch = &config
|
||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||
appAcls = &config
|
||||
break // If we find a match by app name, we can stop searching
|
||||
}
|
||||
}
|
||||
|
||||
return nameMatch
|
||||
return appAcls
|
||||
}
|
||||
|
||||
func (service *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||
// First check in the static config
|
||||
app := service.lookupStaticACLs(domain)
|
||||
app := acls.lookupStaticACLs(domain)
|
||||
|
||||
if app != nil {
|
||||
service.log.App.Debug().Msg("Using static ACLs for app")
|
||||
acls.log.App.Debug().Msg("Using static ACLs for app")
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// If we have a label provider configured, try to get ACLs from it
|
||||
if service.labelProvider != nil && *service.labelProvider != nil {
|
||||
return (*service.labelProvider).GetLabels(domain)
|
||||
if acls.labelProvider != nil {
|
||||
return (*acls.labelProvider).GetLabels(domain)
|
||||
}
|
||||
|
||||
// no labels
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
type mockLabelProvider struct {
|
||||
getLabelsFn func(appDomain string) (*model.App, error)
|
||||
calledWith string
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (m *mockLabelProvider) GetLabels(appDomain string) (*model.App, error) {
|
||||
m.calledWith = appDomain
|
||||
m.callCount++
|
||||
if m.getLabelsFn != nil {
|
||||
return m.getLabelsFn(appDomain)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestLookupStaticACLs(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
apps map[string]model.App
|
||||
domain string
|
||||
expectNil bool
|
||||
expectedDomain string
|
||||
}{
|
||||
{
|
||||
name: "returns nil when no apps are configured",
|
||||
apps: nil,
|
||||
domain: "foo.example.com",
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "returns nil when no app matches",
|
||||
apps: map[string]model.App{
|
||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
domain: "bar.example.com",
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "matches by exact domain",
|
||||
apps: map[string]model.App{
|
||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
domain: "foo.example.com",
|
||||
expectedDomain: "foo.example.com",
|
||||
},
|
||||
{
|
||||
name: "matches by app name when domain does not match any app",
|
||||
apps: map[string]model.App{
|
||||
"foo": {Config: model.AppConfig{Domain: "configured.example.com"}},
|
||||
},
|
||||
domain: "foo.example.com",
|
||||
expectedDomain: "configured.example.com",
|
||||
},
|
||||
{
|
||||
name: "matches by app name for nested subdomains",
|
||||
apps: map[string]model.App{
|
||||
"foo": {Config: model.AppConfig{Domain: "configured.example.com"}},
|
||||
},
|
||||
domain: "foo.sub.example.com",
|
||||
expectedDomain: "configured.example.com",
|
||||
},
|
||||
{
|
||||
name: "selects the app matching by domain among multiple apps",
|
||||
apps: map[string]model.App{
|
||||
"unrelated": {Config: model.AppConfig{Domain: "other.example.com"}},
|
||||
"target": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
domain: "foo.example.com",
|
||||
expectedDomain: "foo.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
||||
got := svc.lookupStaticACLs(tt.domain)
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tt.expectedDomain, got.Config.Domain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccessControls(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
t.Run("returns static ACLs when domain matches", func(t *testing.T) {
|
||||
config := model.Config{
|
||||
Apps: map[string]model.App{
|
||||
"foo": {
|
||||
Config: model.AppConfig{Domain: "foo.example.com"},
|
||||
Users: model.AppUsers{Allow: "alice"},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(log, config, nil)
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
||||
assert.Equal(t, "alice", got.Users.Allow)
|
||||
})
|
||||
|
||||
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||
svc := NewAccessControlsService(log, model.Config{}, nil)
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
|
||||
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||
var provider LabelProvider
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
|
||||
t.Run("falls back to label provider when no static match", func(t *testing.T) {
|
||||
expected := &model.App{
|
||||
Config: model.AppConfig{Domain: "dynamic.example.com"},
|
||||
Users: model.AppUsers{Allow: "bob"},
|
||||
}
|
||||
mock := &mockLabelProvider{
|
||||
getLabelsFn: func(appDomain string) (*model.App, error) {
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, expected, got)
|
||||
assert.Equal(t, "dynamic.example.com", mock.calledWith)
|
||||
assert.Equal(t, 1, mock.callCount)
|
||||
})
|
||||
|
||||
t.Run("does not call label provider when static match found", func(t *testing.T) {
|
||||
mock := &mockLabelProvider{}
|
||||
var provider LabelProvider = mock
|
||||
config := model.Config{
|
||||
Apps: map[string]model.App{
|
||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(log, config, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
||||
assert.Equal(t, 0, mock.callCount)
|
||||
})
|
||||
|
||||
t.Run("propagates label provider errors", func(t *testing.T) {
|
||||
providerErr := errors.New("provider boom")
|
||||
mock := &mockLabelProvider{
|
||||
getLabelsFn: func(appDomain string) (*model.App, error) {
|
||||
return nil, providerErr
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
assert.Nil(t, got)
|
||||
assert.ErrorIs(t, err, providerErr)
|
||||
assert.Equal(t, 1, mock.callCount)
|
||||
})
|
||||
}
|
||||
@@ -2,10 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -77,7 +78,7 @@ type AuthService struct {
|
||||
context context.Context
|
||||
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
oauthBroker *OAuthBrokerService
|
||||
|
||||
loginAttempts map[string]*LoginAttempt
|
||||
@@ -98,7 +99,7 @@ func NewAuthService(
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
ldap *LdapService,
|
||||
queries *repository.Queries,
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
) *AuthService {
|
||||
service := &AuthService{
|
||||
@@ -284,12 +285,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
|
||||
return false
|
||||
}
|
||||
return match
|
||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
@@ -420,7 +416,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
return nil, err
|
||||
@@ -457,6 +453,171 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
return auth.ldap != nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if context.Provider == model.ProviderOAuth {
|
||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||
}
|
||||
|
||||
if acls.Users.Block != "" {
|
||||
auth.log.App.Debug().Msg("Checking users block list")
|
||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("Checking users allow list")
|
||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !context.IsOAuth() {
|
||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||
return true
|
||||
}
|
||||
|
||||
for _, userGroup := range context.OAuth.Groups {
|
||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !context.IsLDAP() {
|
||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, userGroup := range context.LDAP.Groups {
|
||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
||||
if acls == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check for block list
|
||||
if acls.Path.Block != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Block)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
if !regex.MatchString(uri) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for allow list
|
||||
if acls.Path.Allow != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Allow)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
if regex.MatchString(uri) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Merge the global and app IP filter
|
||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
||||
|
||||
for _, blocked := range blockedIps {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, allowed := range allowedIPs {
|
||||
res, err := utils.FilterIP(allowed, ip)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedIPs) > 0 {
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||
return false
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
||||
return true
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, bypassed := range acls.IP.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
||||
auth.ensureOAuthSessionLimit()
|
||||
|
||||
|
||||
@@ -85,23 +85,17 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var nameMatch *model.App
|
||||
|
||||
// First try to find a matching app by domain, then fallback to matching by app name (subdomain)
|
||||
for appName, appLabels := range labels.Apps {
|
||||
if appLabels.Config.Domain == appDomain {
|
||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||
return &appLabels, nil
|
||||
}
|
||||
|
||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||
nameMatch = &appLabels
|
||||
return &appLabels, nil
|
||||
}
|
||||
}
|
||||
|
||||
if nameMatch != nil {
|
||||
return nameMatch, nil
|
||||
}
|
||||
}
|
||||
|
||||
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
@@ -116,7 +115,7 @@ type OIDCService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
context context.Context
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
@@ -129,7 +128,7 @@ func NewOIDCService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries *repository.Queries,
|
||||
queries repository.Store,
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
@@ -434,7 +433,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
||||
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcCode{}, ErrCodeNotFound
|
||||
}
|
||||
return repository.OidcCode{}, err
|
||||
@@ -578,7 +577,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
||||
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return TokenResponse{}, ErrTokenNotFound
|
||||
}
|
||||
return TokenResponse{}, err
|
||||
@@ -657,7 +656,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
|
||||
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcToken{}, ErrTokenNotFound
|
||||
}
|
||||
return repository.OidcToken{}, err
|
||||
@@ -745,15 +744,15 @@ func (service *OIDCService) Hash(token string) string {
|
||||
|
||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -793,14 +792,16 @@ func (service *OIDCService) cleanupRoutine() {
|
||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
|
||||
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
}
|
||||
|
||||
for _, expiredCode := range expiredCodes {
|
||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||
if !errors.Is(err, repository.ErrNotFound) {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
type Policy string
|
||||
|
||||
const (
|
||||
PolicyAllow Policy = "allow"
|
||||
PolicyDeny Policy = "deny"
|
||||
)
|
||||
|
||||
type Effect int
|
||||
|
||||
const (
|
||||
EffectAbstain Effect = iota
|
||||
EffectAllow
|
||||
EffectDeny
|
||||
)
|
||||
|
||||
type Rule interface {
|
||||
Evaluate(ctx *ACLContext) Effect
|
||||
}
|
||||
|
||||
type ACLContext struct {
|
||||
ACLs *model.App
|
||||
UserContext *model.UserContext
|
||||
IP net.IP
|
||||
Path string
|
||||
}
|
||||
|
||||
type PolicyEngine struct {
|
||||
log *logger.Logger
|
||||
rules map[RuleName]Rule
|
||||
policy Policy
|
||||
}
|
||||
|
||||
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
|
||||
engine := PolicyEngine{
|
||||
log: log,
|
||||
rules: make(map[RuleName]Rule),
|
||||
}
|
||||
|
||||
switch config.Auth.ACLs.Policy {
|
||||
case string(PolicyAllow):
|
||||
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
engine.policy = PolicyAllow
|
||||
case string(PolicyDeny):
|
||||
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
engine.policy = PolicyDeny
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
|
||||
}
|
||||
|
||||
return &engine, nil
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) RegisterRule(name RuleName, rule Rule) {
|
||||
engine.log.App.Debug().Str("rule", string(name)).Msg("Registering ACL rule in policy engine")
|
||||
engine.rules[name] = rule
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) evaluateRuleByName(name RuleName, ctx *ACLContext) Effect {
|
||||
rule, exists := engine.rules[name]
|
||||
|
||||
if !exists {
|
||||
engine.log.App.Warn().Str("rule", string(name)).Msg("Rule not found in policy engine, defaulting to deny")
|
||||
return EffectDeny
|
||||
}
|
||||
|
||||
return rule.Evaluate(ctx)
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) effectToAccess(effect Effect) bool {
|
||||
switch effect {
|
||||
case EffectAllow:
|
||||
return true
|
||||
case EffectDeny:
|
||||
return false
|
||||
default:
|
||||
// If the effect is abstain, we fall back to the default policy
|
||||
return engine.policy == PolicyAllow
|
||||
}
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) Evaluate(name RuleName, ctx *ACLContext) bool {
|
||||
effect := engine.evaluateRuleByName(name, ctx)
|
||||
access := engine.effectToAccess(effect)
|
||||
|
||||
engine.log.App.Debug().
|
||||
Str("rule", string(name)).
|
||||
Int("effect", int(effect)).
|
||||
Bool("access", access).
|
||||
Msg("Evaluated ACL rule")
|
||||
|
||||
return access
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) Policy() Policy {
|
||||
return engine.policy
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
||||
return engine.rules
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
// Create test rule
|
||||
type TestRule struct{}
|
||||
|
||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
||||
switch ctx.Path {
|
||||
case "/allowed":
|
||||
return service.EffectAllow
|
||||
case "/denied":
|
||||
return service.EffectDeny
|
||||
default:
|
||||
return service.EffectAbstain
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyEngine(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, _ := test.CreateTestConfigs(t)
|
||||
|
||||
testRule := &TestRule{}
|
||||
|
||||
// Engine should fail with invalid policy
|
||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||
_, err := service.NewPolicyEngine(cfg, log)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Engine should initialize with 'allow' policy
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err := service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
||||
|
||||
// Engine should initialize with 'deny' policy
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
||||
|
||||
// Engine should allow adding rules
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
_, ok := engine.Rules()["test-rule"]
|
||||
assert.True(t, ok)
|
||||
|
||||
// Begin allow policy tests
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
// With allow policy, if rule allows, access should be allowed
|
||||
ctx := &service.ACLContext{Path: "/allowed"}
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With allow policy, if rule denies, access should be denied
|
||||
ctx.Path = "/denied"
|
||||
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With allow policy, if rule abstains, access should be allowed (default)
|
||||
ctx.Path = "/abstain"
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// Begin deny policy tests
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
// With deny policy, if rule allows, access should be allowed
|
||||
ctx.Path = "/allowed"
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With deny policy, if rule denies, access should be denied
|
||||
ctx.Path = "/denied"
|
||||
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With deny policy, if rule abstains, access should be denied (default)
|
||||
ctx.Path = "/abstain"
|
||||
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
|
||||
}
|
||||
@@ -40,9 +40,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
SessionExpiry: 10,
|
||||
LoginTimeout: 10,
|
||||
LoginMaxRetries: 3,
|
||||
ACLs: model.ACLsConfig{
|
||||
Policy: "allow",
|
||||
},
|
||||
},
|
||||
Database: model.DatabaseConfig{
|
||||
Path: filepath.Join(tempDir, "test.db"),
|
||||
@@ -51,32 +48,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
Enabled: true,
|
||||
Path: filepath.Join(tempDir, "resources"),
|
||||
},
|
||||
Apps: map[string]model.App{
|
||||
"app_path_allow": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "path-allow.example.com",
|
||||
},
|
||||
Path: model.AppPath{
|
||||
Allow: "/allowed",
|
||||
},
|
||||
},
|
||||
"app_user_allow": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "user-allow.example.com",
|
||||
},
|
||||
Users: model.AppUsers{
|
||||
Allow: "testuser",
|
||||
},
|
||||
},
|
||||
"ip_bypass": {
|
||||
Config: model.AppConfig{
|
||||
Domain: "ip-bypass.example.com",
|
||||
},
|
||||
IP: model.AppIP{
|
||||
Bypass: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||
|
||||
@@ -3,7 +3,7 @@ package utils
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"errors"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -46,27 +46,26 @@ func EncodeBasicAuth(username string, password string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
func CheckIPFilter(filter string, ip string) (bool, error) {
|
||||
func FilterIP(filter string, ip string) (bool, error) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
||||
if ipAddr == nil {
|
||||
return false, fmt.Errorf("invalid ip address")
|
||||
return false, errors.New("invalid IP address")
|
||||
}
|
||||
|
||||
filter = strings.ReplaceAll(filter, "-", "/")
|
||||
filter = strings.Replace(filter, "-", "/", -1)
|
||||
|
||||
if strings.Contains(filter, "/") {
|
||||
_, cidr, err := net.ParseCIDR(filter)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid cidr notation: %w", err)
|
||||
return false, err
|
||||
}
|
||||
return cidr.Contains(ipAddr), nil
|
||||
}
|
||||
|
||||
ipFilter := net.ParseIP(filter)
|
||||
|
||||
if ipFilter == nil {
|
||||
return false, fmt.Errorf("invalid ip address")
|
||||
return false, errors.New("invalid IP address in filter")
|
||||
}
|
||||
|
||||
if ipFilter.Equal(ipAddr) {
|
||||
@@ -76,29 +75,31 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func CheckFilter(filter string, input string) (bool, error) {
|
||||
func CheckFilter(filter string, str string) bool {
|
||||
if len(strings.TrimSpace(filter)) == 0 {
|
||||
return false, fmt.Errorf("filter is empty")
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||
re, err := regexp.Compile(filter[1 : len(filter)-1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid regex filter: %w", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if re.MatchString(input) {
|
||||
return true, nil
|
||||
if re.MatchString(strings.TrimSpace(str)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for item := range strings.SplitSeq(filter, ",") {
|
||||
if strings.TrimSpace(item) == input {
|
||||
return true, nil
|
||||
filterSplit := strings.Split(filter, ",")
|
||||
|
||||
for _, item := range filterSplit {
|
||||
if strings.TrimSpace(item) == strings.TrimSpace(str) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
func GenerateUUID(str string) string {
|
||||
|
||||
@@ -75,77 +75,66 @@ func TestEncodeBasicAuth(t *testing.T) {
|
||||
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||
}
|
||||
|
||||
func TestCheckIPFilter(t *testing.T) {
|
||||
func TestFilterIP(t *testing.T) {
|
||||
// Exact match IPv4
|
||||
ok, err := utils.CheckIPFilter("10.10.0.1", "10.10.0.1")
|
||||
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// Non-match IPv4
|
||||
ok, err = utils.CheckIPFilter("10.10.0.1", "10.10.0.2")
|
||||
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// CIDR match IPv4
|
||||
ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.10.0.2")
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR match IPv4 with '-' instead of '/'
|
||||
ok, err = utils.CheckIPFilter("10.10.10.0-24", "10.10.10.5")
|
||||
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR non-match IPv4
|
||||
ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.5.0.1")
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid CIDR
|
||||
ok, err = utils.CheckIPFilter("10.10.0.0/222", "10.0.0.1")
|
||||
assert.ErrorContains(t, err, "invalid cidr notation: invalid CIDR address: 10.10.0.0/222")
|
||||
ok, err = utils.FilterIP("10.10.0.0/222", "10.0.0.1")
|
||||
assert.ErrorContains(t, err, "invalid CIDR address")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid IP in filter
|
||||
ok, err = utils.CheckIPFilter("invalid_ip", "10.5.5.5")
|
||||
assert.ErrorContains(t, err, "invalid ip address")
|
||||
ok, err = utils.FilterIP("invalid_ip", "10.5.5.5")
|
||||
assert.ErrorContains(t, err, "invalid IP address in filter")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid IP to check
|
||||
ok, err = utils.CheckIPFilter("10.10.10.10", "invalid_ip")
|
||||
assert.ErrorContains(t, err, "invalid ip address")
|
||||
ok, err = utils.FilterIP("10.10.10.10", "invalid_ip")
|
||||
assert.ErrorContains(t, err, "invalid IP address")
|
||||
assert.Equal(t, false, ok)
|
||||
}
|
||||
|
||||
func TestCheckFilter(t *testing.T) {
|
||||
// Empty filter
|
||||
_, err := utils.CheckFilter("", "anystring")
|
||||
assert.ErrorContains(t, err, "filter is empty")
|
||||
assert.Equal(t, true, utils.CheckFilter("", "anystring"))
|
||||
|
||||
// Exact match
|
||||
ok, err := utils.CheckFilter("hello", "hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
assert.Equal(t, true, utils.CheckFilter("hello", "hello"))
|
||||
|
||||
// Regex match
|
||||
ok, err = utils.CheckFilter("/^h.*o$/", "hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
assert.Equal(t, true, utils.CheckFilter("/^h.*o$/", "hello"))
|
||||
|
||||
// Invalid regex
|
||||
ok, err = utils.CheckFilter("/[unclosed/", "test")
|
||||
assert.ErrorContains(t, err, "invalid regex")
|
||||
assert.Equal(t, false, ok)
|
||||
assert.Equal(t, false, utils.CheckFilter("/[unclosed", "test"))
|
||||
|
||||
// Comma-separated values
|
||||
ok, err = utils.CheckFilter("apple, banana, cherry", "banana")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
assert.Equal(t, true, utils.CheckFilter("apple, banana, cherry", "banana"))
|
||||
|
||||
// No match
|
||||
ok, err = utils.CheckFilter("apple, banana, cherry", "grape")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
assert.Equal(t, false, utils.CheckFilter("apple, banana, cherry", "grape"))
|
||||
}
|
||||
|
||||
func TestGenerateUUID(t *testing.T) {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
version: "2"
|
||||
sql:
|
||||
- engine: "sqlite"
|
||||
queries: "sql/*_queries.sql"
|
||||
schema: "sql/*_schemas.sql"
|
||||
queries: "sql/sqlite/*_queries.sql"
|
||||
schema: "sql/sqlite/*_schemas.sql"
|
||||
gen:
|
||||
go:
|
||||
package: "repository"
|
||||
out: "internal/repository"
|
||||
package: "sqlite"
|
||||
out: "internal/repository/sqlite"
|
||||
rename:
|
||||
uuid: "UUID"
|
||||
oauth_groups: "OAuthGroups"
|
||||
|
||||
Reference in New Issue
Block a user