mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-19 18:50:14 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c855f9b8ac | |||
| a56c349525 | |||
| 8b4ba23328 | |||
| 8932f2ad46 | |||
| 482ba9d99f | |||
| 1bcd1bb59a | |||
| 5349f21212 | |||
| e8071a9d80 | |||
| 1f67797605 | |||
| ca06099466 | |||
| d4b4245017 | |||
| 4c741a5990 |
@@ -1,38 +0,0 @@
|
|||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Create a report to help improve Tinyauth
|
|
||||||
title: "[BUG]"
|
|
||||||
labels: bug
|
|
||||||
assignees:
|
|
||||||
- steveiliop56
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Describe the bug**
|
|
||||||
A clear and concise description of what the bug is.
|
|
||||||
|
|
||||||
**To Reproduce**
|
|
||||||
Steps to reproduce the behavior:
|
|
||||||
1. Go to '...'
|
|
||||||
2. Click on '....'
|
|
||||||
3. Scroll down to '....'
|
|
||||||
4. See error
|
|
||||||
|
|
||||||
**Expected behavior**
|
|
||||||
A clear and concise description of what you expected to happen.
|
|
||||||
|
|
||||||
**Screenshots**
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
|
||||||
|
|
||||||
**Logs**
|
|
||||||
Please include the Tinyauth logs below, make sure to not include sensitive info.
|
|
||||||
|
|
||||||
**Device (please complete the following information):**
|
|
||||||
- OS: [e.g. iOS]
|
|
||||||
- Browser [e.g. chrome, safari]
|
|
||||||
- Tinyauth [e.g. v2.1.1]
|
|
||||||
- Docker [e.g. 27.3.1]
|
|
||||||
|
|
||||||
**
|
|
||||||
**Additional context**
|
|
||||||
Add any other context about the problem here.
|
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Create a report to help us improve this project
|
||||||
|
title: "[BUG]"
|
||||||
|
labels: bug
|
||||||
|
assignees:
|
||||||
|
- steveiliop56
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Thanks for reporting a bug! Please provide detailed information below.
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: Describe the Bug
|
||||||
|
description: "A clear and concise description of what the bug is."
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to Reproduce
|
||||||
|
description: Steps to reproduce the behavior.
|
||||||
|
value: |
|
||||||
|
1. Go to '...'
|
||||||
|
2. Click on '....'
|
||||||
|
3. Scroll down to '....'
|
||||||
|
4. See error
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: expected
|
||||||
|
attributes:
|
||||||
|
label: Expected Behavior
|
||||||
|
description: "A clear and concise description of what you expected to happen."
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: context
|
||||||
|
attributes:
|
||||||
|
label: "Additional Context"
|
||||||
|
description: "If applicable add screenshots to help explain your problem."
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: logs
|
||||||
|
attributes:
|
||||||
|
label: "Logs"
|
||||||
|
description: "Please include the Tinyauth logs, make sure to not include sensitive info."
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
placeholder: "e.g. iOS, Android, Windows, Linux, etc"
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: browser
|
||||||
|
attributes:
|
||||||
|
label: Browser
|
||||||
|
placeholder: "e.g. Chrome, Firefox, Safari, Edge, etc"
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: tinyauth
|
||||||
|
attributes:
|
||||||
|
label: Tinyauth Version
|
||||||
|
placeholder: "e.g. v5.0.0"
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: docker
|
||||||
|
attributes:
|
||||||
|
label: Docker Version (if applicable)
|
||||||
|
placeholder: "e.g. 27.3.1"
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: not-llm
|
||||||
|
attributes:
|
||||||
|
label: Human Written Confirmation
|
||||||
|
options:
|
||||||
|
- label: I confirm this issue was written by me and not generated by an LLM or AI assistant.
|
||||||
|
required: true
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Tinyauth Community Support on Discord
|
||||||
|
url: https://discord.gg/eHzVaCzRRd
|
||||||
|
about: Please ask and answer questions here.
|
||||||
|
- name: Tinyauth Documentation
|
||||||
|
url: https://tinyauth.app/docs/getting-started/
|
||||||
|
about: Please check the documentation here.
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: "[FEATURE]"
|
|
||||||
labels: enhancement
|
|
||||||
assignees:
|
|
||||||
- steveiliop56
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
**Describe the solution you'd like**
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
**Describe alternatives you've considered**
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
name: Feature request
|
||||||
|
description: Suggest an idea for this project
|
||||||
|
title: "[FEATURE]"
|
||||||
|
labels: enhancement
|
||||||
|
assignees:
|
||||||
|
- steveiliop56
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Thanks for suggesting a feature! Please provide detailed information below.
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: Is your feature request related to a problem? Please describe.
|
||||||
|
description: "A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]"
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Describe the solution you'd like.
|
||||||
|
description: "A clear and concise description of what you want to happen."
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: alternatives
|
||||||
|
attributes:
|
||||||
|
label: Describe alternatives you've considered.
|
||||||
|
description: "A clear and concise description of any alternative solutions or features you've considered."
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: context
|
||||||
|
attributes:
|
||||||
|
label: Additional context
|
||||||
|
description: "Add any other context or screenshots about the feature request here."
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: not-llm
|
||||||
|
attributes:
|
||||||
|
label: Human Written Confirmation
|
||||||
|
options:
|
||||||
|
- label: I confirm this request was written by me and not generated by an LLM or AI assistant.
|
||||||
|
required: true
|
||||||
@@ -28,6 +28,18 @@ jobs:
|
|||||||
- name: Go dependencies
|
- name: Go dependencies
|
||||||
run: go mod download
|
run: go mod download
|
||||||
|
|
||||||
|
- name: Setup sqlc
|
||||||
|
uses: sqlc-dev/setup-sqlc@v4
|
||||||
|
with:
|
||||||
|
sqlc-version: "1.31.1"
|
||||||
|
|
||||||
|
- name: Check codegen is up to date
|
||||||
|
run: |
|
||||||
|
sqlc generate
|
||||||
|
go generate ./internal/repository/...
|
||||||
|
git diff --exit-code -- internal/repository/
|
||||||
|
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
|
||||||
|
|
||||||
- name: Install frontend dependencies
|
- name: Install frontend dependencies
|
||||||
working-directory: ./frontend
|
working-directory: ./frontend
|
||||||
run: pnpm ci
|
run: pnpm ci
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ jobs:
|
|||||||
retention-days: 5
|
retention-days: 5
|
||||||
|
|
||||||
- name: Upload to code-scanning
|
- name: Upload to code-scanning
|
||||||
uses: github/codeql-action/upload-sarif@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4
|
uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
|
||||||
with:
|
with:
|
||||||
sarif_file: results.sarif
|
sarif_file: results.sarif
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
name: Close stale issues and PRs
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: 0 10 * * *
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
issues: write
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
stale:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10
|
|
||||||
with:
|
|
||||||
days-before-stale: 30
|
|
||||||
stale-pr-message: This PR has been inactive for 30 days and will be marked as stale.
|
|
||||||
stale-issue-message: This issue has been inactive for 30 days and will be marked as stale.
|
|
||||||
close-issue-message: Closed for inactivity.
|
|
||||||
close-pr-message: Closed for inactivity.
|
|
||||||
stale-issue-label: stale
|
|
||||||
stale-pr-label: stale
|
|
||||||
exempt-issue-labels: pinned
|
|
||||||
exempt-pr-labels: pinned
|
|
||||||
@@ -85,3 +85,4 @@ sql:
|
|||||||
# Go gen
|
# Go gen
|
||||||
generate:
|
generate:
|
||||||
go run ./gen
|
go run ./gen
|
||||||
|
go generate ./internal/repository/...
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
services:
|
services:
|
||||||
traefik:
|
traefik:
|
||||||
image: traefik:v3.6
|
image: traefik:v3.6
|
||||||
command: --api.insecure=true --providers.docker
|
command: --api.insecure=true --providers.docker --entrypoints.web.address=:80 --entrypoints.websecure.address=:443
|
||||||
ports:
|
ports:
|
||||||
- 80:80
|
- 80:80
|
||||||
|
- 443:443
|
||||||
volumes:
|
volumes:
|
||||||
- /var/run/docker.sock:/var/run/docker.sock
|
- /var/run/docker.sock:/var/run/docker.sock
|
||||||
|
|
||||||
@@ -25,6 +26,8 @@ services:
|
|||||||
labels:
|
labels:
|
||||||
traefik.enable: true
|
traefik.enable: true
|
||||||
traefik.http.routers.tinyauth.rule: Host(`tinyauth.127.0.0.1.sslip.io`)
|
traefik.http.routers.tinyauth.rule: Host(`tinyauth.127.0.0.1.sslip.io`)
|
||||||
|
traefik.http.routers.tinyauth.entrypoints: websecure
|
||||||
|
traefik.http.routers.tinyauth.tls: true
|
||||||
|
|
||||||
tinyauth-backend:
|
tinyauth-backend:
|
||||||
build:
|
build:
|
||||||
|
|||||||
@@ -0,0 +1,473 @@
|
|||||||
|
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
|
||||||
|
// internal/repository/<driver>/. Run via:
|
||||||
|
//
|
||||||
|
// go generate ./internal/repository/...
|
||||||
|
//
|
||||||
|
// The generator introspects *Queries methods and the model/params types in the
|
||||||
|
// driver package, then emits a store.go that wraps *Queries so it satisfies
|
||||||
|
// repository.Store using the canonical shared types in the parent package.
|
||||||
|
// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should
|
||||||
|
// implement repository.Store directly by hand.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
_ "embed"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"go/format"
|
||||||
|
"go/types"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"golang.org/x/tools/go/packages"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed store.tmpl
|
||||||
|
var storeSrc string
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("sqlc-wrapper: generating store.go files for sqlc driver packages...")
|
||||||
|
if err := run(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run() error {
|
||||||
|
driverPkg := flag.String("pkg", "", "import path of the driver package")
|
||||||
|
out := flag.String("out", "store.go", "output filename relative to driver package directory")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *driverPkg == "" {
|
||||||
|
return fmt.Errorf("-pkg is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve the driver package directory so we can overlay the output file
|
||||||
|
// with a valid stub. This prevents a stale store.go from poisoning the
|
||||||
|
// type-checker and producing cryptic "undefined" errors.
|
||||||
|
driverDir, err := pkgDir(*driverPkg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("resolve driver dir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outPath := filepath.Join(driverDir, *out)
|
||||||
|
if filepath.IsAbs(*out) {
|
||||||
|
outPath = *out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stub replaces the output file during load so stale generated code is ignored.
|
||||||
|
stub := []byte("package " + filepath.Base(driverDir) + "\n")
|
||||||
|
cfg := &packages.Config{
|
||||||
|
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
|
||||||
|
Overlay: map[string][]byte{outPath: stub},
|
||||||
|
}
|
||||||
|
|
||||||
|
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load driver package: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoPkgPath := parentPkg(*driverPkg)
|
||||||
|
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("load repo package: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
|
||||||
|
return fmt.Errorf("struct shape mismatch: %w", err)
|
||||||
|
}
|
||||||
|
if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
methods, err := collectMethods(driverTypePkg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
src, err := render(tmplData{
|
||||||
|
PkgName: driverTypePkg.Name(),
|
||||||
|
RepoPkg: repoPkgPath,
|
||||||
|
Methods: renderMethods(methods),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("render: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(outPath, src, 0644); err != nil {
|
||||||
|
return fmt.Errorf("write %s: %w", outPath, err)
|
||||||
|
}
|
||||||
|
fmt.Printf("wrote %s\n", outPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadOnePkg loads a single package via cfg and returns its *types.Package,
|
||||||
|
// or an error if the package fails to load or has type errors.
|
||||||
|
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
|
||||||
|
pkgs, err := packages.Load(cfg, importPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("load %s: %w", importPath, err)
|
||||||
|
}
|
||||||
|
if len(pkgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
|
||||||
|
}
|
||||||
|
pkg := pkgs[0]
|
||||||
|
if len(pkg.Errors) > 0 {
|
||||||
|
msgs := make([]string, len(pkg.Errors))
|
||||||
|
for i, e := range pkg.Errors {
|
||||||
|
msgs[i] = e.Error()
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
|
||||||
|
}
|
||||||
|
return pkg.Types, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parentPkg returns the parent import path (everything before the last /).
|
||||||
|
// Panics if imp contains no slash — callers are expected to pass driver sub-packages.
|
||||||
|
func parentPkg(imp string) string {
|
||||||
|
i := strings.LastIndex(imp, "/")
|
||||||
|
if i < 0 {
|
||||||
|
panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp))
|
||||||
|
}
|
||||||
|
return imp[:i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// pkgDir returns the on-disk directory for an import path using `go list`.
|
||||||
|
func pkgDir(importPath string) (string, error) {
|
||||||
|
out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("go list %s: %w", importPath, err)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(out)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scopeStructs returns all named struct types in pkg, excluding the internal
|
||||||
|
// sqlc types Queries, DBTX, and Store. Names are returned in sorted order.
|
||||||
|
func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) {
|
||||||
|
byName = make(map[string]*types.Struct)
|
||||||
|
for _, name := range pkg.Scope().Names() { // Names() is already sorted
|
||||||
|
switch name {
|
||||||
|
case "Queries", "DBTX", "Store":
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
named, ok := obj.Type().(*types.Named)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s, ok := named.Underlying().(*types.Struct)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
byName[name] = s
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateStoreCoverage checks that every method declared in repository.Store
|
||||||
|
// exists on *Queries in the driver package. Missing methods are reported by
|
||||||
|
// name so the developer knows exactly which SQL queries need to be added.
|
||||||
|
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
|
||||||
|
queriesObj := driverPkg.Scope().Lookup("Queries")
|
||||||
|
if queriesObj == nil {
|
||||||
|
return fmt.Errorf("queries type not found in driver package")
|
||||||
|
}
|
||||||
|
queriesNamed := queriesObj.Type().(*types.Named)
|
||||||
|
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
|
||||||
|
queriesMethods := make(map[string]bool)
|
||||||
|
for m := range queriesMS.Methods() {
|
||||||
|
queriesMethods[m.Obj().Name()] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
storeObj := repoPkg.Scope().Lookup("Store")
|
||||||
|
if storeObj == nil {
|
||||||
|
return fmt.Errorf("store type not found in repository package")
|
||||||
|
}
|
||||||
|
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("repository.Store is not an interface")
|
||||||
|
}
|
||||||
|
|
||||||
|
var missing []string
|
||||||
|
for method := range storeIface.Methods() {
|
||||||
|
if name := method.Name(); !queriesMethods[name] {
|
||||||
|
missing = append(missing, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
sort.Strings(missing)
|
||||||
|
return fmt.Errorf(
|
||||||
|
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries",
|
||||||
|
len(missing), strings.Join(missing, "\n - "),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateStructShapes checks that every model/params struct in the driver
|
||||||
|
// package has fields that exactly match the corresponding type in the repo
|
||||||
|
// (parent) package. This catches drift between sqlc-generated types and the
|
||||||
|
// canonical repository types before a broken cast reaches the compiler.
|
||||||
|
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
|
||||||
|
_, repoStructs := scopeStructs(repoPkg)
|
||||||
|
driverNames, driverStructs := scopeStructs(driverPkg)
|
||||||
|
|
||||||
|
var errs []string
|
||||||
|
for _, name := range driverNames {
|
||||||
|
repoStruct, ok := repoStructs[name]
|
||||||
|
if !ok {
|
||||||
|
// Driver has a type not in repo — fine (e.g. internal helpers).
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := compareStructs(name, driverStructs[name], repoStruct); err != nil {
|
||||||
|
errs = append(errs, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(errs) > 0 {
|
||||||
|
sort.Strings(errs)
|
||||||
|
return fmt.Errorf("%s", strings.Join(errs, "\n "))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareStructs(name string, driver, repo *types.Struct) error {
|
||||||
|
if driver.NumFields() != repo.NumFields() {
|
||||||
|
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
|
||||||
|
name, driver.NumFields(), repo.NumFields())
|
||||||
|
}
|
||||||
|
for i := range driver.NumFields() {
|
||||||
|
df := driver.Field(i)
|
||||||
|
rf := repo.Field(i)
|
||||||
|
if df.Name() != rf.Name() {
|
||||||
|
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
|
||||||
|
name, i, df.Name(), rf.Name())
|
||||||
|
}
|
||||||
|
if !types.Identical(df.Type(), rf.Type()) {
|
||||||
|
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
|
||||||
|
name, df.Name(), df.Type(), rf.Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type methodInfo struct {
|
||||||
|
Name string
|
||||||
|
Params []paramInfo
|
||||||
|
Results []resultInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type paramInfo struct {
|
||||||
|
Name string
|
||||||
|
TypeStr string // local (unqualified) type name
|
||||||
|
RepoType string // "repository.X" if this is a driver model/params type; else ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type resultInfo struct {
|
||||||
|
TypeStr string
|
||||||
|
IsSlice bool
|
||||||
|
RepoType string // "repository.X" if driver type; else ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectMethods(pkg *types.Package) ([]methodInfo, error) {
|
||||||
|
obj := pkg.Scope().Lookup("Queries")
|
||||||
|
if obj == nil {
|
||||||
|
return nil, fmt.Errorf("queries type not found in %s", pkg.Path())
|
||||||
|
}
|
||||||
|
named, ok := obj.Type().(*types.Named)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("queries is not a named type")
|
||||||
|
}
|
||||||
|
ms := types.NewMethodSet(types.NewPointer(named))
|
||||||
|
|
||||||
|
var out []methodInfo
|
||||||
|
for method := range ms.Methods() {
|
||||||
|
fn, ok := method.Obj().(*types.Func)
|
||||||
|
if !ok || fn.Name() == "WithTx" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sig := fn.Type().(*types.Signature)
|
||||||
|
mi := methodInfo{Name: fn.Name()}
|
||||||
|
|
||||||
|
// params: skip receiver + first (context.Context)
|
||||||
|
for i := 1; i < sig.Params().Len(); i++ {
|
||||||
|
p := sig.Params().At(i)
|
||||||
|
mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path()))
|
||||||
|
}
|
||||||
|
// results: skip error
|
||||||
|
for r := range sig.Results().Variables() {
|
||||||
|
if r.Type().String() == "error" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path()))
|
||||||
|
}
|
||||||
|
out = append(out, mi)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeParam(name string, t types.Type, driverPath string) paramInfo {
|
||||||
|
return paramInfo{
|
||||||
|
Name: name,
|
||||||
|
TypeStr: localName(t, driverPath),
|
||||||
|
RepoType: repoName(t, driverPath),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeResult(t types.Type, driverPath string) resultInfo {
|
||||||
|
ri := resultInfo{}
|
||||||
|
if sl, ok := t.(*types.Slice); ok {
|
||||||
|
ri.IsSlice = true
|
||||||
|
t = sl.Elem()
|
||||||
|
}
|
||||||
|
ri.TypeStr = localName(t, driverPath)
|
||||||
|
ri.RepoType = repoName(t, driverPath)
|
||||||
|
return ri
|
||||||
|
}
|
||||||
|
|
||||||
|
func localName(t types.Type, driverPath string) string {
|
||||||
|
named, ok := t.(*types.Named)
|
||||||
|
if !ok {
|
||||||
|
return types.TypeString(t, nil)
|
||||||
|
}
|
||||||
|
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
|
||||||
|
return named.Obj().Name()
|
||||||
|
}
|
||||||
|
return types.TypeString(t, func(p *types.Package) string { return p.Name() })
|
||||||
|
}
|
||||||
|
|
||||||
|
func repoName(t types.Type, driverPath string) string {
|
||||||
|
named, ok := t.(*types.Named)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
|
||||||
|
return "repository." + named.Obj().Name()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderedMethod holds pre-built signature and body strings passed to the template.
|
||||||
|
type renderedMethod struct {
|
||||||
|
Signature string
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderMethods(methods []methodInfo) []renderedMethod {
|
||||||
|
out := make([]renderedMethod, len(methods))
|
||||||
|
for i, m := range methods {
|
||||||
|
out[i] = renderedMethod{
|
||||||
|
Signature: buildSig(m),
|
||||||
|
Body: buildBody(m),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSig(m methodInfo) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString("func (s *Store) ")
|
||||||
|
sb.WriteString(m.Name)
|
||||||
|
sb.WriteString("(ctx context.Context")
|
||||||
|
for _, p := range m.Params {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
sb.WriteString(p.Name)
|
||||||
|
sb.WriteString(" ")
|
||||||
|
if p.RepoType != "" {
|
||||||
|
sb.WriteString(p.RepoType)
|
||||||
|
} else {
|
||||||
|
sb.WriteString(p.TypeStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString(") (")
|
||||||
|
for _, r := range m.Results {
|
||||||
|
if r.IsSlice {
|
||||||
|
sb.WriteString("[]")
|
||||||
|
}
|
||||||
|
if r.RepoType != "" {
|
||||||
|
sb.WriteString(r.RepoType)
|
||||||
|
} else {
|
||||||
|
sb.WriteString(r.TypeStr)
|
||||||
|
}
|
||||||
|
sb.WriteString(", ")
|
||||||
|
}
|
||||||
|
sb.WriteString("error)")
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func callArgs(m methodInfo) string {
|
||||||
|
args := make([]string, 0, len(m.Params))
|
||||||
|
for _, p := range m.Params {
|
||||||
|
if p.RepoType != "" {
|
||||||
|
// convert repo type → driver type: DriverType(arg)
|
||||||
|
args = append(args, p.TypeStr+"("+p.Name+")")
|
||||||
|
} else {
|
||||||
|
args = append(args, p.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(args) == 0 {
|
||||||
|
return "ctx"
|
||||||
|
}
|
||||||
|
return "ctx, " + strings.Join(args, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
var bodyTmpl = template.Must(template.New("store").Parse(storeSrc))
|
||||||
|
|
||||||
|
type bodyData struct {
|
||||||
|
Call string
|
||||||
|
RepoType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildBody(m methodInfo) string {
|
||||||
|
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
|
||||||
|
|
||||||
|
var (
|
||||||
|
name string
|
||||||
|
data bodyData
|
||||||
|
)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(m.Results) == 0 || m.Results[0].RepoType == "":
|
||||||
|
name = "void"
|
||||||
|
data = bodyData{Call: call}
|
||||||
|
case m.Results[0].IsSlice:
|
||||||
|
name = "slice"
|
||||||
|
data = bodyData{Call: call, RepoType: m.Results[0].RepoType}
|
||||||
|
default:
|
||||||
|
name = "scalar"
|
||||||
|
data = bodyData{Call: call, RepoType: m.Results[0].RepoType}
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := bodyTmpl.ExecuteTemplate(&buf, name, data); err != nil {
|
||||||
|
panic(fmt.Sprintf("buildBody %s: %v", name, err))
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type tmplData struct {
|
||||||
|
PkgName string
|
||||||
|
RepoPkg string
|
||||||
|
Methods []renderedMethod
|
||||||
|
}
|
||||||
|
|
||||||
|
func render(data tmplData) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := bodyTmpl.Execute(&buf, data); err != nil {
|
||||||
|
return nil, fmt.Errorf("execute template: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String())
|
||||||
|
}
|
||||||
|
return formatted, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||||
|
package {{.PkgName}}
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"{{.RepoPkg}}"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store wraps *Queries and implements repository.Store.
|
||||||
|
type Store struct {
|
||||||
|
q *Queries
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||||
|
func NewStore(q *Queries) repository.Store {
|
||||||
|
return &Store{q: q}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorMap = map[error]error{
|
||||||
|
sql.ErrNoRows: repository.ErrNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapErr(err error) error {
|
||||||
|
for from, to := range errorMap {
|
||||||
|
if errors.Is(err, from) {
|
||||||
|
return to
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
{{range .Methods}}{{.Signature}} {
|
||||||
|
{{.Body}}}
|
||||||
|
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{- define "void"}} return mapErr({{.Call}})
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{- define "scalar"}} r, err := {{.Call}}
|
||||||
|
if err != nil {
|
||||||
|
return {{.RepoType}}{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return {{.RepoType}}(r), nil
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{- define "slice"}} rows, err := {{.Call}}
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]{{.RepoType}}, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = {{.RepoType}}(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
{{end}}
|
||||||
@@ -20,6 +20,7 @@ require (
|
|||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.50.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
|
golang.org/x/tools v0.43.0
|
||||||
k8s.io/apimachinery v0.36.0
|
k8s.io/apimachinery v0.36.0
|
||||||
k8s.io/client-go v0.36.0
|
k8s.io/client-go v0.36.0
|
||||||
modernc.org/sqlite v1.50.0
|
modernc.org/sqlite v1.50.0
|
||||||
@@ -121,6 +122,7 @@ require (
|
|||||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
|
golang.org/x/mod v0.34.0 // indirect
|
||||||
golang.org/x/net v0.52.0 // indirect
|
golang.org/x/net v0.52.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
golang.org/x/sys v0.43.0 // indirect
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
|
|||||||
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
|
|||||||
|
|
||||||
// Migrations
|
// Migrations
|
||||||
//
|
//
|
||||||
//go:embed migrations/*.sql
|
//go:embed migrations/sqlite/*.sql
|
||||||
var Migrations embed.FS
|
var Migrations embed.FS
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -34,6 +35,7 @@ type Services struct {
|
|||||||
ldapService *service.LdapService
|
ldapService *service.LdapService
|
||||||
oauthBrokerService *service.OAuthBrokerService
|
oauthBrokerService *service.OAuthBrokerService
|
||||||
oidcService *service.OIDCService
|
oidcService *service.OIDCService
|
||||||
|
policyEngine *service.PolicyEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
@@ -43,7 +45,7 @@ type BootstrapApp struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
queries *repository.Queries
|
queries repository.Store
|
||||||
router *gin.Engine
|
router *gin.Engine
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
@@ -162,7 +164,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||||
|
|
||||||
// database
|
// database
|
||||||
err = app.SetupDatabase()
|
store, err := app.SetupStore()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup database: %w", err)
|
return fmt.Errorf("failed to setup database: %w", err)
|
||||||
@@ -173,12 +175,13 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
defer func() {
|
defer func() {
|
||||||
app.cancel()
|
app.cancel()
|
||||||
app.wg.Wait()
|
app.wg.Wait()
|
||||||
|
if app.db != nil {
|
||||||
app.db.Close()
|
app.db.Close()
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// queries
|
// store
|
||||||
queries := repository.New(app.db)
|
app.queries = store
|
||||||
app.queries = queries
|
|
||||||
|
|
||||||
// services
|
// services
|
||||||
err = app.setupServices()
|
err = app.setupServices()
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||||
|
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||||
@@ -14,17 +17,28 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) SetupDatabase() error {
|
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||||
dir := filepath.Dir(app.config.Database.Path)
|
switch app.config.Database.Driver {
|
||||||
|
case "memory":
|
||||||
|
return memory.New(), nil
|
||||||
|
case "sqlite", "":
|
||||||
|
return app.setupSQLite(app.config.Database.Path)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
|
||||||
|
dir := filepath.Dir(databasePath)
|
||||||
|
|
||||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||||
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", app.config.Database.Path)
|
db, err := sql.Open("sqlite", databasePath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the database if there is an error during migration
|
// Close the database if there is an error during migration
|
||||||
@@ -38,32 +52,29 @@ func (app *BootstrapApp) SetupDatabase() error {
|
|||||||
// if the sqlite connection starts being a bottleneck
|
// if the sqlite connection starts being a bottleneck
|
||||||
db.SetMaxOpenConns(1)
|
db.SetMaxOpenConns(1)
|
||||||
|
|
||||||
migrations, err := iofs.New(assets.Migrations, "migrations")
|
migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create migrations: %w", err)
|
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create migrator: %w", err)
|
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
return fmt.Errorf("failed to migrate database: %w", err)
|
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.db = db
|
app.db = db
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) GetDB() *sql.DB {
|
return sqlite.NewStore(sqlite.New(db)), nil
|
||||||
return app.db
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||||
controller.NewHealthController(apiRouter)
|
controller.NewHealthController(apiRouter)
|
||||||
|
|||||||
@@ -16,38 +16,21 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
app.services.ldapService = ldapService
|
app.services.ldapService = ldapService
|
||||||
|
|
||||||
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
labelProvider, err := app.getLabelProvider()
|
||||||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
|
||||||
|
|
||||||
var labelProvider service.LabelProvider
|
|
||||||
|
|
||||||
if useKubernetes {
|
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
|
||||||
|
|
||||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
return fmt.Errorf("failed to initialize label provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.kubernetesService = kubernetesService
|
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider)
|
||||||
labelProvider = kubernetesService
|
|
||||||
} else {
|
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
|
||||||
|
|
||||||
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize docker service: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.services.dockerService = dockerService
|
|
||||||
labelProvider = dockerService
|
|
||||||
}
|
|
||||||
|
|
||||||
accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
|
|
||||||
app.services.accessControlService = accessControlsService
|
app.services.accessControlService = accessControlsService
|
||||||
|
|
||||||
|
err = app.setupPolicyEngine()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
app.services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
@@ -64,3 +47,79 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
|
||||||
|
switch app.config.LabelProvider {
|
||||||
|
case "none", "docker", "kubernetes", "auto":
|
||||||
|
if app.config.LabelProvider == "none" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
||||||
|
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
||||||
|
|
||||||
|
if useKubernetes {
|
||||||
|
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
|
||||||
|
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
app.services.kubernetesService = kubernetesService
|
||||||
|
return kubernetesService, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
app.log.App.Debug().Msg("Using Docker label provider")
|
||||||
|
|
||||||
|
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dockerService == nil {
|
||||||
|
if app.config.LabelProvider == "docker" {
|
||||||
|
app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
app.services.dockerService = dockerService
|
||||||
|
return dockerService, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) setupPolicyEngine() error {
|
||||||
|
policyEngine, err := service.NewPolicyEngine(app.config, app.log)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize policy engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||||
|
Log: app.log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||||
|
Log: app.log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||||
|
Log: app.log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||||
|
Log: app.log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||||
|
Log: app.log,
|
||||||
|
Config: app.config,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||||
|
Log: app.log,
|
||||||
|
})
|
||||||
|
|
||||||
|
app.services.policyEngine = policyEngine
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -208,7 +208,12 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
name = user.Name
|
name = user.Name
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
|
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
|
||||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
parts := strings.SplitN(user.Email, "@", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
name = fmt.Sprintf("%s (%s)", utils.Capitalize(parts[0]), parts[1])
|
||||||
|
} else {
|
||||||
|
name = utils.Capitalize(user.Email)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var username string
|
var username string
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
|
controller.authorizeError(c, fmt.Errorf("client not found: %s", req.ClientID), "Client not found", "The client ID is invalid", "", "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,7 +288,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete code")
|
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
controller.log.App.Warn().Msg("Code not found")
|
controller.log.App.Warn().Msg("Code not found")
|
||||||
|
|||||||
@@ -15,10 +15,9 @@ import (
|
|||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -839,16 +838,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
store := memory.New()
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -869,8 +863,4 @@ func TestOIDCController(t *testing.T) {
|
|||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
app.GetDB().Close()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -55,6 +56,7 @@ type ProxyController struct {
|
|||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
acls *service.AccessControlsService
|
acls *service.AccessControlsService
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
|
policyEngine *service.PolicyEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyController(
|
func NewProxyController(
|
||||||
@@ -63,12 +65,14 @@ func NewProxyController(
|
|||||||
router *gin.RouterGroup,
|
router *gin.RouterGroup,
|
||||||
acls *service.AccessControlsService,
|
acls *service.AccessControlsService,
|
||||||
auth *service.AuthService,
|
auth *service.AuthService,
|
||||||
|
policyEngine *service.PolicyEngine,
|
||||||
) *ProxyController {
|
) *ProxyController {
|
||||||
controller := &ProxyController{
|
controller := &ProxyController{
|
||||||
log: log,
|
log: log,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
acls: acls,
|
acls: acls,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
|
policyEngine: policyEngine,
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyGroup := router.Group("/auth")
|
proxyGroup := router.Group("/auth")
|
||||||
@@ -101,7 +105,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
|
|
||||||
if controller.auth.IsBypassedIP(clientIP, acls) {
|
aclsCtx := &service.ACLContext{
|
||||||
|
ACLs: acls,
|
||||||
|
IP: net.ParseIP(clientIP),
|
||||||
|
Path: proxyCtx.Path,
|
||||||
|
}
|
||||||
|
|
||||||
|
if controller.policyEngine.Evaluate(service.RuleIPBypassed, aclsCtx) {
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -110,15 +120,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
if controller.policyEngine.Evaluate(service.RuleAuthEnabled, aclsCtx) {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
|
|
||||||
controller.handleError(c, proxyCtx)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !authEnabled {
|
|
||||||
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
|
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -128,7 +130,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.CheckIP(clientIP, acls) {
|
if !controller.policyEngine.Evaluate(service.RuleIPAllowed, aclsCtx) {
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
IP: clientIP,
|
IP: clientIP,
|
||||||
@@ -144,9 +146,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
c.JSON(401, gin.H{
|
c.JSON(403, gin.H{
|
||||||
"status": 401,
|
"status": 403,
|
||||||
"message": "Unauthorized",
|
"message": "Forbidden",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -164,10 +166,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.Authenticated {
|
aclsCtx.UserContext = userContext
|
||||||
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
|
||||||
|
|
||||||
if !userAllowed {
|
if userContext.Authenticated {
|
||||||
|
if !controller.policyEngine.Evaluate(service.RuleUserAllowed, aclsCtx) {
|
||||||
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
|
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
@@ -205,9 +207,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
var groupOK bool
|
var groupOK bool
|
||||||
|
|
||||||
if userContext.IsOAuth() {
|
if userContext.IsOAuth() {
|
||||||
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
|
groupOK = controller.policyEngine.Evaluate(service.RuleOAuthGroup, aclsCtx)
|
||||||
} else {
|
} else {
|
||||||
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
|
groupOK = controller.policyEngine.Evaluate(service.RuleLDAPGroup, aclsCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !groupOK {
|
if !groupOK {
|
||||||
|
|||||||
@@ -9,10 +9,9 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -24,33 +23,6 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
acls := map[string]model.App{
|
|
||||||
"app_path_allow": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "path-allow.example.com",
|
|
||||||
},
|
|
||||||
Path: model.AppPath{
|
|
||||||
Allow: "/allowed",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"app_user_allow": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "user-allow.example.com",
|
|
||||||
},
|
|
||||||
Users: model.AppUsers{
|
|
||||||
Allow: "testuser",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"ip_bypass": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "ip-bypass.example.com",
|
|
||||||
},
|
|
||||||
IP: model.AppIP{
|
|
||||||
Bypass: []string{"10.10.10.10"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
const browserUserAgent = `
|
const browserUserAgent = `
|
||||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||||
|
|
||||||
@@ -379,19 +351,37 @@ func TestProxyController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
store := memory.New()
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
|
||||||
aclsService := service.NewAccessControlsService(log, nil, acls)
|
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||||
|
|
||||||
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
|
||||||
|
Log: log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{
|
||||||
|
Log: log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{
|
||||||
|
Log: log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{
|
||||||
|
Log: log,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{
|
||||||
|
Log: log,
|
||||||
|
Config: cfg,
|
||||||
|
})
|
||||||
|
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
|
||||||
|
Log: log,
|
||||||
|
})
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
@@ -406,13 +396,9 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(log, runtime, group, aclsService, authService)
|
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
app.GetDB().Close()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
|||||||
if controller.config.Resources.Path == "" {
|
if controller.config.Resources.Path == "" {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Resources not found",
|
"message": "Resource not found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ import (
|
|||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -73,12 +73,7 @@ func TestUserController(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
store := memory.New()
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
@@ -254,7 +249,7 @@ func TestUserController(t *testing.T) {
|
|||||||
totpCtx,
|
totpCtx,
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||||
UUID: "test-totp-login-uuid",
|
UUID: "test-totp-login-uuid",
|
||||||
Username: "test",
|
Username: "test",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
@@ -378,7 +373,7 @@ func TestUserController(t *testing.T) {
|
|||||||
totpAttrCtx,
|
totpAttrCtx,
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
_, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||||
UUID: "test-totp-login-attributes-uuid",
|
UUID: "test-totp-login-attributes-uuid",
|
||||||
Username: "test",
|
Username: "test",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
@@ -420,7 +415,7 @@ func TestUserController(t *testing.T) {
|
|||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
@@ -446,8 +441,4 @@ func TestUserController(t *testing.T) {
|
|||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
app.GetDB().Close()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -92,14 +91,9 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
store := memory.New()
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -114,8 +108,4 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
app.GetDB().Close()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
@@ -31,7 +31,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||||
}
|
}
|
||||||
|
|
||||||
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
_, err := queries.CreateSession(context.Background(), params)
|
_, err := queries.CreateSession(context.Background(), params)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -39,7 +39,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
type runArgs struct {
|
type runArgs struct {
|
||||||
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||||
queries *repository.Queries
|
queries repository.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -252,15 +252,10 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
store := memory.New()
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
|
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
|
||||||
|
|
||||||
@@ -286,11 +281,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
return captured, recorder
|
return captured, recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
test.run(t, runArgs{do: do, queries: queries})
|
test.run(t, runArgs{do: do, queries: store})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
app.GetDB().Close()
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package model
|
|||||||
func NewDefaultConfiguration() *Config {
|
func NewDefaultConfiguration() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Database: DatabaseConfig{
|
Database: DatabaseConfig{
|
||||||
|
Driver: "sqlite",
|
||||||
Path: "./tinyauth.db",
|
Path: "./tinyauth.db",
|
||||||
},
|
},
|
||||||
Analytics: AnalyticsConfig{
|
Analytics: AnalyticsConfig{
|
||||||
@@ -24,6 +25,9 @@ func NewDefaultConfiguration() *Config {
|
|||||||
SessionMaxLifetime: 0, // disabled
|
SessionMaxLifetime: 0, // disabled
|
||||||
LoginTimeout: 300, // 5 minutes
|
LoginTimeout: 300, // 5 minutes
|
||||||
LoginMaxRetries: 3,
|
LoginMaxRetries: 3,
|
||||||
|
ACLs: ACLsConfig{
|
||||||
|
Policy: "allow",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
UI: UIConfig{
|
UI: UIConfig{
|
||||||
Title: "Tinyauth",
|
Title: "Tinyauth",
|
||||||
@@ -78,12 +82,13 @@ type Config struct {
|
|||||||
UI UIConfig `description:"UI customization." yaml:"ui"`
|
UI UIConfig `description:"UI customization." yaml:"ui"`
|
||||||
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"`
|
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"`
|
||||||
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
||||||
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
|
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment." yaml:"labelProvider"`
|
||||||
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DatabaseConfig struct {
|
type DatabaseConfig struct {
|
||||||
Path string `description:"The path to the database, including file name." yaml:"path"`
|
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
|
||||||
|
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnalyticsConfig struct {
|
type AnalyticsConfig struct {
|
||||||
@@ -114,6 +119,7 @@ type AuthConfig struct {
|
|||||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||||
|
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserAttributes struct {
|
type UserAttributes struct {
|
||||||
@@ -223,6 +229,10 @@ type OIDCClientConfig struct {
|
|||||||
Name string `description:"Client name in UI." yaml:"name"`
|
Name string `description:"Client name in UI." yaml:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ACLsConfig struct {
|
||||||
|
Policy string `description:"ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow." yaml:"policy"`
|
||||||
|
}
|
||||||
|
|
||||||
// ACLs
|
// ACLs
|
||||||
|
|
||||||
type Apps struct {
|
type Apps struct {
|
||||||
|
|||||||
@@ -0,0 +1,472 @@
|
|||||||
|
package memory_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ctx = context.Background()
|
||||||
|
|
||||||
|
func TestMemoryStore(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
description string
|
||||||
|
run func(t *testing.T, s repository.Store)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Create and get session",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{
|
||||||
|
UUID: "uuid-1",
|
||||||
|
Username: "alice",
|
||||||
|
Expiry: 9999,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "uuid-1", sess.UUID)
|
||||||
|
assert.Equal(t, "alice", sess.Username)
|
||||||
|
|
||||||
|
got, err := s.GetSession(ctx, "uuid-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, sess, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get session not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetSession(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update session",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||||
|
UUID: "uuid-1",
|
||||||
|
Username: "bob",
|
||||||
|
Email: "bob@example.com",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "bob", updated.Username)
|
||||||
|
assert.Equal(t, "bob@example.com", updated.Email)
|
||||||
|
|
||||||
|
got, err := s.GetSession(ctx, "uuid-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, updated, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update session not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete session",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteSession(ctx, "uuid-1"))
|
||||||
|
|
||||||
|
_, err = s.GetSession(ctx, "uuid-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete expired sessions",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteExpiredSessions(ctx, 50))
|
||||||
|
|
||||||
|
_, err = s.GetSession(ctx, "expired")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
_, err = s.GetSession(ctx, "valid")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create and get OIDC code",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
CodeHash: "hash-1",
|
||||||
|
Scope: "openid",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", code.Sub)
|
||||||
|
|
||||||
|
// destructive read removes the record
|
||||||
|
got, err := s.GetOidcCode(ctx, "hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, code, got)
|
||||||
|
|
||||||
|
_, err = s.GetOidcCode(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCode(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
|
||||||
|
// destructive — gone after read
|
||||||
|
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCodeBySub(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code unsafe",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
|
||||||
|
// non-destructive — still present
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code unsafe not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub unsafe",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hash-1", got.CodeHash)
|
||||||
|
|
||||||
|
// non-destructive — still present
|
||||||
|
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC code by sub unsafe not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC code unique sub constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC code",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC code by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete expired OIDC codes",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, deleted, 1)
|
||||||
|
assert.Equal(t, "hash-1", deleted[0].CodeHash)
|
||||||
|
|
||||||
|
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create and get OIDC token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-hash-1",
|
||||||
|
CodeHash: "code-hash-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", tok.Sub)
|
||||||
|
|
||||||
|
got, err := s.GetOidcToken(ctx, "at-hash-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tok, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcToken(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create OIDC token unique sub constraint",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
|
||||||
|
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token by refresh token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token by refresh token not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "at-1", got.AccessTokenHash)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC token by sub not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcTokenBySub(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC token by refresh token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
RefreshTokenHash: "rt-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
RefreshTokenHash_2: "rt-1",
|
||||||
|
AccessTokenHash: "at-2",
|
||||||
|
RefreshTokenHash: "rt-2",
|
||||||
|
TokenExpiresAt: 200,
|
||||||
|
RefreshTokenExpiresAt: 400,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "at-2", updated.AccessTokenHash)
|
||||||
|
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
|
||||||
|
|
||||||
|
// old key gone, new key present
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
|
||||||
|
got, err := s.GetOidcToken(ctx, "at-2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", got.Sub)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Update OIDC token by refresh token not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||||
|
RefreshTokenHash_2: "missing",
|
||||||
|
})
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC token",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC token by sub",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC token by code hash",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
AccessTokenHash: "at-1",
|
||||||
|
CodeHash: "code-1",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete expired OIDC tokens",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
// both expiries past
|
||||||
|
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-1", AccessTokenHash: "at-1",
|
||||||
|
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
// valid
|
||||||
|
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
|
||||||
|
Sub: "sub-3", AccessTokenHash: "at-3",
|
||||||
|
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
|
TokenExpiresAt: 50,
|
||||||
|
RefreshTokenExpiresAt: 50,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, deleted, 1)
|
||||||
|
|
||||||
|
_, err = s.GetOidcToken(ctx, "at-3")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Create and get OIDC user info",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
|
||||||
|
Sub: "sub-1",
|
||||||
|
Name: "Alice",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "sub-1", u.Sub)
|
||||||
|
|
||||||
|
got, err := s.GetOidcUserInfo(ctx, "sub-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, u, got)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Get OIDC user info not found",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.GetOidcUserInfo(ctx, "missing")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Delete OIDC user info",
|
||||||
|
run: func(t *testing.T, s repository.Store) {
|
||||||
|
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
|
||||||
|
|
||||||
|
_, err = s.GetOidcUserInfo(ctx, "sub-1")
|
||||||
|
assert.ErrorIs(t, err, repository.ErrNotFound)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
s := memory.New()
|
||||||
|
test.run(t, s)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
// Enforce sub UNIQUE constraint
|
||||||
|
for _, c := range s.oidcCodes {
|
||||||
|
if c.Sub == arg.Sub {
|
||||||
|
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
code := repository.OidcCode(arg)
|
||||||
|
s.oidcCodes[arg.CodeHash] = code
|
||||||
|
return code, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||||
|
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
c, ok := s.oidcCodes[codeHash]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
delete(s.oidcCodes, codeHash)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||||
|
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, c := range s.oidcCodes {
|
||||||
|
if c.Sub == sub {
|
||||||
|
delete(s.oidcCodes, k)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||||
|
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
c, ok := s.oidcCodes[codeHash]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||||
|
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, c := range s.oidcCodes {
|
||||||
|
if c.Sub == sub {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.oidcCodes, codeHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, c := range s.oidcCodes {
|
||||||
|
if c.Sub == sub {
|
||||||
|
delete(s.oidcCodes, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var deleted []repository.OidcCode
|
||||||
|
for k, c := range s.oidcCodes {
|
||||||
|
if c.ExpiresAt < expiresAt {
|
||||||
|
deleted = append(deleted, c)
|
||||||
|
delete(s.oidcCodes, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
// Enforce sub UNIQUE constraint
|
||||||
|
for _, t := range s.oidcTokens {
|
||||||
|
if t.Sub == arg.Sub {
|
||||||
|
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tok := repository.OidcToken{
|
||||||
|
Sub: arg.Sub,
|
||||||
|
AccessTokenHash: arg.AccessTokenHash,
|
||||||
|
RefreshTokenHash: arg.RefreshTokenHash,
|
||||||
|
CodeHash: arg.CodeHash,
|
||||||
|
Scope: arg.Scope,
|
||||||
|
ClientID: arg.ClientID,
|
||||||
|
TokenExpiresAt: arg.TokenExpiresAt,
|
||||||
|
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
||||||
|
Nonce: arg.Nonce,
|
||||||
|
}
|
||||||
|
s.oidcTokens[arg.AccessTokenHash] = tok
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
t, ok := s.oidcTokens[accessTokenHash]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, t := range s.oidcTokens {
|
||||||
|
if t.RefreshTokenHash == refreshTokenHash {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
for _, t := range s.oidcTokens {
|
||||||
|
if t.Sub == sub {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
t.AccessTokenHash = arg.AccessTokenHash
|
||||||
|
t.RefreshTokenHash = arg.RefreshTokenHash
|
||||||
|
t.TokenExpiresAt = arg.TokenExpiresAt
|
||||||
|
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||||
|
s.oidcTokens[arg.AccessTokenHash] = t
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.oidcTokens, accessTokenHash)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.Sub == sub {
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.CodeHash == codeHash {
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
var deleted []repository.OidcToken
|
||||||
|
for k, t := range s.oidcTokens {
|
||||||
|
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||||
|
deleted = append(deleted, t)
|
||||||
|
delete(s.oidcTokens, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deleted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
u := repository.OidcUserinfo(arg)
|
||||||
|
s.oidcUsers[arg.Sub] = u
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
u, ok := s.oidcUsers[sub]
|
||||||
|
if !ok {
|
||||||
|
return repository.OidcUserinfo{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.oidcUsers, sub)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
sess := repository.Session(arg)
|
||||||
|
s.sessions[arg.UUID] = sess
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
sess, ok := s.sessions[uuid]
|
||||||
|
if !ok {
|
||||||
|
return repository.Session{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
sess, ok := s.sessions[arg.UUID]
|
||||||
|
if !ok {
|
||||||
|
return repository.Session{}, repository.ErrNotFound
|
||||||
|
}
|
||||||
|
sess.Username = arg.Username
|
||||||
|
sess.Email = arg.Email
|
||||||
|
sess.Name = arg.Name
|
||||||
|
sess.Provider = arg.Provider
|
||||||
|
sess.TotpPending = arg.TotpPending
|
||||||
|
sess.OAuthGroups = arg.OAuthGroups
|
||||||
|
sess.Expiry = arg.Expiry
|
||||||
|
sess.OAuthName = arg.OAuthName
|
||||||
|
sess.OAuthSub = arg.OAuthSub
|
||||||
|
s.sessions[arg.UUID] = sess
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.sessions, uuid)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
for k, v := range s.sessions {
|
||||||
|
if v.Expiry < expiry {
|
||||||
|
delete(s.sessions, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is a thread-safe in-memory implementation of repository.Store.
|
||||||
|
type Store struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]repository.Session
|
||||||
|
oidcCodes map[string]repository.OidcCode
|
||||||
|
oidcTokens map[string]repository.OidcToken
|
||||||
|
oidcUsers map[string]repository.OidcUserinfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new empty in-memory Store.
|
||||||
|
func New() repository.Store {
|
||||||
|
return &Store{
|
||||||
|
sessions: make(map[string]repository.Session),
|
||||||
|
oidcCodes: make(map[string]repository.OidcCode),
|
||||||
|
oidcTokens: make(map[string]repository.OidcToken),
|
||||||
|
oidcUsers: make(map[string]repository.OidcUserinfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,9 +1,22 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
|
||||||
// versions:
|
|
||||||
// sqlc v1.30.0
|
|
||||||
|
|
||||||
package repository
|
package repository
|
||||||
|
|
||||||
|
// Shared model and parameter types for all storage drivers.
|
||||||
|
// sqlc-generated driver packages use these via the conversion layer in their store.go.
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
UUID string
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
Expiry int64
|
||||||
|
CreatedAt int64
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
}
|
||||||
|
|
||||||
type OidcCode struct {
|
type OidcCode struct {
|
||||||
Sub string
|
Sub string
|
||||||
CodeHash string
|
CodeHash string
|
||||||
@@ -49,7 +62,7 @@ type OidcUserinfo struct {
|
|||||||
Address string
|
Address string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type CreateSessionParams struct {
|
||||||
UUID string
|
UUID string
|
||||||
Username string
|
Username string
|
||||||
Email string
|
Email string
|
||||||
@@ -62,3 +75,74 @@ type Session struct {
|
|||||||
OAuthName string
|
OAuthName string
|
||||||
OAuthSub string
|
OAuthSub string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UpdateSessionParams struct {
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
Expiry int64
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
UUID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateOidcCodeParams struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateOidcTokenParams struct {
|
||||||
|
Sub string
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
Scope string
|
||||||
|
ClientID string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
CodeHash string
|
||||||
|
Nonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateOidcTokenByRefreshTokenParams struct {
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
RefreshTokenHash_2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeleteExpiredOidcTokensParams struct {
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateOidcUserInfoParams struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
|
|
||||||
package repository
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
package sqlite
|
||||||
|
|
||||||
|
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.31.1
|
||||||
|
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
type OidcCode struct {
|
||||||
|
Sub string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
ExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcToken struct {
|
||||||
|
Sub string
|
||||||
|
AccessTokenHash string
|
||||||
|
RefreshTokenHash string
|
||||||
|
CodeHash string
|
||||||
|
Scope string
|
||||||
|
ClientID string
|
||||||
|
TokenExpiresAt int64
|
||||||
|
RefreshTokenExpiresAt int64
|
||||||
|
Nonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcUserinfo struct {
|
||||||
|
Sub string
|
||||||
|
Name string
|
||||||
|
PreferredUsername string
|
||||||
|
Email string
|
||||||
|
Groups string
|
||||||
|
UpdatedAt int64
|
||||||
|
GivenName string
|
||||||
|
FamilyName string
|
||||||
|
MiddleName string
|
||||||
|
Nickname string
|
||||||
|
Profile string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Gender string
|
||||||
|
Birthdate string
|
||||||
|
Zoneinfo string
|
||||||
|
Locale string
|
||||||
|
PhoneNumber string
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
UUID string
|
||||||
|
Username string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
Expiry int64
|
||||||
|
CreatedAt int64
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
}
|
||||||
+2
-2
@@ -1,9 +1,9 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: oidc_queries.sql
|
// source: oidc_queries.sql
|
||||||
|
|
||||||
package repository
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
+2
-2
@@ -1,9 +1,9 @@
|
|||||||
// Code generated by sqlc. DO NOT EDIT.
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// sqlc v1.30.0
|
// sqlc v1.31.1
|
||||||
// source: session_queries.sql
|
// source: session_queries.sql
|
||||||
|
|
||||||
package repository
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store wraps *Queries and implements repository.Store.
|
||||||
|
type Store struct {
|
||||||
|
q *Queries
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore wraps a *Queries to satisfy repository.Store.
|
||||||
|
func NewStore(q *Queries) repository.Store {
|
||||||
|
return &Store{q: q}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errorMap = map[error]error{
|
||||||
|
sql.ErrNoRows: repository.ErrNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapErr(err error) error {
|
||||||
|
for from, to := range errorMap {
|
||||||
|
if errors.Is(err, from) {
|
||||||
|
return to
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||||
|
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.Session{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.Session(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||||
|
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcCode, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcCode(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||||
|
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, mapErr(err)
|
||||||
|
}
|
||||||
|
out := make([]repository.OidcToken, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
out[i] = repository.OidcToken(row)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||||
|
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||||
|
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||||
|
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||||
|
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcCode(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcUserinfo{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcUserinfo(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||||
|
r, err := s.q.GetSession(ctx, uuid)
|
||||||
|
if err != nil {
|
||||||
|
return repository.Session{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.Session(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||||
|
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.OidcToken(r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||||
|
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
|
||||||
|
if err != nil {
|
||||||
|
return repository.Session{}, mapErr(err)
|
||||||
|
}
|
||||||
|
return repository.Session(r), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrNotFound is returned by Store methods when the requested record does not exist.
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
// Store is the interface that all storage drivers must implement.
|
||||||
|
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
|
||||||
|
// Future drivers (postgres, etc.) must return the shared types defined in this package.
|
||||||
|
type Store interface {
|
||||||
|
// Sessions
|
||||||
|
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
|
||||||
|
GetSession(ctx context.Context, uuid string) (Session, error)
|
||||||
|
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
|
||||||
|
DeleteSession(ctx context.Context, uuid string) error
|
||||||
|
DeleteExpiredSessions(ctx context.Context, expiry int64) error
|
||||||
|
|
||||||
|
// OIDC codes
|
||||||
|
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
|
||||||
|
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
|
||||||
|
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
|
||||||
|
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
|
||||||
|
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
|
||||||
|
DeleteOidcCode(ctx context.Context, codeHash string) error
|
||||||
|
DeleteOidcCodeBySub(ctx context.Context, sub string) error
|
||||||
|
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
|
||||||
|
|
||||||
|
// OIDC tokens
|
||||||
|
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
|
||||||
|
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
|
||||||
|
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
|
||||||
|
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
|
||||||
|
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
|
||||||
|
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
|
||||||
|
DeleteOidcTokenBySub(ctx context.Context, sub string) error
|
||||||
|
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
|
||||||
|
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
|
||||||
|
|
||||||
|
// OIDC userinfo
|
||||||
|
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
|
||||||
|
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
|
||||||
|
DeleteOidcUserInfo(ctx context.Context, sub string) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,249 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RuleName string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RuleUserAllowed RuleName = "rule-user-allowed"
|
||||||
|
RuleOAuthGroup RuleName = "rule-oauth-group"
|
||||||
|
RuleLDAPGroup RuleName = "rule-ldap-group"
|
||||||
|
RuleAuthEnabled RuleName = "rule-auth-enabled"
|
||||||
|
RuleIPAllowed RuleName = "rule-ip-allowed"
|
||||||
|
RuleIPBypassed RuleName = "rule-ip-bypassed"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserAllowedRule struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
|
if ctx.ACLs == nil || ctx.UserContext == nil {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.UserContext.Provider == model.ProviderOAuth {
|
||||||
|
rule.Log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||||
|
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Whitelist, ctx.UserContext.OAuth.Email)
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.OAuth.Email).Msg("Invalid entry in OAuth whitelist")
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Debug().Str("email", ctx.UserContext.OAuth.Email).Msg("User is in OAuth whitelist, allowing access")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.ACLs.Users.Block != "" {
|
||||||
|
rule.Log.App.Debug().Msg("Checking users block list")
|
||||||
|
match, err := utils.CheckFilter(ctx.ACLs.Users.Block, ctx.UserContext.GetUsername())
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users block list")
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users block list, denying access")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Log.App.Debug().Msg("Checking users allow list")
|
||||||
|
|
||||||
|
match, err := utils.CheckFilter(ctx.ACLs.Users.Allow, ctx.UserContext.GetUsername())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users allow list")
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users allow list, allowing access")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is not in users allow list, denying access")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthGroupRule struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *OAuthGroupRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
|
if ctx.ACLs == nil || ctx.UserContext == nil {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.UserContext.IsOAuth() {
|
||||||
|
rule.Log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := model.OverrideProviders[ctx.UserContext.OAuth.ID]; ok {
|
||||||
|
rule.Log.App.Debug().Str("provider", ctx.UserContext.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range ctx.UserContext.OAuth.Groups {
|
||||||
|
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Groups, strings.TrimSpace(group))
|
||||||
|
if err != nil {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.OAuth.Groups).Msg("User group matched, allowing access")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Log.App.Debug().Msg("No groups matched")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
type LDAPGroupRule struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *LDAPGroupRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
|
if ctx == nil || ctx.UserContext == nil {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ctx.UserContext.IsLDAP() {
|
||||||
|
rule.Log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range ctx.UserContext.LDAP.Groups {
|
||||||
|
match, err := utils.CheckFilter(ctx.ACLs.LDAP.Groups, strings.TrimSpace(group))
|
||||||
|
if err != nil {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.LDAP.Groups).Msg("User group matched, allowing access")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Log.App.Debug().Msg("No groups matched")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthEnabledRule struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *AuthEnabledRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
|
if ctx.ACLs == nil {
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.ACLs.Path.Block != "" {
|
||||||
|
regex, err := regexp.Compile(ctx.ACLs.Path.Block)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Error().Err(err).Msg("Failed to compile block regex")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
if !regex.MatchString(ctx.Path) {
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.ACLs.Path.Allow != "" {
|
||||||
|
regex, err := regexp.Compile(ctx.ACLs.Path.Allow)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Error().Err(err).Msg("Failed to compile allow regex")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
if regex.MatchString(ctx.Path) {
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
type IPAllowedRule struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
Config model.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *IPAllowedRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
|
if ctx.ACLs == nil {
|
||||||
|
return EffectAbstain
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the global and app IP filter
|
||||||
|
blockedIps := append(ctx.ACLs.IP.Block, rule.Config.Auth.IP.Block...)
|
||||||
|
allowedIPs := append(ctx.ACLs.IP.Allow, rule.Config.Auth.IP.Allow...)
|
||||||
|
|
||||||
|
for _, blocked := range blockedIps {
|
||||||
|
match, err := utils.CheckIPFilter(blocked, ctx.IP.String())
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowed := range allowedIPs {
|
||||||
|
match, err := utils.CheckIPFilter(allowed, ctx.IP.String())
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allowedIPs) > 0 {
|
||||||
|
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in allow list, denying access")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in block or allow list, allowing access")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
|
||||||
|
type IPBypassedRule struct {
|
||||||
|
Log *logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *IPBypassedRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
|
if ctx.ACLs == nil {
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bypassed := range ctx.ACLs.IP.Bypass {
|
||||||
|
match, err := utils.CheckIPFilter(bypassed, ctx.IP.String())
|
||||||
|
if err != nil {
|
||||||
|
rule.Log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in bypass list, proceeding with authentication")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
@@ -0,0 +1,732 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserAllowedRule(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
rule := &UserAllowedRule{Log: log}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx *ACLContext
|
||||||
|
expected Effect
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "abstains when ACLs are nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: nil,
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when user context is nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Whitelist: "alice"},
|
||||||
|
},
|
||||||
|
UserContext: nil,
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows OAuth user when email matches whitelist",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Whitelist: "allowed@example.com"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "different-username",
|
||||||
|
Email: "allowed@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies OAuth user when email does not match whitelist",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Whitelist: "allowed@example.com"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{Email: "denied@example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains for OAuth user when whitelist filter is invalid",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Whitelist: "/[/"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{Email: "allowed@example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies local user when username matches block list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Users: model.AppUsers{Block: "alice,bob"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows local user when username does not match block list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Users: model.AppUsers{Block: "alice,bob"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "charlie"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when block list filter is invalid",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Users: model.AppUsers{Block: "/[/"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows local user when username matches allow list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Users: model.AppUsers{Allow: "alice,bob"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies local user when username does not match allow list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Users: model.AppUsers{Allow: "alice,bob"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "charlie"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when allow list filter is invalid",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Users: model.AppUsers{Allow: "/[/"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthGroupRule(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
rule := &OAuthGroupRule{Log: log}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx *ACLContext
|
||||||
|
expected Effect
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "abstains when ACLs are nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: nil,
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
Groups: []string{"admins"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when user context is nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Whitelist: "alice"},
|
||||||
|
},
|
||||||
|
UserContext: nil,
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when user is not OAuth",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when provider is an override provider regardless of groups",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
ID: "google",
|
||||||
|
Groups: []string{"unrelated"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows OAuth user when a group matches",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Groups: "admins,users"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
ID: "custom",
|
||||||
|
Groups: []string{"users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies OAuth user when no group matches",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
ID: "custom",
|
||||||
|
Groups: []string{"users", "guests"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies OAuth user when user has no groups",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
ID: "custom",
|
||||||
|
Groups: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when groups filter is invalid",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Groups: "/[/"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
ID: "custom",
|
||||||
|
Groups: []string{"admins"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLDAPGroupRule(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
rule := &LDAPGroupRule{Log: log}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx *ACLContext
|
||||||
|
expected Effect
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "abstains when context is nil",
|
||||||
|
ctx: nil,
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when user context is nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
OAuth: model.AppOAuth{Whitelist: "alice"},
|
||||||
|
},
|
||||||
|
UserContext: nil,
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when user is not LDAP",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
LDAP: model.AppLDAP{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{Username: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows LDAP user when a group matches",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
LDAP: model.AppLDAP{Groups: "admins,users"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
Groups: []string{"users"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies LDAP user when no group matches",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
LDAP: model.AppLDAP{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
Groups: []string{"users", "guests"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies LDAP user when user has no groups",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
LDAP: model.AppLDAP{Groups: "admins"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
Groups: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "abstains when groups filter is invalid",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
LDAP: model.AppLDAP{Groups: "/[/"},
|
||||||
|
},
|
||||||
|
UserContext: &model.UserContext{
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
Groups: []string{"admins"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthEnabledRule(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
rule := &AuthEnabledRule{Log: log}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx *ACLContext
|
||||||
|
expected Effect
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "deny when ACLs are nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: nil,
|
||||||
|
Path: "/anything",
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when path does not match block regex",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{Block: "^/admin"},
|
||||||
|
},
|
||||||
|
Path: "/public",
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when path matches block regex and no allow regex",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{Block: "^/admin"},
|
||||||
|
},
|
||||||
|
Path: "/admin/users",
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when path matches allow regex",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{Allow: "^/public"},
|
||||||
|
},
|
||||||
|
Path: "/public/index",
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when path does not match allow regex",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{Allow: "^/public"},
|
||||||
|
},
|
||||||
|
Path: "/private",
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when blocked path is also explicitly allowed",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{
|
||||||
|
Block: "^/admin",
|
||||||
|
Allow: "^/admin/public",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Path: "/admin/public/page",
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when block regex fails to compile",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{Block: "[invalid"},
|
||||||
|
},
|
||||||
|
Path: "/anything",
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when allow regex fails to compile",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
Path: model.AppPath{Allow: "[invalid"},
|
||||||
|
},
|
||||||
|
Path: "/anything",
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when no path rules are configured",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{},
|
||||||
|
Path: "/anything",
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPAllowedRule(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config model.Config
|
||||||
|
ctx *ACLContext
|
||||||
|
expected Effect
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "abstains when ACLs are nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: nil,
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectAbstain,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when IP matches app block list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{Block: []string{"10.0.0.1"}},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when IP matches global block list",
|
||||||
|
config: model.Config{
|
||||||
|
Auth: model.AuthConfig{
|
||||||
|
IP: model.IPConfig{Block: []string{"10.0.0.0/24"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{},
|
||||||
|
IP: net.ParseIP("10.0.0.5"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when IP matches app allow list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{Allow: []string{"192.168.1.0/24"}},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("192.168.1.10"),
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when IP matches global allow list",
|
||||||
|
config: model.Config{
|
||||||
|
Auth: model.AuthConfig{
|
||||||
|
IP: model.IPConfig{Allow: []string{"192.168.1.10"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{},
|
||||||
|
IP: net.ParseIP("192.168.1.10"),
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when allow list is set and IP does not match",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{Allow: []string{"192.168.1.0/24"}},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when no block or allow lists are configured",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "block list takes precedence over allow list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{
|
||||||
|
Block: []string{"10.0.0.1"},
|
||||||
|
Allow: []string{"10.0.0.1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skips invalid block entries and continues evaluation",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{
|
||||||
|
Block: []string{"not-an-ip"},
|
||||||
|
Allow: []string{"10.0.0.1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rule := &IPAllowedRule{Log: log, Config: tt.config}
|
||||||
|
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPBypassedRule(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
rule := &IPBypassedRule{Log: log}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx *ACLContext
|
||||||
|
expected Effect
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "deny when ACLs are nil",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: nil,
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows when IP matches bypass list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("10.0.0.5"),
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when IP does not match bypass list",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denies when bypass list is empty",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectDeny,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "skips invalid bypass entries and allows on later match",
|
||||||
|
ctx: &ACLContext{
|
||||||
|
ACLs: &model.App{
|
||||||
|
IP: model.AppIP{Bypass: []string{"not-an-ip", "10.0.0.1"}},
|
||||||
|
},
|
||||||
|
IP: net.ParseIP("10.0.0.1"),
|
||||||
|
},
|
||||||
|
expected: EffectAllow,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,51 +13,52 @@ type LabelProvider interface {
|
|||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
|
config model.Config
|
||||||
labelProvider *LabelProvider
|
labelProvider *LabelProvider
|
||||||
static map[string]model.App
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(
|
func NewAccessControlsService(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
labelProvider *LabelProvider,
|
config model.Config,
|
||||||
static map[string]model.App) *AccessControlsService {
|
labelProvider *LabelProvider) *AccessControlsService {
|
||||||
|
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
log: log,
|
log: log,
|
||||||
|
config: config,
|
||||||
labelProvider: labelProvider,
|
labelProvider: labelProvider,
|
||||||
static: static,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
func (service *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||||
var appAcls *model.App
|
var nameMatch *model.App
|
||||||
for app, config := range acls.static {
|
|
||||||
|
// First try to find a matching app by domain, then fallback to matching by app name (subdomain)
|
||||||
|
for app, config := range service.config.Apps {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
service.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
appAcls = &config
|
return &config
|
||||||
break // If we find a match by domain, we can stop searching
|
}
|
||||||
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
|
service.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
|
nameMatch = &config
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
return nameMatch
|
||||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
|
||||||
appAcls = &config
|
|
||||||
break // If we find a match by app name, we can stop searching
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return appAcls
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
func (service *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||||
// First check in the static config
|
// First check in the static config
|
||||||
app := acls.lookupStaticACLs(domain)
|
app := service.lookupStaticACLs(domain)
|
||||||
|
|
||||||
if app != nil {
|
if app != nil {
|
||||||
acls.log.App.Debug().Msg("Using static ACLs for app")
|
service.log.App.Debug().Msg("Using static ACLs for app")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have a label provider configured, try to get ACLs from it
|
// If we have a label provider configured, try to get ACLs from it
|
||||||
if acls.labelProvider != nil {
|
if service.labelProvider != nil && *service.labelProvider != nil {
|
||||||
return (*acls.labelProvider).GetLabels(domain)
|
return (*service.labelProvider).GetLabels(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// no labels
|
// no labels
|
||||||
|
|||||||
@@ -0,0 +1,199 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockLabelProvider struct {
|
||||||
|
getLabelsFn func(appDomain string) (*model.App, error)
|
||||||
|
calledWith string
|
||||||
|
callCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockLabelProvider) GetLabels(appDomain string) (*model.App, error) {
|
||||||
|
m.calledWith = appDomain
|
||||||
|
m.callCount++
|
||||||
|
if m.getLabelsFn != nil {
|
||||||
|
return m.getLabelsFn(appDomain)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupStaticACLs(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
apps map[string]model.App
|
||||||
|
domain string
|
||||||
|
expectNil bool
|
||||||
|
expectedDomain string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns nil when no apps are configured",
|
||||||
|
apps: nil,
|
||||||
|
domain: "foo.example.com",
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns nil when no app matches",
|
||||||
|
apps: map[string]model.App{
|
||||||
|
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||||
|
},
|
||||||
|
domain: "bar.example.com",
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "matches by exact domain",
|
||||||
|
apps: map[string]model.App{
|
||||||
|
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||||
|
},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
expectedDomain: "foo.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "matches by app name when domain does not match any app",
|
||||||
|
apps: map[string]model.App{
|
||||||
|
"foo": {Config: model.AppConfig{Domain: "configured.example.com"}},
|
||||||
|
},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
expectedDomain: "configured.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "matches by app name for nested subdomains",
|
||||||
|
apps: map[string]model.App{
|
||||||
|
"foo": {Config: model.AppConfig{Domain: "configured.example.com"}},
|
||||||
|
},
|
||||||
|
domain: "foo.sub.example.com",
|
||||||
|
expectedDomain: "configured.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "selects the app matching by domain among multiple apps",
|
||||||
|
apps: map[string]model.App{
|
||||||
|
"unrelated": {Config: model.AppConfig{Domain: "other.example.com"}},
|
||||||
|
"target": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||||
|
},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
expectedDomain: "foo.example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
||||||
|
got := svc.lookupStaticACLs(tt.domain)
|
||||||
|
if tt.expectNil {
|
||||||
|
assert.Nil(t, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, tt.expectedDomain, got.Config.Domain)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccessControls(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
t.Run("returns static ACLs when domain matches", func(t *testing.T) {
|
||||||
|
config := model.Config{
|
||||||
|
Apps: map[string]model.App{
|
||||||
|
"foo": {
|
||||||
|
Config: model.AppConfig{Domain: "foo.example.com"},
|
||||||
|
Users: model.AppUsers{Allow: "alice"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAccessControlsService(log, config, nil)
|
||||||
|
|
||||||
|
got, err := svc.GetAccessControls("foo.example.com")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
||||||
|
assert.Equal(t, "alice", got.Users.Allow)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||||
|
svc := NewAccessControlsService(log, model.Config{}, nil)
|
||||||
|
|
||||||
|
got, err := svc.GetAccessControls("unknown.example.com")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||||
|
var provider LabelProvider
|
||||||
|
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||||
|
|
||||||
|
got, err := svc.GetAccessControls("unknown.example.com")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, got)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to label provider when no static match", func(t *testing.T) {
|
||||||
|
expected := &model.App{
|
||||||
|
Config: model.AppConfig{Domain: "dynamic.example.com"},
|
||||||
|
Users: model.AppUsers{Allow: "bob"},
|
||||||
|
}
|
||||||
|
mock := &mockLabelProvider{
|
||||||
|
getLabelsFn: func(appDomain string) (*model.App, error) {
|
||||||
|
return expected, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var provider LabelProvider = mock
|
||||||
|
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||||
|
|
||||||
|
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Same(t, expected, got)
|
||||||
|
assert.Equal(t, "dynamic.example.com", mock.calledWith)
|
||||||
|
assert.Equal(t, 1, mock.callCount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not call label provider when static match found", func(t *testing.T) {
|
||||||
|
mock := &mockLabelProvider{}
|
||||||
|
var provider LabelProvider = mock
|
||||||
|
config := model.Config{
|
||||||
|
Apps: map[string]model.App{
|
||||||
|
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := NewAccessControlsService(log, config, &provider)
|
||||||
|
|
||||||
|
got, err := svc.GetAccessControls("foo.example.com")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, got)
|
||||||
|
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
||||||
|
assert.Equal(t, 0, mock.callCount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates label provider errors", func(t *testing.T) {
|
||||||
|
providerErr := errors.New("provider boom")
|
||||||
|
mock := &mockLabelProvider{
|
||||||
|
getLabelsFn: func(appDomain string) (*model.App, error) {
|
||||||
|
return nil, providerErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var provider LabelProvider = mock
|
||||||
|
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||||
|
|
||||||
|
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||||
|
|
||||||
|
assert.Nil(t, got)
|
||||||
|
assert.ErrorIs(t, err, providerErr)
|
||||||
|
assert.Equal(t, 1, mock.callCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -2,11 +2,9 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,7 +16,6 @@ import (
|
|||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -79,7 +76,7 @@ type AuthService struct {
|
|||||||
context context.Context
|
context context.Context
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries *repository.Queries
|
queries repository.Store
|
||||||
oauthBroker *OAuthBrokerService
|
oauthBroker *OAuthBrokerService
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
loginAttempts map[string]*LoginAttempt
|
||||||
@@ -100,7 +97,7 @@ func NewAuthService(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
wg *sync.WaitGroup,
|
wg *sync.WaitGroup,
|
||||||
ldap *LdapService,
|
ldap *LdapService,
|
||||||
queries *repository.Queries,
|
queries repository.Store,
|
||||||
oauthBroker *OAuthBrokerService,
|
oauthBroker *OAuthBrokerService,
|
||||||
) *AuthService {
|
) *AuthService {
|
||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
@@ -286,7 +283,12 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return match
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
@@ -417,7 +419,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
session, err := auth.queries.GetSession(ctx, uuid)
|
session, err := auth.queries.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
return nil, errors.New("session not found")
|
return nil, errors.New("session not found")
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -454,171 +456,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
|||||||
return auth.ldap != nil
|
return auth.ldap != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
|
||||||
if acls == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if context.Provider == model.ProviderOAuth {
|
|
||||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
|
||||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
if acls.Users.Block != "" {
|
|
||||||
auth.log.App.Debug().Msg("Checking users block list")
|
|
||||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("Checking users allow list")
|
|
||||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
|
||||||
if acls == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !context.IsOAuth() {
|
|
||||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
|
||||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, userGroup := range context.OAuth.Groups {
|
|
||||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
|
||||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("No groups matched")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
|
||||||
if acls == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !context.IsLDAP() {
|
|
||||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, userGroup := range context.LDAP.Groups {
|
|
||||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
|
||||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("No groups matched")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
|
||||||
if acls == nil {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for block list
|
|
||||||
if acls.Path.Block != "" {
|
|
||||||
regex, err := regexp.Compile(acls.Path.Block)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !regex.MatchString(uri) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for allow list
|
|
||||||
if acls.Path.Allow != "" {
|
|
||||||
regex, err := regexp.Compile(acls.Path.Allow)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if regex.MatchString(uri) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|
||||||
if acls == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Merge the global and app IP filter
|
|
||||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
|
||||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
|
||||||
|
|
||||||
for _, blocked := range blockedIps {
|
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
|
||||||
if err != nil {
|
|
||||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if res {
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, allowed := range allowedIPs {
|
|
||||||
res, err := utils.FilterIP(allowed, ip)
|
|
||||||
if err != nil {
|
|
||||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if res {
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(allowedIPs) > 0 {
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
|
||||||
if acls == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, bypassed := range acls.IP.Bypass {
|
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
|
||||||
if err != nil {
|
|
||||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if res {
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
||||||
auth.ensureOAuthSessionLimit()
|
auth.ensureOAuthSessionLimit()
|
||||||
|
|
||||||
@@ -773,46 +610,49 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
|
|||||||
auth.oauthMutex.Lock()
|
auth.oauthMutex.Lock()
|
||||||
defer auth.oauthMutex.Unlock()
|
defer auth.oauthMutex.Unlock()
|
||||||
|
|
||||||
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
|
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
cleanupIds := make([]string, 0, OAuthCleanupCount)
|
type entry struct {
|
||||||
|
id string
|
||||||
for range OAuthCleanupCount {
|
expiresAt int64
|
||||||
oldestId := ""
|
}
|
||||||
oldestTime := int64(0)
|
|
||||||
|
|
||||||
|
entries := make([]entry, 0, len(auth.oauthPendingSessions))
|
||||||
for id, session := range auth.oauthPendingSessions {
|
for id, session := range auth.oauthPendingSessions {
|
||||||
if oldestTime == 0 {
|
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
|
||||||
oldestId = id
|
|
||||||
oldestTime = session.ExpiresAt.Unix()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if slices.Contains(cleanupIds, id) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if session.ExpiresAt.Unix() < oldestTime {
|
|
||||||
oldestId = id
|
|
||||||
oldestTime = session.ExpiresAt.Unix()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanupIds = append(cleanupIds, oldestId)
|
slices.SortFunc(entries, func(a, b entry) int {
|
||||||
|
if a.expiresAt < b.expiresAt {
|
||||||
|
return -1
|
||||||
}
|
}
|
||||||
|
if a.expiresAt > b.expiresAt {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
|
||||||
for _, id := range cleanupIds {
|
for _, e := range entries[:OAuthCleanupCount] {
|
||||||
delete(auth.oauthPendingSessions, id)
|
delete(auth.oauthPendingSessions, e.id)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) lockdownMode() {
|
func (auth *AuthService) lockdownMode() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
|
||||||
auth.lockdownCtx = ctx
|
|
||||||
auth.lockdownCancelFunc = cancel
|
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
|
if auth.lockdown != nil && auth.lockdown.Active {
|
||||||
|
auth.loginMutex.Unlock()
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.lockdownCtx = ctx
|
||||||
|
auth.lockdownCancelFunc = cancel
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown = &Lockdown{
|
auth.lockdown = &Lockdown{
|
||||||
@@ -825,10 +665,12 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
||||||
defer timer.Stop()
|
|
||||||
|
|
||||||
auth.loginMutex.Unlock()
|
auth.loginMutex.Unlock()
|
||||||
|
|
||||||
|
defer cancel()
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
|
|||||||
@@ -85,17 +85,23 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var nameMatch *model.App
|
||||||
|
|
||||||
|
// First try to find a matching app by domain, then fallback to matching by app name (subdomain)
|
||||||
for appName, appLabels := range labels.Apps {
|
for appName, appLabels := range labels.Apps {
|
||||||
if appLabels.Config.Domain == appDomain {
|
if appLabels.Config.Domain == appDomain {
|
||||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||||
return &appLabels, nil
|
nameMatch = &appLabels
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nameMatch != nil {
|
||||||
|
return nameMatch, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
|
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
InsecureSkipVerify: config.Insecure,
|
InsecureSkipVerify: config.Insecure,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
@@ -116,12 +115,12 @@ type OIDCService struct {
|
|||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
queries *repository.Queries
|
queries repository.Store
|
||||||
context context.Context
|
context context.Context
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
publicKey crypto.PublicKey
|
publicKey *rsa.PublicKey
|
||||||
issuer string
|
issuer string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +128,7 @@ func NewOIDCService(
|
|||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
config model.Config,
|
config model.Config,
|
||||||
runtime model.RuntimeConfig,
|
runtime model.RuntimeConfig,
|
||||||
queries *repository.Queries,
|
queries repository.Store,
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
wg *sync.WaitGroup) (*OIDCService, error) {
|
wg *sync.WaitGroup) (*OIDCService, error) {
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
@@ -239,6 +238,16 @@ func NewOIDCService(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rPublicKey, ok := publicKey.(*rsa.PublicKey)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("public key is not an rsa public key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rPublicKey.N.Cmp(privateKey.N) != 0 || rPublicKey.E != privateKey.E {
|
||||||
|
return nil, fmt.Errorf("public key does not pair with private key")
|
||||||
|
}
|
||||||
|
|
||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
clients := make(map[string]model.OIDCClientConfig)
|
clients := make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
@@ -271,7 +280,7 @@ func NewOIDCService(
|
|||||||
|
|
||||||
clients: clients,
|
clients: clients,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
publicKey: publicKey,
|
publicKey: rPublicKey,
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,7 +433,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
|||||||
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
return repository.OidcCode{}, ErrCodeNotFound
|
return repository.OidcCode{}, ErrCodeNotFound
|
||||||
}
|
}
|
||||||
return repository.OidcCode{}, err
|
return repository.OidcCode{}, err
|
||||||
@@ -455,7 +464,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
|
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
|
|
||||||
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
|
der := x509.MarshalPKCS1PublicKey(service.publicKey)
|
||||||
|
|
||||||
if der == nil {
|
if der == nil {
|
||||||
return "", errors.New("failed to marshal public key")
|
return "", errors.New("failed to marshal public key")
|
||||||
@@ -568,7 +577,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
|||||||
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
return TokenResponse{}, ErrTokenNotFound
|
return TokenResponse{}, ErrTokenNotFound
|
||||||
}
|
}
|
||||||
return TokenResponse{}, err
|
return TokenResponse{}, err
|
||||||
@@ -647,7 +656,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
|
|||||||
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, repository.ErrNotFound) {
|
||||||
return repository.OidcToken{}, ErrTokenNotFound
|
return repository.OidcToken{}, ErrTokenNotFound
|
||||||
}
|
}
|
||||||
return repository.OidcToken{}, err
|
return repository.OidcToken{}, err
|
||||||
@@ -735,15 +744,15 @@ func (service *OIDCService) Hash(token string) string {
|
|||||||
|
|
||||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||||
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -790,7 +799,9 @@ func (service *OIDCService) cleanupRoutine() {
|
|||||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if !errors.Is(err, repository.ErrNotFound) {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -813,7 +824,7 @@ func (service *OIDCService) cleanupRoutine() {
|
|||||||
func (service *OIDCService) GetJWK() ([]byte, error) {
|
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
|
|
||||||
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
|
der := x509.MarshalPKCS1PublicKey(service.publicKey)
|
||||||
|
|
||||||
if der == nil {
|
if der == nil {
|
||||||
return nil, errors.New("failed to marshal public key")
|
return nil, errors.New("failed to marshal public key")
|
||||||
@@ -822,13 +833,13 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
|
|||||||
hasher.Write(der)
|
hasher.Write(der)
|
||||||
|
|
||||||
jwk := jose.JSONWebKey{
|
jwk := jose.JSONWebKey{
|
||||||
Key: service.privateKey,
|
Key: service.publicKey,
|
||||||
Algorithm: string(jose.RS256),
|
Algorithm: string(jose.RS256),
|
||||||
Use: "sig",
|
Use: "sig",
|
||||||
KeyID: base64.URLEncoding.EncodeToString(hasher.Sum(nil)),
|
KeyID: base64.URLEncoding.EncodeToString(hasher.Sum(nil)),
|
||||||
}
|
}
|
||||||
|
|
||||||
return jwk.Public().MarshalJSON()
|
return jwk.MarshalJSON()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
|
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
|
||||||
|
|||||||
@@ -0,0 +1,110 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Policy string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PolicyAllow Policy = "allow"
|
||||||
|
PolicyDeny Policy = "deny"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Effect int
|
||||||
|
|
||||||
|
const (
|
||||||
|
EffectAbstain Effect = iota
|
||||||
|
EffectAllow
|
||||||
|
EffectDeny
|
||||||
|
)
|
||||||
|
|
||||||
|
type Rule interface {
|
||||||
|
Evaluate(ctx *ACLContext) Effect
|
||||||
|
}
|
||||||
|
|
||||||
|
type ACLContext struct {
|
||||||
|
ACLs *model.App
|
||||||
|
UserContext *model.UserContext
|
||||||
|
IP net.IP
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyEngine struct {
|
||||||
|
log *logger.Logger
|
||||||
|
rules map[RuleName]Rule
|
||||||
|
policy Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
|
||||||
|
engine := PolicyEngine{
|
||||||
|
log: log,
|
||||||
|
rules: make(map[RuleName]Rule),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.Auth.ACLs.Policy {
|
||||||
|
case string(PolicyAllow):
|
||||||
|
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||||
|
engine.policy = PolicyAllow
|
||||||
|
case string(PolicyDeny):
|
||||||
|
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||||
|
engine.policy = PolicyDeny
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) RegisterRule(name RuleName, rule Rule) {
|
||||||
|
engine.log.App.Debug().Str("rule", string(name)).Msg("Registering ACL rule in policy engine")
|
||||||
|
engine.rules[name] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) evaluateRuleByName(name RuleName, ctx *ACLContext) Effect {
|
||||||
|
rule, exists := engine.rules[name]
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
engine.log.App.Warn().Str("rule", string(name)).Msg("Rule not found in policy engine, defaulting to deny")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule.Evaluate(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) effectToAccess(effect Effect) bool {
|
||||||
|
switch effect {
|
||||||
|
case EffectAllow:
|
||||||
|
return true
|
||||||
|
case EffectDeny:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
// If the effect is abstain, we fall back to the default policy
|
||||||
|
return engine.policy == PolicyAllow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) Evaluate(name RuleName, ctx *ACLContext) bool {
|
||||||
|
effect := engine.evaluateRuleByName(name, ctx)
|
||||||
|
access := engine.effectToAccess(effect)
|
||||||
|
|
||||||
|
engine.log.App.Debug().
|
||||||
|
Str("rule", string(name)).
|
||||||
|
Int("effect", int(effect)).
|
||||||
|
Bool("access", access).
|
||||||
|
Msg("Evaluated ACL rule")
|
||||||
|
|
||||||
|
return access
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) Policy() Policy {
|
||||||
|
return engine.policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
||||||
|
return engine.rules
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package service_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create test rule
|
||||||
|
type TestRule struct{}
|
||||||
|
|
||||||
|
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
||||||
|
switch ctx.Path {
|
||||||
|
case "/allowed":
|
||||||
|
return service.EffectAllow
|
||||||
|
case "/denied":
|
||||||
|
return service.EffectDeny
|
||||||
|
default:
|
||||||
|
return service.EffectAbstain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEngine(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
cfg, _ := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
testRule := &TestRule{}
|
||||||
|
|
||||||
|
// Engine should fail with invalid policy
|
||||||
|
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||||
|
_, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Engine should initialize with 'allow' policy
|
||||||
|
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||||
|
engine, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
||||||
|
|
||||||
|
// Engine should initialize with 'deny' policy
|
||||||
|
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||||
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
||||||
|
|
||||||
|
// Engine should allow adding rules
|
||||||
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
_, ok := engine.Rules()["test-rule"]
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
// Begin allow policy tests
|
||||||
|
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||||
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
|
// With allow policy, if rule allows, access should be allowed
|
||||||
|
ctx := &service.ACLContext{Path: "/allowed"}
|
||||||
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
|
// With allow policy, if rule denies, access should be denied
|
||||||
|
ctx.Path = "/denied"
|
||||||
|
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
|
// With allow policy, if rule abstains, access should be allowed (default)
|
||||||
|
ctx.Path = "/abstain"
|
||||||
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
|
// Begin deny policy tests
|
||||||
|
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||||
|
engine, err = service.NewPolicyEngine(cfg, log)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
|
// With deny policy, if rule allows, access should be allowed
|
||||||
|
ctx.Path = "/allowed"
|
||||||
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
|
// With deny policy, if rule denies, access should be denied
|
||||||
|
ctx.Path = "/denied"
|
||||||
|
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
|
// With deny policy, if rule abstains, access should be denied (default)
|
||||||
|
ctx.Path = "/abstain"
|
||||||
|
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
|
||||||
|
}
|
||||||
@@ -40,6 +40,9 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
SessionExpiry: 10,
|
SessionExpiry: 10,
|
||||||
LoginTimeout: 10,
|
LoginTimeout: 10,
|
||||||
LoginMaxRetries: 3,
|
LoginMaxRetries: 3,
|
||||||
|
ACLs: model.ACLsConfig{
|
||||||
|
Policy: "allow",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Database: model.DatabaseConfig{
|
Database: model.DatabaseConfig{
|
||||||
Path: filepath.Join(tempDir, "test.db"),
|
Path: filepath.Join(tempDir, "test.db"),
|
||||||
@@ -48,6 +51,32 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Path: filepath.Join(tempDir, "resources"),
|
Path: filepath.Join(tempDir, "resources"),
|
||||||
},
|
},
|
||||||
|
Apps: map[string]model.App{
|
||||||
|
"app_path_allow": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "path-allow.example.com",
|
||||||
|
},
|
||||||
|
Path: model.AppPath{
|
||||||
|
Allow: "/allowed",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"app_user_allow": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "user-allow.example.com",
|
||||||
|
},
|
||||||
|
Users: model.AppUsers{
|
||||||
|
Allow: "testuser",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ip_bypass": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ip-bypass.example.com",
|
||||||
|
},
|
||||||
|
IP: model.AppIP{
|
||||||
|
Bypass: []string{"10.10.10.10"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -46,26 +46,27 @@ func EncodeBasicAuth(username string, password string) string {
|
|||||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||||
}
|
}
|
||||||
|
|
||||||
func FilterIP(filter string, ip string) (bool, error) {
|
func CheckIPFilter(filter string, ip string) (bool, error) {
|
||||||
ipAddr := net.ParseIP(ip)
|
ipAddr := net.ParseIP(ip)
|
||||||
|
|
||||||
if ipAddr == nil {
|
if ipAddr == nil {
|
||||||
return false, errors.New("invalid IP address")
|
return false, fmt.Errorf("invalid ip address")
|
||||||
}
|
}
|
||||||
|
|
||||||
filter = strings.Replace(filter, "-", "/", -1)
|
filter = strings.ReplaceAll(filter, "-", "/")
|
||||||
|
|
||||||
if strings.Contains(filter, "/") {
|
if strings.Contains(filter, "/") {
|
||||||
_, cidr, err := net.ParseCIDR(filter)
|
_, cidr, err := net.ParseCIDR(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, fmt.Errorf("invalid cidr notation: %w", err)
|
||||||
}
|
}
|
||||||
return cidr.Contains(ipAddr), nil
|
return cidr.Contains(ipAddr), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipFilter := net.ParseIP(filter)
|
ipFilter := net.ParseIP(filter)
|
||||||
|
|
||||||
if ipFilter == nil {
|
if ipFilter == nil {
|
||||||
return false, errors.New("invalid IP address in filter")
|
return false, fmt.Errorf("invalid ip address")
|
||||||
}
|
}
|
||||||
|
|
||||||
if ipFilter.Equal(ipAddr) {
|
if ipFilter.Equal(ipAddr) {
|
||||||
@@ -75,31 +76,29 @@ func FilterIP(filter string, ip string) (bool, error) {
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckFilter(filter string, str string) bool {
|
func CheckFilter(filter string, input string) (bool, error) {
|
||||||
if len(strings.TrimSpace(filter)) == 0 {
|
if len(strings.TrimSpace(filter)) == 0 {
|
||||||
return true
|
return false, fmt.Errorf("filter is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||||
re, err := regexp.Compile(filter[1 : len(filter)-1])
|
re, err := regexp.Compile(filter[1 : len(filter)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false, fmt.Errorf("invalid regex filter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if re.MatchString(strings.TrimSpace(str)) {
|
if re.MatchString(input) {
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filterSplit := strings.Split(filter, ",")
|
for item := range strings.SplitSeq(filter, ",") {
|
||||||
|
if strings.TrimSpace(item) == input {
|
||||||
for _, item := range filterSplit {
|
return true, nil
|
||||||
if strings.TrimSpace(item) == strings.TrimSpace(str) {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateUUID(str string) string {
|
func GenerateUUID(str string) string {
|
||||||
|
|||||||
@@ -75,66 +75,77 @@ func TestEncodeBasicAuth(t *testing.T) {
|
|||||||
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilterIP(t *testing.T) {
|
func TestCheckIPFilter(t *testing.T) {
|
||||||
// Exact match IPv4
|
// Exact match IPv4
|
||||||
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
ok, err := utils.CheckIPFilter("10.10.0.1", "10.10.0.1")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// Non-match IPv4
|
// Non-match IPv4
|
||||||
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
ok, err = utils.CheckIPFilter("10.10.0.1", "10.10.0.2")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// CIDR match IPv4
|
// CIDR match IPv4
|
||||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.10.0.2")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// CIDR match IPv4 with '-' instead of '/'
|
// CIDR match IPv4 with '-' instead of '/'
|
||||||
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
ok, err = utils.CheckIPFilter("10.10.10.0-24", "10.10.10.5")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// CIDR non-match IPv4
|
// CIDR non-match IPv4
|
||||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.5.0.1")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// Invalid CIDR
|
// Invalid CIDR
|
||||||
ok, err = utils.FilterIP("10.10.0.0/222", "10.0.0.1")
|
ok, err = utils.CheckIPFilter("10.10.0.0/222", "10.0.0.1")
|
||||||
assert.ErrorContains(t, err, "invalid CIDR address")
|
assert.ErrorContains(t, err, "invalid cidr notation: invalid CIDR address: 10.10.0.0/222")
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// Invalid IP in filter
|
// Invalid IP in filter
|
||||||
ok, err = utils.FilterIP("invalid_ip", "10.5.5.5")
|
ok, err = utils.CheckIPFilter("invalid_ip", "10.5.5.5")
|
||||||
assert.ErrorContains(t, err, "invalid IP address in filter")
|
assert.ErrorContains(t, err, "invalid ip address")
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// Invalid IP to check
|
// Invalid IP to check
|
||||||
ok, err = utils.FilterIP("10.10.10.10", "invalid_ip")
|
ok, err = utils.CheckIPFilter("10.10.10.10", "invalid_ip")
|
||||||
assert.ErrorContains(t, err, "invalid IP address")
|
assert.ErrorContains(t, err, "invalid ip address")
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckFilter(t *testing.T) {
|
func TestCheckFilter(t *testing.T) {
|
||||||
// Empty filter
|
// Empty filter
|
||||||
assert.Equal(t, true, utils.CheckFilter("", "anystring"))
|
_, err := utils.CheckFilter("", "anystring")
|
||||||
|
assert.ErrorContains(t, err, "filter is empty")
|
||||||
|
|
||||||
// Exact match
|
// Exact match
|
||||||
assert.Equal(t, true, utils.CheckFilter("hello", "hello"))
|
ok, err := utils.CheckFilter("hello", "hello")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// Regex match
|
// Regex match
|
||||||
assert.Equal(t, true, utils.CheckFilter("/^h.*o$/", "hello"))
|
ok, err = utils.CheckFilter("/^h.*o$/", "hello")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// Invalid regex
|
// Invalid regex
|
||||||
assert.Equal(t, false, utils.CheckFilter("/[unclosed", "test"))
|
ok, err = utils.CheckFilter("/[unclosed/", "test")
|
||||||
|
assert.ErrorContains(t, err, "invalid regex")
|
||||||
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// Comma-separated values
|
// Comma-separated values
|
||||||
assert.Equal(t, true, utils.CheckFilter("apple, banana, cherry", "banana"))
|
ok, err = utils.CheckFilter("apple, banana, cherry", "banana")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// No match
|
// No match
|
||||||
assert.Equal(t, false, utils.CheckFilter("apple, banana, cherry", "grape"))
|
ok, err = utils.CheckFilter("apple, banana, cherry", "grape")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, false, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGenerateUUID(t *testing.T) {
|
func TestGenerateUUID(t *testing.T) {
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
version: "2"
|
version: "2"
|
||||||
sql:
|
sql:
|
||||||
- engine: "sqlite"
|
- engine: "sqlite"
|
||||||
queries: "sql/*_queries.sql"
|
queries: "sql/sqlite/*_queries.sql"
|
||||||
schema: "sql/*_schemas.sql"
|
schema: "sql/sqlite/*_schemas.sql"
|
||||||
gen:
|
gen:
|
||||||
go:
|
go:
|
||||||
package: "repository"
|
package: "sqlite"
|
||||||
out: "internal/repository"
|
out: "internal/repository/sqlite"
|
||||||
rename:
|
rename:
|
||||||
uuid: "UUID"
|
uuid: "UUID"
|
||||||
oauth_groups: "OAuthGroups"
|
oauth_groups: "OAuthGroups"
|
||||||
|
|||||||
Reference in New Issue
Block a user