Compare commits

..

24 Commits

Author SHA1 Message Date
Stavros 04b2290d73 fix: review comments batch 3 2026-05-05 18:56:35 +03:00
Stavros e04980468f fix: review comments batch 2 2026-05-05 18:54:45 +03:00
Stavros d47e4d3d79 fix: review comments batch 1 2026-05-05 18:43:22 +03:00
Stavros f3965a7470 fix: fix verion setting in cd and dockerfiles 2026-05-04 21:08:45 +03:00
Stavros 36d4e3ec52 tests: fix log wrapper tests 2026-05-04 21:01:13 +03:00
Stavros eab9f71110 tests: remove error wrapper from context tests 2026-05-04 20:57:37 +03:00
Stavros e13598bf3c tests: add tests for context middleware 2026-05-04 20:52:59 +03:00
Stavros 4d3860f860 tests: add tests for context parsing 2026-05-04 20:33:49 +03:00
Stavros 3b5da06862 fix: fix config reference generator 2026-05-04 20:25:56 +03:00
Stavros 8f337aaff8 tests: move to testify for testing in utils 2026-05-04 20:25:16 +03:00
Stavros ff3c25c09d tests: fix utils tests 2026-05-04 20:18:34 +03:00
Stavros 26daef7d4e tests: fix service tests 2026-05-04 20:11:07 +03:00
Stavros c932817757 fix: fix controller tests 2026-05-04 20:07:03 +03:00
Stavros 004df2f852 chore: rename get basic auth to encode basic auth for clarity 2026-05-04 16:14:45 +03:00
Stavros df56708b9a refactor: simplify acls checking logic by passing the entire acl struct 2026-05-04 16:13:39 +03:00
Stavros 62ffd2fd11 feat: finalize context functionality 2026-04-29 20:11:43 +03:00
Stavros a3ec07230c fix: fix oauth and oidc controller imports and context 2026-04-29 20:00:36 +03:00
Stavros b4eb7090bd fix: fix imports and context in proxy controller 2026-04-29 19:58:39 +03:00
Stavros 2f24f823eb fix: use new context in user controller 2026-04-29 19:45:23 +03:00
Stavros 9a219046ac fix: context controller 2026-04-29 19:31:44 +03:00
Stavros 97d58b376d fix: fix cli imports 2026-04-29 19:28:40 +03:00
Stavros b426a1529e fix: fix bootstrap import issues 2026-04-29 19:27:38 +03:00
Stavros c7efb71a5a fix: fix util imports 2026-04-29 19:25:23 +03:00
Stavros eec75a6f49 wip 2026-04-29 19:21:07 +03:00
96 changed files with 2106 additions and 2474 deletions
-6
View File
@@ -26,12 +26,6 @@ jobs:
- name: Go dependencies
run: go mod download
- name: Check codegen is up to date
run: |
go generate ./internal/repository/...
git diff --exit-code -- internal/repository/
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
- name: Install frontend dependencies
run: |
cd frontend
+2 -2
View File
@@ -84,7 +84,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -130,7 +130,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
+2 -2
View File
@@ -60,7 +60,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -103,7 +103,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
+3 -3
View File
@@ -38,9 +38,9 @@ COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM alpine:3.23 AS runner
+3 -3
View File
@@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM gcr.io/distroless/static-debian12:latest AS runner
+3 -3
View File
@@ -37,9 +37,9 @@ webui: clean-webui
# Build the binary
binary: webui
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \
-o ${BIN_NAME} ./cmd/tinyauth
# Build for amd64
-540
View File
@@ -1,540 +0,0 @@
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
// internal/repository/<driver>/. Run via:
//
// go generate ./internal/repository/...
//
// The generator introspects *Queries methods and the model/params types in the
// driver package, then emits a store.go that wraps *Queries so it satisfies
// repository.Store using the canonical shared types in the parent package.
// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should
// implement repository.Store directly by hand.
package main
import (
"bytes"
"flag"
"fmt"
"go/format"
"go/types"
"log"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"text/template"
"golang.org/x/tools/go/packages"
)
func main() {
driverPkg := flag.String("pkg", "", "import path of the driver package")
out := flag.String("out", "store.go", "output filename relative to driver package directory")
flag.Parse()
if *driverPkg == "" {
log.Fatal("-pkg is required")
}
// Resolve the driver package directory so we can overlay the output file
// with a valid stub. This prevents a stale store.go from poisoning the
// type-checker and producing cryptic "undefined" errors.
driverDir, err := pkgDir(*driverPkg)
if err != nil {
log.Fatalf("resolve driver dir: %v", err)
}
outPath := filepath.Join(driverDir, *out)
if filepath.IsAbs(*out) {
outPath = *out
}
// Stub replaces the output file during load so stale generated code is ignored.
stub := []byte("package " + filepath.Base(driverDir) + "\n")
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
Overlay: map[string][]byte{outPath: stub},
}
pkgs, err := packages.Load(cfg, *driverPkg)
if err != nil {
log.Fatalf("load %s: %v", *driverPkg, err)
}
if len(pkgs) != 1 {
log.Fatalf("expected 1 package, got %d", len(pkgs))
}
pkg := pkgs[0]
if len(pkg.Errors) > 0 {
for _, e := range pkg.Errors {
log.Printf("package error: %v", e)
}
log.Fatal("package has errors")
}
repoPkg := parentPkg(*driverPkg)
// Load the parent (repository) package so we can validate struct shapes.
repoPkgs, err := packages.Load(cfg, repoPkg)
if err != nil {
log.Fatalf("load repo pkg %s: %v", repoPkg, err)
}
if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 {
log.Fatalf("could not load repo package %s cleanly", repoPkg)
}
if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil {
log.Fatalf("struct shape mismatch: %v", err)
}
// Check *Queries covers every method in repository.Store before generating.
if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil {
log.Fatalf("%v", err)
}
methods, err := collectMethods(pkg.Types)
if err != nil {
log.Fatal(err)
}
models, _ := collectTypes(pkg.Types)
data := tmplData{
PkgName: pkg.Name,
RepoPkg: repoPkg,
ModelTypes: models,
Methods: renderMethods(methods),
}
src, err := render(data)
if err != nil {
log.Fatalf("render: %v", err)
}
if err := os.WriteFile(outPath, src, 0644); err != nil {
log.Fatalf("write %s: %v", outPath, err)
}
fmt.Printf("wrote %s\n", outPath)
}
func parentPkg(imp string) string {
parts := strings.Split(imp, "/")
return strings.Join(parts[:len(parts)-1], "/")
}
// pkgDir returns the on-disk directory for an import path using `go list`.
func pkgDir(importPath string) (string, error) {
out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output()
if err != nil {
return "", fmt.Errorf("go list %s: %w", importPath, err)
}
return strings.TrimSpace(string(out)), nil
}
// validateStoreCoverage checks that every method declared in repository.Store
// exists on *Queries in the driver package. Missing methods are reported by
// name so the developer knows exactly which SQL queries need to be added.
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
// Collect *Queries method names.
queriesObj := driverPkg.Scope().Lookup("Queries")
if queriesObj == nil {
return fmt.Errorf("Queries type not found in driver package")
}
queriesNamed := queriesObj.Type().(*types.Named)
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
queriesMethods := make(map[string]bool)
for m := range queriesMS.Methods() {
queriesMethods[m.Obj().Name()] = true
}
// Collect repository.Store interface methods.
storeObj := repoPkg.Scope().Lookup("Store")
if storeObj == nil {
return fmt.Errorf("Store type not found in repository package")
}
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
if !ok {
return fmt.Errorf("repository.Store is not an interface")
}
var missing []string
for i := range storeIface.NumMethods() {
name := storeIface.Method(i).Name()
if !queriesMethods[name] {
missing = append(missing, name)
}
}
if len(missing) > 0 {
sort.Strings(missing)
return fmt.Errorf(
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.",
len(missing), strings.Join(missing, "\n - "),
)
}
return nil
}
type methodInfo struct {
Name string
Params []paramInfo
Results []resultInfo
}
type paramInfo struct {
Name string
TypeStr string // local (unqualified) type name
RepoType string // "repository.X" if this is a driver model/params type; else ""
}
type resultInfo struct {
TypeStr string
IsSlice bool
RepoType string // "repository.X" if driver type; else ""
}
func collectMethods(pkg *types.Package) ([]methodInfo, error) {
obj := pkg.Scope().Lookup("Queries")
if obj == nil {
return nil, fmt.Errorf("queries type not found in %s", pkg.Path())
}
named, ok := obj.Type().(*types.Named)
if !ok {
return nil, fmt.Errorf("queries is not a named type")
}
ms := types.NewMethodSet(types.NewPointer(named))
var out []methodInfo
for method := range ms.Methods() {
fn, ok := method.Obj().(*types.Func)
if !ok || fn.Name() == "WithTx" {
continue
}
sig := fn.Type().(*types.Signature)
mi := methodInfo{Name: fn.Name()}
// params: skip receiver + first (context.Context)
for i := 1; i < sig.Params().Len(); i++ {
p := sig.Params().At(i)
mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path()))
}
// results: skip error
for r := range sig.Results().Variables() {
if r.Type().String() == "error" {
continue
}
mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path()))
}
out = append(out, mi)
}
return out, nil
}
func makeParam(name string, t types.Type, driverPath string) paramInfo {
pi := paramInfo{Name: name}
pi.TypeStr = localName(t, driverPath)
pi.RepoType = repoName(t, driverPath)
return pi
}
func makeResult(t types.Type, driverPath string) resultInfo {
ri := resultInfo{}
if sl, ok := t.(*types.Slice); ok {
ri.IsSlice = true
t = sl.Elem()
}
ri.TypeStr = localName(t, driverPath)
ri.RepoType = repoName(t, driverPath)
return ri
}
func localName(t types.Type, driverPath string) string {
named, ok := t.(*types.Named)
if !ok {
return types.TypeString(t, nil)
}
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
return named.Obj().Name()
}
return types.TypeString(t, func(p *types.Package) string { return p.Name() })
}
func repoName(t types.Type, driverPath string) string {
named, ok := t.(*types.Named)
if !ok {
return ""
}
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
return "repository." + named.Obj().Name()
}
return ""
}
func collectTypes(pkg *types.Package) (models []string, params []string) {
for _, name := range pkg.Scope().Names() {
obj := pkg.Scope().Lookup(name)
if obj == nil {
continue
}
tn, ok := obj.(*types.TypeName)
if !ok {
continue
}
named, ok := tn.Type().(*types.Named)
if !ok {
continue
}
if _, ok := named.Underlying().(*types.Struct); !ok {
continue
}
switch name {
case "Queries", "DBTX", "Store":
continue
}
if strings.HasSuffix(name, "Params") {
params = append(params, name)
} else {
models = append(models, name)
}
}
return
}
// validateStructShapes checks that every model/params struct in the driver
// package has fields that exactly match the corresponding type in the repo
// (parent) package. This catches drift between sqlc-generated types and the
// canonical repository types before a broken cast reaches the compiler.
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
var errs []string
for _, name := range driverPkg.Scope().Names() {
obj := driverPkg.Scope().Lookup(name)
if obj == nil {
continue
}
tn, ok := obj.(*types.TypeName)
if !ok {
continue
}
named, ok := tn.Type().(*types.Named)
if !ok {
continue
}
driverStruct, ok := named.Underlying().(*types.Struct)
if !ok {
continue
}
switch name {
case "Queries", "DBTX", "Store":
continue
}
repoObj := repoPkg.Scope().Lookup(name)
if repoObj == nil {
// Driver has a type not in repo — that's fine (e.g. internal helpers).
continue
}
repoNamed, ok := repoObj.Type().(*types.Named)
if !ok {
continue
}
repoStruct, ok := repoNamed.Underlying().(*types.Struct)
if !ok {
errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name))
continue
}
if err := compareStructs(name, driverStruct, repoStruct); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
return fmt.Errorf("%s", strings.Join(errs, "\n "))
}
return nil
}
func compareStructs(name string, driver, repo *types.Struct) error {
if driver.NumFields() != repo.NumFields() {
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
name, driver.NumFields(), repo.NumFields())
}
for i := range driver.NumFields() {
df := driver.Field(i)
rf := repo.Field(i)
if df.Name() != rf.Name() {
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
name, i, df.Name(), rf.Name())
}
if !types.Identical(df.Type(), rf.Type()) {
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
name, df.Name(), df.Type(), rf.Type())
}
}
return nil
}
// converterFn: "Session" -> "sessionToRepo"
func converterFn(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = []rune(strings.ToLower(string(r[0])))[0]
return string(r) + "ToRepo"
}
// renderedMethod is the pre-built method body passed to the template.
type renderedMethod struct {
Signature string
Body string
}
// renderMethods converts []methodInfo into fully pre-rendered signature+body strings.
func renderMethods(methods []methodInfo) []renderedMethod {
var out []renderedMethod
for _, m := range methods {
out = append(out, renderedMethod{
Signature: buildSig(m),
Body: buildBody(m),
})
}
return out
}
func buildSig(m methodInfo) string {
var sb strings.Builder
sb.WriteString("func (s *Store) ")
sb.WriteString(m.Name)
sb.WriteString("(ctx context.Context")
for _, p := range m.Params {
sb.WriteString(", ")
sb.WriteString(p.Name)
sb.WriteString(" ")
if p.RepoType != "" {
sb.WriteString(p.RepoType)
} else {
sb.WriteString(p.TypeStr)
}
}
sb.WriteString(") (")
for _, r := range m.Results {
if r.IsSlice {
sb.WriteString("[]")
}
if r.RepoType != "" {
sb.WriteString(r.RepoType)
} else {
sb.WriteString(r.TypeStr)
}
sb.WriteString(", ")
}
sb.WriteString("error)")
return sb.String()
}
func callArgs(m methodInfo) string {
var args []string
for _, p := range m.Params {
if p.RepoType != "" {
// convert repo type → driver type: DriverType(arg)
args = append(args, p.TypeStr+"("+p.Name+")")
} else {
args = append(args, p.Name)
}
}
if len(args) == 0 {
return "ctx"
}
return "ctx, " + strings.Join(args, ", ")
}
func buildBody(m methodInfo) string {
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
// no repo-typed result → direct return
if len(m.Results) == 0 || m.Results[0].RepoType == "" {
return "\treturn mapErr(" + call + ")\n"
}
r := m.Results[0]
if r.IsSlice {
return fmt.Sprintf(
"\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n",
call, r.RepoType, converterFn(r.TypeStr),
)
}
return fmt.Sprintf(
"\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n",
call, r.RepoType, converterFn(r.TypeStr),
)
}
type tmplData struct {
PkgName string
RepoPkg string
ModelTypes []string
Methods []renderedMethod
}
const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package {{.PkgName}}
import (
"context"
"database/sql"
"errors"
"{{.RepoPkg}}"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
{{range .ModelTypes -}}
func {{converterFn .}}(v {{.}}) repository.{{.}} {
return repository.{{.}}(v)
}
{{end -}}
{{range .Methods}}{{.Signature}} {
{{.Body}}}
{{end}}`
func render(data tmplData) ([]byte, error) {
t, err := template.New("store").Funcs(template.FuncMap{
"converterFn": converterFn,
}).Parse(storeSrc)
if err != nil {
return nil, fmt.Errorf("parse template: %w", err)
}
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute template: %w", err)
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String())
}
return formatted, nil
}
+3 -3
View File
@@ -73,7 +73,7 @@ func generateTotpCmd() *cli.Command {
docker = true
}
if user.TotpSecret != "" {
if user.TOTPSecret != "" {
return fmt.Errorf("user already has a TOTP secret")
}
@@ -102,14 +102,14 @@ func generateTotpCmd() *cli.Command {
qrterminal.GenerateWithConfig(key.URL(), config)
user.TotpSecret = secret
user.TOTPSecret = secret
// If using docker escape re-escape it
if docker {
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
}
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil
},
+4 -4
View File
@@ -5,7 +5,7 @@ import (
"charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -14,7 +14,7 @@ import (
)
func main() {
tConfig := config.NewDefaultConfiguration()
tConfig := model.NewDefaultConfiguration()
loaders := []cli.ResourceLoader{
&loaders.FileLoader{},
@@ -108,11 +108,11 @@ func main() {
}
}
func runCmd(cfg config.Config) error {
func runCmd(cfg model.Config) error {
logger := tlog.NewLogger(cfg.Log)
logger.Init()
tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth")
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg)
+2 -2
View File
@@ -95,7 +95,7 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("password is incorrect: %w", err)
}
if user.TotpSecret == "" {
if user.TOTPSecret == "" {
if tCfg.Totp != "" {
tlog.App.Warn().Msg("User does not have TOTP secret")
}
@@ -103,7 +103,7 @@ func verifyUserCmd() *cli.Command {
return nil
}
ok := totp.Validate(tCfg.Totp, user.TotpSecret)
ok := totp.Validate(tCfg.Totp, user.TOTPSecret)
if !ok {
return fmt.Errorf("TOTP code incorrect")
+4 -5
View File
@@ -3,9 +3,8 @@ package main
import (
"fmt"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/model"
)
func versionCmd() *cli.Command {
@@ -15,9 +14,9 @@ func versionCmd() *cli.Command {
Configuration: nil,
Resources: nil,
Run: func(_ []string) error {
fmt.Printf("Version: %s\n", config.Version)
fmt.Printf("Commit Hash: %s\n", config.CommitHash)
fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp)
fmt.Printf("Version: %s\n", model.Version)
fmt.Printf("Commit Hash: %s\n", model.CommitHash)
fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp)
return nil
},
}
+2 -2
View File
@@ -10,7 +10,7 @@ import (
"reflect"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type EnvEntry struct {
@@ -20,7 +20,7 @@ type EnvEntry struct {
}
func generateExampleEnv() {
cfg := config.NewDefaultConfiguration()
cfg := model.NewDefaultConfiguration()
entries := make([]EnvEntry, 0)
root := reflect.TypeOf(cfg).Elem()
+2 -2
View File
@@ -10,7 +10,7 @@ import (
"reflect"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type MarkdownEntry struct {
@@ -21,7 +21,7 @@ type MarkdownEntry struct {
}
func generateMarkdown() {
cfg := config.NewDefaultConfiguration()
cfg := model.NewDefaultConfiguration()
entries := make([]MarkdownEntry, 0)
root := reflect.TypeOf(cfg).Elem()
+1 -3
View File
@@ -20,8 +20,6 @@ require (
github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.50.0
golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.43.0
gotest.tools/v3 v3.5.2
k8s.io/apimachinery v0.32.2
k8s.io/client-go v0.32.2
modernc.org/sqlite v1.49.1
@@ -125,7 +123,6 @@ require (
go.opentelemetry.io/otel/trace v1.43.0 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
@@ -135,6 +132,7 @@ require (
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.5.2 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
modernc.org/libc v1.72.0 // indirect
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations
//
//go:embed migrations/sqlite/*.sql
//go:embed migrations/*.sql
var Migrations embed.FS
+23 -20
View File
@@ -12,15 +12,15 @@ import (
"strings"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
type BootstrapApp struct {
config config.Config
config model.Config
context struct {
appUrl string
uuid string
@@ -29,15 +29,15 @@ type BootstrapApp struct {
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig
oidcClients []model.OIDCClientConfig
}
services Services
}
func NewBootstrapApp(config config.Config) *BootstrapApp {
func NewBootstrapApp(config model.Config) *BootstrapApp {
return &BootstrapApp{
config: config,
}
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
return err
}
app.context.users = users
app.context.localUsers = users
// Setup OAuth providers
app.context.oauthProviders = app.config.OAuth.Providers
@@ -88,7 +88,7 @@ func (app *BootstrapApp) Setup() error {
for id, provider := range app.context.oauthProviders {
if provider.Name == "" {
if name, ok := config.OverrideProviders[id]; ok {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
} else {
provider.Name = utils.Capitalize(id)
@@ -115,14 +115,14 @@ func (app *BootstrapApp) Setup() error {
// Cookie names
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump")
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
@@ -130,14 +130,17 @@ func (app *BootstrapApp) Setup() error {
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
// Database
store, err := app.SetupStore()
db, err := app.SetupDatabase(app.config.Database.Path)
if err != nil {
return fmt.Errorf("failed to setup database: %w", err)
}
// Queries
queries := repository.New(db)
// Services
services, err := app.initServices(store)
services, err := app.initServices(queries)
if err != nil {
return fmt.Errorf("failed to initialize services: %w", err)
@@ -168,7 +171,7 @@ func (app *BootstrapApp) Setup() error {
})
}
if services.authService.LdapAuthConfigured() {
if services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP",
ID: "ldap",
@@ -193,7 +196,7 @@ func (app *BootstrapApp) Setup() error {
// Start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(store)
go app.dbCleanupRoutine(queries)
// If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
@@ -241,7 +244,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
var body heartbeat
body.UUID = app.context.uuid
body.Version = config.Version
body.Version = model.Version
bodyJson, err := json.Marshal(body)
@@ -254,7 +257,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
}
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
for range ticker.C {
tlog.App.Debug().Msg("Sending heartbeat")
@@ -283,7 +286,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
}
}
func (app *BootstrapApp) dbCleanupRoutine(queries repository.Store) {
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
ctx := context.Background()
+3 -17
View File
@@ -7,9 +7,6 @@ import (
"path/filepath"
"github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
@@ -17,18 +14,7 @@ import (
_ "modernc.org/sqlite"
)
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
switch app.config.Database.Driver {
case "memory":
return memory.New(), nil
case "sqlite", "":
return app.setupSQLite(app.config.Database.Path)
default:
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
}
}
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil {
@@ -45,7 +31,7 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
// if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1)
migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
migrations, err := iofs.New(assets.Migrations, "migrations")
if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err)
@@ -67,5 +53,5 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
return sqlite.NewStore(sqlite.New(db)), nil
return db, nil
}
+6 -4
View File
@@ -4,9 +4,9 @@ import (
"fmt"
"slices"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/gin-gonic/gin"
)
@@ -14,7 +14,7 @@ import (
var DEV_MODES = []string{"main", "test", "development"}
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, config.Version) {
if !slices.Contains(DEV_MODES, model.Version) {
gin.SetMode(gin.ReleaseMode)
}
@@ -30,7 +30,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
}
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: app.context.cookieDomain,
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init()
@@ -98,7 +99,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
proxyController.SetupRoutes()
userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain,
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, apiRouter, app.services.authService)
userController.SetupRoutes()
+11 -11
View File
@@ -18,18 +18,18 @@ type Services struct {
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries repository.Store) (Services, error) {
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.Ldap.Address,
BindDN: app.config.Ldap.BindDN,
BindPassword: app.config.Ldap.BindPassword,
BaseDN: app.config.Ldap.BaseDN,
Insecure: app.config.Ldap.Insecure,
SearchFilter: app.config.Ldap.SearchFilter,
AuthCert: app.config.Ldap.AuthCert,
AuthKey: app.config.Ldap.AuthKey,
Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN,
BindPassword: app.config.LDAP.BindPassword,
BaseDN: app.config.LDAP.BaseDN,
Insecure: app.config.LDAP.Insecure,
SearchFilter: app.config.LDAP.SearchFilter,
AuthCert: app.config.LDAP.AuthCert,
AuthKey: app.config.LDAP.AuthKey,
})
err := ldapService.Init()
@@ -89,7 +89,7 @@ func (app *BootstrapApp) initServices(queries repository.Store) (Services, error
services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{
Users: app.context.users,
LocalUsers: app.context.localUsers,
OauthWhitelist: app.config.OAuth.Whitelist,
SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
@@ -99,7 +99,7 @@ func (app *BootstrapApp) initServices(queries repository.Store) (Services, error
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
}, services.ldapService, queries, services.oauthBrokerService)
err = authService.Init()
+21 -20
View File
@@ -4,7 +4,7 @@ import (
"fmt"
"net/url"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
@@ -19,7 +19,7 @@ type UserContextResponse struct {
Email string `json:"email"`
Provider string `json:"provider"`
OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`
TOTPPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
}
@@ -76,28 +76,29 @@ func (controller *ContextController) SetupRoutes() {
}
func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := utils.GetContext(c)
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
IsLoggedIn: false,
})
return
}
userContext := UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: context.IsLoggedIn,
Username: context.Username,
Name: context.Name,
Email: context.Email,
Provider: context.Provider,
OAuth: context.OAuth,
TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
}
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
userContext.Status = 401
userContext.Message = "Unauthorized"
userContext.IsLoggedIn = false
c.JSON(200, userContext)
return
IsLoggedIn: context.Authenticated,
Username: context.GetUsername(),
Name: context.GetName(),
Email: context.GetEmail(),
Provider: context.ProviderName(),
OAuth: context.IsOAuth(),
TOTPPending: context.TOTPPending(),
OAuthName: context.OAuthName(),
}
c.JSON(200, userContext)
+12 -8
View File
@@ -7,11 +7,11 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)
func TestContextController(t *testing.T) {
@@ -79,12 +79,16 @@ func TestContextController(t *testing.T) {
description: "Ensure user context returns when authorized",
middlewares: []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
Provider: "local",
IsLoggedIn: true,
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
},
},
})
},
},
+12
View File
@@ -0,0 +1,12 @@
package controller
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
+5 -4
View File
@@ -6,7 +6,6 @@ import (
"strings"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -176,7 +175,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
})
@@ -236,7 +235,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -244,6 +243,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
@@ -259,7 +260,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}
if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(config.RedirectQuery{
queries, err := query.Values(RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
})
+5 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -111,14 +112,14 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
userContext, err := utils.GetContext(c)
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
if !userContext.IsLoggedIn {
if !userContext.Authenticated {
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
return
}
@@ -151,7 +152,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
code := utils.GenerateString(32)
// Before storing the code, delete old session
@@ -170,7 +171,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
// We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
+29 -15
View File
@@ -12,13 +12,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
func TestOIDCController(t *testing.T) {
@@ -26,7 +27,7 @@ func TestOIDCController(t *testing.T) {
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
@@ -43,12 +44,16 @@ func TestOIDCController(t *testing.T) {
controllerCfg := controller.OIDCControllerConfig{}
simpleCtx := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "test",
Name: "Test User",
Email: "test@example.com",
IsLoggedIn: true,
Provider: "local",
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "test",
Name: "Test User",
Email: "test@example.com",
},
},
})
c.Next()
}
@@ -847,10 +852,14 @@ func TestOIDCController(t *testing.T) {
},
}
store := memory.New()
app := bootstrap.NewBootstrapApp(model.Config{})
oidcService := service.NewOIDCService(oidcServiceCfg, store)
err := oidcService.Init()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
require.NoError(t, err)
for _, test := range tests {
@@ -872,4 +881,9 @@ func TestOIDCController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
+41 -40
View File
@@ -8,7 +8,7 @@ import (
"regexp"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -103,7 +103,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(acls.IP, clientIP) {
if controller.auth.IsBypassedIP(clientIP, acls) {
controller.setHeaders(c, acls)
c.JSON(200, gin.H{
"status": 200,
@@ -112,7 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
@@ -130,8 +130,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if !controller.auth.CheckIP(acls.IP, clientIP) {
queries, err := query.Values(config.UnauthorizedQuery{
if !controller.auth.CheckIP(clientIP, acls) {
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP,
})
@@ -157,28 +157,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
var userContext config.UserContext
context, err := utils.GetContext(c)
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
userContext = config.UserContext{
IsLoggedIn: false,
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
userContext = &model.UserContext{
Authenticated: false,
}
} else {
userContext = context
}
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.IsLoggedIn {
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
})
@@ -188,10 +184,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
queries.Set("username", userContext.Email)
if userContext.IsOAuth() {
queries.Set("username", userContext.GetEmail())
} else {
queries.Set("username", userContext.Username)
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -209,19 +205,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth || userContext.Provider == "ldap" {
if userContext.IsOAuth() || userContext.IsLDAP() {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.IsOAuth() {
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true,
})
@@ -232,10 +228,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
queries.Set("username", userContext.Email)
if userContext.IsOAuth() {
queries.Set("username", userContext.GetEmail())
} else {
queries.Set("username", userContext.Username)
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -254,17 +250,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
}
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername()))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName()))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail()))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.IsLDAP() {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ",")))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
if userContext.IsOAuth() {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
}
controller.setHeaders(c, acls)
@@ -275,7 +272,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(config.RedirectQuery{
queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
@@ -299,9 +296,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
if acls == nil {
return
}
headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers {
@@ -313,7 +314,7 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
}
}
+52 -31
View File
@@ -2,23 +2,26 @@ package controller_test
import (
"net/http/httptest"
"path"
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
@@ -26,7 +29,7 @@ func TestProxyController(t *testing.T) {
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
@@ -40,28 +43,28 @@ func TestProxyController(t *testing.T) {
AppURL: "https://tinyauth.example.com",
}
acls := map[string]config.App{
acls := map[string]model.App{
"app_path_allow": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "path-allow.example.com",
},
Path: config.AppPath{
Path: model.AppPath{
Allow: "/allowed",
},
},
"app_user_allow": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "user-allow.example.com",
},
Users: config.AppUsers{
Users: model.AppUsers{
Allow: "testuser",
},
},
"ip_bypass": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "ip-bypass.example.com",
},
IP: config.AppIP{
IP: model.AppIP{
Bypass: []string{"10.10.10.10"},
},
},
@@ -71,24 +74,32 @@ func TestProxyController(t *testing.T) {
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
simpleCtx := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
Provider: "local",
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
},
})
c.Next()
}
simpleCtxTotp := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
IsLoggedIn: true,
Provider: "local",
TotpEnabled: true,
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPEnabled: true,
},
})
c.Next()
}
@@ -388,12 +399,17 @@ func TestProxyController(t *testing.T) {
},
}
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
store := memory.New()
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
docker := service.NewDockerService()
err := docker.Init()
err = docker.Init()
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{})
@@ -404,7 +420,7 @@ func TestProxyController(t *testing.T) {
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
@@ -429,4 +445,9 @@ func TestProxyController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
+104 -48
View File
@@ -1,10 +1,12 @@
package controller
import (
"errors"
"fmt"
"net/http"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -24,7 +26,8 @@ type TotpRequest struct {
}
type UserControllerConfig struct {
CookieDomain string
CookieDomain string
SessionCookieName string
}
type UserController struct {
@@ -77,20 +80,28 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
userSearch := controller.auth.SearchUser(req.Username)
search, err := controller.auth.SearchUser(req.Username)
if userSearch.Type == "unknown" {
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
if err != nil {
if errors.Is(err, service.ErrUserNotFound) {
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if !controller.auth.VerifyUser(userSearch, req.Password) {
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
@@ -106,30 +117,26 @@ func (controller *UserController) loginHandler(c *gin.Context) {
controller.auth.RecordLoginAttempt(req.Username, true)
var localUser *config.User
if userSearch.Type == "local" {
user := controller.auth.GetLocalUser(userSearch.Username)
localUser = &user
}
var localUser *model.LocalUser
if userSearch.Type == "local" && localUser != nil {
user := *localUser
if search.Type == model.UserLocal {
localUser = controller.auth.GetLocalUser(req.Username)
if user.TotpSecret != "" {
if localUser.TOTPSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
name := user.Attributes.Name
name := localUser.Attributes.Name
if name == "" {
name = utils.Capitalize(user.Username)
name = utils.Capitalize(localUser.Username)
}
email := user.Attributes.Email
email := localUser.Attributes.Email
if email == "" {
email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain)
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
}
err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: user.Username,
cookie, err := controller.auth.CreateSession(c, repository.Session{
Username: localUser.Username,
Name: name,
Email: email,
Provider: "local",
@@ -145,6 +152,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{
"status": 200,
"message": "TOTP required",
@@ -161,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Provider: "local",
}
if userSearch.Type == "local" && localUser != nil {
if search.Type == model.UserLocal {
if localUser.Attributes.Name != "" {
sessionCookie.Name = localUser.Attributes.Name
}
@@ -170,13 +179,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
if userSearch.Type == "ldap" {
if search.Type == model.UserLDAP {
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -187,6 +196,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
@@ -196,13 +207,47 @@ func (controller *UserController) loginHandler(c *gin.Context) {
func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received")
controller.auth.DeleteSessionCookie(c)
uuid, err := c.Cookie(controller.config.SessionCookieName)
context, err := utils.GetContext(c)
if err == nil && context.IsLoggedIn {
tlog.AuditLogout(c, context.Username, context.Provider)
if err != nil {
if errors.Is(err, http.ErrNoCookie) {
tlog.App.Warn().Msg("No session cookie found on logout request")
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
})
return
}
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
context, err := new(model.UserContext).NewFromGin(c)
if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
} else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
tlog.AuditLogout(c, "unknown", "unknown")
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
@@ -222,7 +267,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
context, err := utils.GetContext(c)
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context")
@@ -233,7 +278,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
if !context.TotpPending {
if !context.TOTPPending() {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{
"status": 401,
@@ -242,12 +287,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt")
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.Username)
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked {
tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts")
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{
@@ -257,14 +302,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
user := controller.auth.GetLocalUser(context.Username)
user := controller.auth.GetLocalUser(context.GetUsername())
ok := totp.Validate(req.Code, user.TotpSecret)
if !ok {
tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.Username, false)
tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code")
if user == nil {
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -272,10 +313,23 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.Username, "totp")
ok := totp.Validate(req.Code, user.TOTPSecret)
controller.auth.RecordLoginAttempt(context.Username, true)
if !ok {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
sessionCookie := repository.Session{
Username: user.Username,
@@ -293,7 +347,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -304,6 +358,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
+99 -65
View File
@@ -3,26 +3,29 @@ package controller_test
import (
"encoding/json"
"net/http/httptest"
"path"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
@@ -30,12 +33,12 @@ func TestUserController(t *testing.T) {
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
{
Username: "attruser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Attributes: config.UserAttributes{
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
@@ -43,8 +46,8 @@ func TestUserController(t *testing.T) {
{
Username: "attrtotpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: config.UserAttributes{
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
@@ -58,7 +61,54 @@ func TestUserController(t *testing.T) {
}
userControllerCfg := controller.UserControllerConfig{
CookieDomain: "example.com",
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
TOTPEnabled: true,
},
})
}
totpAttrCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "attrtotpuser",
Name: "Bob Jones",
Email: "bob@example.com",
},
TOTPPending: true,
TOTPEnabled: true,
},
})
}
simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Test User",
Email: "testuser@example.com",
},
},
})
}
type testCase struct {
@@ -91,7 +141,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 10, cookie.MaxAge)
assert.Equal(t, 9, cookie.MaxAge)
},
},
{
@@ -180,12 +230,14 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions
},
},
{
description: "Should be able to logout",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
loginReq := controller.LoginRequest{
@@ -201,9 +253,10 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1)
cookies := recorder.Result().Cookies()
assert.Len(t, cookies, 1)
cookie := recorder.Result().Cookies()[0]
cookie := cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
// Now logout using the session cookie
@@ -214,17 +267,20 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1)
cookies = recorder.Result().Cookies()
assert.Len(t, cookies, 1)
logoutCookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", logoutCookie.Name)
assert.Equal(t, "", logoutCookie.Value)
assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
cookie = cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.Equal(t, "", cookie.Value)
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
{
description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
assert.NoError(t, err)
@@ -250,12 +306,14 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", totpCookie.Name)
assert.True(t, totpCookie.HttpOnly)
assert.Equal(t, "example.com", totpCookie.Domain)
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time
},
},
{
description: "Totp should rate limit on multiple invalid attempts",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
totpReq := controller.TotpRequest{
@@ -325,7 +383,9 @@ func TestUserController(t *testing.T) {
},
{
description: "TOTP completion uses name and email from user attributes",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
totpAttrCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
@@ -346,12 +406,17 @@ func TestUserController(t *testing.T) {
},
}
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
store := memory.New()
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
docker := service.NewDockerService()
err := docker.Init()
err = docker.Init()
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{})
@@ -362,7 +427,7 @@ func TestUserController(t *testing.T) {
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, store, broker)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
@@ -371,33 +436,6 @@ func TestUserController(t *testing.T) {
authService.ClearRateLimitsTestingOnly()
}
setTotpMiddlewareOverrides := map[string]config.UserContext{
"Should be able to login with totp": {
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
"Totp should rate limit on multiple invalid attempts": {
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
"TOTP completion uses name and email from user attributes": {
Username: "attrtotpuser",
Name: "Bob Jones",
Email: "bob@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
}
for _, test := range tests {
beforeEach()
t.Run(test.description, func(t *testing.T) {
@@ -407,15 +445,6 @@ func TestUserController(t *testing.T) {
router.Use(middleware)
}
// Gin is stupid and doesn't allow setting a middleware after the groups
// so we need to do some stupid overrides here
if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok {
ctx := ctx
router.Use(func(c *gin.Context) {
c.Set("context", &ctx)
})
}
group := router.Group("/api")
gin.SetMode(gin.TestMode)
@@ -427,4 +456,9 @@ func TestUserController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
@@ -8,13 +8,14 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
func TestWellKnownController(t *testing.T) {
@@ -22,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
@@ -100,10 +101,15 @@ func TestWellKnownController(t *testing.T) {
},
}
store := memory.New()
app := bootstrap.NewBootstrapApp(model.Config{})
oidcService := service.NewOIDCService(oidcServiceCfg, store)
err := oidcService.Init()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
require.NoError(t, err)
for _, test := range tests {
@@ -119,4 +125,9 @@ func TestWellKnownController(t *testing.T) {
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
+171 -176
View File
@@ -1,10 +1,13 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -33,7 +36,8 @@ var (
)
type ContextMiddlewareConfig struct {
CookieDomain string
CookieDomain string
SessionCookieName string
}
type ContextMiddleware struct {
@@ -61,194 +65,41 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return
}
cookie, err := m.auth.GetSessionCookie(c)
uuid, err := c.Cookie(m.config.SessionCookieName)
if err != nil {
tlog.App.Debug().Err(err).Msg("No valid session cookie found")
goto basic
}
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
if cookie.TotpPending {
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
c.Next()
return
}
switch cookie.Provider {
case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
}
if userSearch.Type != cookie.Provider {
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
m.auth.DeleteSessionCookie(c)
c.Next()
return
}
var ldapGroups []string
var localAttributes config.UserAttributes
if cookie.Provider == "ldap" {
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
if err == nil {
if cookie != nil {
http.SetCookie(c.Writer, cookie)
}
ldapGroups = ldapUser.Groups
}
if cookie.Provider == "local" {
localUser := m.auth.GetLocalUser(cookie.Username)
localAttributes = localUser.Attributes
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
IsLoggedIn: true,
LdapGroups: strings.Join(ldapGroups, ","),
Attributes: localAttributes,
})
c.Next()
return
default:
_, exists := m.broker.GetService(cookie.Provider)
if !exists {
tlog.App.Debug().Msg("OAuth provider from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
}
if !m.auth.IsEmailWhitelisted(cookie.Email) {
tlog.App.Debug().Msg("Email from session cookie not whitelisted")
m.auth.DeleteSessionCookie(c)
goto basic
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
OAuthGroups: cookie.OAuthGroups,
OAuthName: cookie.OAuthName,
OAuthSub: cookie.OAuthSub,
IsLoggedIn: true,
OAuth: true,
})
c.Next()
return
}
basic:
basic := m.auth.GetBasicAuth(c)
if basic == nil {
tlog.App.Debug().Msg("No basic auth provided")
c.Next()
return
}
locked, remaining := m.auth.IsAccountLocked(basic.Username)
if locked {
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.Next()
return
}
userSearch := m.auth.SearchUser(basic.Username)
if userSearch.Type == "unknown" || userSearch.Type == "error" {
m.auth.RecordLoginAttempt(basic.Username, false)
tlog.App.Debug().Msg("User from basic auth not found")
c.Next()
return
}
if !m.auth.VerifyUser(userSearch, basic.Password) {
m.auth.RecordLoginAttempt(basic.Username, false)
tlog.App.Debug().Msg("Invalid password for basic auth user")
c.Next()
return
}
m.auth.RecordLoginAttempt(basic.Username, true)
switch userSearch.Type {
case "local":
tlog.App.Debug().Msg("Basic auth user is local")
user := m.auth.GetLocalUser(basic.Username)
if user.TotpSecret != "" {
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
c.Set("context", userContext)
c.Next()
return
} else {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
}
}
name := utils.Capitalize(user.Username)
if user.Attributes.Name != "" {
name = user.Attributes.Name
}
email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
if user.Attributes.Email != "" {
email = user.Attributes.Email
}
username, password, ok := c.Request.BasicAuth()
c.Set("context", &config.UserContext{
Username: user.Username,
Name: name,
Email: email,
Provider: "local",
IsLoggedIn: true,
IsBasicAuth: true,
Attributes: user.Attributes,
})
c.Next()
return
case "ldap":
tlog.App.Debug().Msg("Basic auth user is LDAP")
ldapUser, err := m.auth.GetLdapUser(basic.Username)
if ok {
userContext, headers, err := m.basicAuth(username, password)
if err != nil {
tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
c.Next()
return
}
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
IsBasicAuth: true,
})
for k, v := range headers {
c.Header(k, v)
}
c.Set("context", userContext)
c.Next()
return
}
@@ -257,6 +108,150 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
}
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
session, err := m.auth.GetSession(ctx, uuid)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
}
userContext, err := new(model.UserContext).NewFromSession(session)
if err != nil {
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
}
if userContext.Provider == model.ProviderLocal &&
userContext.Local.TOTPPending {
userContext.Local.TOTPEnabled = true
return userContext, nil, nil
}
switch userContext.Provider {
case model.ProviderLocal:
user := m.auth.GetLocalUser(userContext.Local.Username)
if user == nil {
return nil, nil, fmt.Errorf("local user not found")
}
userContext.Local.Attributes = user.Attributes
if userContext.Local.Attributes.Name == "" {
userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
}
if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
}
case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username)
if err != nil {
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
}
if search.Type != model.UserLDAP {
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
}
user, err := m.auth.GetLDAPUser(search.Username)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
}
userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID)
if !exists {
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
}
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
}
}
cookie, err := m.auth.RefreshSession(ctx, uuid)
if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
}
return userContext, cookie, nil
}
func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) {
headers := make(map[string]string)
userContext := new(model.UserContext)
locked, remaining := m.auth.IsAccountLocked(username)
if locked {
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
headers["x-tinyauth-lock-locked"] = "true"
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
return nil, headers, nil
}
search, err := m.auth.SearchUser(username)
if err != nil {
return nil, nil, fmt.Errorf("error searching for user: %w", err)
}
err = m.auth.CheckUserPassword(*search, password)
if err != nil {
m.auth.RecordLoginAttempt(username, false)
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
}
m.auth.RecordLoginAttempt(username, true)
switch search.Type {
case model.UserLocal:
user := m.auth.GetLocalUser(username)
if user.TOTPSecret != "" {
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
}
userContext.Local = &model.LocalContext{
BaseContext: model.BaseContext{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
},
Attributes: user.Attributes,
}
userContext.Provider = model.ProviderLocal
case model.UserLDAP:
user, err := m.auth.GetLDAPUser(username)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
}
userContext.LDAP = &model.LDAPContext{
BaseContext: model.BaseContext{
Username: username,
Name: utils.Capitalize(username),
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
},
Groups: user.Groups,
}
userContext.Provider = model.ProviderLDAP
}
userContext.Authenticated = true
return userContext, nil, nil
}
func (m *ContextMiddleware) isIgnorePath(path string) bool {
for _, prefix := range contextSkipPathsPrefix {
if strings.HasPrefix(path, prefix) {
@@ -0,0 +1,330 @@
package middleware_test
import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
"path"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
func TestContextMiddleware(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
middlewareCfg := middleware.ContextMiddlewareConfig{
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
}
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
t.Helper()
_, err := queries.CreateSession(context.Background(), params)
require.NoError(t, err)
}
type runArgs struct {
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
queries *repository.Queries
}
type testCase struct {
description string
run func(t *testing.T, args runArgs)
}
tests := []testCase{
{
description: "Skip path bypasses auth processing",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/healthz", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "No credentials yields no context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Valid session cookie sets authenticated local context",
run: func(t *testing.T, args runArgs) {
uuid := "session-valid-local"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "testuser",
Provider: "local",
Expiry: time.Now().Add(10 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
require.NotNil(t, userCtx.Local)
assert.False(t, userCtx.Local.TOTPEnabled)
},
},
{
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
run: func(t *testing.T, args runArgs) {
uuid := "session-totp-pending"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "totpuser",
Provider: "local",
TotpPending: true,
Expiry: time.Now().Add(60 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "totpuser", userCtx.GetUsername())
assert.False(t, userCtx.Authenticated)
require.NotNil(t, userCtx.Local)
assert.True(t, userCtx.Local.TOTPPending)
assert.True(t, userCtx.Local.TOTPEnabled)
},
},
{
description: "Unknown session cookie yields no context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Session for missing local user yields no context",
run: func(t *testing.T, args runArgs) {
uuid := "session-deleted-user"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "ghostuser",
Provider: "local",
Expiry: time.Now().Add(10 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Expired session cookie yields no context",
run: func(t *testing.T, args runArgs) {
uuid := "session-expired"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "testuser",
Provider: "local",
Expiry: time.Now().Add(-1 * time.Second).Unix(),
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Valid basic auth sets authenticated local context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
{
description: "Invalid basic auth password yields no context",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Basic auth is rejected for users with totp",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
userCtx, _ := args.do(req)
assert.Nil(t, userCtx)
},
},
{
description: "Locked account on basic auth sets lock headers",
run: func(t *testing.T, args runArgs) {
for range 3 {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
args.do(req)
}
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, recorder := args.do(req)
assert.Nil(t, userCtx)
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
},
},
{
description: "Cookie auth takes precedence over basic auth",
run: func(t *testing.T, args runArgs) {
uuid := "session-precedence"
seedSession(t, args.queries, repository.CreateSessionParams{
UUID: uuid,
Username: "testuser",
Provider: "local",
Expiry: time.Now().Add(10 * time.Second).Unix(),
CreatedAt: time.Now().Unix(),
})
req := httptest.NewRequest("GET", "/api/test", nil)
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
{
description: "Ensure fallback to basic auth when cookie is missing",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
}
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
err = contextMiddleware.Init()
require.NoError(t, err)
for _, test := range tests {
authService.ClearRateLimitsTestingOnly()
t.Run(test.description, func(t *testing.T) {
gin.SetMode(gin.TestMode)
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
var captured *model.UserContext
router := gin.New()
router.Use(contextMiddleware.Middleware())
handler := func(c *gin.Context) {
if val, exists := c.Get("context"); exists {
captured, _ = val.(*model.UserContext)
}
}
router.GET("/api/test", handler)
router.GET("/api/healthz", handler)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
return captured, recorder
}
test.run(t, runArgs{do: do, queries: queries})
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
}
@@ -1,11 +1,10 @@
package config
package model
// Default configuration
func NewDefaultConfiguration() *Config {
return &Config{
Database: DatabaseConfig{
Driver: "sqlite",
Path: "./tinyauth.db",
Path: "./tinyauth.db",
},
Analytics: AnalyticsConfig{
Enabled: true,
@@ -30,7 +29,7 @@ func NewDefaultConfiguration() *Config {
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
Ldap: LdapConfig{
LDAP: LDAPConfig{
Insecure: false,
SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes
@@ -64,20 +63,6 @@ func NewDefaultConfiguration() *Config {
}
}
// Version information, set at build time
var Version = "development"
var CommitHash = "development"
var BuildTimestamp = "0000-00-00T00:00:00Z"
// Cookie name templates
var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"
var OAuthSessionCookieName = "tinyauth-oauth"
// Main app config
type Config struct {
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
Database DatabaseConfig `description:"Database configuration." yaml:"database"`
@@ -89,15 +74,14 @@ type Config struct {
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"`
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
Log LogConfig `description:"Logging configuration." yaml:"log"`
}
type DatabaseConfig struct {
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
Path string `description:"The path to the database, including file name." yaml:"path"`
}
type AnalyticsConfig struct {
@@ -179,7 +163,7 @@ type UIConfig struct {
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
}
type LdapConfig struct {
type LDAPConfig struct {
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
@@ -212,20 +196,6 @@ type ExperimentalConfig struct {
ConfigFile string `description:"Path to config file." yaml:"-"`
}
// Config loader options
const DefaultNamePrefix = "TINYAUTH_"
// OAuth/OIDC config
type Claims struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
@@ -248,60 +218,6 @@ type OIDCClientConfig struct {
Name string `description:"Client name in UI." yaml:"name"`
}
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
}
// User/session related stuff
type User struct {
Username string
Password string
TotpSecret string
Attributes UserAttributes
}
type LdapUser struct {
DN string
Groups []string
}
type UserSearch struct {
Username string
Type string // local, ldap or unknown
}
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
OAuthGroups string
TotpEnabled bool
OAuthName string
OAuthSub string
LdapGroups string
Attributes UserAttributes
}
// API responses and queries
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
// ACLs
type Apps struct {
@@ -357,7 +273,3 @@ type AppPath struct {
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
}
// API server
var ApiServer = "https://api.tinyauth.app"
+23
View File
@@ -0,0 +1,23 @@
package model
const DefaultNamePrefix = "TINYAUTH_"
const APIServer = "https://api.tinyauth.app"
type Claims struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
}
const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth"
+251
View File
@@ -0,0 +1,251 @@
package model
import (
"errors"
"strings"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
type ProviderType int
const (
ProviderLocal ProviderType = iota
ProviderBasicAuth
ProviderOAuth
ProviderLDAP
)
type UserContext struct {
Authenticated bool
Provider ProviderType
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
}
type BaseContext struct {
Username string
Name string
Email string
}
type LocalContext struct {
BaseContext
TOTPPending bool
TOTPEnabled bool
Attributes UserAttributes
}
type OAuthContext struct {
BaseContext
Groups []string
Sub string
DisplayName string
ID string
}
type LDAPContext struct {
BaseContext
Groups []string
}
func (c *UserContext) IsAuthenticated() bool {
return c.Authenticated
}
func (c *UserContext) IsLocal() bool {
return c.Provider == ProviderLocal && c.Local != nil
}
func (c *UserContext) IsOAuth() bool {
return c.Provider == ProviderOAuth && c.OAuth != nil
}
func (c *UserContext) IsLDAP() bool {
return c.Provider == ProviderLDAP && c.LDAP != nil
}
func (c *UserContext) IsBasicAuth() bool {
return c.Provider == ProviderBasicAuth && c.Local != nil
}
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContextValue, exists := ginctx.Get("context")
if !exists {
return nil, errors.New("failed to get user context")
}
userContext, ok := userContextValue.(*UserContext)
if !ok || userContext == nil {
return nil, errors.New("invalid user context type")
}
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
return nil, errors.New("incomplete user context")
}
*c = *userContext
return c, nil
}
// Compatability layer until we get an excuse to drop in database migrations
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
}
switch session.Provider {
case "local":
c.Provider = ProviderLocal
c.Local = &LocalContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
TOTPPending: session.TotpPending,
}
case "ldap":
c.Provider = ProviderLDAP
c.LDAP = &LDAPContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
}
// By default we assume an unkown name which is oauth
default:
c.Provider = ProviderOAuth
c.OAuth = &OAuthContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
Groups: func() []string {
if session.OAuthGroups == "" {
return nil
}
return strings.Split(session.OAuthGroups, ",")
}(),
Sub: session.OAuthSub,
DisplayName: session.OAuthName,
ID: session.Provider,
}
}
return c, nil
}
func (c *UserContext) GetUsername() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Username
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Username
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Username
default:
return ""
}
}
func (c *UserContext) GetEmail() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Email
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Email
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Email
default:
return ""
}
}
func (c *UserContext) GetName() string {
switch c.Provider {
case ProviderLocal:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderLDAP:
if c.LDAP == nil {
return ""
}
return c.LDAP.Name
case ProviderBasicAuth:
if c.Local == nil {
return ""
}
return c.Local.Name
case ProviderOAuth:
if c.OAuth == nil {
return ""
}
return c.OAuth.Name
default:
return ""
}
}
func (c *UserContext) ProviderName() string {
switch c.Provider {
case ProviderBasicAuth, ProviderLocal:
return "local"
case ProviderLDAP:
return "ldap"
case ProviderOAuth:
return c.OAuth.DisplayName // compatability
default:
return "unknown"
}
}
func (c *UserContext) TOTPPending() bool {
if c.Provider == ProviderLocal && c.Local != nil {
return c.Local.TOTPPending
}
return false
}
func (c *UserContext) OAuthName() string {
if c.Provider == ProviderOAuth && c.OAuth != nil {
return c.OAuth.DisplayName
}
return ""
}
+276
View File
@@ -0,0 +1,276 @@
package model_test
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func TestContext(t *testing.T) {
newGinCtx := func(value any, set bool) *gin.Context {
c, _ := gin.CreateTestContext(httptest.NewRecorder())
if set {
c.Set("context", value)
}
return c
}
tests := []struct {
description string
context *model.UserContext
run func(*testing.T, *model.UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
})
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
expected: [2]any{model.ProviderLocal, true},
},
{
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
require.NoError(t, err)
return got.Authenticated
},
expected: false,
},
{
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
expected: model.ProviderLDAP,
},
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
})
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
},
{
description: "Local getters return BaseContext fields",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
},
{
description: "BasicAuth getters fall back to local fields",
context: &model.UserContext{
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
},
{
description: "LDAP getters return LDAP fields",
context: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
},
{
description: "OAuth getters return OAuth fields",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
},
{
description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "ldap",
},
{
description: "ProviderName returns OAuth DisplayName for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "GitHub",
},
{
description: "TOTPPending returns true when local context is pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: true,
},
{
description: "TOTPPending returns false when local context is not pending",
context: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "OAuthName returns DisplayName for ProviderOAuth",
context: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"},
},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
stored := &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
require.NoError(t, err)
return [2]any{got.Authenticated, got.GetUsername()}
},
expected: [2]any{true, "alice"},
},
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
expected: "failed to get user context",
},
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
expected: "invalid user context type",
},
{
description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
{
description: "Getters should not panic if provider context is empty",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"", "", ""},
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
assert.Equal(t, test.expected, test.run(t, test.context))
})
}
}
+25
View File
@@ -0,0 +1,25 @@
package model
type UserSearchType int
const (
UserLocal UserSearchType = iota
UserLDAP
)
type LDAPUser struct {
DN string
Groups []string
}
type LocalUser struct {
Username string
Password string
TOTPSecret string
Attributes UserAttributes
}
type UserSearch struct {
Username string
Type UserSearchType
}
+5
View File
@@ -0,0 +1,5 @@
package model
var Version = "development"
var CommitHash = "development"
var BuildTimestamp = "0000-00-00T00:00:00Z"
@@ -1,8 +1,8 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.0
// sqlc v1.30.0
package sqlite
package repository
import (
"context"
-241
View File
@@ -1,241 +0,0 @@
package memory
import (
"context"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, c := range s.oidcCodes {
if c.Sub == arg.Sub {
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
}
}
code := repository.OidcCode(arg)
s.oidcCodes[arg.CodeHash] = code
return code, nil
}
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcCode{}, repository.ErrNotFound
}
delete(s.oidcCodes, codeHash)
return c, nil
}
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcCode{}, repository.ErrNotFound
}
return c, nil
}
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.oidcCodes {
if c.Sub == sub {
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcCodes, codeHash)
return nil
}
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcCode
for k, c := range s.oidcCodes {
if c.ExpiresAt < expiresAt {
deleted = append(deleted, c)
delete(s.oidcCodes, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, t := range s.oidcTokens {
if t.Sub == arg.Sub {
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
}
}
tok := repository.OidcToken{
Sub: arg.Sub,
AccessTokenHash: arg.AccessTokenHash,
RefreshTokenHash: arg.RefreshTokenHash,
CodeHash: arg.CodeHash,
Scope: arg.Scope,
ClientID: arg.ClientID,
TokenExpiresAt: arg.TokenExpiresAt,
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
Nonce: arg.Nonce,
}
s.oidcTokens[arg.AccessTokenHash] = tok
return tok, nil
}
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.oidcTokens[accessTokenHash]
if !ok {
return repository.OidcToken{}, repository.ErrNotFound
}
return t, nil
}
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.RefreshTokenHash == refreshTokenHash {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.Sub == sub {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
delete(s.oidcTokens, k)
t.AccessTokenHash = arg.AccessTokenHash
t.RefreshTokenHash = arg.RefreshTokenHash
t.TokenExpiresAt = arg.TokenExpiresAt
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
s.oidcTokens[arg.AccessTokenHash] = t
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcTokens, accessTokenHash)
return nil
}
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.Sub == sub {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.CodeHash == codeHash {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcToken
for k, t := range s.oidcTokens {
if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
deleted = append(deleted, t)
delete(s.oidcTokens, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
u := repository.OidcUserinfo(arg)
s.oidcUsers[arg.Sub] = u
return u, nil
}
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
u, ok := s.oidcUsers[sub]
if !ok {
return repository.OidcUserinfo{}, repository.ErrNotFound
}
return u, nil
}
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcUsers, sub)
return nil
}
@@ -1,63 +0,0 @@
package memory
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess := repository.Session(arg)
s.sessions[arg.UUID] = sess
return sess, nil
}
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.sessions[uuid]
if !ok {
return repository.Session{}, repository.ErrNotFound
}
return sess, nil
}
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess, ok := s.sessions[arg.UUID]
if !ok {
return repository.Session{}, repository.ErrNotFound
}
sess.Username = arg.Username
sess.Email = arg.Email
sess.Name = arg.Name
sess.Provider = arg.Provider
sess.TotpPending = arg.TotpPending
sess.OAuthGroups = arg.OAuthGroups
sess.Expiry = arg.Expiry
sess.OAuthName = arg.OAuthName
sess.OAuthSub = arg.OAuthSub
s.sessions[arg.UUID] = sess
return sess, nil
}
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, uuid)
return nil
}
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.sessions {
if v.Expiry < expiry {
delete(s.sessions, k)
}
}
return nil
}
-27
View File
@@ -1,27 +0,0 @@
// Package memory provides an in-memory implementation of repository.Store for use in tests.
package memory
import (
"sync"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store is a thread-safe in-memory implementation of repository.Store.
type Store struct {
mu sync.RWMutex
sessions map[string]repository.Session
oidcCodes map[string]repository.OidcCode
oidcTokens map[string]repository.OidcToken
oidcUsers map[string]repository.OidcUserinfo
}
// New returns a new empty in-memory Store.
func New() repository.Store {
return &Store{
sessions: make(map[string]repository.Session),
oidcCodes: make(map[string]repository.OidcCode),
oidcTokens: make(map[string]repository.OidcToken),
oidcUsers: make(map[string]repository.OidcUserinfo),
}
}
+5 -89
View File
@@ -1,22 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository
// Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go.
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
type OidcCode struct {
Sub string
CodeHash string
@@ -62,7 +49,7 @@ type OidcUserinfo struct {
Address string
}
type CreateSessionParams struct {
type Session struct {
UUID string
Username string
Email string
@@ -75,74 +62,3 @@ type CreateSessionParams struct {
OAuthName string
OAuthSub string
}
type UpdateSessionParams struct {
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
UUID string
}
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
CodeHash string
Nonce string
}
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.0
// sqlc v1.30.0
// source: oidc_queries.sql
package sqlite
package repository
import (
"context"
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.0
// sqlc v1.30.0
// source: session_queries.sql
package sqlite
package repository
import (
"context"
-3
View File
@@ -1,3 +0,0 @@
package sqlite
//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
-64
View File
@@ -1,64 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.0
package sqlite
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
-224
View File
@@ -1,224 +0,0 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package sqlite
import (
"context"
"database/sql"
"errors"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
return repository.OidcCode(v)
}
func oidcTokenToRepo(v OidcToken) repository.OidcToken {
return repository.OidcToken(v)
}
func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo {
return repository.OidcUserinfo(v)
}
func sessionToRepo(v Session) repository.Session {
return repository.Session(v)
}
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = oidcCodeToRepo(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = oidcTokenToRepo(row)
}
return out, nil
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
r, err := s.q.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
-47
View File
@@ -1,47 +0,0 @@
package repository
import (
"context"
"errors"
)
// ErrNotFound is returned by Store methods when the requested record does not exist.
var ErrNotFound = errors.New("not found")
// Store is the interface that all storage drivers must implement.
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
// Future drivers (postgres, etc.) must return the shared types defined in this package.
type Store interface {
// Sessions
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
GetSession(ctx context.Context, uuid string) (Session, error)
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC codes
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
DeleteOidcCode(ctx context.Context, codeHash string) error
DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
}
+14 -12
View File
@@ -1,23 +1,22 @@
package service
import (
"errors"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
)
type LabelProvider interface {
GetLabels(appDomain string) (config.App, error)
GetLabels(appDomain string) (*model.App, error)
}
type AccessControlsService struct {
labelProvider LabelProvider
static map[string]config.App
static map[string]model.App
}
func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService {
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
labelProvider: labelProvider,
static: static,
@@ -28,26 +27,29 @@ func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
return config, nil
appAcls = &config
break // If we find a match by domain, we can stop searching
}
if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
return config, nil
appAcls = &config
break // If we find a match by app name, we can stop searching
}
}
return config.App{}, errors.New("no results")
return appAcls
}
func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
// First check in the static config
app, err := acls.lookupStaticACLs(domain)
app := acls.lookupStaticACLs(domain)
if err == nil {
if app != nil {
tlog.App.Debug().Msg("Using ACls from static configuration")
return app, nil
}
+165 -159
View File
@@ -2,14 +2,16 @@ package service
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -28,6 +30,10 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed
type OAuthURLParams struct {
@@ -67,7 +73,7 @@ type Lockdown struct {
}
type AuthServiceConfig struct {
Users []config.User
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
@@ -76,7 +82,7 @@ type AuthServiceConfig struct {
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
IP model.IPConfig
LDAPGroupsCacheTTL int
}
@@ -89,14 +95,14 @@ type AuthService struct {
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries repository.Store
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries repository.Store, oauthBroker *OAuthBrokerService) *AuthService {
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{
config: config,
loginAttempts: make(map[string]*LoginAttempt),
@@ -105,7 +111,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries reposit
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
}
}
}
func (auth *AuthService) Init() error {
@@ -113,79 +119,73 @@ func (auth *AuthService) Init() error {
return nil
}
func (auth *AuthService) SearchUser(username string) config.UserSearch {
if auth.GetLocalUser(username).Username != "" {
return config.UserSearch{
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
if auth.GetLocalUser(username) != nil {
return &model.UserSearch{
Username: username,
Type: "local",
}
Type: model.UserLocal,
}, nil
}
if auth.ldap.IsConfigured() {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "unknown",
}
return nil, fmt.Errorf("failed to get ldap user: %w", err)
}
return config.UserSearch{
return &model.UserSearch{
Username: userDN,
Type: "ldap",
}
Type: model.UserLDAP,
}, nil
}
return config.UserSearch{
Type: "unknown",
}
return nil, ErrUserNotFound
}
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
switch search.Type {
case "local":
case model.UserLocal:
user := auth.GetLocalUser(search.Username)
return auth.CheckPassword(user, password)
case "ldap":
if user == nil {
return ErrUserNotFound
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap.IsConfigured() {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
return false
return fmt.Errorf("failed to bind to ldap user: %w", err)
}
err = auth.ldap.BindService(true)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
return false
return fmt.Errorf("failed to bind to ldap service account: %w", err)
}
return true
return nil
}
default:
tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
return false
return errors.New("unknown user search type")
}
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
return false
return errors.New("user authentication failed")
}
func (auth *AuthService) GetLocalUser(username string) config.User {
for _, user := range auth.config.Users {
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.config.LocalUsers == nil {
return nil
}
for _, user := range *auth.config.LocalUsers {
if user.Username == username {
return user
return &user
}
}
tlog.App.Warn().Str("username", username).Msg("Local user not found")
return config.User{}
return nil
}
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() {
return config.LdapUser{}, errors.New("LDAP service not initialized")
return nil, errors.New("ldap service not configured")
}
auth.ldapGroupsMutex.RLock()
@@ -193,7 +193,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) {
return config.LdapUser{
return &model.LDAPUser{
DN: userDN,
Groups: entry.Groups,
}, nil
@@ -202,7 +202,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
}
auth.ldapGroupsMutex.Lock()
@@ -212,16 +212,12 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
}
auth.ldapGroupsMutex.Unlock()
return config.LdapUser{
return &model.LDAPUser{
DN: userDN,
Groups: groups,
}, nil
}
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
@@ -290,11 +286,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
uuid, err := uuid.NewRandom()
if err != nil {
return err
return nil, fmt.Errorf("failed to generate session uuid: %w", err)
}
var expiry int
@@ -305,6 +301,8 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
expiry = auth.config.SessionExpiry
}
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
session := repository.CreateSessionParams{
UUID: uuid.String(),
Username: data.Username,
@@ -313,34 +311,36 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
Provider: data.Provider,
TotpPending: data.TotpPending,
OAuthGroups: data.OAuthGroups,
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
Expiry: expiresAt.Unix(),
CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
}
_, err = auth.queries.CreateSession(c, session)
_, err = auth.queries.CreateSession(ctx, session)
if err != nil {
return err
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.config.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
return err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
return err
return nil, fmt.Errorf("failed to retrieve session: %w", err)
}
currentTime := time.Now().Unix()
@@ -354,12 +354,12 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
}
if session.Expiry-currentTime > refreshThreshold {
return nil
return nil, nil
}
newExpiry := session.Expiry + refreshThreshold
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
Username: session.Username,
Email: session.Email,
Name: session.Name,
@@ -373,122 +373,123 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
})
if err != nil {
return err
return nil, fmt.Errorf("failed to update session expiry: %w", err)
}
c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
return &http.Cookie{
Name: auth.config.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
return nil
}
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return err
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
}
err = auth.queries.DeleteSession(c, cookie)
if err != nil {
return err
}
c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.config.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return repository.Session{}, fmt.Errorf("session not found")
if errors.Is(err, sql.ErrNoRows) {
return nil, errors.New("session not found")
}
return repository.Session{}, err
return nil, err
}
currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
err = auth.queries.DeleteSession(c, cookie)
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
return nil, fmt.Errorf("session max lifetime exceeded")
}
}
if currentTime > session.Expiry {
err = auth.queries.DeleteSession(c, cookie)
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
return repository.Session{}, fmt.Errorf("session expired")
return nil, fmt.Errorf("session expired")
}
return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
}, nil
return &session, nil
}
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured()
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
if context.OAuth {
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users")
if utils.CheckFilter(acls.Users.Block, context.Username) {
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
tlog.App.Debug().Msg("Checking users")
return utils.CheckFilter(acls.Users.Allow, context.Username)
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
for id := range config.OverrideProviders {
if context.Provider == id {
tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
return true
}
if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
@@ -497,14 +498,19 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
@@ -513,10 +519,14 @@ func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContex
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
if acls == nil {
return true, nil
}
// Check for block list
if path.Block != "" {
regex, err := regexp.Compile(path.Block)
if acls.Path.Block != "" {
regex, err := regexp.Compile(acls.Path.Block)
if err != nil {
return true, err
@@ -528,8 +538,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
}
// Check for allow list
if path.Allow != "" {
regex, err := regexp.Compile(path.Allow)
if acls.Path.Allow != "" {
regex, err := regexp.Compile(acls.Path.Allow)
if err != nil {
return true, err
@@ -543,22 +553,14 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
return true, nil
}
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
username, password, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Debug().Msg("No basic auth provided")
return nil
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
if acls == nil {
return true
}
return &config.User{
Username: username,
Password: password,
}
}
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
@@ -593,8 +595,12 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
return true
}
func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
for _, bypassed := range acls.Bypass {
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
if acls == nil {
return false
}
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
@@ -673,21 +679,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return token, nil
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return config.Claims{}, err
return nil, err
}
if session.Token == nil {
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
return nil, fmt.Errorf("failed to get userinfo: %w", err)
}
return userinfo, nil
+12 -20
View File
@@ -4,7 +4,7 @@ import (
"context"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -51,56 +51,48 @@ func (docker *DockerService) Init() error {
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
if err != nil {
return nil, err
}
return containers, nil
return docker.client.ContainerList(docker.context, container.ListOptions{})
}
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
inspect, err := docker.client.ContainerInspect(docker.context, containerId)
if err != nil {
return container.InspectResponse{}, err
}
return inspect, nil
return docker.client.ContainerInspect(docker.context, containerId)
}
func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
return config.App{}, nil
return nil, nil
}
containers, err := docker.getContainers()
if err != nil {
return config.App{}, err
return nil, err
}
for _, ctr := range containers {
inspect, err := docker.inspectContainer(ctr.ID)
if err != nil {
return config.App{}, err
return nil, err
}
labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps")
if err != nil {
return config.App{}, err
return nil, err
}
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return appLabels, nil
return &appLabels, nil
}
if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return appLabels, nil
return &appLabels, nil
}
}
}
tlog.App.Debug().Msg("No matching container found, returning empty labels")
return config.App{}, nil
return nil, nil
}
+20 -17
View File
@@ -7,7 +7,7 @@ import (
"sync"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -32,7 +32,7 @@ type ingressAppKey struct {
type ingressApp struct {
domain string
appName string
app config.App
app model.App
}
type KubernetesService struct {
@@ -89,36 +89,38 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
}
}
func (k *KubernetesService) getByDomain(domain string) (config.App, bool) {
func (k *KubernetesService) getByDomain(domain string) *model.App {
k.mu.RLock()
defer k.mu.RUnlock()
if appKey, ok := k.domainIndex[domain]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps {
for i := range apps {
app := &apps[i]
if app.domain == domain && app.appName == appKey.appName {
return app.app, true
return &app.app
}
}
}
}
return config.App{}, false
return nil
}
func (k *KubernetesService) getByAppName(appName string) (config.App, bool) {
func (k *KubernetesService) getByAppName(appName string) *model.App {
k.mu.RLock()
defer k.mu.RUnlock()
if appKey, ok := k.appNameIndex[appName]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps {
for i := range apps {
app := &apps[i]
if app.appName == appName {
return app.app, true
return &app.app
}
}
}
}
return config.App{}, false
return nil
}
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
@@ -129,7 +131,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
k.removeIngress(namespace, name)
return
}
labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps")
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil {
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
k.removeIngress(namespace, name)
@@ -280,24 +282,25 @@ func (k *KubernetesService) Init() error {
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) {
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started {
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
return config.App{}, nil
return nil, nil
}
// First check cache
if app, found := k.getByDomain(appDomain); found {
app := k.getByDomain(appDomain)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil
}
appName := strings.SplitN(appDomain, ".", 2)[0]
if app, found := k.getByAppName(appName); found {
app = k.getByAppName(appName)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil
}
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
return config.App{}, nil
return nil, nil
}
+31 -31
View File
@@ -3,11 +3,11 @@ package service
import (
"testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
)
func TestKubernetesService(t *testing.T) {
@@ -20,69 +20,69 @@ func TestKubernetesService(t *testing.T) {
{
description: "Cache by domain returns app and misses unknown domain",
run: func(t *testing.T, svc *KubernetesService) {
app := config.App{Config: config.AppConfig{Domain: "foo.example.com"}}
app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "foo.example.com", appName: "foo", app: app},
})
got, ok := svc.getByDomain("foo.example.com")
require.True(t, ok)
got := svc.getByDomain("foo.example.com")
require.NotNil(t, got)
assert.Equal(t, "foo.example.com", got.Config.Domain)
_, ok = svc.getByDomain("notfound.example.com")
assert.False(t, ok)
got = svc.getByDomain("notfound.example.com")
assert.Nil(t, got)
},
},
{
description: "Cache by app name returns app and misses unknown name",
run: func(t *testing.T, svc *KubernetesService) {
app := config.App{Config: config.AppConfig{Domain: "bar.example.com"}}
app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "bar.example.com", appName: "bar", app: app},
})
got, ok := svc.getByAppName("bar")
require.True(t, ok)
got := svc.getByAppName("bar")
require.NotNil(t, got)
assert.Equal(t, "bar.example.com", got.Config.Domain)
_, ok = svc.getByAppName("notfound")
assert.False(t, ok)
got = svc.getByAppName("notfound")
assert.Nil(t, got)
},
},
{
description: "RemoveIngress clears domain and app name entries",
run: func(t *testing.T, svc *KubernetesService) {
app := config.App{Config: config.AppConfig{Domain: "baz.example.com"}}
app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "baz.example.com", appName: "baz", app: app},
})
svc.removeIngress("default", "my-ingress")
_, ok := svc.getByDomain("baz.example.com")
assert.False(t, ok)
_, ok = svc.getByAppName("baz")
assert.False(t, ok)
got := svc.getByDomain("baz.example.com")
assert.Nil(t, got)
got = svc.getByAppName("baz")
assert.Nil(t, got)
},
},
{
description: "AddIngressApps replaces stale entries for the same ingress",
run: func(t *testing.T, svc *KubernetesService) {
old := config.App{Config: config.AppConfig{Domain: "old.example.com"}}
old := model.App{Config: model.AppConfig{Domain: "old.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "old.example.com", appName: "old", app: old},
})
updated := config.App{Config: config.AppConfig{Domain: "new.example.com"}}
updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "new.example.com", appName: "new", app: updated},
})
_, ok := svc.getByDomain("old.example.com")
assert.False(t, ok)
got := svc.getByDomain("old.example.com")
assert.Nil(t, got)
got, ok := svc.getByDomain("new.example.com")
require.True(t, ok)
got = svc.getByDomain("new.example.com")
require.NotNil(t, got)
assert.Equal(t, "new.example.com", got.Config.Domain)
},
},
@@ -91,7 +91,7 @@ func TestKubernetesService(t *testing.T) {
run: func(t *testing.T, svc *KubernetesService) {
svc.started = true
app := config.App{Config: config.AppConfig{Domain: "hit.example.com"}}
app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}}
svc.addIngressApps("default", "ing", []ingressApp{
{domain: "hit.example.com", appName: "hit", app: app},
})
@@ -108,7 +108,7 @@ func TestKubernetesService(t *testing.T) {
got, err := svc.GetLabels("notfound.example.com")
require.NoError(t, err)
assert.Equal(t, config.App{}, got)
assert.Nil(t, got)
},
},
{
@@ -116,7 +116,7 @@ func TestKubernetesService(t *testing.T) {
run: func(t *testing.T, svc *KubernetesService) {
svc.started = true
app := config.App{Config: config.AppConfig{Domain: "myapp.internal.example.com"}}
app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}}
svc.addIngressApps("default", "ing", []ingressApp{
{domain: "myapp.internal.example.com", appName: "myapp", app: app},
})
@@ -131,7 +131,7 @@ func TestKubernetesService(t *testing.T) {
run: func(t *testing.T, svc *KubernetesService) {
got, err := svc.GetLabels("anything.example.com")
require.NoError(t, err)
assert.Equal(t, config.App{}, got)
assert.Nil(t, got)
},
},
{
@@ -147,8 +147,8 @@ func TestKubernetesService(t *testing.T) {
svc.updateFromItem(&item)
got, ok := svc.getByDomain("myapp.example.com")
require.True(t, ok)
got := svc.getByDomain("myapp.example.com")
require.NotNil(t, got)
assert.Equal(t, "myapp.example.com", got.Config.Domain)
assert.Equal(t, "alice", got.Users.Allow)
},
@@ -156,7 +156,7 @@ func TestKubernetesService(t *testing.T) {
{
description: "UpdateFromItem with no annotations removes existing cache entries",
run: func(t *testing.T, svc *KubernetesService) {
app := config.App{Config: config.AppConfig{Domain: "todelete.example.com"}}
app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}}
svc.addIngressApps("default", "test-ingress", []ingressApp{
{domain: "todelete.example.com", appName: "todelete", app: app},
})
@@ -167,8 +167,8 @@ func TestKubernetesService(t *testing.T) {
svc.updateFromItem(&item)
_, ok := svc.getByDomain("todelete.example.com")
assert.False(t, ok)
got := svc.getByDomain("todelete.example.com")
assert.Nil(t, got)
},
},
}
+5 -5
View File
@@ -1,7 +1,7 @@
package service
import (
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"slices"
@@ -15,20 +15,20 @@ type OAuthServiceImpl interface {
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (config.Claims, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
}
type OAuthBrokerService struct {
services map[string]OAuthServiceImpl
configs map[string]config.OAuthServiceConfig
configs map[string]model.OAuthServiceConfig
}
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
services: make(map[string]OAuthServiceImpl),
configs: configs,
+31 -21
View File
@@ -8,12 +8,13 @@ import (
"net/http"
"strconv"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
type GithubUserInfoResponse struct {
@@ -22,32 +23,32 @@ type GithubUserInfoResponse struct {
ID int `json:"id"`
}
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
return simpleReq[config.Claims](client, url, nil)
func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
return simpleReq[model.Claims](client, url, nil)
}
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user config.Claims
func githubExtractor(client *http.Client, url string) (*model.Claims, error) {
var user model.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
return nil, err
}
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
return nil, err
}
if len(userEmails) == 0 {
return user, errors.New("no emails found")
if len(*userEmails) == 0 {
return nil, errors.New("no emails found")
}
for _, email := range userEmails {
for _, email := range *userEmails {
if email.Primary {
user.Email = email.Email
break
@@ -56,22 +57,31 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = userEmails[0].Email
for _, email := range *userEmails {
if email.Verified {
user.Email = email.Email
break
}
}
}
if user.Email == "" {
return nil, errors.New("no verified email found")
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
return &user, nil
}
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) {
var decodedRes T
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return decodedRes, err
return nil, err
}
for key, value := range headers {
@@ -80,23 +90,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
res, err := client.Do(req)
if err != nil {
return decodedRes, err
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
return nil, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return decodedRes, err
return nil, err
}
err = json.Unmarshal(body, &decodedRes)
if err != nil {
return decodedRes, err
return nil, err
}
return decodedRes, nil
return &decodedRes, nil
}
+3 -3
View File
@@ -1,11 +1,11 @@
package service
import (
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
@@ -14,7 +14,7 @@ func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
return NewOAuthService(config, "google")
}
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
+5 -5
View File
@@ -6,21 +6,21 @@ import (
"net/http"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2"
)
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type OAuthService struct {
serviceCfg config.OAuthServiceConfig
serviceCfg model.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor UserinfoExtractor
id string
}
func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
@@ -78,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
}
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
+69 -66
View File
@@ -7,6 +7,7 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/json"
"encoding/pem"
@@ -21,7 +22,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -67,27 +68,27 @@ type ClaimSet struct {
}
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
Email string `json:"email,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
Address *config.AddressClaim `json:"address,omitempty"`
UpdatedAt int64 `json:"updated_at"`
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
Email string `json:"email,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
Address *model.AddressClaim `json:"address,omitempty"`
UpdatedAt int64 `json:"updated_at"`
}
type TokenResponse struct {
@@ -111,7 +112,7 @@ type AuthorizeRequest struct {
}
type OIDCServiceConfig struct {
Clients map[string]config.OIDCClientConfig
Clients map[string]model.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
@@ -120,15 +121,15 @@ type OIDCServiceConfig struct {
type OIDCService struct {
config OIDCServiceConfig
queries repository.Store
clients map[string]config.OIDCClientConfig
queries *repository.Queries
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
isConfigured bool
}
func NewOIDCService(config OIDCServiceConfig, queries repository.Store) *OIDCService {
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
@@ -254,7 +255,7 @@ func (service *OIDCService) Init() error {
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]config.OIDCClientConfig)
service.clients = make(map[string]model.OIDCClientConfig)
for id, client := range service.config.Clients {
client.ID = id
@@ -282,7 +283,7 @@ func (service *OIDCService) GetIssuer() string {
return service.issuer
}
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) {
client, ok := service.clients[id]
return client, ok
}
@@ -366,43 +367,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
addressJSON, err := json.Marshal(userContext.Attributes.Address)
if err != nil {
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
Name: userContext.GetName(),
Email: userContext.GetEmail(),
PreferredUsername: userContext.GetUsername(),
UpdatedAt: time.Now().Unix(),
GivenName: userContext.Attributes.GivenName,
FamilyName: userContext.Attributes.FamilyName,
MiddleName: userContext.Attributes.MiddleName,
Nickname: userContext.Attributes.Nickname,
Profile: userContext.Attributes.Profile,
Picture: userContext.Attributes.Picture,
Website: userContext.Attributes.Website,
Gender: userContext.Attributes.Gender,
Birthdate: userContext.Attributes.Birthdate,
Zoneinfo: userContext.Attributes.Zoneinfo,
Locale: userContext.Attributes.Locale,
PhoneNumber: userContext.Attributes.PhoneNumber,
Address: string(addressJSON),
}
if userContext.IsLocal() {
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
if err != nil {
return err
}
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
userInfoParams.Profile = userContext.Local.Attributes.Profile
userInfoParams.Picture = userContext.Local.Attributes.Picture
userInfoParams.Website = userContext.Local.Attributes.Website
userInfoParams.Gender = userContext.Local.Attributes.Gender
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfoParams.Locale = userContext.Local.Attributes.Locale
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfoParams.Address = string(addressJSON)
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
if userContext.IsLDAP() {
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
if userContext.IsOAuth() {
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
}
_, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
@@ -419,7 +422,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
if errors.Is(err, sql.ErrNoRows) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
@@ -444,7 +447,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
@@ -510,7 +513,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil {
@@ -563,7 +566,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
if errors.Is(err, sql.ErrNoRows) {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
@@ -584,7 +587,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
return TokenResponse{}, err
}
idToken, err := service.generateIDToken(config.OIDCClientConfig{
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, user, entry.Scope, entry.Nonce)
@@ -642,7 +645,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
if errors.Is(err, sql.ErrNoRows) {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
@@ -713,7 +716,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
}
if slices.Contains(scopes, "address") {
var addr config.AddressClaim
var addr model.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr
}
@@ -730,15 +733,15 @@ func (service *OIDCService) Hash(token string) string {
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
return nil
@@ -783,7 +786,7 @@ func (service *OIDCService) Cleanup() {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
if errors.Is(err, sql.ErrNoRows) {
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
+2 -2
View File
@@ -7,13 +7,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
)
func newTestUser() repository.OidcUserinfo {
addr := config.AddressClaim{
addr := model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
-18
View File
@@ -7,10 +7,8 @@ import (
"net/url"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
"github.com/weppos/publicsuffix-go/publicsuffix"
)
@@ -73,22 +71,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
return res
}
func GetContext(c *gin.Context) (config.UserContext, error) {
userContextValue, exists := c.Get("context")
if !exists {
return config.UserContext{}, errors.New("no user context in request")
}
userContext, ok := userContextValue.(*config.UserContext)
if !ok {
return config.UserContext{}, errors.New("invalid user context in request")
}
return *userContext, nil
}
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
+20 -45
View File
@@ -3,11 +3,8 @@ package utils_test
import (
"testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"gotest.tools/v3/assert"
)
func TestGetRootDomain(t *testing.T) {
@@ -15,14 +12,14 @@ func TestGetRootDomain(t *testing.T) {
domain := "http://sub.tinyauth.app"
expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain with multiple subdomains
domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Invalid domain (only TLD)
@@ -44,14 +41,14 @@ func TestGetRootDomain(t *testing.T) {
domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// URL with port
domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, expected, result)
// Domain managed by ICANN
@@ -98,57 +95,35 @@ func TestFilter(t *testing.T) {
testFunc := func(n int) bool { return n%2 == 0 }
expected := []int{2, 4}
result := utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result)
assert.Equal(t, expected, result)
// Case with no matches
slice = []int{1, 3, 5}
testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{}
result = utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result)
assert.Equal(t, expected, result)
// Case with all matches
slice = []int{2, 4, 6}
testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{2, 4, 6}
result = utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result)
assert.Equal(t, expected, result)
// Case with empty slice
slice = []int{}
testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{}
result = utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result)
assert.Equal(t, expected, result)
// Case with different type (string)
sliceStr := []string{"apple", "banana", "cherry"}
testFuncStr := func(s string) bool { return len(s) > 5 }
expectedStr := []string{"banana", "cherry"}
resultStr := utils.Filter(sliceStr, testFuncStr)
assert.DeepEqual(t, expectedStr, resultStr)
}
func TestGetContext(t *testing.T) {
// Setup
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(nil)
// Normal case
c.Set("context", &config.UserContext{Username: "testuser"})
result, err := utils.GetContext(c)
assert.NilError(t, err)
assert.Equal(t, "testuser", result.Username)
// Case with no context
c.Set("context", nil)
_, err = utils.GetContext(c)
assert.Error(t, err, "invalid user context in request")
// Case with invalid context type
c.Set("context", "invalid type")
_, err = utils.GetContext(c)
assert.Error(t, err, "invalid user context in request")
assert.Equal(t, expectedStr, resultStr)
}
func TestIsRedirectSafe(t *testing.T) {
@@ -158,50 +133,50 @@ func TestIsRedirectSafe(t *testing.T) {
// Case with no subdomain
redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result)
assert.True(t, result)
// Case with different domain
redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
assert.False(t, result)
// Case with subdomain
redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result)
assert.True(t, result)
// Case with sub-subdomain
redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result)
assert.True(t, result)
// Case with empty redirect URL
redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
assert.False(t, result)
// Case with invalid URL
redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
assert.False(t, result)
// Case with URL having port
redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result)
assert.True(t, result)
// Case with URL having different subdomain
redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result)
assert.True(t, result)
// Case with URL having different TLD
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
assert.False(t, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
assert.False(t, result)
}
+14 -15
View File
@@ -3,42 +3,41 @@ package decoders_test
import (
"testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"gotest.tools/v3/assert"
)
func TestDecodeLabels(t *testing.T) {
// Variables
expected := config.Apps{
Apps: map[string]config.App{
expected := model.Apps{
Apps: map[string]model.App{
"foo": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "example.com",
},
Users: config.AppUsers{
Users: model.AppUsers{
Allow: "user1,user2",
Block: "user3",
},
OAuth: config.AppOAuth{
OAuth: model.AppOAuth{
Whitelist: "somebody@example.com",
Groups: "group3",
},
IP: config.AppIP{
IP: model.AppIP{
Allow: []string{"10.71.0.1/24", "10.71.0.2"},
Block: []string{"10.10.10.10", "10.0.0.0/24"},
Bypass: []string{"192.168.1.1"},
},
Response: config.AppResponse{
Response: model.AppResponse{
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
BasicAuth: config.AppBasicAuth{
BasicAuth: model.AppBasicAuth{
Username: "admin",
Password: "password",
PasswordFile: "/path/to/passwordfile",
},
},
Path: config.AppPath{
Path: model.AppPath{
Allow: "/public",
Block: "/private",
},
@@ -63,7 +62,7 @@ func TestDecodeLabels(t *testing.T) {
}
// Test
result, err := decoders.DecodeLabels[config.Apps](test, "apps")
assert.NilError(t, err)
assert.DeepEqual(t, expected, result)
result, err := decoders.DecodeLabels[model.Apps](test, "apps")
assert.NoError(t, err)
assert.Equal(t, expected, result)
}
+6 -5
View File
@@ -4,24 +4,25 @@ import (
"os"
"testing"
"gotest.tools/v3/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadFile(t *testing.T) {
// Setup
file, err := os.Create("/tmp/tinyauth_test_file")
assert.NilError(t, err)
require.NoError(t, err)
_, err = file.WriteString("file content\n")
assert.NilError(t, err)
require.NoError(t, err)
err = file.Close()
assert.NilError(t, err)
require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_file")
// Normal case
content, err := ReadFile("/tmp/tinyauth_test_file")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, "file content\n", content)
// Non-existing file
+6 -7
View File
@@ -3,9 +3,8 @@ package utils_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
)
func TestParseHeaders(t *testing.T) {
@@ -18,7 +17,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value",
"Another-Header": "AnotherValue",
}
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
assert.Equal(t, expected, utils.ParseHeaders(headers))
// Case insensitivity and trimming
headers = []string{
@@ -29,7 +28,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value",
"Another-Header": "AnotherValue",
}
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
assert.Equal(t, expected, utils.ParseHeaders(headers))
// Invalid headers (missing '=', empty key/value)
headers = []string{
@@ -39,7 +38,7 @@ func TestParseHeaders(t *testing.T) {
" = ",
}
expected = map[string]string{}
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
assert.Equal(t, expected, utils.ParseHeaders(headers))
// Headers with unsafe characters
headers = []string{
@@ -52,7 +51,7 @@ func TestParseHeaders(t *testing.T) {
"Another-Header": "AnotherValue",
"Good-Header": "GoodValue",
}
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
assert.Equal(t, expected, utils.ParseHeaders(headers))
// Header with spaces in key (should be ignored)
headers = []string{
@@ -62,7 +61,7 @@ func TestParseHeaders(t *testing.T) {
expected = map[string]string{
"Valid-Header": "ValidValue",
}
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
assert.Equal(t, expected, utils.ParseHeaders(headers))
}
func TestSanitizeHeader(t *testing.T) {
+3 -4
View File
@@ -4,21 +4,20 @@ import (
"fmt"
"os"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/paerser/env"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type EnvLoader struct{}
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration)
vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration)
if len(vars) == 0 {
return false, nil
}
if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil {
if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil {
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
}
+1 -1
View File
@@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string {
return ""
}
func GetBasicAuth(username string, password string) string {
func EncodeBasicAuth(username string, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
+15 -15
View File
@@ -4,21 +4,21 @@ import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
)
func TestGetSecret(t *testing.T) {
// Setup
file, err := os.Create("/tmp/tinyauth_test_secret")
assert.NilError(t, err)
require.NoError(t, err)
_, err = file.WriteString(" secret \n")
assert.NilError(t, err)
require.NoError(t, err)
err = file.Close()
assert.NilError(t, err)
require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret")
// Get from config
@@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) {
assert.Equal(t, "", utils.ParseSecretFile(content))
}
func TestGetBasicAuth(t *testing.T) {
func TestEncodeBasicAuth(t *testing.T) {
// Normal case
username := "user"
password := "pass"
expected := "dXNlcjpwYXNz" // base64 of "user:pass"
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
// Empty username
username = ""
password = "pass"
expected = "OnBhc3M=" // base64 of ":pass"
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
// Empty password
username = "user"
password = ""
expected = "dXNlcjo=" // base64 of "user:"
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
}
func TestFilterIP(t *testing.T) {
// Exact match IPv4
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, true, ok)
// Non-match IPv4
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, false, ok)
// CIDR match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, true, ok)
// CIDR match IPv4 with '-' instead of '/'
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, true, ok)
// CIDR non-match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, false, ok)
// Invalid CIDR
@@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) {
// Different output for different input
id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3)
assert.NotEqual(t, id2, id3)
}
+1 -2
View File
@@ -3,9 +3,8 @@ package utils_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
)
func TestCapitalize(t *testing.T) {
+13 -13
View File
@@ -7,7 +7,7 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
@@ -22,7 +22,7 @@ var (
App zerolog.Logger
)
func NewLogger(cfg config.LogConfig) *Logger {
func NewLogger(cfg model.LogConfig) *Logger {
baseLogger := log.With().
Timestamp().
Caller().
@@ -44,24 +44,24 @@ func NewLogger(cfg config.LogConfig) *Logger {
}
func NewSimpleLogger() *Logger {
return NewLogger(config.LogConfig{
return NewLogger(model.LogConfig{
Level: "info",
Json: false,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: false},
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
}
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
return NewLogger(model.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
}
@@ -72,7 +72,7 @@ func (l *Logger) Init() {
App = l.App
}
func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled {
return zerolog.Nop()
}
+30 -30
View File
@@ -5,75 +5,75 @@ import (
"encoding/json"
"testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog"
"gotest.tools/v3/assert"
)
func TestNewLogger(t *testing.T) {
cfg := config.LogConfig{
cfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true, Level: "info"},
App: config.LogStreamConfig{Enabled: true, Level: ""},
Audit: config.LogStreamConfig{Enabled: false, Level: ""},
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
App: model.LogStreamConfig{Enabled: true, Level: ""},
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
},
}
logger := tlog.NewLogger(cfg)
assert.Assert(t, logger != nil)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel)
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger()
assert.Assert(t, logger != nil)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel)
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger()
logger.Init()
assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled)
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
}
func TestLoggerWithDisabledStreams(t *testing.T) {
cfg := config.LogConfig{
cfg := model.LogConfig{
Level: "info",
Json: false,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: false},
App: config.LogStreamConfig{Enabled: false},
Audit: config.LogStreamConfig{Enabled: false},
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: false},
Audit: model.LogStreamConfig{Enabled: false},
},
}
logger := tlog.NewLogger(cfg)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled)
assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled)
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLogStreamField(t *testing.T) {
var buf bytes.Buffer
cfg := config.LogConfig{
cfg := model.LogConfig{
Level: "info",
Json: true,
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
@@ -86,7 +86,7 @@ func TestLogStreamField(t *testing.T) {
var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"])
+16 -16
View File
@@ -6,14 +6,14 @@ import (
"net/mail"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
)
func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
var users []config.User
func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
var users []model.LocalUser
if len(usersStr) == 0 {
return []config.User{}, nil
return &users, nil
}
for _, user := range usersStr {
@@ -22,22 +22,22 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut
}
parsed, err := ParseUser(strings.TrimSpace(user))
if err != nil {
return []config.User{}, err
return nil, err
}
if attrs, ok := userAttributes[parsed.Username]; ok {
parsed.Attributes = attrs
}
users = append(users, parsed)
users = append(users, *parsed)
}
return users, nil
return &users, nil
}
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
var usersStr []string
if len(usersCfg) == 0 && usersPath == "" {
return []config.User{}, nil
return nil, nil
}
if len(usersCfg) > 0 {
@@ -48,7 +48,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]con
contents, err := ReadFile(usersPath)
if err != nil {
return []config.User{}, err
return nil, err
}
lines := strings.SplitSeq(contents, "\n")
@@ -65,7 +65,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]con
return ParseUsers(usersStr, userAttributes)
}
func ParseUser(userStr string) (config.User, error) {
func ParseUser(userStr string) (*model.LocalUser, error) {
if strings.Contains(userStr, "$$") {
userStr = strings.ReplaceAll(userStr, "$$", "$")
}
@@ -73,27 +73,27 @@ func ParseUser(userStr string) (config.User, error) {
parts := strings.SplitN(userStr, ":", 4)
if len(parts) < 2 || len(parts) > 3 {
return config.User{}, errors.New("invalid user format")
return nil, errors.New("invalid user format")
}
for i, part := range parts {
trimmed := strings.TrimSpace(part)
if trimmed == "" {
return config.User{}, errors.New("invalid user format")
return nil, errors.New("invalid user format")
}
parts[i] = trimmed
}
user := config.User{
user := model.LocalUser{
Username: parts[0],
Password: parts[1],
}
if len(parts) == 3 {
user.TotpSecret = parts[2]
user.TOTPSecret = parts[2]
}
return user, nil
return &user, nil
}
func CompileUserEmail(username string, domain string) string {
+47 -47
View File
@@ -4,74 +4,76 @@ import (
"os"
"testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
)
func TestGetUsers(t *testing.T) {
tmpDir := t.TempDir()
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
// Setup
file, err := os.Create("/tmp/tinyauth_users_test.txt")
assert.NilError(t, err)
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
require.NoError(t, err)
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
assert.NilError(t, err)
require.NoError(t, err)
err = file.Close()
assert.NilError(t, err)
defer os.Remove("/tmp/tinyauth_users_test.txt")
require.NoError(t, err)
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
noAttrs := map[string]config.UserAttributes{}
noAttrs := map[string]model.UserAttributes{}
// Test file only
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
assert.NilError(t, err)
assert.NoError(t, err)
assert.NotNil(t, users)
assert.Len(t, *users, 2)
assert.Equal(t, 2, len(users))
assert.Equal(t, "user1", users[0].Username)
assert.Equal(t, hash, users[0].Password)
assert.Equal(t, "user2", users[1].Username)
assert.Equal(t, hash, users[1].Password)
assert.Equal(t, "user1", (*users)[0].Username)
assert.Equal(t, hash, (*users)[0].Password)
assert.Equal(t, "user2", (*users)[1].Username)
assert.Equal(t, hash, (*users)[1].Password)
// Test inline config only
users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, 2, len(users))
assert.Equal(t, "user3", users[0].Username)
assert.Equal(t, "user4", users[1].Username)
assert.Len(t, *users, 2)
assert.Equal(t, "user3", (*users)[0].Username)
assert.Equal(t, "user4", (*users)[1].Username)
// Test both
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, 3, len(users))
assert.Len(t, *users, 3)
usernames := map[string]bool{}
for _, u := range users {
for _, u := range *users {
usernames[u.Username] = true
}
assert.Assert(t, usernames["user1"])
assert.Assert(t, usernames["user2"])
assert.Assert(t, usernames["user5"])
assert.True(t, usernames["user1"])
assert.True(t, usernames["user2"])
assert.True(t, usernames["user5"])
// Test attributes applied from userAttributes map
attrs := map[string]config.UserAttributes{
attrs := map[string]model.UserAttributes{
"user1": {Name: "User One", Email: "user1@example.com"},
}
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
assert.NilError(t, err)
assert.Equal(t, 2, len(users))
assert.NoError(t, err)
assert.Len(t, *users, 2)
for _, u := range users {
for _, u := range *users {
if u.Username == "user1" {
assert.Equal(t, "User One", u.Attributes.Name)
assert.Equal(t, "user1@example.com", u.Attributes.Email)
@@ -84,16 +86,14 @@ func TestGetUsers(t *testing.T) {
// Test empty
users, err = utils.GetUsers([]string{}, "", noAttrs)
assert.NilError(t, err)
assert.Equal(t, 0, len(users))
assert.NoError(t, err)
assert.Nil(t, users)
// Test non-existent file
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
assert.ErrorContains(t, err, "no such file or directory")
assert.Equal(t, 0, len(users))
assert.Nil(t, users)
}
func TestParseUser(t *testing.T) {
@@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) {
// Valid user without TOTP
user, err := utils.ParseUser("user1:" + hash)
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, "user1", user.Username)
assert.Equal(t, hash, user.Password)
assert.Equal(t, "", user.TotpSecret)
assert.Equal(t, "", user.TOTPSecret)
// Valid user with TOTP
user, err = utils.ParseUser("user2:" + hash + ":ABCDEF")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, "user2", user.Username)
assert.Equal(t, hash, user.Password)
assert.Equal(t, "ABCDEF", user.TotpSecret)
assert.Equal(t, "ABCDEF", user.TOTPSecret)
// Valid user with $$ in password
user, err = utils.ParseUser("user3:pa$$word123")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, "user3", user.Username)
assert.Equal(t, "pa$word123", user.Password)
assert.Equal(t, "", user.TotpSecret)
assert.Equal(t, "", user.TOTPSecret)
// User with spaces
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, "user4", user.Username)
assert.Equal(t, "password123", user.Password)
assert.Equal(t, "TOTPSECRET", user.TotpSecret)
assert.Equal(t, "TOTPSECRET", user.TOTPSecret)
// Invalid users
_, err = utils.ParseUser("user1") // Missing password
+4 -4
View File
@@ -1,12 +1,12 @@
version: "2"
sql:
- engine: "sqlite"
queries: "sql/sqlite/*_queries.sql"
schema: "sql/sqlite/*_schemas.sql"
queries: "sql/*_queries.sql"
schema: "sql/*_schemas.sql"
gen:
go:
package: "sqlite"
out: "internal/repository/sqlite"
package: "repository"
out: "internal/repository"
rename:
uuid: "UUID"
oauth_groups: "OAuthGroups"