Compare commits

..

7 Commits

Author SHA1 Message Date
Scott McKendry e2fd60529c test(db): add memory driver tests 2026-05-11 17:29:43 +12:00
Scott McKendry 68f04ec66a refactor(db): cleanup sqlc-wrapper gen 2026-05-11 17:29:43 +12:00
Scott McKendry 36bfcd45c1 feat(db): add memory storage driver
removes the sqlite dependency for tests, also brings back the option for
users to run zero persistence instances of tinyauth.

adds new mapErr fn for sqlc wrapper gen to prevent sql errors from
leaking out of the store implementation.
2026-05-11 17:29:43 +12:00
Scott McKendry 03cebb8dba feat(db): add code gen to build sqlc-compatible wrappers 2026-05-11 17:25:23 +12:00
Scott McKendry ad6751df2a refactor(db): use new store interface 2026-05-11 17:24:56 +12:00
dependabot[bot] a6351790c3 chore(deps): bump github/codeql-action from 4.35.3 to 4.35.4 (#842)
Bumps [github/codeql-action](https://github.com/github/codeql-action) from 4.35.3 to 4.35.4.
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/github/codeql-action/compare/e46ed2cbd01164d986452f91f178727624ae40d7...68bde559dea0fdcac2102bfdf6230c5f70eb485e)

---
updated-dependencies:
- dependency-name: github/codeql-action
  dependency-version: 4.35.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-10 16:36:01 +03:00
Stavros 4f7335ed73 refactor: rework app logging, dependency injection and cancellation (#844)
* feat: add new logger

* refactor: use one struct for context handling and cancellation

* refactor: rework logging and config in controllers

* refactor: rework logging and config in middlewares

* refactor: rework logging and cancellation in services

* refactor: rework cli logging

* fix: improve logging in routines

* feat: use sync groups for better cancellation

* refactor: simplify middleware, controller and service init

* tests: fix controller tests

* tests: use require instead of assert where previous step is required

* tests: fix middleware tests

* tests: fix service tests

* tests: fix context tests

* fix: fix typos

* feat: add option to enable or disable concurrent listeners

* fix: assign public key correctly in oidc server

* tests: fix don't try to test logger with char size

* fix: coderabbit comments

* tests: use filepath join instead of path join

* fix: ensure unix socket shutdown doesn't run twice

* chore: remove temp lint file
2026-05-10 16:10:36 +03:00
99 changed files with 3686 additions and 1813 deletions
+6 -5
View File
@@ -26,6 +26,12 @@ jobs:
- name: Go dependencies - name: Go dependencies
run: go mod download run: go mod download
- name: Check codegen is up to date
run: |
go generate ./internal/repository/...
git diff --exit-code -- internal/repository/
git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true
- name: Install frontend dependencies - name: Install frontend dependencies
run: | run: |
cd frontend cd frontend
@@ -49,11 +55,6 @@ jobs:
run: | run: |
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
- name: Lint backend
uses: golangci/golangci-lint-action@v9
with:
version: v2.12
- name: Run tests - name: Run tests
run: go test -coverprofile=coverage.txt -v ./... run: go test -coverprofile=coverage.txt -v ./...
+1 -1
View File
@@ -38,6 +38,6 @@ jobs:
retention-days: 5 retention-days: 5
- name: Upload to code-scanning - name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4 uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
with: with:
sarif_file: results.sarif sarif_file: results.sarif
-17
View File
@@ -1,17 +0,0 @@
version: "2"
linters:
settings:
errcheck:
exclude-functions:
- (http.ResponseWriter).Write
- (http.ResponseWriter).WriteString
- (github.com/gin-gonic/gin.ResponseWriter).Write
- (github.com/gin-gonic/gin.ResponseWriter).WriteString
exclusions:
rules:
- linters:
- errcheck
text: "//nolint:errcheck"
- linters:
- staticcheck
text: "//nolint:staticcheck"
+1 -1
View File
@@ -84,4 +84,4 @@ sql:
# Go gen # Go gen
generate: generate:
go run ./gen go generate ./internal/repository/...
+527
View File
@@ -0,0 +1,527 @@
// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under
// internal/repository/<driver>/. Run via:
//
// go generate ./internal/repository/...
//
// The generator introspects *Queries methods and the model/params types in the
// driver package, then emits a store.go that wraps *Queries so it satisfies
// repository.Store using the canonical shared types in the parent package.
// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should
// implement repository.Store directly by hand.
package main
import (
"bytes"
_ "embed"
"flag"
"fmt"
"go/format"
"go/types"
"log"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"text/template"
"golang.org/x/tools/go/packages"
)
//go:embed store.tmpl
var storeSrc string
func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}
func run() error {
driverPkg := flag.String("pkg", "", "import path of the driver package")
out := flag.String("out", "store.go", "output filename relative to driver package directory")
flag.Parse()
if *driverPkg == "" {
return fmt.Errorf("-pkg is required")
}
// Resolve the driver package directory so we can overlay the output file
// with a valid stub. This prevents a stale store.go from poisoning the
// type-checker and producing cryptic "undefined" errors.
driverDir, err := pkgDir(*driverPkg)
if err != nil {
return fmt.Errorf("resolve driver dir: %w", err)
}
outPath := filepath.Join(driverDir, *out)
if filepath.IsAbs(*out) {
outPath = *out
}
// Stub replaces the output file during load so stale generated code is ignored.
stub := []byte("package " + filepath.Base(driverDir) + "\n")
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports,
Overlay: map[string][]byte{outPath: stub},
}
driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
if err != nil {
return fmt.Errorf("load driver package: %w", err)
}
repoPkgPath := parentPkg(*driverPkg)
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if err != nil {
return fmt.Errorf("load repo package: %w", err)
}
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
return fmt.Errorf("struct shape mismatch: %w", err)
}
if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil {
return err
}
methods, err := collectMethods(driverTypePkg)
if err != nil {
return err
}
models, _ := collectTypes(driverTypePkg)
src, err := render(tmplData{
PkgName: driverTypePkg.Name(),
RepoPkg: repoPkgPath,
ModelTypes: models,
Methods: renderMethods(methods),
})
if err != nil {
return fmt.Errorf("render: %w", err)
}
if err := os.WriteFile(outPath, src, 0644); err != nil {
return fmt.Errorf("write %s: %w", outPath, err)
}
fmt.Printf("wrote %s\n", outPath)
return nil
}
// loadOnePkg loads a single package via cfg and returns its *types.Package,
// or an error if the package fails to load or has type errors.
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
pkgs, err := packages.Load(cfg, importPath)
if err != nil {
return nil, fmt.Errorf("load %s: %w", importPath, err)
}
if len(pkgs) != 1 {
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
}
pkg := pkgs[0]
if len(pkg.Errors) > 0 {
msgs := make([]string, len(pkg.Errors))
for i, e := range pkg.Errors {
msgs[i] = e.Error()
}
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
}
return pkg.Types, nil
}
// parentPkg returns the parent import path (everything before the last /).
// Panics if imp contains no slash — callers are expected to pass driver sub-packages.
func parentPkg(imp string) string {
i := strings.LastIndex(imp, "/")
if i < 0 {
panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp))
}
return imp[:i]
}
// pkgDir returns the on-disk directory for an import path using `go list`.
func pkgDir(importPath string) (string, error) {
out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output()
if err != nil {
return "", fmt.Errorf("go list %s: %w", importPath, err)
}
return strings.TrimSpace(string(out)), nil
}
// scopeStructs returns all named struct types in pkg, excluding the internal
// sqlc types Queries, DBTX, and Store. Names are returned in sorted order.
func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) {
byName = make(map[string]*types.Struct)
for _, name := range pkg.Scope().Names() { // Names() is already sorted
switch name {
case "Queries", "DBTX", "Store":
continue
}
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
if !ok {
continue
}
named, ok := obj.Type().(*types.Named)
if !ok {
continue
}
s, ok := named.Underlying().(*types.Struct)
if !ok {
continue
}
names = append(names, name)
byName[name] = s
}
return
}
// validateStoreCoverage checks that every method declared in repository.Store
// exists on *Queries in the driver package. Missing methods are reported by
// name so the developer knows exactly which SQL queries need to be added.
func validateStoreCoverage(driverPkg, repoPkg *types.Package) error {
queriesObj := driverPkg.Scope().Lookup("Queries")
if queriesObj == nil {
return fmt.Errorf("queries type not found in driver package")
}
queriesNamed := queriesObj.Type().(*types.Named)
queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed))
queriesMethods := make(map[string]bool)
for m := range queriesMS.Methods() {
queriesMethods[m.Obj().Name()] = true
}
storeObj := repoPkg.Scope().Lookup("Store")
if storeObj == nil {
return fmt.Errorf("store type not found in repository package")
}
storeIface, ok := storeObj.Type().Underlying().(*types.Interface)
if !ok {
return fmt.Errorf("repository.Store is not an interface")
}
var missing []string
for method := range storeIface.Methods() {
if name := method.Name(); !queriesMethods[name] {
missing = append(missing, name)
}
}
if len(missing) > 0 {
sort.Strings(missing)
return fmt.Errorf(
"driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries",
len(missing), strings.Join(missing, "\n - "),
)
}
return nil
}
// validateStructShapes checks that every model/params struct in the driver
// package has fields that exactly match the corresponding type in the repo
// (parent) package. This catches drift between sqlc-generated types and the
// canonical repository types before a broken cast reaches the compiler.
func validateStructShapes(driverPkg, repoPkg *types.Package) error {
_, repoStructs := scopeStructs(repoPkg)
driverNames, driverStructs := scopeStructs(driverPkg)
var errs []string
for _, name := range driverNames {
repoStruct, ok := repoStructs[name]
if !ok {
// Driver has a type not in repo — fine (e.g. internal helpers).
continue
}
if err := compareStructs(name, driverStructs[name], repoStruct); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
sort.Strings(errs)
return fmt.Errorf("%s", strings.Join(errs, "\n "))
}
return nil
}
func compareStructs(name string, driver, repo *types.Struct) error {
if driver.NumFields() != repo.NumFields() {
return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)",
name, driver.NumFields(), repo.NumFields())
}
for i := range driver.NumFields() {
df := driver.Field(i)
rf := repo.Field(i)
if df.Name() != rf.Name() {
return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)",
name, i, df.Name(), rf.Name())
}
if !types.Identical(df.Type(), rf.Type()) {
return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)",
name, df.Name(), df.Type(), rf.Type())
}
}
return nil
}
// collectTypes returns model and params struct names from the driver package.
func collectTypes(pkg *types.Package) (models []string, params []string) {
names, _ := scopeStructs(pkg)
for _, name := range names {
if strings.HasSuffix(name, "Params") {
params = append(params, name)
} else {
models = append(models, name)
}
}
return
}
type methodInfo struct {
Name string
Params []paramInfo
Results []resultInfo
}
type paramInfo struct {
Name string
TypeStr string // local (unqualified) type name
RepoType string // "repository.X" if this is a driver model/params type; else ""
}
type resultInfo struct {
TypeStr string
IsSlice bool
RepoType string // "repository.X" if driver type; else ""
}
func collectMethods(pkg *types.Package) ([]methodInfo, error) {
obj := pkg.Scope().Lookup("Queries")
if obj == nil {
return nil, fmt.Errorf("queries type not found in %s", pkg.Path())
}
named, ok := obj.Type().(*types.Named)
if !ok {
return nil, fmt.Errorf("queries is not a named type")
}
ms := types.NewMethodSet(types.NewPointer(named))
var out []methodInfo
for method := range ms.Methods() {
fn, ok := method.Obj().(*types.Func)
if !ok || fn.Name() == "WithTx" {
continue
}
sig := fn.Type().(*types.Signature)
mi := methodInfo{Name: fn.Name()}
// params: skip receiver + first (context.Context)
for i := 1; i < sig.Params().Len(); i++ {
p := sig.Params().At(i)
mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path()))
}
// results: skip error
for r := range sig.Results().Variables() {
if r.Type().String() == "error" {
continue
}
mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path()))
}
out = append(out, mi)
}
return out, nil
}
func makeParam(name string, t types.Type, driverPath string) paramInfo {
return paramInfo{
Name: name,
TypeStr: localName(t, driverPath),
RepoType: repoName(t, driverPath),
}
}
func makeResult(t types.Type, driverPath string) resultInfo {
ri := resultInfo{}
if sl, ok := t.(*types.Slice); ok {
ri.IsSlice = true
t = sl.Elem()
}
ri.TypeStr = localName(t, driverPath)
ri.RepoType = repoName(t, driverPath)
return ri
}
func localName(t types.Type, driverPath string) string {
named, ok := t.(*types.Named)
if !ok {
return types.TypeString(t, nil)
}
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
return named.Obj().Name()
}
return types.TypeString(t, func(p *types.Package) string { return p.Name() })
}
func repoName(t types.Type, driverPath string) string {
named, ok := t.(*types.Named)
if !ok {
return ""
}
if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath {
return "repository." + named.Obj().Name()
}
return ""
}
// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo".
func converterFn(s string) string {
if s == "" {
return ""
}
return strings.ToLower(s[:1]) + s[1:] + "ToRepo"
}
// renderedMethod holds pre-built signature and body strings passed to the template.
type renderedMethod struct {
Signature string
Body string
}
func renderMethods(methods []methodInfo) []renderedMethod {
out := make([]renderedMethod, len(methods))
for i, m := range methods {
out[i] = renderedMethod{
Signature: buildSig(m),
Body: buildBody(m),
}
}
return out
}
func buildSig(m methodInfo) string {
var sb strings.Builder
sb.WriteString("func (s *Store) ")
sb.WriteString(m.Name)
sb.WriteString("(ctx context.Context")
for _, p := range m.Params {
sb.WriteString(", ")
sb.WriteString(p.Name)
sb.WriteString(" ")
if p.RepoType != "" {
sb.WriteString(p.RepoType)
} else {
sb.WriteString(p.TypeStr)
}
}
sb.WriteString(") (")
for _, r := range m.Results {
if r.IsSlice {
sb.WriteString("[]")
}
if r.RepoType != "" {
sb.WriteString(r.RepoType)
} else {
sb.WriteString(r.TypeStr)
}
sb.WriteString(", ")
}
sb.WriteString("error)")
return sb.String()
}
func callArgs(m methodInfo) string {
args := make([]string, 0, len(m.Params))
for _, p := range m.Params {
if p.RepoType != "" {
// convert repo type → driver type: DriverType(arg)
args = append(args, p.TypeStr+"("+p.Name+")")
} else {
args = append(args, p.Name)
}
}
if len(args) == 0 {
return "ctx"
}
return "ctx, " + strings.Join(args, ", ")
}
// bodyTemplates holds the per-shape method body templates, parsed once at init.
var bodyTemplates = template.Must(
template.New("bodies").Parse(`
{{define "void"}} return mapErr({{.Call}})
{{end}}
{{define "scalar"}} r, err := {{.Call}}
if err != nil {
return {{.RepoType}}{}, mapErr(err)
}
return {{.Converter}}(r), nil
{{end}}
{{define "slice"}} rows, err := {{.Call}}
if err != nil {
return nil, mapErr(err)
}
out := make([]{{.RepoType}}, len(rows))
for i, row := range rows {
out[i] = {{.Converter}}(row)
}
return out, nil
{{end}}`),
)
type bodyData struct {
Call string
RepoType string
Converter string
}
func buildBody(m methodInfo) string {
call := "s.q." + m.Name + "(" + callArgs(m) + ")"
var (
name string
data bodyData
)
switch {
case len(m.Results) == 0 || m.Results[0].RepoType == "":
name = "void"
data = bodyData{Call: call}
case m.Results[0].IsSlice:
name = "slice"
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
default:
name = "scalar"
data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)}
}
var buf bytes.Buffer
if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil {
panic(fmt.Sprintf("buildBody %s: %v", name, err))
}
return buf.String()
}
type tmplData struct {
PkgName string
RepoPkg string
ModelTypes []string
Methods []renderedMethod
}
func render(data tmplData) ([]byte, error) {
t, err := template.New("store").Funcs(template.FuncMap{
"converterFn": converterFn,
}).Parse(storeSrc)
if err != nil {
return nil, fmt.Errorf("parse template: %w", err)
}
var buf bytes.Buffer
if err := t.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute template: %w", err)
}
formatted, err := format.Source(buf.Bytes())
if err != nil {
return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String())
}
return formatted, nil
}
+46
View File
@@ -0,0 +1,46 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package {{.PkgName}}
import (
"context"
"database/sql"
"errors"
"{{.RepoPkg}}"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
{{range .ModelTypes -}}
func {{converterFn .}}(v {{.}}) repository.{{.}} {
return repository.{{.}}(v)
}
{{end -}}
{{range .Methods}}{{.Signature}} {
{{.Body}}}
{{end}}
+5 -4
View File
@@ -6,8 +6,8 @@ import (
"strings" "strings"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -40,7 +40,8 @@ func createUserCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
tlog.NewSimpleLogger().Init() log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -73,7 +74,7 @@ func createUserCmd() *cli.Command {
return errors.New("username and password cannot be empty") return errors.New("username and password cannot be empty")
} }
tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user") log.App.Info().Str("username", tCfg.Username).Msg("Creating user")
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost) passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
@@ -86,7 +87,7 @@ func createUserCmd() *cli.Command {
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$") passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
} }
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
return nil return nil
}, },
+10 -6
View File
@@ -7,7 +7,7 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/mdp/qrterminal/v3" "github.com/mdp/qrterminal/v3"
@@ -40,7 +40,8 @@ func generateTotpCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
tlog.NewSimpleLogger().Init() log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -68,7 +69,10 @@ func generateTotpCmd() *cli.Command {
return fmt.Errorf("failed to parse user: %w", err) return fmt.Errorf("failed to parse user: %w", err)
} }
docker := strings.Contains(tCfg.User, "$$") docker := false
if strings.Contains(tCfg.User, "$$") {
docker = true
}
if user.TOTPSecret != "" { if user.TOTPSecret != "" {
return fmt.Errorf("user already has a TOTP secret") return fmt.Errorf("user already has a TOTP secret")
@@ -85,9 +89,9 @@ func generateTotpCmd() *cli.Command {
secret := key.Secret() secret := key.Secret()
tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret") log.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
tlog.App.Info().Msg("Generated QR code") log.App.Info().Msg("Generated QR code")
config := qrterminal.Config{ config := qrterminal.Config{
Level: qrterminal.L, Level: qrterminal.L,
@@ -106,7 +110,7 @@ func generateTotpCmd() *cli.Command {
user.Password = strings.ReplaceAll(user.Password, "$", "$$") user.Password = strings.ReplaceAll(user.Password, "$", "$$")
} }
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil return nil
}, },
+7 -6
View File
@@ -10,7 +10,7 @@ import (
"time" "time"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type healthzResponse struct { type healthzResponse struct {
@@ -26,7 +26,8 @@ func healthcheckCmd() *cli.Command {
Resources: nil, Resources: nil,
AllowArg: true, AllowArg: true,
Run: func(args []string) error { Run: func(args []string) error {
tlog.NewSimpleLogger().Init() log := logger.NewLogger().WithSimpleConfig()
log.Init()
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS") srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
if srvAddr == "" { if srvAddr == "" {
@@ -45,10 +46,10 @@ func healthcheckCmd() *cli.Command {
} }
if appUrl == "" { if appUrl == "" {
return errors.New("could not determine app url") return errors.New("Could not determine app URL")
} }
tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check") log.App.Info().Str("app_url", appUrl).Msg("Performing health check")
client := http.Client{ client := http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
@@ -70,7 +71,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("service is not healthy, got: %s", resp.Status) return fmt.Errorf("service is not healthy, got: %s", resp.Status)
} }
defer resp.Body.Close() //nolint:errcheck defer resp.Body.Close()
var healthResp healthzResponse var healthResp healthzResponse
@@ -86,7 +87,7 @@ func healthcheckCmd() *cli.Command {
return fmt.Errorf("failed to decode response: %w", err) return fmt.Errorf("failed to decode response: %w", err)
} }
tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
return nil return nil
}, },
-6
View File
@@ -7,7 +7,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
@@ -109,11 +108,6 @@ func main() {
} }
func runCmd(cfg model.Config) error { func runCmd(cfg model.Config) error {
logger := tlog.NewLogger(cfg.Log)
logger.Init()
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg) app := bootstrap.NewBootstrapApp(cfg)
err := app.Setup() err := app.Setup()
+6 -5
View File
@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -44,7 +44,8 @@ func verifyUserCmd() *cli.Command {
Configuration: tCfg, Configuration: tCfg,
Resources: loaders, Resources: loaders,
Run: func(_ []string) error { Run: func(_ []string) error {
tlog.NewSimpleLogger().Init() log := logger.NewLogger().WithSimpleConfig()
log.Init()
if tCfg.Interactive { if tCfg.Interactive {
form := huh.NewForm( form := huh.NewForm(
@@ -97,9 +98,9 @@ func verifyUserCmd() *cli.Command {
if user.TOTPSecret == "" { if user.TOTPSecret == "" {
if tCfg.Totp != "" { if tCfg.Totp != "" {
tlog.App.Warn().Msg("User does not have TOTP secret") log.App.Warn().Msg("User does not have TOTP secret")
} }
tlog.App.Info().Msg("User verified") log.App.Info().Msg("User verified")
return nil return nil
} }
@@ -109,7 +110,7 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("TOTP code incorrect") return fmt.Errorf("TOTP code incorrect")
} }
tlog.App.Info().Msg("User verified") log.App.Info().Msg("User verified")
return nil return nil
}, },
+2
View File
@@ -20,6 +20,7 @@ require (
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.50.0 golang.org/x/crypto v0.50.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.43.0
k8s.io/apimachinery v0.36.0 k8s.io/apimachinery v0.36.0
k8s.io/client-go v0.36.0 k8s.io/client-go v0.36.0
modernc.org/sqlite v1.50.0 modernc.org/sqlite v1.50.0
@@ -121,6 +122,7 @@ require (
go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/arch v0.22.0 // indirect golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/net v0.52.0 // indirect golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect golang.org/x/sys v0.43.0 // indirect
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations // Migrations
// //
//go:embed migrations/*.sql //go:embed migrations/sqlite/*.sql
var Migrations embed.FS var Migrations embed.FS
+267 -125
View File
@@ -3,39 +3,50 @@ package bootstrap
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"os/signal"
"sort" "sort"
"strings" "strings"
"sync"
"syscall"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
context struct { runtime model.RuntimeConfig
appUrl string
uuid string
cookieDomain string
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
localUsers *[]model.LocalUser
oauthProviders map[string]model.OAuthServiceConfig
oauthWhitelist []string
configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig
}
services Services services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
wg sync.WaitGroup
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -45,56 +56,69 @@ func NewBootstrapApp(config model.Config) *BootstrapApp {
} }
func (app *BootstrapApp) Setup() error { func (app *BootstrapApp) Setup() error {
// create context
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
app.ctx = ctx
app.cancel = cancel
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
app.log = log
// get app url // get app url
if app.config.AppURL == "" { if app.config.AppURL == "" {
return fmt.Errorf("app URL cannot be empty, perhaps config loading failed") return errors.New("app url cannot be empty, perhaps config loading failed")
} }
appUrl, err := url.Parse(app.config.AppURL) appUrl, err := url.Parse(app.config.AppURL)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to parse app url: %w", err)
} }
app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
// validate session config // validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
return fmt.Errorf("session max lifetime cannot be less than session expiry") return errors.New("session max lifetime cannot be less than session expiry")
} }
// Parse users // parse users
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to load users: %w", err)
} }
app.context.localUsers = users app.runtime.LocalUsers = *users
// load oauth whitelist
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to load oauth whitelist: %w", err)
} }
app.context.oauthWhitelist = oauthWhitelist app.runtime.OAuthWhitelist = oauthWhitelist
// Setup OAuth providers // setup oauth providers
app.context.oauthProviders = app.config.OAuth.Providers app.runtime.OAuthProviders = app.config.OAuth.Providers
for name, provider := range app.context.oauthProviders { for id, provider := range app.runtime.OAuthProviders {
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret provider.ClientSecret = secret
provider.ClientSecretFile = "" provider.ClientSecretFile = ""
if provider.RedirectURL == "" { if provider.RedirectURL == "" {
provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
} }
app.context.oauthProviders[name] = provider app.runtime.OAuthProviders[id] = provider
} }
for id, provider := range app.context.oauthProviders { // set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
@@ -102,71 +126,71 @@ func (app *BootstrapApp) Setup() error {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
} }
} }
app.context.oauthProviders[id] = provider app.runtime.OAuthProviders[id] = provider
} }
// Setup OIDC clients // setup oidc clients
for id, client := range app.config.OIDC.Clients { for id, client := range app.config.OIDC.Clients {
client.ID = id client.ID = id
app.context.oidcClients = append(app.context.oidcClients, client) app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
} }
// Get cookie domain // cookie domain
cookieDomainResolver := utils.GetCookieDomain cookieDomainResolver := utils.GetCookieDomain
if !app.config.Auth.SubdomainsEnabled { if !app.config.Auth.SubdomainsEnabled {
tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work") app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain cookieDomainResolver = utils.GetStandaloneCookieDomain
} }
cookieDomain, err := cookieDomainResolver(app.context.appUrl) cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get cookie domain: %w", err)
} }
app.context.cookieDomain = cookieDomain app.runtime.CookieDomain = cookieDomain
// Cookie names // cookie names
app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// Dumps cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
// Database app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
db, err := app.SetupDatabase(app.config.Database.Path) app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
// database
store, err := app.SetupStore()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
// Queries // after this point, we start initializing dependencies so it's a good time to setup a defer
queries := repository.New(db) // to ensure that resources are cleaned up properly in case of an error during initialization
defer func() {
app.cancel()
app.wg.Wait()
app.db.Close()
}()
// Services // store
services, err := app.initServices(queries) app.queries = store
// services
err = app.setupServices()
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize services: %w", err) return fmt.Errorf("failed to initialize services: %w", err)
} }
app.services = services // configured providers
configuredProviders := make([]model.Provider, 0)
// Configured providers for id, provider := range app.runtime.OAuthProviders {
configuredProviders := make([]controller.Provider, 0) configuredProviders = append(configuredProviders, model.Provider{
for id, provider := range app.context.oauthProviders {
configuredProviders = append(configuredProviders, controller.Provider{
Name: provider.Name, Name: provider.Name,
ID: id, ID: id,
OAuth: true, OAuth: true,
@@ -177,70 +201,171 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name return configuredProviders[i].Name < configuredProviders[j].Name
}) })
if services.authService.LocalAuthConfigured() { if app.services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
OAuth: false, OAuth: false,
}) })
} }
if services.authService.LDAPAuthConfigured() { if app.services.authService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
OAuth: false, OAuth: false,
}) })
} }
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
if len(configuredProviders) == 0 { if len(configuredProviders) == 0 {
return fmt.Errorf("no authentication providers configured") return errors.New("no authentication providers configured")
} }
app.context.configuredProviders = configuredProviders for _, provider := range configuredProviders {
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
}
// Setup router app.runtime.ConfiguredProviders = configuredProviders
router, err := app.setupRouter()
// setup router
err = app.setupRouter()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup routes: %w", err) return fmt.Errorf("failed to setup routes: %w", err)
} }
// Start db cleanup routine // start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine") app.log.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(queries) app.wg.Go(app.dbCleanupRoutine)
// If analytics are not disabled, start heartbeat // if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
tlog.App.Debug().Msg("Starting heartbeat routine") app.log.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeatRoutine() app.wg.Go(app.heartbeatRoutine)
} }
// If we have an socket path, bind to it // create err channel to listen for server errors
if app.config.Server.SocketPath != "" { errChanLen := 0
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) runUnix := app.config.Server.SocketPath != ""
err := os.Remove(app.config.Server.SocketPath) runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
if runUnix {
errChanLen++
}
if runHTTP {
errChanLen++
}
errChan := make(chan error, errChanLen)
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// serve unix
if runUnix {
app.wg.Go(func() {
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
}
// serve to http
if runHTTP {
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
}
// monitor cancellation and server errors
for {
select {
case <-app.ctx.Done():
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil
case err := <-errChan:
if err != nil { if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err) return fmt.Errorf("server error: %w", err)
} }
} }
}
}
tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) func (app *BootstrapApp) serveHTTP() error {
if err := router.RunUnix(app.config.Server.SocketPath); err != nil { address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
tlog.App.Fatal().Err(err).Msg("Failed to start server")
}
app.log.App.Info().Msgf("Starting server on %s", address)
server := &http.Server{
Addr: address,
Handler: app.router.Handler(),
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down http listener")
server.Shutdown(app.ctx)
}()
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start http listener: %w", err)
}
return nil
}
func (app *BootstrapApp) serveUnix() error {
if app.config.Server.SocketPath == "" {
return nil return nil
} }
// Start server _, err := os.Stat(app.config.Server.SocketPath)
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
tlog.App.Info().Msgf("Starting server on %s", address) if err == nil {
if err := router.Run(address); err != nil { app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
tlog.App.Fatal().Err(err).Msg("Failed to start server") err := os.Remove(app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to remove existing socket file: %w", err)
}
}
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listener: %w", err)
}
server := &http.Server{
Handler: app.router.Handler(),
}
shutdown := func() {
server.Shutdown(app.ctx)
listener.Close()
os.Remove(app.config.Server.SocketPath)
}
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix socket listener")
shutdown()
}()
err = server.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
shutdown()
return fmt.Errorf("failed to start unix socket listener: %w", err)
} }
return nil return nil
@@ -250,20 +375,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
type heartbeat struct { type Heartbeat struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Version string `json:"version"` Version string `json:"version"`
} }
var body heartbeat var body Heartbeat
body.UUID = app.context.uuid body.UUID = app.runtime.UUID
body.Version = model.Version body.Version = model.Version
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body") app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
return return
} }
@@ -273,43 +398,60 @@ func (app *BootstrapApp) heartbeatRoutine() {
heartbeatURL := model.APIServer + "/v1/instances/heartbeat" heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
for range ticker.C { for {
tlog.App.Debug().Msg("Sending heartbeat") select {
case <-ticker.C:
app.log.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create heartbeat request") app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
continue continue
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to send heartbeat") app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
continue continue
} }
res.Body.Close() //nolint:errcheck res.Body.Close()
if res.StatusCode != 200 && res.StatusCode != 201 { if res.StatusCode != 200 && res.StatusCode != 201 {
tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
} }
} }
} }
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { func (app *BootstrapApp) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background()
for range ticker.C { for {
tlog.App.Debug().Msg("Cleaning up old database sessions") select {
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) case <-ticker.C:
if err != nil { app.log.App.Debug().Msg("Running database cleanup")
tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
}
app.log.App.Debug().Msg("Database cleanup completed")
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
} }
} }
} }
+26 -3
View File
@@ -7,6 +7,9 @@ import (
"path/filepath" "path/filepath"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/database/sqlite3"
@@ -14,7 +17,18 @@ import (
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { func (app *BootstrapApp) SetupStore() (repository.Store, error) {
switch app.config.Database.Driver {
case "memory":
return memory.New(), nil
case "sqlite", "":
return app.setupSQLite(app.config.Database.Path)
default:
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
}
}
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
dir := filepath.Dir(databasePath) dir := filepath.Dir(databasePath)
if err := os.MkdirAll(dir, 0750); err != nil { if err := os.MkdirAll(dir, 0750); err != nil {
@@ -27,11 +41,18 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
// Close the database if there is an error during migration
defer func() {
if err != nil {
db.Close()
}
}()
// Limit to 1 connection to sequence writes, this may need to be revisited in the future // Limit to 1 connection to sequence writes, this may need to be revisited in the future
// if the sqlite connection starts being a bottleneck // if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
migrations, err := iofs.New(assets.Migrations, "migrations") migrations, err := iofs.New(assets.Migrations, "migrations/sqlite")
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err) return nil, fmt.Errorf("failed to create migrations: %w", err)
@@ -53,5 +74,7 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
return nil, fmt.Errorf("failed to migrate database: %w", err) return nil, fmt.Errorf("failed to migrate database: %w", err)
} }
return db, nil app.db = db
return sqlite.NewStore(sqlite.New(db)), nil
} }
+18 -88
View File
@@ -2,21 +2,16 @@ package bootstrap
import ( import (
"fmt" "fmt"
"slices"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var DEV_MODES = []string{"main", "test", "development"} func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { gin.SetMode(gin.ReleaseMode)
if !slices.Contains(DEV_MODES, model.Version) {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New() engine := gin.New()
engine.Use(gin.Recovery()) engine.Use(gin.Recovery())
@@ -25,101 +20,36 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set trusted proxies: %w", err) return fmt.Errorf("failed to set trusted proxies: %w", err)
} }
} }
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
}
engine.Use(contextMiddleware.Middleware()) engine.Use(contextMiddleware.Middleware())
uiMiddleware := middleware.NewUIMiddleware() uiMiddleware, err := middleware.NewUIMiddleware()
err = uiMiddleware.Init()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize UI middleware: %w", err) return fmt.Errorf("failed to initialize UI middleware: %w", err)
} }
engine.Use(uiMiddleware.Middleware()) engine.Use(uiMiddleware.Middleware())
zerologMiddleware := middleware.NewZerologMiddleware() zerologMiddleware := middleware.NewZerologMiddleware(app.log)
err = zerologMiddleware.Init()
if err != nil {
return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
}
engine.Use(zerologMiddleware.Middleware()) engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
contextController := controller.NewContextController(controller.ContextControllerConfig{ controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
Providers: app.context.configuredProviders, controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
Title: app.config.UI.Title, controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
AppURL: app.config.AppURL, controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
CookieDomain: app.context.cookieDomain, controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, controller.NewResourcesController(app.config, &engine.RouterGroup)
BackgroundImage: app.config.UI.BackgroundImage, controller.NewHealthController(apiRouter)
OAuthAutoRedirect: app.config.OAuth.AutoRedirect, controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
WarningsEnabled: app.config.UI.WarningsEnabled,
}, apiRouter)
contextController.SetupRoutes() app.router = engine
return nil
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, apiRouter, app.services.authService)
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes()
userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, apiRouter, app.services.authService)
userController.SetupRoutes()
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
Path: app.config.Resources.Path,
Enabled: app.config.Resources.Enabled,
}, &engine.RouterGroup)
resourcesController.SetupRoutes()
healthController := controller.NewHealthController(apiRouter)
healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
wellknownController.SetupRoutes()
return engine, nil
} }
+33 -98
View File
@@ -1,131 +1,66 @@
package bootstrap package bootstrap
import ( import (
"fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type Services struct { func (app *BootstrapApp) setupServices() error {
accessControlService *service.AccessControlsService ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN,
BindPassword: app.config.LDAP.BindPassword,
BaseDN: app.config.LDAP.BaseDN,
Insecure: app.config.LDAP.Insecure,
SearchFilter: app.config.LDAP.SearchFilter,
AuthCert: app.config.LDAP.AuthCert,
AuthKey: app.config.LDAP.AuthKey,
})
err := ldapService.Init()
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it") app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
ldapService.Unconfigure() //nolint:errcheck
} }
services.ldapService = ldapService app.services.ldapService = ldapService
var labelProvider service.LabelProvider
var dockerService *service.DockerService
var kubernetesService *service.KubernetesService
useKubernetes := app.config.LabelProvider == "kubernetes" || useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
var labelProvider service.LabelProvider
if useKubernetes { if useKubernetes {
tlog.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService = service.NewKubernetesService()
err = kubernetesService.Init() kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize kubernetes service: %w", err)
} }
services.kubernetesService = kubernetesService
app.services.kubernetesService = kubernetesService
labelProvider = kubernetesService labelProvider = kubernetesService
} else { } else {
tlog.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService = service.NewDockerService()
err = dockerService.Init() dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize docker service: %w", err)
} }
services.dockerService = dockerService
app.services.dockerService = dockerService
labelProvider = dockerService labelProvider = dockerService
} }
accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
app.services.accessControlService = accessControlsService
err = accessControlsService.Init() oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
if err != nil { if err != nil {
return Services{}, err return fmt.Errorf("failed to initialize oidc service: %w", err)
} }
services.accessControlService = accessControlsService app.services.oidcService = oidcService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) return nil
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{
LocalUsers: app.context.localUsers,
OauthWhitelist: app.context.oauthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.context.cookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, services.ldapService, queries, services.oauthBrokerService)
err = authService.Init()
if err != nil {
return Services{}, err
}
services.authService = authService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, queries)
err = oidcService.Init()
if err != nil {
return Services{}, err
}
services.oidcService = oidcService
return services, nil
} }
+40 -49
View File
@@ -5,7 +5,7 @@ import (
"net/url" "net/url"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -24,62 +24,52 @@ type UserContextResponse struct {
} }
type AppContextResponse struct { type AppContextResponse struct {
Status int `json:"status"` Status int `json:"status"`
Message string `json:"message"` Message string `json:"message"`
Providers []Provider `json:"providers"` Providers []model.Provider `json:"providers"`
Title string `json:"title"` Title string `json:"title"`
AppURL string `json:"appUrl"` AppURL string `json:"appUrl"`
CookieDomain string `json:"cookieDomain"` CookieDomain string `json:"cookieDomain"`
ForgotPasswordMessage string `json:"forgotPasswordMessage"` ForgotPasswordMessage string `json:"forgotPasswordMessage"`
BackgroundImage string `json:"backgroundImage"` BackgroundImage string `json:"backgroundImage"`
OAuthAutoRedirect string `json:"oauthAutoRedirect"` OAuthAutoRedirect string `json:"oauthAutoRedirect"`
WarningsEnabled bool `json:"warningsEnabled"` WarningsEnabled bool `json:"warningsEnabled"`
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
type ContextControllerConfig struct {
Providers []Provider
Title string
AppURL string
CookieDomain string
ForgotPasswordMessage string
BackgroundImage string
OAuthAutoRedirect string
WarningsEnabled bool
} }
type ContextController struct { type ContextController struct {
config ContextControllerConfig log *logger.Logger
router *gin.RouterGroup config model.Config
runtime model.RuntimeConfig
} }
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { func NewContextController(
if !config.WarningsEnabled { log *logger.Logger,
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.") config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
controller := &ContextController{
log: log,
config: config,
runtime: runtimeConfig,
} }
return &ContextController{ if !config.UI.WarningsEnabled {
config: config, log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
router: router,
} }
}
func (controller *ContextController) SetupRoutes() { contextGroup := router.Group("/context")
contextGroup := controller.router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler) contextGroup.GET("/app", controller.appContextHandler)
return controller
} }
func (controller *ContextController) userContextHandler(c *gin.Context) { func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request") controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
c.JSON(200, UserContextResponse{ c.JSON(200, UserContextResponse{
Status: 401, Status: 401,
Message: "Unauthorized", Message: "Unauthorized",
@@ -105,9 +95,10 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
} }
func (controller *ContextController) appContextHandler(c *gin.Context) { func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.config.AppURL) appUrl, err := url.Parse(controller.runtime.AppURL)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to parse app URL") controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -118,13 +109,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
c.JSON(200, AppContextResponse{ c.JSON(200, AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controller.config.Providers, Providers: controller.runtime.ConfiguredProviders,
Title: controller.config.Title, Title: controller.config.UI.Title,
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
CookieDomain: controller.config.CookieDomain, CookieDomain: controller.runtime.CookieDomain,
ForgotPasswordMessage: controller.config.ForgotPasswordMessage, ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
BackgroundImage: controller.config.BackgroundImage, BackgroundImage: controller.config.UI.BackgroundImage,
OAuthAutoRedirect: controller.config.OAuthAutoRedirect, OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
WarningsEnabled: controller.config.WarningsEnabled, WarningsEnabled: controller.config.UI.WarningsEnabled,
}) })
} }
+22 -34
View File
@@ -8,30 +8,19 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestContextController(t *testing.T) { func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init() log := logger.NewLogger().WithTestConfig()
controllerConfig := controller.ContextControllerConfig{ log.Init()
Providers: []controller.Provider{
{ cfg, runtime := test.CreateTestConfigs(t)
Name: "Local",
ID: "local",
OAuth: false,
},
},
Title: "Tinyauth",
AppURL: "https://tinyauth.example.com",
CookieDomain: "example.com",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
OAuthAutoRedirect: "none",
WarningsEnabled: true,
}
tests := []struct { tests := []struct {
description string description string
@@ -47,17 +36,17 @@ func TestContextController(t *testing.T) {
expectedAppContextResponse := controller.AppContextResponse{ expectedAppContextResponse := controller.AppContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
Providers: controllerConfig.Providers, Providers: runtime.ConfiguredProviders,
Title: controllerConfig.Title, Title: cfg.UI.Title,
AppURL: controllerConfig.AppURL, AppURL: runtime.AppURL,
CookieDomain: controllerConfig.CookieDomain, CookieDomain: runtime.CookieDomain,
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage, ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: controllerConfig.BackgroundImage, BackgroundImage: cfg.UI.BackgroundImage,
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect, OAuthAutoRedirect: cfg.OAuth.AutoRedirect,
WarningsEnabled: controllerConfig.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
} }
bytes, err := json.Marshal(expectedAppContextResponse) bytes, err := json.Marshal(expectedAppContextResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -71,7 +60,7 @@ func TestContextController(t *testing.T) {
Message: "Unauthorized", Message: "Unauthorized",
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -86,7 +75,7 @@ func TestContextController(t *testing.T) {
BaseContext: model.BaseContext{ BaseContext: model.BaseContext{
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
}, },
}, },
}) })
@@ -100,11 +89,11 @@ func TestContextController(t *testing.T) {
IsLoggedIn: true, IsLoggedIn: true,
Username: "johndoe", Username: "johndoe",
Name: "John Doe", Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
Provider: "local", Provider: "local",
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -121,13 +110,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
contextController := controller.NewContextController(controllerConfig, group) controller.NewContextController(log, cfg, runtime, group)
contextController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.path, nil) request, err := http.NewRequest("GET", test.path, nil)
assert.NoError(t, err) require.NoError(t, err)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
+5 -8
View File
@@ -3,18 +3,15 @@ package controller
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
type HealthController struct { type HealthController struct {
router *gin.RouterGroup
} }
func NewHealthController(router *gin.RouterGroup) *HealthController { func NewHealthController(router *gin.RouterGroup) *HealthController {
return &HealthController{ controller := &HealthController{}
router: router,
}
}
func (controller *HealthController) SetupRoutes() { router.GET("/healthz", controller.healthHandler)
controller.router.GET("/healthz", controller.healthHandler) router.HEAD("/healthz", controller.healthHandler)
controller.router.HEAD("/healthz", controller.healthHandler)
return controller
} }
func (controller *HealthController) healthHandler(c *gin.Context) { func (controller *HealthController) healthHandler(c *gin.Context) {
@@ -7,13 +7,12 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct { tests := []struct {
description string description string
path string path string
@@ -30,7 +29,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy", "message": "Healthy",
} }
bytes, err := json.Marshal(expectedHealthResponse) bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -44,7 +43,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy", "message": "Healthy",
} }
bytes, err := json.Marshal(expectedHealthResponse) bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -56,13 +55,12 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
healthController := controller.NewHealthController(group) controller.NewHealthController(group)
healthController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil) request, err := http.NewRequest(test.method, test.path, nil)
assert.NoError(t, err) require.NoError(t, err)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
+77 -72
View File
@@ -6,10 +6,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -19,34 +20,32 @@ type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"` Provider string `uri:"provider" binding:"required"`
} }
type OAuthControllerConfig struct {
CSRFCookieName string
OAuthSessionCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
CookieDomain string
SubdomainsEnabled bool
}
type OAuthController struct { type OAuthController struct {
config OAuthControllerConfig log *logger.Logger
router *gin.RouterGroup config model.Config
auth *service.AuthService runtime model.RuntimeConfig
auth *service.AuthService
} }
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { func NewOAuthController(
return &OAuthController{ log *logger.Logger,
config: config, config model.Config,
router: router, runtimeConfig model.RuntimeConfig,
auth: auth, router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
controller := &OAuthController{
log: log,
config: config,
runtime: runtimeConfig,
auth: auth,
} }
}
func (controller *OAuthController) SetupRoutes() { oauthGroup := router.Group("/oauth")
oauthGroup := controller.router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
return controller
} }
func (controller *OAuthController) oauthURLHandler(c *gin.Context) { func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
@@ -54,7 +53,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI") controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -67,7 +66,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind query parameters") controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -76,10 +75,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
if !controller.isOidcRequest(reqParams) { if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain) isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe { if !isRedirectSafe {
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring") controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = "" reqParams.RedirectURI = ""
} }
} }
@@ -87,7 +86,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session") controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -98,7 +97,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
authUrl, err := controller.auth.GetOAuthURL(sessionId) authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -106,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true) c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -120,7 +119,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI") controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -128,21 +127,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true) c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session") controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -150,8 +149,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
if state != oauthPendingSession.State { if state != oauthPendingSession.State {
tlog.App.Warn().Err(err).Msg("CSRF token mismatch") controller.log.App.Warn().Msg("OAuth state mismatch")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -159,74 +158,80 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code) _, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to exchange code for token") controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user info from OAuth provider") controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if user == nil {
controller.log.App.Warn().Msg("OAuth provider did not return user info")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
if user.Email == "" { if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email") controller.log.App.Warn().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
if !controller.auth.IsEmailWhitelisted(user.Email) { if !controller.auth.IsEmailWhitelisted(user.Email) {
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
var name string var name string
if strings.TrimSpace(user.Name) != "" { if strings.TrimSpace(user.Name) != "" {
tlog.App.Debug().Msg("Using name from OAuth provider") controller.log.App.Debug().Msg("Using name from OAuth provider")
name = user.Name name = user.Name
} else { } else {
tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name") controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
} }
var username string var username string
if strings.TrimSpace(user.PreferredUsername) != "" { if strings.TrimSpace(user.PreferredUsername) != "" {
tlog.App.Debug().Msg("Using preferred username from OAuth provider") controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
username = user.PreferredUsername username = user.PreferredUsername
} else { } else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username") controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
svc, err := controller.auth.GetOAuthService(sessionIdCookie) svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
if svc.ID() != req.Provider { if svc.ID() != req.Provider {
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -240,29 +245,29 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
if controller.isOidcRequest(oauthPendingSession.CallbackParams) { if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page") controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
queries, err := query.Values(oauthPendingSession.CallbackParams) queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
@@ -272,16 +277,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
} }
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
@@ -292,8 +297,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
if controller.config.SubdomainsEnabled { if controller.config.Auth.SubdomainsEnabled {
return "." + controller.config.CookieDomain return "." + controller.runtime.CookieDomain
} }
return controller.config.CookieDomain return controller.runtime.CookieDomain
} }
+74 -55
View File
@@ -13,15 +13,13 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type OIDCControllerConfig struct{}
type OIDCController struct { type OIDCController struct {
config OIDCControllerConfig log *logger.Logger
router *gin.RouterGroup oidc *service.OIDCService
oidc *service.OIDCService runtime model.RuntimeConfig
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -58,29 +56,42 @@ type ClientCredentials struct {
ClientSecret string ClientSecret string
} }
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { func NewOIDCController(
return &OIDCController{ log *logger.Logger,
config: config, oidcService *service.OIDCService,
oidc: oidcService, runtimeConfig model.RuntimeConfig,
router: router, router *gin.RouterGroup) *OIDCController {
controller := &OIDCController{
log: log,
oidc: oidcService,
runtime: runtimeConfig,
} }
}
func (controller *OIDCController) SetupRoutes() { oidcGroup := router.Group("/oidc")
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo)
return controller
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest var req ClientRequest
err := c.BindUri(&req) err := c.BindUri(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI") controller.log.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -91,7 +102,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID) client, ok := controller.oidc.GetClient(req.ClientID)
if !ok { if !ok {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Client not found", "message": "Client not found",
@@ -107,7 +118,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
} }
func (controller *OIDCController) Authorize(c *gin.Context) { func (controller *OIDCController) Authorize(c *gin.Context) {
if !controller.oidc.IsConfigured() { if controller.oidc == nil {
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return return
} }
@@ -142,7 +153,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.ValidateAuthorizeParams(req) err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to validate authorize params") controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" { if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return return
@@ -174,7 +185,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database") controller.log.App.Error().Err(err).Msg("Failed to store user info")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return return
} }
@@ -197,10 +208,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() { if controller.oidc == nil {
tlog.App.Warn().Msg("OIDC not configured") controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(500, gin.H{
"error": "not_found", "error": "server_error",
}) })
return return
} }
@@ -209,7 +220,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err := c.Bind(&req) err := c.Bind(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request") controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -218,7 +229,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
err = controller.oidc.ValidateGrantType(req.GrantType) err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil { if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") controller.log.App.Warn().Err(err).Msg("Invalid grant type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": err.Error(), "error": err.Error(),
}) })
@@ -233,12 +244,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
// If it fails, we try basic auth // If it fails, we try basic auth
if creds.ClientID == "" || creds.ClientSecret == "" { if creds.ClientID == "" || creds.ClientSecret == "" {
tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth") controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
clientId, clientSecret, ok := c.Request.BasicAuth() clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok { if !ok {
tlog.App.Error().Msg("Missing authorization header") controller.log.App.Warn().Msg("Client credentials not found in basic auth")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
@@ -255,7 +266,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
client, ok := controller.oidc.GetClient(creds.ClientID) client, ok := controller.oidc.GetClient(creds.ClientID)
if !ok { if !ok {
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found") controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -263,7 +274,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if client.ClientSecret != creds.ClientSecret { if client.ClientSecret != creds.ClientSecret {
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret") controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
@@ -277,30 +288,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil { if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash") controller.log.App.Error().Err(err).Msg("Failed to delete code")
} }
if errors.Is(err, service.ErrCodeNotFound) { if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found") controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrCodeExpired) { if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Msg("Code expired") controller.log.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Warn().Msg("Invalid client ID") controller.log.App.Warn().Msg("Code does not belong to client")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })
return return
} }
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") controller.log.App.Error().Err(err).Msg("Failed to get code entry")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -308,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if entry.RedirectURI != req.RedirectURI { if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") controller.log.App.Warn().Msg("Redirect URI does not match")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -318,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok { if !ok {
tlog.App.Warn().Msg("PKCE validation failed") controller.log.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -328,7 +339,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token") controller.log.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -341,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenExpired) { if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Refresh token expired") controller.log.App.Warn().Msg("Refresh token expired")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
@@ -349,14 +360,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
if errors.Is(err, service.ErrInvalidClient) { if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Error().Err(err).Msg("Invalid client") controller.log.App.Warn().Msg("Refresh token does not belong to client")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
tlog.App.Error().Err(err).Msg("Failed to refresh access token") controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -373,10 +384,10 @@ func (controller *OIDCController) Token(c *gin.Context) {
} }
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() { if controller.oidc == nil {
tlog.App.Warn().Msg("OIDC not configured") controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(500, gin.H{
"error": "not_found", "error": "server_error",
}) })
return return
} }
@@ -387,7 +398,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if authorization != "" { if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ") tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok { if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -395,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -405,7 +416,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
token = bearerToken token = bearerToken
} else if c.Request.Method == http.MethodPost { } else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" { if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -413,14 +424,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
token = c.PostForm("access_token") token = c.PostForm("access_token")
if token == "" { if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body") controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
return return
} }
} else { } else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_request",
}) })
@@ -431,14 +442,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrTokenNotFound) { if errors.Is(err, service.ErrTokenNotFound) {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_grant", "error": "invalid_grant",
}) })
return return
} }
tlog.App.Err(err).Msg("Failed to get token entry") controller.log.App.Error().Err(err).Msg("Failed to get access token")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -447,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
// If we don't have the openid scope, return an error // If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_scope", "error": "invalid_scope",
}) })
@@ -457,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
user, err := controller.oidc.GetUserinfo(c, entry.Sub) user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil { if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry") controller.log.App.Error().Err(err).Msg("Failed to get user info")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "server_error", "error": "server_error",
}) })
@@ -468,7 +479,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
} }
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
tlog.App.Error().Err(err).Msg(reason) controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
if callback != "" { if callback != "" {
errorQueries := CallbackError{ errorQueries := CallbackError{
@@ -508,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
return return
} }
redirectUrl := ""
if controller.oidc != nil {
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
} else {
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()), "redirect_uri": redirectUrl,
}) })
} }
+62 -85
View File
@@ -1,47 +1,33 @@
package controller_test package controller_test
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"path"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestOIDCController(t *testing.T) { func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init() log := logger.NewLogger().WithTestConfig()
tempDir := t.TempDir() log.Init()
oidcServiceCfg := service.OIDCServiceConfig{ cfg, runtime := test.CreateTestConfigs(t)
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
controllerCfg := controller.OIDCControllerConfig{}
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
@@ -103,7 +89,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
}, },
@@ -123,7 +109,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce", Nonce: "some-nonce",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -131,7 +117,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
}, },
@@ -151,7 +137,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce", Nonce: "some-nonce",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -160,11 +146,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -183,7 +169,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -191,7 +177,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, res["error"], "unsupported_grant_type") assert.Equal(t, res["error"], "unsupported_grant_type")
}, },
@@ -206,7 +192,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -244,7 +230,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -267,11 +253,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string) redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -283,7 +269,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -306,7 +292,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok := tokenRes["refresh_token"] _, ok := tokenRes["refresh_token"]
assert.True(t, ok, "Expected refresh token in response") assert.True(t, ok, "Expected refresh token in response")
@@ -320,7 +306,7 @@ func TestOIDCController(t *testing.T) {
ClientSecret: "some-client-secret", ClientSecret: "some-client-secret",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -332,7 +318,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
var refreshRes map[string]any var refreshRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok = refreshRes["access_token"] _, ok = refreshRes["access_token"]
assert.True(t, ok, "Expected access token in refresh response") assert.True(t, ok, "Expected access token in refresh response")
@@ -353,11 +339,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string) redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -369,7 +355,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -389,7 +375,7 @@ func TestOIDCController(t *testing.T) {
var secondRes map[string]any var secondRes map[string]any
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_grant", secondRes["error"]) assert.Equal(t, "invalid_grant", secondRes["error"])
}, },
@@ -417,7 +403,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err) require.NoError(t, err)
accessToken := tokenRes["access_token"].(string) accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -429,7 +415,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok := userInfoRes["sub"] _, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response") assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -449,7 +435,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -464,7 +450,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -479,7 +465,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -494,7 +480,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"]) assert.Equal(t, "invalid_grant", res["error"])
}, },
}, },
@@ -509,7 +495,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -524,7 +510,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -541,7 +527,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err) require.NoError(t, err)
accessToken := tokenRes["access_token"].(string) accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -555,7 +541,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok := userInfoRes["sub"] _, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response") assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -579,7 +565,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "", CodeChallengeMethod: "",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -588,11 +574,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -609,7 +595,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge", CodeVerifier: "some-challenge",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err) require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -640,7 +626,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256", CodeChallengeMethod: "S256",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -649,11 +635,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -670,7 +656,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge", CodeVerifier: "some-challenge",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err) require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -701,7 +687,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256", CodeChallengeMethod: "S256",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -710,11 +696,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -731,7 +717,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge-1", CodeVerifier: "some-challenge-1",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err) require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -762,7 +748,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "foo", CodeChallengeMethod: "foo",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -771,11 +757,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
error := queryParams.Get("error") error := queryParams.Get("error")
@@ -794,11 +780,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -810,7 +796,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -821,7 +807,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
accessToken := res["access_token"].(string) accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -846,20 +832,17 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"]) assert.Equal(t, "invalid_grant", res["error"])
}, },
}, },
} }
app := bootstrap.NewBootstrapApp(model.Config{}) store := memory.New()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) wg := &sync.WaitGroup{}
require.NoError(t, err)
queries := repository.New(db) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -873,17 +856,11 @@ func TestOIDCController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) controller.NewOIDCController(log, oidcService, runtime, group)
oidcController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
} }
+43 -48
View File
@@ -11,7 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -50,29 +50,31 @@ type ProxyContext struct {
ProxyType ProxyType ProxyType ProxyType
} }
type ProxyControllerConfig struct {
AppURL string
}
type ProxyController struct { type ProxyController struct {
config ProxyControllerConfig log *logger.Logger
router *gin.RouterGroup runtime model.RuntimeConfig
acls *service.AccessControlsService acls *service.AccessControlsService
auth *service.AuthService auth *service.AuthService
} }
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController { func NewProxyController(
return &ProxyController{ log *logger.Logger,
config: config, runtime model.RuntimeConfig,
router: router, router *gin.RouterGroup,
acls: acls, acls *service.AccessControlsService,
auth: auth, auth *service.AuthService,
) *ProxyController {
controller := &ProxyController{
log: log,
runtime: runtime,
acls: acls,
auth: auth,
} }
}
func (controller *ProxyController) SetupRoutes() { proxyGroup := router.Group("/auth")
proxyGroup := controller.router.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler) proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
} }
func (controller *ProxyController) proxyHandler(c *gin.Context) { func (controller *ProxyController) proxyHandler(c *gin.Context) {
@@ -80,7 +82,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proxyCtx, err := controller.getProxyContext(c) proxyCtx, err := controller.getProxyContext(c)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to get proxy context") controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad request", "message": "Bad request",
@@ -88,19 +90,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
// Get acls // Get acls
acls, err := controller.acls.GetAccessControls(proxyCtx.Host) acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource") controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
clientIP := c.ClientIP() clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(clientIP, acls) { if controller.auth.IsBypassedIP(clientIP, acls) {
@@ -115,13 +113,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
if !authEnabled { if !authEnabled {
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access") controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
controller.setHeaders(c, acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -137,12 +135,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -160,26 +158,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated") controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
userContext = &model.UserContext{ userContext = &model.UserContext{
Authenticated: false, Authenticated: false,
} }
} }
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Authenticated { if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed { if !userAllowed {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
@@ -190,7 +186,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.GetUsername())
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -215,7 +211,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
if !groupOK { if !groupOK {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
@@ -223,7 +219,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
@@ -234,7 +230,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.GetUsername())
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -277,12 +273,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
controller.handleError(c, proxyCtx) controller.handleError(c, proxyCtx)
return return
} }
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -306,20 +302,19 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
headers := utils.ParseHeaders(acls.Response.Headers) headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers { for key, value := range headers {
tlog.App.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
if acls.Response.BasicAuth.Username != "" && basicPassword != "" { if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header") controller.log.App.Debug().Msg("Setting basic auth header for response")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
} }
} }
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL) redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
if !controller.useBrowserResponse(proxyCtx) { if !controller.useBrowserResponse(proxyCtx) {
c.Header("x-tinyauth-location", redirectURL) c.Header("x-tinyauth-location", redirectURL)
@@ -520,7 +515,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
return ProxyContext{}, err return ProxyContext{}, err
} }
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy) controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
authModules := controller.determineAuthModules(proxy) authModules := controller.determineAuthModules(proxy)
@@ -531,13 +526,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
var ctx ProxyContext var ctx ProxyContext
for _, module := range authModules { for _, module := range authModules {
tlog.App.Debug().Msgf("Trying auth module: %v", module) controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
ctx, err = controller.getContextFromAuthModule(c, module) ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil { if err == nil {
tlog.App.Debug().Msgf("Auth module %v succeeded", module) controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
break break
} }
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module) controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
} }
if err != nil { if err != nil {
@@ -549,9 +544,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
isBrowser := BrowserUserAgentRegex.MatchString(userAgent) isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser { if isBrowser {
tlog.App.Debug().Msg("Request identified as coming from a browser") controller.log.App.Debug().Msg("Request identified as coming from a browser client")
} else { } else {
tlog.App.Debug().Msg("Request identified as coming from a non-browser client") controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
} }
ctx.IsBrowser = isBrowser ctx.IsBrowser = isBrowser
+15 -60
View File
@@ -1,47 +1,26 @@
package controller_test package controller_test
import ( import (
"context"
"net/http/httptest" "net/http/httptest"
"path" "sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestProxyController(t *testing.T) { func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init() log := logger.NewLogger().WithTestConfig()
tempDir := t.TempDir() log.Init()
authServiceCfg := service.AuthServiceConfig{ cfg, runtime := test.CreateTestConfigs(t)
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
controllerCfg := controller.ProxyControllerConfig{
AppURL: "https://tinyauth.example.com",
}
acls := map[string]model.App{ acls := map[string]model.App{
"app_path_allow": { "app_path_allow": {
@@ -398,32 +377,14 @@ func TestProxyController(t *testing.T) {
}, },
} }
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) store := memory.New()
app := bootstrap.NewBootstrapApp(model.Config{}) wg := &sync.WaitGroup{}
ctx := context.TODO()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
require.NoError(t, err) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
aclsService := service.NewAccessControlsService(log, nil, acls)
queries := repository.New(db)
docker := service.NewDockerService()
err = docker.Init()
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
aclsService := service.NewAccessControlsService(docker, acls)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
@@ -438,15 +399,9 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService) controller.NewProxyController(log, runtime, group, aclsService, authService)
proxyController.SetupRoutes()
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
} }
+13 -16
View File
@@ -4,42 +4,39 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
type ResourcesControllerConfig struct {
Path string
Enabled bool
}
type ResourcesController struct { type ResourcesController struct {
config ResourcesControllerConfig config model.Config
router *gin.RouterGroup
fileServer http.Handler fileServer http.Handler
} }
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { func NewResourcesController(
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path))) config model.Config,
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
return &ResourcesController{ controller := &ResourcesController{
config: config, config: config,
router: router,
fileServer: fileServer, fileServer: fileServer,
} }
}
func (controller *ResourcesController) SetupRoutes() { router.GET("/resources/*resource", controller.resourcesHandler)
controller.router.GET("/resources/*resource", controller.resourcesHandler)
return controller
} }
func (controller *ResourcesController) resourcesHandler(c *gin.Context) { func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
if controller.config.Path == "" { if controller.config.Resources.Path == "" {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Resources not found", "message": "Resources not found",
}) })
return return
} }
if !controller.config.Enabled { if !controller.config.Resources.Enabled {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"status": 403, "status": 403,
"message": "Resources are disabled", "message": "Resources are disabled",
@@ -3,26 +3,20 @@ package controller_test
import ( import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path" "path/filepath"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/test"
) )
func TestResourcesController(t *testing.T) { func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init() cfg, _ := test.CreateTestConfigs(t)
tempDir := t.TempDir()
resourcesControllerCfg := controller.ResourcesControllerConfig{ err := os.MkdirAll(cfg.Resources.Path, 0777)
Path: path.Join(tempDir, "resources"),
Enabled: true,
}
err := os.Mkdir(resourcesControllerCfg.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
type testCase struct { type testCase struct {
@@ -61,11 +55,11 @@ func TestResourcesController(t *testing.T) {
}, },
} }
testFilePath := resourcesControllerCfg.Path + "/testfile.txt" testFilePath := cfg.Resources.Path + "/testfile.txt"
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777) err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
require.NoError(t, err) require.NoError(t, err)
testFilePathParent := tempDir + "/somefile.txt" testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt"
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777) err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
require.NoError(t, err) require.NoError(t, err)
@@ -75,8 +69,7 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/") group := router.Group("/")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
resourcesController := controller.NewResourcesController(resourcesControllerCfg, group) controller.NewResourcesController(cfg, group)
resourcesController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
test.run(t, router, recorder) test.run(t, router, recorder)
+67 -61
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
@@ -25,30 +25,30 @@ type TotpRequest struct {
Code string `json:"code"` Code string `json:"code"`
} }
type UserControllerConfig struct {
CookieDomain string
SessionCookieName string
}
type UserController struct { type UserController struct {
config UserControllerConfig log *logger.Logger
router *gin.RouterGroup runtime model.RuntimeConfig
auth *service.AuthService auth *service.AuthService
} }
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { func NewUserController(
return &UserController{ log *logger.Logger,
config: config, runtimeConfig model.RuntimeConfig,
router: router, router *gin.RouterGroup,
auth: auth, auth *service.AuthService,
) *UserController {
controller := &UserController{
log: log,
runtime: runtimeConfig,
auth: auth,
} }
}
func (controller *UserController) SetupRoutes() { userGroup := router.Group("/user")
userGroup := controller.router.Group("/user")
userGroup.POST("/login", controller.loginHandler) userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler) userGroup.POST("/totp", controller.totpHandler)
return controller
} }
func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) loginHandler(c *gin.Context) {
@@ -56,7 +56,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON") controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -64,13 +64,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
tlog.App.Debug().Str("username", req.Username).Msg("Login attempt") controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
isLocked, remaining := controller.auth.IsAccountLocked(req.Username) isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
if isLocked { if isLocked {
tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
tlog.AuditLoginFailure(c, req.Username, "username", "account locked") controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -84,16 +84,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if err != nil { if err != nil {
if errors.Is(err, service.ErrUserNotFound) { if errors.Is(err, service.ErrUserNotFound) {
tlog.App.Warn().Str("username", req.Username).Msg("User not found") controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found") controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
} }
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -102,9 +102,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password") controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") if search.Type == model.UserLocal {
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
} else {
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
}
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -118,7 +122,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
localUser = controller.auth.GetLocalUser(req.Username) localUser = controller.auth.GetLocalUser(req.Username)
if localUser == nil { if localUser == nil {
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login") controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -127,7 +131,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
if localUser.TOTPSecret != "" { if localUser.TOTPSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
name := localUser.Attributes.Name name := localUser.Attributes.Name
if name == "" { if name == "" {
@@ -136,7 +140,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
email := localUser.Attributes.Email email := localUser.Attributes.Email
if email == "" { if email == "" {
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
} }
cookie, err := controller.auth.CreateSession(c, repository.Session{ cookie, err := controller.auth.CreateSession(c, repository.Session{
@@ -148,7 +152,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}) })
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -170,7 +174,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: req.Username, Username: req.Username,
Name: utils.Capitalize(req.Username), Name: utils.Capitalize(req.Username),
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain), Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -187,12 +191,10 @@ func (controller *UserController) loginHandler(c *gin.Context) {
sessionCookie.Provider = "ldap" sessionCookie.Provider = "ldap"
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -202,8 +204,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", req.Username).Msg("Login successful") controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
if search.Type == model.UserLocal {
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
} else {
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
}
controller.auth.RecordLoginAttempt(req.Username, true) controller.auth.RecordLoginAttempt(req.Username, true)
@@ -214,20 +221,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received") controller.log.App.Debug().Msg("Logout attempt")
uuid, err := c.Cookie(controller.config.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err != nil { if err != nil {
if errors.Is(err, http.ErrNoCookie) { if errors.Is(err, http.ErrNoCookie) {
tlog.App.Warn().Msg("No session cookie found on logout request") controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logout successful", "message": "Logout successful",
}) })
return return
} }
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout") controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -238,7 +245,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
cookie, err := controller.auth.DeleteSession(c, uuid) cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -249,10 +256,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err == nil { if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID()) controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
} else { } else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
tlog.AuditLogout(c, "unknown", "unknown") controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
} }
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
@@ -268,7 +275,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON") controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
@@ -279,7 +286,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context") controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -288,7 +295,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
} }
if !context.TOTPPending() { if !context.TOTPPending() {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -296,12 +303,13 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked { if isLocked {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -314,7 +322,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
user := controller.auth.GetLocalUser(context.GetUsername()) user := controller.auth.GetLocalUser(context.GetUsername())
if user == nil { if user == nil {
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler") controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -325,9 +333,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
ok := totp.Validate(req.Code, user.TOTPSecret) ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok { if !ok {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
controller.auth.RecordLoginAttempt(context.GetUsername(), false) controller.auth.RecordLoginAttempt(context.GetUsername(), false)
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -335,15 +343,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
uuid, err := c.Cookie(controller.config.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid) _, err = controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
} else { } else {
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it") controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
} }
controller.auth.RecordLoginAttempt(context.GetUsername(), true) controller.auth.RecordLoginAttempt(context.GetUsername(), true)
@@ -351,7 +359,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain), Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
Provider: "local", Provider: "local",
} }
@@ -362,12 +370,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -377,8 +383,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
+35 -96
View File
@@ -5,8 +5,8 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -14,58 +14,20 @@ import (
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestUserController(t *testing.T) { func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init() log := logger.NewLogger().WithTestConfig()
tempDir := t.TempDir() log.Init()
authServiceCfg := service.AuthServiceConfig{ cfg, runtime := test.CreateTestConfigs(t)
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
{
Username: "attruser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
userControllerCfg := controller.UserControllerConfig{
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
@@ -111,14 +73,7 @@ func TestUserController(t *testing.T) {
}) })
} }
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) store := memory.New()
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
type testCase struct { type testCase struct {
description string description string
@@ -136,7 +91,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -144,7 +99,7 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1) require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -164,7 +119,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword", Password: "wrongpassword",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -179,13 +134,13 @@ func TestUserController(t *testing.T) {
{ {
description: "Should rate limit on 3 invalid attempts", description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{}, middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{ loginReq := controller.LoginRequest{
Username: "testuser", Username: "testuser",
Password: "wrongpassword", Password: "wrongpassword",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
for range 3 { for range 3 {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -201,7 +156,7 @@ func TestUserController(t *testing.T) {
} }
// 4th attempt should be rate limited // 4th attempt should be rate limited
recorder = httptest.NewRecorder() //nolint:staticcheck recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -220,7 +175,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -231,12 +186,12 @@ func TestUserController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, decodedBody["totpPending"], true) assert.Equal(t, decodedBody["totpPending"], true)
// should set the session cookie // should set the session cookie
assert.Len(t, recorder.Result().Cookies(), 1) require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly) assert.True(t, cookie.HttpOnly)
@@ -257,7 +212,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -266,7 +221,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
cookies := recorder.Result().Cookies() cookies := recorder.Result().Cookies()
assert.Len(t, cookies, 1) require.Len(t, cookies, 1)
cookie := cookies[0] cookie := cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -280,7 +235,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
cookies = recorder.Result().Cookies() cookies = recorder.Result().Cookies()
assert.Len(t, cookies, 1) require.Len(t, cookies, 1)
cookie = cookies[0] cookie = cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -293,8 +248,8 @@ func TestUserController(t *testing.T) {
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
totpCtx, totpCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-uuid", UUID: "test-totp-login-uuid",
Username: "test", Username: "test",
Email: "test@example.com", Email: "test@example.com",
@@ -307,16 +262,16 @@ func TestUserController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
assert.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{ totpReq := controller.TotpRequest{
Code: code, Code: code,
} }
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err) require.NoError(t, err)
recorder = httptest.NewRecorder() //nolint:staticcheck recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
@@ -329,7 +284,7 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1) require.Len(t, recorder.Result().Cookies(), 1)
// should set a new session cookie with totp pending removed // should set a new session cookie with totp pending removed
totpCookie := recorder.Result().Cookies()[0] totpCookie := recorder.Result().Cookies()[0]
@@ -345,16 +300,16 @@ func TestUserController(t *testing.T) {
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{
totpCtx, totpCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { //nolint:staticcheck run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 { for range 3 {
totpReq := controller.TotpRequest{ totpReq := controller.TotpRequest{
Code: "000000", // invalid code Code: "000000", // invalid code
} }
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err) require.NoError(t, err)
recorder = httptest.NewRecorder() //nolint:staticcheck recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -418,7 +373,7 @@ func TestUserController(t *testing.T) {
totpAttrCtx, totpAttrCtx,
}, },
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-attributes-uuid", UUID: "test-totp-login-attributes-uuid",
Username: "test", Username: "test",
Email: "test@example.com", Email: "test@example.com",
@@ -456,21 +411,11 @@ func TestUserController(t *testing.T) {
}, },
} }
docker := service.NewDockerService() ctx := context.TODO()
err = docker.Init() wg := &sync.WaitGroup{}
require.NoError(t, err)
ldap := service.NewLdapService(service.LdapServiceConfig{}) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
err = ldap.Init() authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -489,17 +434,11 @@ func TestUserController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
userController := controller.NewUserController(userControllerCfg, group, authService) controller.NewUserController(log, runtime, group, authService)
userController.SetupRoutes()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
} }
+25 -15
View File
@@ -26,28 +26,30 @@ type OpenIDConnectConfiguration struct {
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
} }
type WellKnownControllerConfig struct{}
type WellKnownController struct { type WellKnownController struct {
config WellKnownControllerConfig oidc *service.OIDCService
engine *gin.Engine
oidc *service.OIDCService
} }
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController { func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
return &WellKnownController{ controller := &WellKnownController{
config: config, oidc: oidc,
oidc: oidc,
engine: engine,
} }
}
func (controller *WellKnownController) SetupRoutes() { router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) router.GET("/.well-known/jwks.json", controller.JWKS)
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
return controller
} }
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC service not configured",
})
return
}
issuer := controller.oidc.GetIssuer() issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{ c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer, Issuer: issuer,
@@ -69,11 +71,19 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
} }
func (controller *WellKnownController) JWKS(c *gin.Context) { func (controller *WellKnownController) JWKS(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC service not configured",
})
return
}
jwks, err := controller.oidc.GetJWK() jwks, err := controller.oidc.GetJWK()
if err != nil { if err != nil {
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": "500", "status": 500,
"message": "failed to get JWK", "message": "failed to get JWK",
}) })
return return
@@ -1,41 +1,28 @@
package controller_test package controller_test
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http/httptest" "net/http/httptest"
"path" "sync"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestWellKnownController(t *testing.T) { func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init() log := logger.NewLogger().WithTestConfig()
tempDir := t.TempDir() log.Init()
oidcServiceCfg := service.OIDCServiceConfig{ cfg, runtime := test.CreateTestConfigs(t)
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
type testCase struct { type testCase struct {
description string description string
@@ -56,11 +43,11 @@ func TestWellKnownController(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{ expected := controller.OpenIDConnectConfiguration{
Issuer: oidcServiceCfg.Issuer, Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer), AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer), TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer), UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL),
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer), JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL),
ScopesSupported: service.SupportedScopes, ScopesSupported: service.SupportedScopes,
ResponseTypesSupported: service.SupportedResponseTypes, ResponseTypesSupported: service.SupportedResponseTypes,
GrantTypesSupported: service.SupportedGrantTypes, GrantTypesSupported: service.SupportedGrantTypes,
@@ -101,15 +88,12 @@ func TestWellKnownController(t *testing.T) {
}, },
} }
app := bootstrap.NewBootstrapApp(model.Config{}) ctx := context.TODO()
wg := &sync.WaitGroup{}
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) store := memory.New()
require.NoError(t, err)
queries := repository.New(db) oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
require.NoError(t, err) require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
@@ -119,15 +103,9 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router) controller.NewWellKnownController(oidcService, &router.RouterGroup)
wellKnownController.SetupRoutes()
test.run(t, router, recorder) test.run(t, router, recorder)
}) })
} }
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
} }
+25 -27
View File
@@ -10,7 +10,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -35,29 +35,27 @@ var (
} }
) )
type ContextMiddlewareConfig struct {
CookieDomain string
SessionCookieName string
}
type ContextMiddleware struct { type ContextMiddleware struct {
config ContextMiddlewareConfig log *logger.Logger
auth *service.AuthService runtime model.RuntimeConfig
broker *service.OAuthBrokerService auth *service.AuthService
broker *service.OAuthBrokerService
} }
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { func NewContextMiddleware(
log *logger.Logger,
runtime model.RuntimeConfig,
auth *service.AuthService,
broker *service.OAuthBrokerService,
) *ContextMiddleware {
return &ContextMiddleware{ return &ContextMiddleware{
config: config, log: log,
auth: auth, runtime: runtime,
broker: broker, auth: auth,
broker: broker,
} }
} }
func (m *ContextMiddleware) Init() error {
return nil
}
func (m *ContextMiddleware) Middleware() gin.HandlerFunc { func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
@@ -65,7 +63,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return return
} }
uuid, err := c.Cookie(m.config.SessionCookieName) uuid, err := c.Cookie(m.runtime.SessionCookieName)
if err == nil { if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
@@ -75,12 +73,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
http.SetCookie(c.Writer, cookie) http.SetCookie(c.Writer, cookie)
} }
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername()) m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername())
c.Set("context", userContext) c.Set("context", userContext)
c.Next() c.Next()
return return
} else { } else {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
} }
} }
@@ -90,7 +88,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
userContext, headers, err := m.basicAuth(username, password) userContext, headers, err := m.basicAuth(username, password)
if err != nil { if err != nil {
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) m.log.App.Error().Msgf("Error authenticating basic auth: %v", err)
c.Next() c.Next()
return return
} }
@@ -141,7 +139,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
} }
if userContext.Local.Attributes.Email == "" { if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain) userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain)
} }
case model.ProviderLDAP: case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username) search, err := m.auth.SearchUser(userContext.LDAP.Username)
@@ -162,7 +160,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
userContext.LDAP.Groups = user.Groups userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain) userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain)
case model.ProviderOAuth: case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID) _, exists := m.broker.GetService(userContext.OAuth.ID)
@@ -171,7 +169,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid) //nolint:errcheck m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
@@ -191,7 +189,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
locked, remaining := m.auth.IsAccountLocked(username) locked, remaining := m.auth.IsAccountLocked(username)
if locked { if locked {
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
headers["x-tinyauth-lock-locked"] = "true" headers["x-tinyauth-lock-locked"] = "true"
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
return nil, headers, nil return nil, headers, nil
@@ -224,7 +222,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{ BaseContext: model.BaseContext{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain),
}, },
Attributes: user.Attributes, Attributes: user.Attributes,
} }
@@ -240,7 +238,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
BaseContext: model.BaseContext{ BaseContext: model.BaseContext{
Username: username, Username: username,
Name: utils.Capitalize(username), Name: utils.Capitalize(username),
Email: utils.CompileUserEmail(username, m.config.CookieDomain), Email: utils.CompileUserEmail(username, m.runtime.CookieDomain),
}, },
Groups: user.Groups, Groups: user.Groups,
} }
+16 -57
View File
@@ -5,54 +5,33 @@ import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path" "sync"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestContextMiddleware(t *testing.T) { func TestContextMiddleware(t *testing.T) {
tlog.NewTestLogger().Init() log := logger.NewLogger().WithTestConfig()
tempDir := t.TempDir() log.Init()
authServiceCfg := service.AuthServiceConfig{ cfg, runtime := test.CreateTestConfigs(t)
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
middlewareCfg := middleware.ContextMiddlewareConfig{
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
} }
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) {
t.Helper() t.Helper()
_, err := queries.CreateSession(context.Background(), params) _, err := queries.CreateSession(context.Background(), params)
require.NoError(t, err) require.NoError(t, err)
@@ -60,7 +39,7 @@ func TestContextMiddleware(t *testing.T) {
type runArgs struct { type runArgs struct {
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
queries *repository.Queries queries repository.Store
} }
type testCase struct { type testCase struct {
@@ -270,30 +249,15 @@ func TestContextMiddleware(t *testing.T) {
}, },
} }
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) ctx := context.TODO()
wg := &sync.WaitGroup{}
app := bootstrap.NewBootstrapApp(model.Config{}) store := memory.New()
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
require.NoError(t, err) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker)
queries := repository.New(db) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
err = contextMiddleware.Init()
require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
authService.ClearRateLimitsTestingOnly() authService.ClearRateLimitsTestingOnly()
@@ -317,12 +281,7 @@ func TestContextMiddleware(t *testing.T) {
return captured, recorder return captured, recorder
} }
test.run(t, runArgs{do: do, queries: queries}) test.run(t, runArgs{do: do, queries: store})
}) })
} }
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
})
} }
+4 -9
View File
@@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -19,29 +18,25 @@ type UIMiddleware struct {
uiFileServer http.Handler uiFileServer http.Handler
} }
func NewUIMiddleware() *UIMiddleware { func NewUIMiddleware() (*UIMiddleware, error) {
return &UIMiddleware{} m := &UIMiddleware{}
}
func (m *UIMiddleware) Init() error {
ui, err := fs.Sub(assets.FrontendAssets, "dist") ui, err := fs.Sub(assets.FrontendAssets, "dist")
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to load ui assets: %w", err)
} }
m.uiFs = ui m.uiFs = ui
m.uiFileServer = http.FileServerFS(ui) m.uiFileServer = http.FileServerFS(ui)
return nil return m, nil
} }
func (m *UIMiddleware) Middleware() gin.HandlerFunc { func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/") path := strings.TrimPrefix(c.Request.URL.Path, "/")
tlog.App.Debug().Str("path", path).Msg("path")
switch strings.SplitN(path, "/", 2)[0] { switch strings.SplitN(path, "/", 2)[0] {
case "api", "resources", ".well-known": case "api", "resources", ".well-known":
c.Next() c.Next()
+8 -8
View File
@@ -5,7 +5,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
// See context middleware for explanation of why we have to do this // See context middleware for explanation of why we have to do this
@@ -17,14 +17,14 @@ var (
} }
) )
type ZerologMiddleware struct{} type ZerologMiddleware struct {
log *logger.Logger
func NewZerologMiddleware() *ZerologMiddleware {
return &ZerologMiddleware{}
} }
func (m *ZerologMiddleware) Init() error { func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
return nil return &ZerologMiddleware{
log: log,
}
} }
func (m *ZerologMiddleware) logPath(path string) bool { func (m *ZerologMiddleware) logPath(path string) bool {
@@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
latency := time.Since(tStart).String() latency := time.Since(tStart).String()
subLogger := tlog.HTTP.With().Str("method", method). subLogger := m.log.HTTP.With().Str("method", method).
Str("path", path). Str("path", path).
Str("address", address). Str("address", address).
Str("client_ip", clientIP). Str("client_ip", clientIP).
+15 -11
View File
@@ -4,7 +4,8 @@ package model
func NewDefaultConfiguration() *Config { func NewDefaultConfiguration() *Config {
return &Config{ return &Config{
Database: DatabaseConfig{ Database: DatabaseConfig{
Path: "./tinyauth.db", Driver: "sqlite",
Path: "./tinyauth.db",
}, },
Analytics: AnalyticsConfig{ Analytics: AnalyticsConfig{
Enabled: true, Enabled: true,
@@ -14,8 +15,9 @@ func NewDefaultConfiguration() *Config {
Path: "./resources", Path: "./resources",
}, },
Server: ServerConfig{ Server: ServerConfig{
Port: 3000, Port: 3000,
Address: "0.0.0.0", Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
}, },
Auth: AuthConfig{ Auth: AuthConfig{
SubdomainsEnabled: true, SubdomainsEnabled: true,
@@ -82,7 +84,8 @@ type Config struct {
} }
type DatabaseConfig struct { type DatabaseConfig struct {
Path string `description:"The path to the database, including file name." yaml:"path"` Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
} }
type AnalyticsConfig struct { type AnalyticsConfig struct {
@@ -95,9 +98,10 @@ type ResourcesConfig struct {
} }
type ServerConfig struct { type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"` Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"` Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
} }
type AuthConfig struct { type AuthConfig struct {
@@ -147,10 +151,10 @@ type IPConfig struct {
} }
type OAuthConfig struct { type OAuthConfig struct {
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
} }
type OIDCConfig struct { type OIDCConfig struct {
+6 -2
View File
@@ -8,6 +8,10 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
) )
var (
ErrUserContextNotFound = errors.New("user context not found")
)
type ProviderType int type ProviderType int
const ( const (
@@ -74,7 +78,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContextValue, exists := ginctx.Get("context") userContextValue, exists := ginctx.Get("context")
if !exists { if !exists {
return nil, errors.New("failed to get user context") return nil, ErrUserContextNotFound
} }
userContext, ok := userContextValue.(*UserContext) userContext, ok := userContextValue.(*UserContext)
@@ -117,7 +121,7 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
Email: session.Email, Email: session.Email,
}, },
} }
// By default we assume an unkown name which is oauth // By default we assume an unknown name which is oauth
default: default:
c.Provider = ProviderOAuth c.Provider = ProviderOAuth
c.OAuth = &OAuthContext{ c.OAuth = &OAuthContext{
+1 -1
View File
@@ -238,7 +238,7 @@ func TestContext(t *testing.T) {
_, err := c.NewFromGin(newGinCtx(nil, false)) _, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error() return err.Error()
}, },
expected: "failed to get user context", expected: model.ErrUserContextNotFound.Error(),
}, },
{ {
description: "NewFromGin returns error when context value has wrong type", description: "NewFromGin returns error when context value has wrong type",
+22
View File
@@ -0,0 +1,22 @@
package model
type RuntimeConfig struct {
AppURL string
UUID string
CookieDomain string
SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string
LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string
ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig
}
type Provider struct {
Name string `json:"name"`
ID string `json:"id"`
OAuth bool `json:"oauth"`
}
+427
View File
@@ -0,0 +1,427 @@
package memory_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
)
var ctx = context.Background()
func TestCreateAndGetSession(t *testing.T) {
s := memory.New()
sess, err := s.CreateSession(ctx, repository.CreateSessionParams{
UUID: "uuid-1",
Username: "alice",
Expiry: 9999,
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", sess.UUID)
assert.Equal(t, "alice", sess.Username)
got, err := s.GetSession(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, sess, got)
}
func TestGetSession_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetSession(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestUpdateSession(t *testing.T) {
s := memory.New()
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"})
require.NoError(t, err)
updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{
UUID: "uuid-1",
Username: "bob",
Email: "bob@example.com",
})
require.NoError(t, err)
assert.Equal(t, "bob", updated.Username)
assert.Equal(t, "bob@example.com", updated.Email)
got, err := s.GetSession(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, updated, got)
}
func TestUpdateSession_NotFound(t *testing.T) {
s := memory.New()
_, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteSession(t *testing.T) {
s := memory.New()
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteSession(ctx, "uuid-1"))
_, err = s.GetSession(ctx, "uuid-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteExpiredSessions(t *testing.T) {
s := memory.New()
_, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10})
require.NoError(t, err)
_, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100})
require.NoError(t, err)
require.NoError(t, s.DeleteExpiredSessions(ctx, 50))
_, err = s.GetSession(ctx, "expired")
assert.ErrorIs(t, err, repository.ErrNotFound)
_, err = s.GetSession(ctx, "valid")
assert.NoError(t, err)
}
func TestCreateAndGetOidcCode(t *testing.T) {
s := memory.New()
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
Sub: "sub-1",
CodeHash: "hash-1",
Scope: "openid",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", code.Sub)
// destructive read removes the record
got, err := s.GetOidcCode(ctx, "hash-1")
require.NoError(t, err)
assert.Equal(t, code, got)
_, err = s.GetOidcCode(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestGetOidcCode_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcCode(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestGetOidcCodeBySub(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// destructive — gone after read
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestGetOidcCodeBySub_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcCodeBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestGetOidcCodeUnsafe(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// non-destructive — still present
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.NoError(t, err)
}
func TestGetOidcCodeUnsafe_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestGetOidcCodeBySubUnsafe(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "hash-1", got.CodeHash)
// non-destructive — still present
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
assert.NoError(t, err)
}
func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
}
func TestDeleteOidcCode(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteOidcCodeBySub(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteExpiredOidcCodes(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
require.NoError(t, err)
require.Len(t, deleted, 1)
assert.Equal(t, "hash-1", deleted[0].CodeHash)
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
assert.NoError(t, err)
}
func TestCreateAndGetOidcToken(t *testing.T) {
s := memory.New()
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-hash-1",
CodeHash: "code-hash-1",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", tok.Sub)
got, err := s.GetOidcToken(ctx, "at-hash-1")
require.NoError(t, err)
assert.Equal(t, tok, got)
}
func TestGetOidcToken_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
}
func TestGetOidcTokenByRefreshToken(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
}
func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestGetOidcTokenBySub(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
})
require.NoError(t, err)
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "at-1", got.AccessTokenHash)
}
func TestGetOidcTokenBySub_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcTokenBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestUpdateOidcTokenByRefreshToken(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
RefreshTokenHash_2: "rt-1",
AccessTokenHash: "at-2",
RefreshTokenHash: "rt-2",
TokenExpiresAt: 200,
RefreshTokenExpiresAt: 400,
})
require.NoError(t, err)
assert.Equal(t, "at-2", updated.AccessTokenHash)
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
// old key gone, new key present
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
got, err := s.GetOidcToken(ctx, "at-2")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
}
func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) {
s := memory.New()
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
RefreshTokenHash_2: "missing",
})
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteOidcToken(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteOidcTokenBySub(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteOidcTokenByCodeHash(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
CodeHash: "code-1",
})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteExpiredOidcTokens(t *testing.T) {
s := memory.New()
// expired by TokenExpiresAt
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", AccessTokenHash: "at-1",
TokenExpiresAt: 10, RefreshTokenExpiresAt: 100,
})
require.NoError(t, err)
// expired by RefreshTokenExpiresAt
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-2", AccessTokenHash: "at-2",
TokenExpiresAt: 100, RefreshTokenExpiresAt: 10,
})
require.NoError(t, err)
// valid
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-3", AccessTokenHash: "at-3",
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
})
require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: 50,
RefreshTokenExpiresAt: 50,
})
require.NoError(t, err)
assert.Len(t, deleted, 2)
_, err = s.GetOidcToken(ctx, "at-3")
assert.NoError(t, err)
}
func TestCreateAndGetOidcUserInfo(t *testing.T) {
s := memory.New()
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
Sub: "sub-1",
Name: "Alice",
Email: "alice@example.com",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", u.Sub)
got, err := s.GetOidcUserInfo(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, u, got)
}
func TestGetOidcUserInfo_NotFound(t *testing.T) {
s := memory.New()
_, err := s.GetOidcUserInfo(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
func TestDeleteOidcUserInfo(t *testing.T) {
s := memory.New()
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
_, err = s.GetOidcUserInfo(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
}
+241
View File
@@ -0,0 +1,241 @@
package memory
import (
"context"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, c := range s.oidcCodes {
if c.Sub == arg.Sub {
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
}
}
code := repository.OidcCode(arg)
s.oidcCodes[arg.CodeHash] = code
return code, nil
}
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcCode{}, repository.ErrNotFound
}
delete(s.oidcCodes, codeHash)
return c, nil
}
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcCode{}, repository.ErrNotFound
}
return c, nil
}
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.oidcCodes {
if c.Sub == sub {
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcCodes, codeHash)
return nil
}
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcCode
for k, c := range s.oidcCodes {
if c.ExpiresAt < expiresAt {
deleted = append(deleted, c)
delete(s.oidcCodes, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, t := range s.oidcTokens {
if t.Sub == arg.Sub {
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
}
}
tok := repository.OidcToken{
Sub: arg.Sub,
AccessTokenHash: arg.AccessTokenHash,
RefreshTokenHash: arg.RefreshTokenHash,
CodeHash: arg.CodeHash,
Scope: arg.Scope,
ClientID: arg.ClientID,
TokenExpiresAt: arg.TokenExpiresAt,
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
Nonce: arg.Nonce,
}
s.oidcTokens[arg.AccessTokenHash] = tok
return tok, nil
}
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.oidcTokens[accessTokenHash]
if !ok {
return repository.OidcToken{}, repository.ErrNotFound
}
return t, nil
}
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.RefreshTokenHash == refreshTokenHash {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.Sub == sub {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
delete(s.oidcTokens, k)
t.AccessTokenHash = arg.AccessTokenHash
t.RefreshTokenHash = arg.RefreshTokenHash
t.TokenExpiresAt = arg.TokenExpiresAt
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
s.oidcTokens[arg.AccessTokenHash] = t
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcTokens, accessTokenHash)
return nil
}
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.Sub == sub {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.CodeHash == codeHash {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcToken
for k, t := range s.oidcTokens {
if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
deleted = append(deleted, t)
delete(s.oidcTokens, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
u := repository.OidcUserinfo(arg)
s.oidcUsers[arg.Sub] = u
return u, nil
}
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
u, ok := s.oidcUsers[sub]
if !ok {
return repository.OidcUserinfo{}, repository.ErrNotFound
}
return u, nil
}
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcUsers, sub)
return nil
}
@@ -0,0 +1,63 @@
package memory
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess := repository.Session(arg)
s.sessions[arg.UUID] = sess
return sess, nil
}
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.sessions[uuid]
if !ok {
return repository.Session{}, repository.ErrNotFound
}
return sess, nil
}
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess, ok := s.sessions[arg.UUID]
if !ok {
return repository.Session{}, repository.ErrNotFound
}
sess.Username = arg.Username
sess.Email = arg.Email
sess.Name = arg.Name
sess.Provider = arg.Provider
sess.TotpPending = arg.TotpPending
sess.OAuthGroups = arg.OAuthGroups
sess.Expiry = arg.Expiry
sess.OAuthName = arg.OAuthName
sess.OAuthSub = arg.OAuthSub
s.sessions[arg.UUID] = sess
return sess, nil
}
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, uuid)
return nil
}
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, v := range s.sessions {
if v.Expiry < expiry {
delete(s.sessions, k)
}
}
return nil
}
+27
View File
@@ -0,0 +1,27 @@
// Package memory provides an in-memory implementation of repository.Store for use in tests.
package memory
import (
"sync"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store is a thread-safe in-memory implementation of repository.Store.
type Store struct {
mu sync.RWMutex
sessions map[string]repository.Session
oidcCodes map[string]repository.OidcCode
oidcTokens map[string]repository.OidcToken
oidcUsers map[string]repository.OidcUserinfo
}
// New returns a new empty in-memory Store.
func New() repository.Store {
return &Store{
sessions: make(map[string]repository.Session),
oidcCodes: make(map[string]repository.OidcCode),
oidcTokens: make(map[string]repository.OidcToken),
oidcUsers: make(map[string]repository.OidcUserinfo),
}
}
+89 -5
View File
@@ -1,9 +1,22 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
package repository package repository
// Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go.
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
type OidcCode struct { type OidcCode struct {
Sub string Sub string
CodeHash string CodeHash string
@@ -49,7 +62,7 @@ type OidcUserinfo struct {
Address string Address string
} }
type Session struct { type CreateSessionParams struct {
UUID string UUID string
Username string Username string
Email string Email string
@@ -62,3 +75,74 @@ type Session struct {
OAuthName string OAuthName string
OAuthSub string OAuthSub string
} }
type UpdateSessionParams struct {
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
UUID string
}
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
CodeHash string
Nonce string
}
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
@@ -1,8 +1,8 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
package repository package sqlite
import ( import (
"context" "context"
+3
View File
@@ -0,0 +1,3 @@
package sqlite
//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite
+64
View File
@@ -0,0 +1,64 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
package sqlite
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
// source: oidc_queries.sql // source: oidc_queries.sql
package repository package sqlite
import ( import (
"context" "context"
@@ -1,9 +1,9 @@
// Code generated by sqlc. DO NOT EDIT. // Code generated by sqlc. DO NOT EDIT.
// versions: // versions:
// sqlc v1.30.0 // sqlc v1.31.1
// source: session_queries.sql // source: session_queries.sql
package repository package sqlite
import ( import (
"context" "context"
+224
View File
@@ -0,0 +1,224 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package sqlite
import (
"context"
"database/sql"
"errors"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errMap = []struct {
from error
to error
}{
{sql.ErrNoRows, repository.ErrNotFound},
}
func mapErr(err error) error {
for _, e := range errMap {
if errors.Is(err, e.from) {
return e.to
}
}
return err
}
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
return repository.OidcCode(v)
}
func oidcTokenToRepo(v OidcToken) repository.OidcToken {
return repository.OidcToken(v)
}
func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo {
return repository.OidcUserinfo(v)
}
func sessionToRepo(v Session) repository.Session {
return repository.Session(v)
}
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = oidcCodeToRepo(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = oidcTokenToRepo(row)
}
return out, nil
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return oidcCodeToRepo(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return oidcUserinfoToRepo(r), nil
}
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
r, err := s.q.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return oidcTokenToRepo(r), nil
}
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return sessionToRepo(r), nil
}
+47
View File
@@ -0,0 +1,47 @@
package repository
import (
"context"
"errors"
)
// ErrNotFound is returned by Store methods when the requested record does not exist.
var ErrNotFound = errors.New("not found")
// Store is the interface that all storage drivers must implement.
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
// Future drivers (postgres, etc.) must return the shared types defined in this package.
type Store interface {
// Sessions
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
GetSession(ctx context.Context, uuid string) (Session, error)
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC codes
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
DeleteOidcCode(ctx context.Context, codeHash string) error
DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
}
+18 -13
View File
@@ -4,7 +4,7 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LabelProvider interface { type LabelProvider interface {
@@ -12,32 +12,33 @@ type LabelProvider interface {
} }
type AccessControlsService struct { type AccessControlsService struct {
labelProvider LabelProvider log *logger.Logger
labelProvider *LabelProvider
static map[string]model.App static map[string]model.App
} }
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { func NewAccessControlsService(
log *logger.Logger,
labelProvider *LabelProvider,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log,
labelProvider: labelProvider, labelProvider: labelProvider,
static: static, static: static,
} }
} }
func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App var appAcls *model.App
for app, config := range acls.static { for app, config := range acls.static {
if config.Config.Domain == domain { if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config appAcls = &config
break // If we find a match by domain, we can stop searching break // If we find a match by domain, we can stop searching
} }
if strings.SplitN(domain, ".", 2)[0] == app { if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config appAcls = &config
break // If we find a match by app name, we can stop searching break // If we find a match by app name, we can stop searching
} }
@@ -50,11 +51,15 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
app := acls.lookupStaticACLs(domain) app := acls.lookupStaticACLs(domain)
if app != nil { if app != nil {
tlog.App.Debug().Msg("Using ACls from static configuration") acls.log.App.Debug().Msg("Using static ACLs for app")
return app, nil return app, nil
} }
// Fallback to label provider // If we have a label provider configured, try to get ACLs from it
tlog.App.Debug().Msg("Falling back to label provider for ACLs") if acls.labelProvider != nil {
return acls.labelProvider.GetLabels(domain) return (*acls.labelProvider).GetLabels(domain)
}
// no labels
return nil, nil
} }
+103 -95
View File
@@ -2,7 +2,6 @@ package service
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -14,7 +13,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices" "slices"
@@ -72,39 +71,41 @@ type Lockdown struct {
ActiveUntil time.Time ActiveUntil time.Time
} }
type AuthServiceConfig struct {
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP model.IPConfig
LDAPGroupsCacheTTL int
SubdomainsEnabled bool
}
type AuthService struct { type AuthService struct {
config AuthServiceConfig log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries repository.Store
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex oauthMutex sync.RWMutex
loginMutex sync.RWMutex loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown lockdown *Lockdown
lockdownCtx context.Context lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc lockdownCancelFunc context.CancelFunc
} }
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { func NewAuthService(
return &AuthService{ log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries repository.Store,
oauthBroker *OAuthBrokerService,
) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
context: ctx,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -113,11 +114,10 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
} }
}
func (auth *AuthService) Init() error { wg.Go(service.CleanupOAuthSessionsRoutine)
go auth.CleanupOAuthSessionsRoutine()
return nil return service
} }
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
@@ -128,7 +128,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}, nil }, nil
} }
if auth.ldap.IsConfigured() { if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username) userDN, err := auth.ldap.GetUserDN(username)
if err != nil { if err != nil {
@@ -153,7 +153,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
} }
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP: case model.UserLDAP:
if auth.ldap.IsConfigured() { if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password) err := auth.ldap.Bind(search.Username, password)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err) return fmt.Errorf("failed to bind to ldap user: %w", err)
@@ -173,10 +173,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
} }
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.config.LocalUsers == nil { if auth.runtime.LocalUsers == nil {
return nil return nil
} }
for _, user := range *auth.config.LocalUsers { for _, user := range auth.runtime.LocalUsers {
if user.Username == username { if user.Username == username {
return &user return &user
} }
@@ -185,7 +185,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
} }
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() { if auth.ldap == nil {
return nil, errors.New("ldap service not configured") return nil, errors.New("ldap service not configured")
} }
@@ -209,7 +209,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.Lock() auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups, Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second), Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
} }
auth.ldapGroupsMutex.Unlock() auth.ldapGroupsMutex.Unlock()
@@ -228,7 +228,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining return true, remaining
} }
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0 return false, 0
} }
@@ -246,7 +246,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
} }
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return return
} }
@@ -277,14 +277,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++ attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.LoginMaxRetries { if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
} }
} }
func (auth *AuthService) IsEmailWhitelisted(email string) bool { func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -299,7 +299,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending { if data.TotpPending {
expiry = 3600 expiry = 3600
} else { } else {
expiry = auth.config.SessionExpiry expiry = auth.config.Auth.SessionExpiry
} }
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
@@ -325,13 +325,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.config.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie, Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -348,8 +348,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
var refreshThreshold int64 var refreshThreshold int64
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) { if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.SessionExpiry / 2) refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else { } else {
refreshThreshold = int64(time.Hour.Seconds()) refreshThreshold = int64(time.Hour.Seconds())
} }
@@ -378,13 +378,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.config.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie, Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -395,23 +395,17 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return nil, err
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.config.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.SecureCookie, Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -421,7 +415,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
session, err := auth.queries.GetSession(ctx, uuid) session, err := auth.queries.GetSession(ctx, uuid)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, repository.ErrNotFound) {
return nil, errors.New("session not found") return nil, errors.New("session not found")
} }
return nil, err return nil, err
@@ -429,8 +423,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) return nil, fmt.Errorf("failed to delete expired session: %w", err)
@@ -451,11 +445,11 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
} }
func (auth *AuthService) LocalAuthConfigured() bool { func (auth *AuthService) LocalAuthConfigured() bool {
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 return len(auth.runtime.LocalUsers) > 0
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured() return auth.ldap != nil
} }
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
@@ -464,18 +458,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
} }
if context.Provider == model.ProviderOAuth { if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist") auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
} }
if acls.Users.Block != "" { if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users") auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false return false
} }
} }
tlog.App.Debug().Msg("Checking users") auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
} }
@@ -485,23 +479,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
} }
if !context.IsOAuth() { if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false return false
} }
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true return true
} }
for _, userGroup := range context.OAuth.Groups { for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true return true
} }
} }
tlog.App.Debug().Msg("No groups matched") auth.log.App.Debug().Msg("No groups matched")
return false return false
} }
@@ -511,18 +505,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
} }
if !context.IsLDAP() { if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false return false
} }
for _, userGroup := range context.LDAP.Groups { for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true return true
} }
} }
tlog.App.Debug().Msg("No groups matched") auth.log.App.Debug().Msg("No groups matched")
return false return false
} }
@@ -566,17 +560,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
} }
// Merge the global and app IP filter // Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.IP.Block...) blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps { for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip) res, err := utils.FilterIP(blocked, ip)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue continue
} }
if res { if res {
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false return false
} }
} }
@@ -584,21 +578,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs { for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip) res, err := utils.FilterIP(allowed, ip)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue continue
} }
if res { if res {
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access") auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true return true
} }
} }
if len(allowedIPs) > 0 { if len(allowedIPs) > 0 {
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false return false
} }
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true return true
} }
@@ -610,16 +604,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
for _, bypassed := range acls.IP.Bypass { for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue continue
} }
if res { if res {
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access") auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true return true
} }
} }
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false return false
} }
@@ -723,21 +717,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
} }
func (auth *AuthService) CleanupOAuthSessionsRoutine() { func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for {
auth.oauthMutex.Lock() select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
now := time.Now() auth.oauthMutex.Lock()
for sessionId, session := range auth.oauthPendingSessions { now := time.Now()
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId) for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
} }
}
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
} }
} }
@@ -806,11 +811,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock() auth.loginMutex.Lock()
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{ auth.lockdown = &Lockdown{
Active: true, Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second), ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
} }
// At this point all login attemps will also expire so, // At this point all login attemps will also expire so,
@@ -827,11 +832,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
} }
auth.loginMutex.Lock() auth.loginMutex.Lock()
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil auth.lockdown = nil
auth.loginMutex.Unlock() auth.loginMutex.Unlock()
} }
+41 -25
View File
@@ -3,51 +3,56 @@ package service
import ( import (
"context" "context"
"strings" "strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
type DockerService struct { type DockerService struct {
client *client.Client log *logger.Logger
context context.Context client *client.Client
context context.Context
isConnected bool isConnected bool
} }
func NewDockerService() *DockerService { func NewDockerService(
return &DockerService{} log *logger.Logger,
} ctx context.Context,
wg *sync.WaitGroup,
) (*DockerService, error) {
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil { if err != nil {
return err return nil, err
} }
ctx := context.Background()
client.NegotiateAPIVersion(ctx) client.NegotiateAPIVersion(ctx)
docker.client = client _, err = client.Ping(ctx)
docker.context = ctx
_, err = docker.client.Ping(docker.context)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("Docker not connected") log.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false return nil, nil
docker.client = nil
docker.context = nil
return nil
} }
docker.isConnected = true service := &DockerService{
tlog.App.Debug().Msg("Docker connected") log: log,
client: client,
context: ctx,
}
return nil service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
wg.Go(service.watchAndClose)
return service, nil
} }
func (docker *DockerService) getContainers() ([]container.Summary, error) { func (docker *DockerService) getContainers() ([]container.Summary, error) {
@@ -60,7 +65,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected { if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels") docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
return nil, nil return nil, nil
} }
@@ -82,17 +87,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
for appName, appLabels := range labels.Apps { for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain { if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil return &appLabels, nil
} }
if strings.SplitN(appDomain, ".", 2)[0] == appName { if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil return &appLabels, nil
} }
} }
} }
tlog.App.Debug().Msg("No matching container found, returning empty labels") docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
return nil, nil return nil, nil
} }
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
}
+64 -60
View File
@@ -9,7 +9,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -36,9 +36,10 @@ type ingressApp struct {
} }
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface client dynamic.Interface
ctx context.Context
cancel context.CancelFunc
started bool started bool
mu sync.RWMutex mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp ingressApps map[ingressKey][]ingressApp
@@ -46,12 +47,55 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey appNameIndex map[string]ingressAppKey
} }
func NewKubernetesService() *KubernetesService { func NewKubernetesService(
return &KubernetesService{ log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
} }
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
@@ -133,7 +177,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
} }
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
return return
} }
@@ -161,13 +205,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
return err return err
} }
for i := range list.Items { for i := range list.Items {
k.updateFromItem(&list.Items[i]) k.updateFromItem(&list.Items[i])
} }
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
return nil return nil
} }
@@ -181,14 +225,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
return false return false
case event, ok := <-w.ResultChan(): case event, ok := <-w.ResultChan():
if !ok { if !ok {
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds") k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
w.Stop() w.Stop()
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
return true return true
} }
item, ok := event.Object.(*unstructured.Unstructured) item, ok := event.Object.(*unstructured.Unstructured)
if !ok { if !ok {
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object") k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
continue continue
} }
switch event.Type { switch event.Type {
@@ -199,7 +243,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
} }
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
} }
} }
} }
@@ -210,29 +254,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
defer resyncTicker.Stop() defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
for { for {
select { select {
case <-k.ctx.Done(): case <-k.ctx.Done():
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
} }
default: default:
ctx, cancel := context.WithCancel(k.ctx) ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
cancel() cancel()
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
continue continue
} }
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) { if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel() cancel()
return return
@@ -242,65 +286,25 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
} }
} }
func (k *KubernetesService) Init() error {
var cfg *rest.Config
var err error
cfg, err = rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create Kubernetes client: %w", err)
}
k.client = client
k.ctx, k.cancel = context.WithCancel(context.Background())
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
k.started = false
return nil
}
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
go k.watchGVR(gvr)
k.started = true
tlog.App.Info().Msg("Kubernetes label provider initialized")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started { if !k.started {
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
return nil, nil return nil, nil
} }
// First check cache // First check cache
app := k.getByDomain(appDomain) app := k.getByDomain(appDomain)
if app != nil { if app != nil {
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil return app, nil
} }
appName := strings.SplitN(appDomain, ".", 2)[0] appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName) app = k.getByAppName(appName)
if app != nil { if app != nil {
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil return app, nil
} }
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
return nil, nil return nil, nil
} }
@@ -8,9 +8,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func TestKubernetesService(t *testing.T) { func TestKubernetesService(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
type testCase struct { type testCase struct {
description string description string
run func(t *testing.T, svc *KubernetesService) run func(t *testing.T, svc *KubernetesService)
@@ -179,6 +183,7 @@ func TestKubernetesService(t *testing.T) {
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
log: log,
} }
test.run(t, svc) test.run(t, svc)
}) })
+63 -72
View File
@@ -9,69 +9,47 @@ import (
"github.com/cenkalti/backoff/v5" "github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LdapServiceConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct { type LdapService struct {
config LdapServiceConfig log *logger.Logger
conn *ldapgo.Conn config model.Config
mutex sync.RWMutex context context.Context
cert *tls.Certificate
isConfigured bool conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
} }
func NewLdapService(config LdapServiceConfig) *LdapService { func NewLdapService(
return &LdapService{ log *logger.Logger,
config: config, config model.Config,
} ctx context.Context,
} wg *sync.WaitGroup,
) (*LdapService, error) {
func (ldap *LdapService) IsConfigured() bool { if config.LDAP.Address == "" {
return ldap.isConfigured return nil, nil
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
} }
if ldap.conn != nil { ldap := &LdapService{
if err := ldap.conn.Close(); err != nil { log: log,
return fmt.Errorf("failed to close LDAP connection: %w", err) config: config,
} context: ctx,
} }
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" { if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey) cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert ldap.cert = &cert
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/* /*
@@ -84,26 +62,39 @@ func (ldap *LdapService) Init() error {
} }
*/ */
} }
_, err := ldap.connect() _, err := ldap.connect()
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
go func() { wg.Go(func() {
for range time.Tick(time.Duration(5) * time.Minute) { ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
err := ldap.heartbeat()
if err != nil { ticker := time.NewTicker(5 * time.Minute)
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed") defer ticker.Stop()
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") for {
continue select {
case <-ticker.C:
err := ldap.heartbeat()
if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
} }
tlog.App.Info().Msg("Successfully reconnected to LDAP server") case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
} }
} }
}() })
return nil return ldap, nil
} }
func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
@@ -120,13 +111,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig) // 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind() // 3. conn.externalBind()
if ldap.cert != nil { if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert}, Certificates: []tls.Certificate{*ldap.cert},
})) }))
} else { } else {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.Insecure, InsecureSkipVerify: ldap.config.LDAP.Insecure,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
})) }))
} }
@@ -146,10 +137,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection // Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username) escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN, ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter, filter,
[]string{"dn"}, []string{"dn"},
@@ -176,7 +167,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN, ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"}, []string{"dn"},
@@ -224,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil { if ldap.cert != nil {
return ldap.conn.ExternalBind() return ldap.conn.ExternalBind()
} }
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword) return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
} }
func (ldap *LdapService) Bind(userDN string, password string) error { func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -238,7 +229,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
} }
func (ldap *LdapService) heartbeat() error { func (ldap *LdapService) heartbeat() error {
tlog.App.Debug().Msg("Performing LDAP connection heartbeat") ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
"", "",
@@ -260,7 +251,7 @@ func (ldap *LdapService) heartbeat() error {
} }
func (ldap *LdapService) reconnect() error { func (ldap *LdapService) reconnect() error {
tlog.App.Info().Msg("Reconnecting to LDAP server") ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff() exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond exp.InitialInterval = 500 * time.Millisecond
@@ -269,7 +260,7 @@ func (ldap *LdapService) reconnect() error {
exp.Reset() exp.Reset()
operation := func() (*ldapgo.Conn, error) { operation := func() (*ldapgo.Conn, error) {
ldap.conn.Close() //nolint:errcheck ldap.conn.Close()
conn, err := ldap.connect() conn, err := ldap.connect()
if err != nil { if err != nil {
return nil, err return nil, err
+20 -12
View File
@@ -1,8 +1,10 @@
package service package service
import ( import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices" "slices"
@@ -19,33 +21,39 @@ type OAuthServiceImpl interface {
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig configs map[string]model.OAuthServiceConfig
} }
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService, "github": newGitHubOAuthService,
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { func NewOAuthBrokerService(
return &OAuthBrokerService{ log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: configs,
} }
}
func (broker *OAuthBrokerService) Init() error { for name, cfg := range configs {
for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg) service.services[name] = presetFunc(cfg, ctx)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
broker.services[name] = NewOAuthService(cfg, name) service.services[name] = NewOAuthService(cfg, name, ctx)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
} }
} }
return nil
return service
} }
func (broker *OAuthBrokerService) GetConfiguredServices() []string { func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+1 -1
View File
@@ -92,7 +92,7 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer res.Body.Close() //nolint:errcheck defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 { if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("request failed with status: %s", res.Status) return nil, fmt.Errorf("request failed with status: %s", res.Status)
+6 -4
View File
@@ -1,23 +1,25 @@
package service package service
import ( import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/endpoints"
) )
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"} scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google") return NewOAuthService(config, "google", ctx)
} }
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"} scopes := []string{"read:user", "user:email"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
} }
+3 -4
View File
@@ -20,7 +20,7 @@ type OAuthService struct {
id string id string
} }
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: &http.Transport{ Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
}, },
}, },
} }
ctx := context.Background() vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{ return &OAuthService{
serviceCfg: config, serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
TokenURL: config.TokenURL, TokenURL: config.TokenURL,
}, },
}, },
ctx: ctx, ctx: vctx,
userinfoExtractor: defaultExtractor, userinfoExtractor: defaultExtractor,
id: id, id: id,
} }
+134 -125
View File
@@ -7,7 +7,6 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"database/sql"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
@@ -16,6 +15,7 @@ import (
"net/url" "net/url"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"slices" "slices"
@@ -25,7 +25,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
var ( var (
@@ -111,172 +111,173 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallengeMethod string `json:"code_challenge_method"`
} }
type OIDCServiceConfig struct {
Clients map[string]model.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
SessionExpiry int
}
type OIDCService struct { type OIDCService struct {
config OIDCServiceConfig log *logger.Logger
queries *repository.Queries config model.Config
clients map[string]model.OIDCClientConfig runtime model.RuntimeConfig
privateKey *rsa.PrivateKey queries repository.Store
publicKey crypto.PublicKey context context.Context
issuer string
isConfigured bool clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
} }
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { func NewOIDCService(
return &OIDCService{ log *logger.Logger,
config: config, config model.Config,
queries: queries, runtime model.RuntimeConfig,
} queries repository.Store,
} ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
// If not configured, skip init // If not configured, skip init
if len(service.config.Clients) == 0 { if len(runtime.OIDCClients) == 0 {
service.isConfigured = false return nil, nil
return nil
} }
service.isConfigured = true
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer) uissuer, err := url.Parse(runtime.AppURL)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse app url: %w", err)
} }
if uissuer.Scheme != "https" { if uissuer.Scheme != "https" {
return errors.New("issuer must be https") return nil, errors.New("issuer must be https")
} }
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" || if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" { strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required") return nil, errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return nil, err
} }
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048) privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to generate private key: %w", err)
} }
der := x509.MarshalPKCS1PrivateKey(privateKey) der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil { if der == nil {
return errors.New("failed to marshal private key") return nil, errors.New("failed to marshal private key")
} }
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to write private key to file: %w", err)
} }
service.privateKey = privateKey
} else { } else {
block, _ := pem.Decode(fprivateKey) block, _ := pem.Decode(fprivateKey)
if block == nil { if block == nil {
return errors.New("failed to decode private key") return nil, errors.New("failed to decode private key")
} }
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key") log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse private key: %w", err)
} }
service.privateKey = privateKey
} }
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return nil, fmt.Errorf("failed to read public key: %w", err)
} }
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public() publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil { if der == nil {
return errors.New("failed to marshal public key") return nil, errors.New("failed to marshal public key")
} }
encoded := pem.EncodeToMemory(&pem.Block{ encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return err return nil, err
} }
service.publicKey = publicKey
} else { } else {
block, _ := pem.Decode(fpublicKey) block, _ := pem.Decode(fpublicKey)
if block == nil { if block == nil {
return errors.New("failed to decode public key") return nil, errors.New("failed to decode public key")
} }
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key") log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse public key: %w", err)
} }
service.publicKey = publicKey
case "PUBLIC KEY": case "PUBLIC KEY":
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("failed to parse public key: %w", err)
} }
service.publicKey = publicKey.(crypto.PublicKey)
default: default:
return fmt.Errorf("unsupported public key type: %s", block.Type) return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
} }
} }
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig) clients := make(map[string]model.OIDCClientConfig)
for id, client := range service.config.Clients { for id, client := range config.OIDC.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
} }
service.clients[client.ClientID] = client clients[client.ClientID] = client
} }
// Load the client secrets from files if they exist // Load the client secrets from files if they exist
for id, client := range service.clients { for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" { if secret != "" {
client.ClientSecret = secret client.ClientSecret = secret
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
service.clients[id] = client clients[id] = client
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
} }
return nil // Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}
// Start cleanup routine
wg.Go(service.cleanupRoutine)
return service, nil
} }
func (service *OIDCService) GetIssuer() string { func (service *OIDCService) GetIssuer() string {
@@ -307,7 +308,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope") return errors.New("invalid_scope")
} }
if !slices.Contains(SupportedScopes, scope) { if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
} }
} }
@@ -357,7 +358,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
} }
} }
@@ -422,7 +423,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
oidcCode, err := service.queries.GetOidcCode(c, codeHash) oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, repository.ErrNotFound) {
return repository.OidcCode{}, ErrCodeNotFound return repository.OidcCode{}, ErrCodeNotFound
} }
return repository.OidcCode{}, err return repository.OidcCode{}, err
@@ -449,7 +450,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
hasher := sha256.New() hasher := sha256.New()
@@ -529,16 +530,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo // Refresh token lives double the time of an access token but can't be used to access userinfo
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry), ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
} }
@@ -566,7 +567,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, repository.ErrNotFound) {
return TokenResponse{}, ErrTokenNotFound return TokenResponse{}, ErrTokenNotFound
} }
return TokenResponse{}, err return TokenResponse{}, err
@@ -598,14 +599,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: newRefreshToken, RefreshToken: newRefreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry), ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "), Scope: strings.ReplaceAll(entry.Scope, ",", " "),
} }
@@ -645,7 +646,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
entry, err := service.queries.GetOidcToken(c, tokenHash) entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, repository.ErrNotFound) {
return repository.OidcToken{}, ErrTokenNotFound return repository.OidcToken{}, ErrTokenNotFound
} }
return repository.OidcToken{}, err return repository.OidcToken{}, err
@@ -733,71 +734,79 @@ func (service *OIDCService) Hash(token string) string {
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub) err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err return err
} }
err = service.queries.DeleteOidcTokenBySub(ctx, sub) err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err return err
} }
err = service.queries.DeleteOidcUserInfo(ctx, sub) err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err return err
} }
return nil return nil
} }
// Cleanup routine - Resource heavy due to the linked tables // Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() { func (service *OIDCService) cleanupRoutine() {
// We need a context for the routine service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ctx := context.Background()
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for {
currentTime := time.Now().Unix() select {
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
// For the OIDC tokens, if they are expired we delete the userinfo and codes currentTime := time.Now().Unix()
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
if err != nil { // For the OIDC tokens, if they are expired we delete the userinfo and codes
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
} TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session") service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
} }
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything for _, expiredToken := range expiredTokens {
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
if err != nil { // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
if !errors.Is(err, repository.ErrNotFound) {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
continue continue
} }
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub) err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session") service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
} }
} }
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
} }
} }
} }
+26 -7
View File
@@ -1,7 +1,9 @@
package service_test package service_test
import ( import (
"context"
"encoding/json" "encoding/json"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -10,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
func newTestUser() repository.OidcUserinfo { func newTestUser() repository.OidcUserinfo {
@@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo {
func TestCompileUserinfo(t *testing.T) { func TestCompileUserinfo(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
svc := service.NewOIDCService(service.OIDCServiceConfig{
PrivateKeyPath: dir + "/key.pem", cfg := model.Config{
PublicKeyPath: dir + "/key.pub", OIDC: model.OIDCConfig{
Issuer: "https://tinyauth.example.com", PrivateKeyPath: dir + "/key.pem",
SessionExpiry: 3600, PublicKeyPath: dir + "/key.pub",
}, nil) },
require.NoError(t, svc.Init()) Auth: model.AuthConfig{
SessionExpiry: 3600,
},
}
runtime := model.RuntimeConfig{
AppURL: "https://tinyauth.example.com",
}
log := logger.NewLogger().WithTestConfig()
log.Init()
ctx := context.TODO()
wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err)
type testCase struct { type testCase struct {
description string description string
+106
View File
@@ -0,0 +1,106 @@
package test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/crypto/bcrypt"
)
var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
tempDir := t.TempDir()
config := model.Config{
UI: model.UIConfig{
Title: "Tinyauth Test",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
OAuth: model.OAuthConfig{
AutoRedirect: "none",
},
OIDC: model.OIDCConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: filepath.Join(tempDir, "key.pem"),
PublicKeyPath: filepath.Join(tempDir, "key.pub"),
},
Auth: model.AuthConfig{
SessionExpiry: 10,
LoginTimeout: 10,
LoginMaxRetries: 3,
},
Database: model.DatabaseConfig{
Path: filepath.Join(tempDir, "test.db"),
},
Resources: model.ResourcesConfig{
Enabled: true,
Path: filepath.Join(tempDir, "resources"),
},
}
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
require.NoError(t, err)
runtime := model.RuntimeConfig{
ConfiguredProviders: []model.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
LocalUsers: []model.LocalUser{
{
Username: "testuser",
Password: string(passwd),
},
{
Username: "totpuser",
Password: string(passwd),
TOTPSecret: TestingTOTPSecret,
},
{
Username: "attruser",
Password: string(passwd),
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: string(passwd),
TOTPSecret: TestingTOTPSecret,
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig {
var clients []model.OIDCClientConfig
for id, client := range config.OIDC.Clients {
client.ID = id
clients = append(clients, client)
}
return clients
}(),
}
return config, runtime
}
-3
View File
@@ -7,8 +7,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
@@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) {
parts := strings.Split(host, ".") parts := strings.Split(host, ".")
if len(parts) == 2 { if len(parts) == 2 {
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
return host, nil return host, nil
} }
+1 -1
View File
@@ -18,7 +18,7 @@ func TestReadFile(t *testing.T) {
err = file.Close() err = file.Close()
require.NoError(t, err) require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_file") //nolint:errcheck defer os.Remove("/tmp/tinyauth_test_file")
// Normal case // Normal case
content, err := ReadFile("/tmp/tinyauth_test_file") content, err := ReadFile("/tmp/tinyauth_test_file")
+160
View File
@@ -0,0 +1,160 @@
package logger
import (
"io"
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
HTTP zerolog.Logger
App zerolog.Logger
config model.LogConfig
base zerolog.Logger
audit zerolog.Logger
writer io.Writer
}
func NewLogger() *Logger {
return &Logger{
writer: os.Stderr,
config: model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{
Enabled: true,
},
App: model.LogStreamConfig{
Enabled: true,
},
// No reason to enable audit by default since it will be suppressed by the log level
},
},
}
}
func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
l.config = cfg
return l
}
func (l *Logger) WithSimpleConfig() *Logger {
l.config = model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
return l
}
func (l *Logger) WithTestConfig() *Logger {
l.config = model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
return l
}
func (l *Logger) WithWriter(writer io.Writer) *Logger {
l.writer = writer
return l
}
func (l *Logger) Init() {
base := log.With().
Timestamp().
Logger().
Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
if !l.config.Json {
base = base.Output(zerolog.ConsoleWriter{
Out: l.writer,
TimeFormat: time.RFC3339,
})
}
if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
base = base.With().Caller().Logger()
}
l.base = base
l.audit = l.createLogger("audit", l.config.Streams.Audit)
l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
l.App = l.createLogger("app", l.config.Streams.App)
}
func (l *Logger) parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsed, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
parsed = zerolog.ErrorLevel
}
return parsed
}
func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
if !cfg.Enabled {
return zerolog.Nop()
}
sub := l.base.With().Str("stream", component).Logger()
if cfg.Level != "" {
sub = sub.Level(l.parseLogLevel(cfg.Level))
}
return sub
}
func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
l.audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Str("reason", reason).
Send()
}
func (l *Logger) AuditLogout(username, provider, ip string) {
l.audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", ip).
Send()
}
// Used for testing
func (l *Logger) GetConfig() model.LogConfig {
return l.config
}
+173
View File
@@ -0,0 +1,173 @@
package logger_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestLogger(t *testing.T) {
type testCase struct {
description string
run func(t *testing.T)
}
tests := []testCase{
{
description: "Should create a simple logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithSimpleConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
},
},
{
description: "Should create a test logger with the expected config",
run: func(t *testing.T) {
l := logger.NewLogger().WithTestConfig()
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "trace",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
},
},
{
description: "Should create a logger with a custom config",
run: func(t *testing.T) {
customCfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, customCfg)
},
},
{
description: "Default logger should use error type and log json",
run: func(t *testing.T) {
buf := bytes.Buffer{}
l := logger.NewLogger().WithWriter(&buf)
l.Init()
cfg := l.GetConfig()
assert.Equal(t, cfg, model.LogConfig{
Level: "error",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
l.App.Error().Msg("test")
var entry map[string]any
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "test", entry["message"])
assert.Equal(t, "app", entry["stream"])
assert.Equal(t, "error", entry["level"])
assert.NotEmpty(t, entry["time"])
},
},
{
description: "Should default to error level if an invalid level is provided",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "invalid",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
// should not get logged
l.AuditLoginFailure("test", "test", "test", "test")
assert.Empty(t, buf.String())
},
},
{
description: "Should use nop logger for disabled streams",
run: func(t *testing.T) {
buf := bytes.Buffer{}
customCfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
}
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
l.Init()
assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
l.App.Info().Msg("test")
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
assert.NotEmpty(t, buf.String())
assert.NotContains(t, buf.String(), "test_nop")
},
},
}
for _, test := range tests {
t.Run(test.description, test.run)
}
}
+1 -1
View File
@@ -53,7 +53,7 @@ func FilterIP(filter string, ip string) (bool, error) {
return false, errors.New("invalid IP address") return false, errors.New("invalid IP address")
} }
filter = strings.ReplaceAll(filter, "-", "/") filter = strings.Replace(filter, "-", "/", -1)
if strings.Contains(filter, "/") { if strings.Contains(filter, "/") {
_, cidr, err := net.ParseCIDR(filter) _, cidr, err := net.ParseCIDR(filter)
+1 -1
View File
@@ -19,7 +19,7 @@ func TestGetSecret(t *testing.T) {
err = file.Close() err = file.Close()
require.NoError(t, err) require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret") //nolint:errcheck defer os.Remove("/tmp/tinyauth_test_secret")
// Get from config // Get from config
assert.Equal(t, "mysecret", utils.GetSecret("mysecret", "")) assert.Equal(t, "mysecret", utils.GetSecret("mysecret", ""))
+1 -1
View File
@@ -73,7 +73,7 @@ func TestGetStringList(t *testing.T) {
err = file.Close() err = file.Close()
assert.NoError(t, err) assert.NoError(t, err)
defer os.Remove("/tmp/tinyauth_list_test_file") //nolint:errcheck defer os.Remove("/tmp/tinyauth_list_test_file")
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file") values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
assert.NoError(t, err) assert.NoError(t, err)
-39
View File
@@ -1,39 +0,0 @@
package tlog
import "github.com/gin-gonic/gin"
// functions here use CallerSkipFrame to ensure correct caller info is logged
func AuditLoginSuccess(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
Audit.Warn().
CallerSkipFrame(1).
Str("event", "login").
Str("result", "failure").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Str("reason", reason).
Send()
}
func AuditLogout(c *gin.Context, username, provider string) {
Audit.Info().
CallerSkipFrame(1).
Str("event", "logout").
Str("result", "success").
Str("username", username).
Str("provider", provider).
Str("ip", c.ClientIP()).
Send()
}
-97
View File
@@ -1,97 +0,0 @@
package tlog
import (
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type Logger struct {
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
}
var (
Audit zerolog.Logger
HTTP zerolog.Logger
App zerolog.Logger
)
func NewLogger(cfg model.LogConfig) *Logger {
baseLogger := log.With().
Timestamp().
Caller().
Logger().
Level(parseLogLevel(cfg.Level))
if !cfg.Json {
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: time.RFC3339,
})
}
return &Logger{
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
App: createLogger("app", cfg.Streams.App, baseLogger),
}
}
func NewSimpleLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false},
},
})
}
func NewTestLogger() *Logger {
return NewLogger(model.LogConfig{
Level: "trace",
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() {
Audit = l.Audit
HTTP = l.HTTP
App = l.App
}
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled {
return zerolog.Nop()
}
subLogger := baseLogger.With().Str("log_stream", component).Logger()
// override level if specified, otherwise use base level
if streamCfg.Level != "" {
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
}
return subLogger
}
func parseLogLevel(level string) zerolog.Level {
if level == "" {
return zerolog.InfoLevel
}
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
if err != nil {
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
parsedLevel = zerolog.InfoLevel
}
return parsedLevel
}
-93
View File
@@ -1,93 +0,0 @@
package tlog_test
import (
"bytes"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog"
)
func TestNewLogger(t *testing.T) {
cfg := model.LogConfig{
Level: "debug",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
App: model.LogStreamConfig{Enabled: true, Level: ""},
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
},
}
logger := tlog.NewLogger(cfg)
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger()
assert.NotNil(t, logger)
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger()
logger.Init()
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
}
func TestLoggerWithDisabledStreams(t *testing.T) {
cfg := model.LogConfig{
Level: "info",
Json: false,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: false},
App: model.LogStreamConfig{Enabled: false},
Audit: model.LogStreamConfig{Enabled: false},
},
}
logger := tlog.NewLogger(cfg)
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
}
func TestLogStreamField(t *testing.T) {
var buf bytes.Buffer
cfg := model.LogConfig{
Level: "info",
Json: true,
Streams: model.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true},
},
}
logger := tlog.NewLogger(cfg)
// Override output for HTTP logger to capture output
logger.HTTP = logger.HTTP.Output(&buf)
logger.HTTP.Info().Msg("test message")
var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NoError(t, err)
assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"])
}
+1 -1
View File
@@ -24,7 +24,7 @@ func TestGetUsers(t *testing.T) {
err = file.Close() err = file.Close()
require.NoError(t, err) require.NoError(t, err)
defer os.Remove(tmpDir + "/tinyauth_users_test.txt") //nolint:errcheck defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
noAttrs := map[string]model.UserAttributes{} noAttrs := map[string]model.UserAttributes{}
+4 -4
View File
@@ -1,12 +1,12 @@
version: "2" version: "2"
sql: sql:
- engine: "sqlite" - engine: "sqlite"
queries: "sql/*_queries.sql" queries: "sql/sqlite/*_queries.sql"
schema: "sql/*_schemas.sql" schema: "sql/sqlite/*_schemas.sql"
gen: gen:
go: go:
package: "repository" package: "sqlite"
out: "internal/repository" out: "internal/repository/sqlite"
rename: rename:
uuid: "UUID" uuid: "UUID"
oauth_groups: "OAuthGroups" oauth_groups: "OAuthGroups"