mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 14:28:12 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e718471ad3 | |||
| c6d36673eb | |||
| 71ae3e0cd2 | |||
| 0d9865793c | |||
| 04b2290d73 | |||
| e04980468f | |||
| d47e4d3d79 | |||
| f3965a7470 | |||
| 36d4e3ec52 | |||
| eab9f71110 | |||
| e13598bf3c | |||
| 4d3860f860 | |||
| 3b5da06862 | |||
| 8f337aaff8 | |||
| ff3c25c09d | |||
| 26daef7d4e | |||
| c932817757 | |||
| 004df2f852 | |||
| df56708b9a | |||
| 62ffd2fd11 | |||
| a3ec07230c | |||
| b4eb7090bd | |||
| 2f24f823eb | |||
| 9a219046ac | |||
| 97d58b376d | |||
| b426a1529e | |||
| c7efb71a5a | |||
| eec75a6f49 |
@@ -91,8 +91,6 @@ TINYAUTH_APPS_name_LDAP_GROUPS=
|
|||||||
|
|
||||||
# Comma-separated list of allowed OAuth domains.
|
# Comma-separated list of allowed OAuth domains.
|
||||||
TINYAUTH_OAUTH_WHITELIST=
|
TINYAUTH_OAUTH_WHITELIST=
|
||||||
# Path to the OAuth whitelist file.
|
|
||||||
TINYAUTH_OAUTH_WHITELISTFILE=
|
|
||||||
# The OAuth provider to use for automatic redirection.
|
# The OAuth provider to use for automatic redirection.
|
||||||
TINYAUTH_OAUTH_AUTOREDIRECT=
|
TINYAUTH_OAUTH_AUTOREDIRECT=
|
||||||
# OAuth client ID.
|
# OAuth client ID.
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ jobs:
|
|||||||
retention-days: 5
|
retention-days: 5
|
||||||
|
|
||||||
- name: Upload to code-scanning
|
- name: Upload to code-scanning
|
||||||
uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
|
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4
|
||||||
with:
|
with:
|
||||||
sarif_file: results.sarif
|
sarif_file: results.sarif
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma
|
|||||||
|
|
||||||
A big thank you to the following people for providing me with more coffee:
|
A big thank you to the following people for providing me with more coffee:
|
||||||
|
|
||||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <!-- sponsors -->
|
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <!-- sponsors -->
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,8 +40,7 @@ func createUserCmd() *cli.Command {
|
|||||||
Configuration: tCfg,
|
Configuration: tCfg,
|
||||||
Resources: loaders,
|
Resources: loaders,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
if tCfg.Interactive {
|
if tCfg.Interactive {
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
@@ -74,7 +73,7 @@ func createUserCmd() *cli.Command {
|
|||||||
return errors.New("username and password cannot be empty")
|
return errors.New("username and password cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("username", tCfg.Username).Msg("Creating user")
|
tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user")
|
||||||
|
|
||||||
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
|
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -87,7 +86,7 @@ func createUserCmd() *cli.Command {
|
|||||||
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
|
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
|
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
"github.com/mdp/qrterminal/v3"
|
"github.com/mdp/qrterminal/v3"
|
||||||
@@ -40,8 +40,7 @@ func generateTotpCmd() *cli.Command {
|
|||||||
Configuration: tCfg,
|
Configuration: tCfg,
|
||||||
Resources: loaders,
|
Resources: loaders,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
if tCfg.Interactive {
|
if tCfg.Interactive {
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
@@ -89,9 +88,9 @@ func generateTotpCmd() *cli.Command {
|
|||||||
|
|
||||||
secret := key.Secret()
|
secret := key.Secret()
|
||||||
|
|
||||||
log.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
|
tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
|
||||||
|
|
||||||
log.App.Info().Msg("Generated QR code")
|
tlog.App.Info().Msg("Generated QR code")
|
||||||
|
|
||||||
config := qrterminal.Config{
|
config := qrterminal.Config{
|
||||||
Level: qrterminal.L,
|
Level: qrterminal.L,
|
||||||
@@ -110,7 +109,7 @@ func generateTotpCmd() *cli.Command {
|
|||||||
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
|
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
|
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type healthzResponse struct {
|
type healthzResponse struct {
|
||||||
@@ -26,8 +26,7 @@ func healthcheckCmd() *cli.Command {
|
|||||||
Resources: nil,
|
Resources: nil,
|
||||||
AllowArg: true,
|
AllowArg: true,
|
||||||
Run: func(args []string) error {
|
Run: func(args []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
|
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
|
||||||
if srvAddr == "" {
|
if srvAddr == "" {
|
||||||
@@ -49,7 +48,7 @@ func healthcheckCmd() *cli.Command {
|
|||||||
return errors.New("Could not determine app URL")
|
return errors.New("Could not determine app URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("app_url", appUrl).Msg("Performing health check")
|
tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
|
||||||
|
|
||||||
client := http.Client{
|
client := http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
@@ -87,7 +86,7 @@ func healthcheckCmd() *cli.Command {
|
|||||||
return fmt.Errorf("failed to decode response: %w", err)
|
return fmt.Errorf("failed to decode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
|
tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
@@ -108,6 +109,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runCmd(cfg model.Config) error {
|
func runCmd(cfg model.Config) error {
|
||||||
|
logger := tlog.NewLogger(cfg.Log)
|
||||||
|
logger.Init()
|
||||||
|
|
||||||
|
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(cfg)
|
||||||
|
|
||||||
err := app.Setup()
|
err := app.Setup()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -44,8 +44,7 @@ func verifyUserCmd() *cli.Command {
|
|||||||
Configuration: tCfg,
|
Configuration: tCfg,
|
||||||
Resources: loaders,
|
Resources: loaders,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
if tCfg.Interactive {
|
if tCfg.Interactive {
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
@@ -98,9 +97,9 @@ func verifyUserCmd() *cli.Command {
|
|||||||
|
|
||||||
if user.TOTPSecret == "" {
|
if user.TOTPSecret == "" {
|
||||||
if tCfg.Totp != "" {
|
if tCfg.Totp != "" {
|
||||||
log.App.Warn().Msg("User does not have TOTP secret")
|
tlog.App.Warn().Msg("User does not have TOTP secret")
|
||||||
}
|
}
|
||||||
log.App.Info().Msg("User verified")
|
tlog.App.Info().Msg("User verified")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +109,7 @@ func verifyUserCmd() *cli.Command {
|
|||||||
return fmt.Errorf("TOTP code incorrect")
|
return fmt.Errorf("TOTP code incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Msg("User verified")
|
tlog.App.Info().Msg("User verified")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ require (
|
|||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.50.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
k8s.io/apimachinery v0.36.0
|
k8s.io/apimachinery v0.32.2
|
||||||
k8s.io/client-go v0.36.0
|
k8s.io/client-go v0.32.2
|
||||||
modernc.org/sqlite v1.50.0
|
modernc.org/sqlite v1.49.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -63,7 +63,7 @@ require (
|
|||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
|
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||||
@@ -74,6 +74,9 @@ require (
|
|||||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||||
github.com/goccy/go-json v0.10.5 // indirect
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||||
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
github.com/huandu/xstrings v1.5.0 // indirect
|
github.com/huandu/xstrings v1.5.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
@@ -90,7 +93,7 @@ require (
|
|||||||
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
||||||
github.com/moby/term v0.5.2 // indirect
|
github.com/moby/term v0.5.2 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
@@ -118,7 +121,6 @@ require (
|
|||||||
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
|
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
|
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/net v0.52.0 // indirect
|
golang.org/x/net v0.52.0 // indirect
|
||||||
@@ -126,20 +128,18 @@ require (
|
|||||||
golang.org/x/sys v0.43.0 // indirect
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
golang.org/x/term v0.42.0 // indirect
|
golang.org/x/term v0.42.0 // indirect
|
||||||
golang.org/x/text v0.36.0 // indirect
|
golang.org/x/text v0.36.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
golang.org/x/time v0.12.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
gotest.tools/v3 v3.5.2 // indirect
|
gotest.tools/v3 v3.5.2 // indirect
|
||||||
k8s.io/klog/v2 v2.140.0 // indirect
|
k8s.io/klog/v2 v2.130.1 // indirect
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
|
|
||||||
modernc.org/libc v1.72.0 // indirect
|
modernc.org/libc v1.72.0 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
rsc.io/qr v0.2.0 // indirect
|
rsc.io/qr v0.2.0 // indirect
|
||||||
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
|
||||||
sigs.k8s.io/randfill v1.0.0 // indirect
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect
|
||||||
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 // indirect
|
sigs.k8s.io/yaml v1.4.0 // indirect
|
||||||
sigs.k8s.io/yaml v1.6.0 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,14 +97,14 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4
|
|||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes=
|
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
|
||||||
github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
||||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
@@ -140,16 +140,23 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
|||||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
||||||
github.com/google/gnostic-models v0.7.0 h1:qwTtogB15McXDaNqTZdzPJRHvaVJlAl+HVQnLmJEJxo=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/google/gnostic-models v0.7.0/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
|
github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=
|
||||||
|
github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U=
|
||||||
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0=
|
github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0=
|
||||||
github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU=
|
github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||||
|
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
@@ -178,6 +185,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
|
|||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
@@ -219,9 +228,8 @@ github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFL
|
|||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8=
|
|
||||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
|
||||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
@@ -261,12 +269,11 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
|||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -287,6 +294,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
|||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
@@ -311,35 +320,56 @@ go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpu
|
|||||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
|
||||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
|
||||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
|
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||||
@@ -347,13 +377,13 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af h1:+5/Sw3GsDNlEmu7TfklWKPdQ0Ykja5VEmq2i817+jbI=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/evanphx/json-patch.v4 v4.13.0 h1:czT3CmqEaQ1aanPc5SdlgQrrEIb8w/wwCvWWnfEbYzo=
|
gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4=
|
||||||
gopkg.in/evanphx/json-patch.v4 v4.13.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M=
|
gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M=
|
||||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
@@ -361,18 +391,18 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
k8s.io/api v0.36.0 h1:SgqDhZzHdOtMk40xVSvCXkP9ME0H05hPM3p9AB1kL80=
|
k8s.io/api v0.32.2 h1:bZrMLEkgizC24G9eViHGOPbW+aRo9duEISRIJKfdJuw=
|
||||||
k8s.io/api v0.36.0/go.mod h1:m1LVrGPNYax5NBHdO+QuAedXyuzTt4RryI/qnmNvs34=
|
k8s.io/api v0.32.2/go.mod h1:hKlhk4x1sJyYnHENsrdCWw31FEmCijNGPJO5WzHiJ6Y=
|
||||||
k8s.io/apimachinery v0.36.0 h1:jZyPzhd5Z+3h9vJLt0z9XdzW9VzNzWAUw+P1xZ9PXtQ=
|
k8s.io/apimachinery v0.32.2 h1:yoQBR9ZGkA6Rgmhbp/yuT9/g+4lxtsGYwW6dR6BDPLQ=
|
||||||
k8s.io/apimachinery v0.36.0/go.mod h1:FklypaRJt6n5wUIwWXIP6GJlIpUizTgfo1T/As+Tyxc=
|
k8s.io/apimachinery v0.32.2/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE=
|
||||||
k8s.io/client-go v0.36.0 h1:pOYi7C4RHChYjMiHpZSpSbIM6ZxVbRXBy7CuiIwqA3c=
|
k8s.io/client-go v0.32.2 h1:4dYCD4Nz+9RApM2b/3BtVvBHw54QjMFUl1OLcJG5yOA=
|
||||||
k8s.io/client-go v0.36.0/go.mod h1:ZKKcpwF0aLYfkHFCjillCKaTK/yBkEDHTDXCFY6AS9Y=
|
k8s.io/client-go v0.32.2/go.mod h1:fpZ4oJXclZ3r2nDOv+Ux3XcJutfrwjKTCHz2H3sww94=
|
||||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
|
||||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f h1:GA7//TjRY9yWGy1poLzYYJJ4JRdzg3+O6e8I+e+8T5Y=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
|
k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f/go.mod h1:R/HEjbvWI0qdfb8viZUeVZm0X6IZnxAydC7YU42CMw4=
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 h1:M3sRQVHv7vB20Xc2ybTt7ODCeFj6JSWYFzOFnYeS6Ro=
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
|
||||||
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
|
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
|
||||||
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
|
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
|
||||||
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
|
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
|
||||||
@@ -395,19 +425,17 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.50.0 h1:eMowQSWLK0MeiQTdmz3lqoF5dqclujdlIKeJA11+7oM=
|
modernc.org/sqlite v1.49.1 h1:dYGHTKcX1sJ+EQDnUzvz4TJ5GbuvhNJa8Fg6ElGx73U=
|
||||||
modernc.org/sqlite v1.50.0/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
|
modernc.org/sqlite v1.49.1/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||||
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8=
|
||||||
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo=
|
||||||
sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU=
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2 h1:MdmvkGuXi/8io6ixD5wud3vOLwc1rj0aNqRlpuvjmwA=
|
||||||
sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4=
|
||||||
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 h1:kwVWMx5yS1CrnFWA/2QHyRVJ8jM6dBA80uLmm0wJkk8=
|
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
|
||||||
sigs.k8s.io/structured-merge-diff/v6 v6.3.2/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE=
|
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
|
||||||
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
|
|
||||||
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
|
|
||||||
|
|||||||
+121
-278
@@ -3,50 +3,38 @@ package bootstrap
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Services struct {
|
|
||||||
accessControlService *service.AccessControlsService
|
|
||||||
authService *service.AuthService
|
|
||||||
dockerService *service.DockerService
|
|
||||||
kubernetesService *service.KubernetesService
|
|
||||||
ldapService *service.LdapService
|
|
||||||
oauthBrokerService *service.OAuthBrokerService
|
|
||||||
oidcService *service.OIDCService
|
|
||||||
}
|
|
||||||
|
|
||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
context struct {
|
||||||
|
appUrl string
|
||||||
|
uuid string
|
||||||
|
cookieDomain string
|
||||||
|
sessionCookieName string
|
||||||
|
csrfCookieName string
|
||||||
|
redirectCookieName string
|
||||||
|
oauthSessionCookieName string
|
||||||
|
localUsers *[]model.LocalUser
|
||||||
|
oauthProviders map[string]model.OAuthServiceConfig
|
||||||
|
configuredProviders []controller.Provider
|
||||||
|
oidcClients []model.OIDCClientConfig
|
||||||
|
}
|
||||||
services Services
|
services Services
|
||||||
log *logger.Logger
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
queries *repository.Queries
|
|
||||||
router *gin.Engine
|
|
||||||
db *sql.DB
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
@@ -56,69 +44,49 @@ func NewBootstrapApp(config model.Config) *BootstrapApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) Setup() error {
|
func (app *BootstrapApp) Setup() error {
|
||||||
// create context
|
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
||||||
app.ctx = ctx
|
|
||||||
app.cancel = cancel
|
|
||||||
|
|
||||||
// setup logger
|
|
||||||
log := logger.NewLogger().WithConfig(app.config.Log)
|
|
||||||
log.Init()
|
|
||||||
app.log = log
|
|
||||||
|
|
||||||
// get app url
|
// get app url
|
||||||
if app.config.AppURL == "" {
|
if app.config.AppURL == "" {
|
||||||
return errors.New("app url cannot be empty, perhaps config loading failed")
|
return fmt.Errorf("app URL cannot be empty, perhaps config loading failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
appUrl, err := url.Parse(app.config.AppURL)
|
appUrl, err := url.Parse(app.config.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse app url: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
|
||||||
|
|
||||||
// validate session config
|
// validate session config
|
||||||
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
||||||
return errors.New("session max lifetime cannot be less than session expiry")
|
return fmt.Errorf("session max lifetime cannot be less than session expiry")
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse users
|
// Parse users
|
||||||
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
|
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load users: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.LocalUsers = *users
|
app.context.localUsers = users
|
||||||
|
|
||||||
// load oauth whitelist
|
// Setup OAuth providers
|
||||||
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
|
app.context.oauthProviders = app.config.OAuth.Providers
|
||||||
|
|
||||||
if err != nil {
|
for name, provider := range app.context.oauthProviders {
|
||||||
return fmt.Errorf("failed to load oauth whitelist: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.runtime.OAuthWhitelist = oauthWhitelist
|
|
||||||
|
|
||||||
// setup oauth providers
|
|
||||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
|
||||||
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||||
provider.ClientSecret = secret
|
provider.ClientSecret = secret
|
||||||
provider.ClientSecretFile = ""
|
provider.ClientSecretFile = ""
|
||||||
|
|
||||||
if provider.RedirectURL == "" {
|
if provider.RedirectURL == "" {
|
||||||
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.OAuthProviders[id] = provider
|
app.context.oauthProviders[name] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// set presets for built-in providers
|
for id, provider := range app.context.oauthProviders {
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
|
||||||
if provider.Name == "" {
|
if provider.Name == "" {
|
||||||
if name, ok := model.OverrideProviders[id]; ok {
|
if name, ok := model.OverrideProviders[id]; ok {
|
||||||
provider.Name = name
|
provider.Name = name
|
||||||
@@ -126,72 +94,65 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
provider.Name = utils.Capitalize(id)
|
provider.Name = utils.Capitalize(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
app.runtime.OAuthProviders[id] = provider
|
app.context.oauthProviders[id] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup oidc clients
|
// Setup OIDC clients
|
||||||
for id, client := range app.config.OIDC.Clients {
|
for id, client := range app.config.OIDC.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
app.context.oidcClients = append(app.context.oidcClients, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cookie domain
|
// Get cookie domain
|
||||||
cookieDomainResolver := utils.GetCookieDomain
|
cookieDomain, err := utils.GetCookieDomain(app.context.appUrl)
|
||||||
|
|
||||||
if !app.config.Auth.SubdomainsEnabled {
|
|
||||||
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
|
|
||||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
|
||||||
}
|
|
||||||
|
|
||||||
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.CookieDomain = cookieDomain
|
app.context.cookieDomain = cookieDomain
|
||||||
|
|
||||||
// cookie names
|
// Cookie names
|
||||||
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
|
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
|
||||||
|
cookieId := strings.Split(app.context.uuid, "-")[0]
|
||||||
|
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||||
|
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
||||||
|
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
||||||
|
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||||
|
|
||||||
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
// Dumps
|
||||||
|
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
|
||||||
|
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
|
||||||
|
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
|
||||||
|
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
|
||||||
|
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
|
||||||
|
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
|
||||||
|
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
|
||||||
|
|
||||||
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
// Database
|
||||||
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
db, err := app.SetupDatabase(app.config.Database.Path)
|
||||||
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
|
||||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
|
||||||
|
|
||||||
// database
|
|
||||||
err = app.SetupDatabase()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup database: %w", err)
|
return fmt.Errorf("failed to setup database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// after this point, we start initializing dependencies so it's a good time to setup a defer
|
// Queries
|
||||||
// to ensure that resources are cleaned up properly in case of an error during initialization
|
queries := repository.New(db)
|
||||||
defer func() {
|
|
||||||
app.cancel()
|
|
||||||
app.wg.Wait()
|
|
||||||
app.db.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// queries
|
// Services
|
||||||
queries := repository.New(app.db)
|
services, err := app.initServices(queries)
|
||||||
app.queries = queries
|
|
||||||
|
|
||||||
// services
|
|
||||||
err = app.setupServices()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize services: %w", err)
|
return fmt.Errorf("failed to initialize services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// configured providers
|
app.services = services
|
||||||
configuredProviders := make([]model.Provider, 0)
|
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
// Configured providers
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders := make([]controller.Provider, 0)
|
||||||
|
|
||||||
|
for id, provider := range app.context.oauthProviders {
|
||||||
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: provider.Name,
|
Name: provider.Name,
|
||||||
ID: id,
|
ID: id,
|
||||||
OAuth: true,
|
OAuth: true,
|
||||||
@@ -202,171 +163,70 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return configuredProviders[i].Name < configuredProviders[j].Name
|
return configuredProviders[i].Name < configuredProviders[j].Name
|
||||||
})
|
})
|
||||||
|
|
||||||
if app.services.authService.LocalAuthConfigured() {
|
if services.authService.LocalAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: "Local",
|
Name: "Local",
|
||||||
ID: "local",
|
ID: "local",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.services.authService.LDAPAuthConfigured() {
|
if services.authService.LDAPAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: "LDAP",
|
Name: "LDAP",
|
||||||
ID: "ldap",
|
ID: "ldap",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
|
||||||
|
|
||||||
if len(configuredProviders) == 0 {
|
if len(configuredProviders) == 0 {
|
||||||
return errors.New("no authentication providers configured")
|
return fmt.Errorf("no authentication providers configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, provider := range configuredProviders {
|
app.context.configuredProviders = configuredProviders
|
||||||
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
|
|
||||||
}
|
|
||||||
|
|
||||||
app.runtime.ConfiguredProviders = configuredProviders
|
// Setup router
|
||||||
|
router, err := app.setupRouter()
|
||||||
// setup router
|
|
||||||
err = app.setupRouter()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup routes: %w", err)
|
return fmt.Errorf("failed to setup routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// start db cleanup routine
|
// Start db cleanup routine
|
||||||
app.log.App.Debug().Msg("Starting database cleanup routine")
|
tlog.App.Debug().Msg("Starting database cleanup routine")
|
||||||
app.wg.Go(app.dbCleanupRoutine)
|
go app.dbCleanupRoutine(queries)
|
||||||
|
|
||||||
// if analytics are not disabled, start heartbeat
|
// If analytics are not disabled, start heartbeat
|
||||||
if app.config.Analytics.Enabled {
|
if app.config.Analytics.Enabled {
|
||||||
app.log.App.Debug().Msg("Starting heartbeat routine")
|
tlog.App.Debug().Msg("Starting heartbeat routine")
|
||||||
app.wg.Go(app.heartbeatRoutine)
|
go app.heartbeatRoutine()
|
||||||
}
|
}
|
||||||
|
|
||||||
// create err channel to listen for server errors
|
// If we have an socket path, bind to it
|
||||||
errChanLen := 0
|
if app.config.Server.SocketPath != "" {
|
||||||
|
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
|
||||||
runUnix := app.config.Server.SocketPath != ""
|
tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
|
||||||
runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
|
err := os.Remove(app.config.Server.SocketPath)
|
||||||
|
|
||||||
if runUnix {
|
|
||||||
errChanLen++
|
|
||||||
}
|
|
||||||
|
|
||||||
if runHTTP {
|
|
||||||
errChanLen++
|
|
||||||
}
|
|
||||||
|
|
||||||
errChan := make(chan error, errChanLen)
|
|
||||||
|
|
||||||
if app.config.Server.ConcurrentListenersEnabled {
|
|
||||||
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
|
|
||||||
}
|
|
||||||
|
|
||||||
// serve unix
|
|
||||||
if runUnix {
|
|
||||||
app.wg.Go(func() {
|
|
||||||
if err := app.serveUnix(); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// serve to http
|
|
||||||
if runHTTP {
|
|
||||||
app.wg.Go(func() {
|
|
||||||
if err := app.serveHTTP(); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// monitor cancellation and server errors
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-app.ctx.Done():
|
|
||||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
|
||||||
return nil
|
|
||||||
case err := <-errChan:
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server error: %w", err)
|
return fmt.Errorf("failed to remove existing socket file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) serveHTTP() error {
|
tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
|
||||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
if err := router.RunUnix(app.config.Server.SocketPath); err != nil {
|
||||||
|
tlog.App.Fatal().Err(err).Msg("Failed to start server")
|
||||||
|
}
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Addr: address,
|
|
||||||
Handler: app.router.Handler(),
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-app.ctx.Done()
|
|
||||||
app.log.App.Debug().Msg("Shutting down http listener")
|
|
||||||
server.Shutdown(app.ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := server.ListenAndServe()
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
return fmt.Errorf("failed to start http listener: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) serveUnix() error {
|
|
||||||
if app.config.Server.SocketPath == "" {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := os.Stat(app.config.Server.SocketPath)
|
// Start server
|
||||||
|
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||||
if err == nil {
|
tlog.App.Info().Msgf("Starting server on %s", address)
|
||||||
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
|
if err := router.Run(address); err != nil {
|
||||||
err := os.Remove(app.config.Server.SocketPath)
|
tlog.App.Fatal().Err(err).Msg("Failed to start server")
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove existing socket file: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
|
|
||||||
|
|
||||||
listener, err := net.Listen("unix", app.config.Server.SocketPath)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create unix socket listener: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Handler: app.router.Handler(),
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdown := func() {
|
|
||||||
server.Shutdown(app.ctx)
|
|
||||||
listener.Close()
|
|
||||||
os.Remove(app.config.Server.SocketPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-app.ctx.Done()
|
|
||||||
app.log.App.Debug().Msg("Shutting down unix socket listener")
|
|
||||||
shutdown()
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = server.Serve(listener)
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
shutdown()
|
|
||||||
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -376,20 +236,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
|||||||
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
type Heartbeat struct {
|
type heartbeat struct {
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var body Heartbeat
|
var body heartbeat
|
||||||
|
|
||||||
body.UUID = app.runtime.UUID
|
body.UUID = app.context.uuid
|
||||||
body.Version = model.Version
|
body.Version = model.Version
|
||||||
|
|
||||||
bodyJson, err := json.Marshal(body)
|
bodyJson, err := json.Marshal(body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
|
tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,60 +259,43 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
|||||||
|
|
||||||
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
|
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
tlog.App.Debug().Msg("Sending heartbeat")
|
||||||
case <-ticker.C:
|
|
||||||
app.log.App.Debug().Msg("Sending heartbeat")
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
|
tlog.App.Error().Err(err).Msg("Failed to create heartbeat request")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
|
tlog.App.Error().Err(err).Msg("Failed to send heartbeat")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode != 200 && res.StatusCode != 201 {
|
if res.StatusCode != 200 && res.StatusCode != 201 {
|
||||||
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
||||||
}
|
|
||||||
case <-app.ctx.Done():
|
|
||||||
app.log.App.Debug().Msg("Stopping heartbeat routine")
|
|
||||||
ticker.Stop()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) dbCleanupRoutine() {
|
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
tlog.App.Debug().Msg("Cleaning up old database sessions")
|
||||||
case <-ticker.C:
|
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
||||||
app.log.App.Debug().Msg("Running database cleanup")
|
if err != nil {
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
|
||||||
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
|
|
||||||
}
|
|
||||||
|
|
||||||
app.log.App.Debug().Msg("Database cleanup completed")
|
|
||||||
case <-app.ctx.Done():
|
|
||||||
app.log.App.Debug().Msg("Stopping database cleanup routine")
|
|
||||||
ticker.Stop()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,26 +14,19 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) SetupDatabase() error {
|
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
|
||||||
dir := filepath.Dir(app.config.Database.Path)
|
dir := filepath.Dir(databasePath)
|
||||||
|
|
||||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||||
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", app.config.Database.Path)
|
db, err := sql.Open("sqlite", databasePath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the database if there is an error during migration
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
db.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
||||||
// if the sqlite connection starts being a bottleneck
|
// if the sqlite connection starts being a bottleneck
|
||||||
db.SetMaxOpenConns(1)
|
db.SetMaxOpenConns(1)
|
||||||
@@ -41,29 +34,24 @@ func (app *BootstrapApp) SetupDatabase() error {
|
|||||||
migrations, err := iofs.New(assets.Migrations, "migrations")
|
migrations, err := iofs.New(assets.Migrations, "migrations")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create migrations: %w", err)
|
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create migrator: %w", err)
|
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
return fmt.Errorf("failed to migrate database: %w", err)
|
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.db = db
|
return db, nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) GetDB() *sql.DB {
|
|
||||||
return app.db
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,16 +2,21 @@ package bootstrap
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupRouter() error {
|
var DEV_MODES = []string{"main", "test", "development"}
|
||||||
// we don't want gin debug mode
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||||
|
if !slices.Contains(DEV_MODES, model.Version) {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
}
|
||||||
|
|
||||||
engine := gin.New()
|
engine := gin.New()
|
||||||
engine.Use(gin.Recovery())
|
engine.Use(gin.Recovery())
|
||||||
@@ -20,36 +25,100 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
|
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set trusted proxies: %w", err)
|
return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
|
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
||||||
engine.Use(contextMiddleware.Middleware())
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
|
}, app.services.authService, app.services.oauthBrokerService)
|
||||||
|
|
||||||
uiMiddleware, err := middleware.NewUIMiddleware()
|
err := contextMiddleware.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize UI middleware: %w", err)
|
return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.Use(contextMiddleware.Middleware())
|
||||||
|
|
||||||
|
uiMiddleware := middleware.NewUIMiddleware()
|
||||||
|
|
||||||
|
err = uiMiddleware.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize UI middleware: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.Use(uiMiddleware.Middleware())
|
engine.Use(uiMiddleware.Middleware())
|
||||||
|
|
||||||
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
|
zerologMiddleware := middleware.NewZerologMiddleware()
|
||||||
|
|
||||||
|
err = zerologMiddleware.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
engine.Use(zerologMiddleware.Middleware())
|
engine.Use(zerologMiddleware.Middleware())
|
||||||
|
|
||||||
apiRouter := engine.Group("/api")
|
apiRouter := engine.Group("/api")
|
||||||
|
|
||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
contextController := controller.NewContextController(controller.ContextControllerConfig{
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
Providers: app.context.configuredProviders,
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
Title: app.config.UI.Title,
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
AppURL: app.config.AppURL,
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
CookieDomain: app.context.cookieDomain,
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
|
||||||
controller.NewHealthController(apiRouter)
|
BackgroundImage: app.config.UI.BackgroundImage,
|
||||||
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
|
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
|
||||||
|
WarningsEnabled: app.config.UI.WarningsEnabled,
|
||||||
|
}, apiRouter)
|
||||||
|
|
||||||
app.router = engine
|
contextController.SetupRoutes()
|
||||||
return nil
|
|
||||||
|
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
|
||||||
|
AppURL: app.config.AppURL,
|
||||||
|
SecureCookie: app.config.Auth.SecureCookie,
|
||||||
|
CSRFCookieName: app.context.csrfCookieName,
|
||||||
|
RedirectCookieName: app.context.redirectCookieName,
|
||||||
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
OAuthSessionCookieName: app.context.oauthSessionCookieName,
|
||||||
|
}, apiRouter, app.services.authService)
|
||||||
|
|
||||||
|
oauthController.SetupRoutes()
|
||||||
|
|
||||||
|
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
|
||||||
|
|
||||||
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
|
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
|
||||||
|
AppURL: app.config.AppURL,
|
||||||
|
}, apiRouter, app.services.accessControlService, app.services.authService)
|
||||||
|
|
||||||
|
proxyController.SetupRoutes()
|
||||||
|
|
||||||
|
userController := controller.NewUserController(controller.UserControllerConfig{
|
||||||
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
|
}, apiRouter, app.services.authService)
|
||||||
|
|
||||||
|
userController.SetupRoutes()
|
||||||
|
|
||||||
|
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
|
||||||
|
Path: app.config.Resources.Path,
|
||||||
|
Enabled: app.config.Resources.Enabled,
|
||||||
|
}, &engine.RouterGroup)
|
||||||
|
|
||||||
|
resourcesController.SetupRoutes()
|
||||||
|
|
||||||
|
healthController := controller.NewHealthController(apiRouter)
|
||||||
|
|
||||||
|
healthController.SetupRoutes()
|
||||||
|
|
||||||
|
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
|
||||||
|
|
||||||
|
wellknownController.SetupRoutes()
|
||||||
|
|
||||||
|
return engine, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,66 +1,130 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
type Services struct {
|
||||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
|
accessControlService *service.AccessControlsService
|
||||||
|
authService *service.AuthService
|
||||||
|
dockerService *service.DockerService
|
||||||
|
kubernetesService *service.KubernetesService
|
||||||
|
ldapService *service.LdapService
|
||||||
|
oauthBrokerService *service.OAuthBrokerService
|
||||||
|
oidcService *service.OIDCService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
|
||||||
|
services := Services{}
|
||||||
|
|
||||||
|
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||||
|
Address: app.config.LDAP.Address,
|
||||||
|
BindDN: app.config.LDAP.BindDN,
|
||||||
|
BindPassword: app.config.LDAP.BindPassword,
|
||||||
|
BaseDN: app.config.LDAP.BaseDN,
|
||||||
|
Insecure: app.config.LDAP.Insecure,
|
||||||
|
SearchFilter: app.config.LDAP.SearchFilter,
|
||||||
|
AuthCert: app.config.LDAP.AuthCert,
|
||||||
|
AuthKey: app.config.LDAP.AuthKey,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := ldapService.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
|
||||||
|
ldapService.Unconfigure()
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.ldapService = ldapService
|
services.ldapService = ldapService
|
||||||
|
|
||||||
|
var labelProvider service.LabelProvider
|
||||||
|
var dockerService *service.DockerService
|
||||||
|
var kubernetesService *service.KubernetesService
|
||||||
|
|
||||||
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
||||||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
||||||
|
|
||||||
var labelProvider service.LabelProvider
|
|
||||||
|
|
||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
tlog.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
kubernetesService = service.NewKubernetesService()
|
||||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
err = kubernetesService.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
return Services{}, err
|
||||||
}
|
}
|
||||||
|
services.kubernetesService = kubernetesService
|
||||||
app.services.kubernetesService = kubernetesService
|
|
||||||
labelProvider = kubernetesService
|
labelProvider = kubernetesService
|
||||||
} else {
|
} else {
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
tlog.App.Debug().Msg("Using Docker label provider")
|
||||||
|
dockerService = service.NewDockerService()
|
||||||
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
err = dockerService.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize docker service: %w", err)
|
return Services{}, err
|
||||||
}
|
}
|
||||||
|
services.dockerService = dockerService
|
||||||
app.services.dockerService = dockerService
|
|
||||||
labelProvider = dockerService
|
labelProvider = dockerService
|
||||||
}
|
}
|
||||||
|
|
||||||
accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
|
accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps)
|
||||||
app.services.accessControlService = accessControlsService
|
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
err = accessControlsService.Init()
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
|
||||||
|
|
||||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
|
||||||
app.services.authService = authService
|
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
return Services{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.oidcService = oidcService
|
services.accessControlService = accessControlsService
|
||||||
|
|
||||||
return nil
|
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
|
||||||
|
|
||||||
|
err = oauthBrokerService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
|
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||||
|
LocalUsers: app.context.localUsers,
|
||||||
|
OauthWhitelist: app.config.OAuth.Whitelist,
|
||||||
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
|
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
||||||
|
SecureCookie: app.config.Auth.SecureCookie,
|
||||||
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
LoginTimeout: app.config.Auth.LoginTimeout,
|
||||||
|
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
|
IP: app.config.Auth.IP,
|
||||||
|
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
||||||
|
}, services.ldapService, queries, services.oauthBrokerService)
|
||||||
|
|
||||||
|
err = authService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.authService = authService
|
||||||
|
|
||||||
|
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
|
||||||
|
Clients: app.config.OIDC.Clients,
|
||||||
|
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
||||||
|
PublicKeyPath: app.config.OIDC.PublicKeyPath,
|
||||||
|
Issuer: app.config.AppURL,
|
||||||
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
|
}, queries)
|
||||||
|
|
||||||
|
err = oidcService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.oidcService = oidcService
|
||||||
|
|
||||||
|
return services, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -24,52 +24,62 @@ type UserContextResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AppContextResponse struct {
|
type AppContextResponse struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Providers []model.Provider `json:"providers"`
|
Providers []Provider `json:"providers"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
AppURL string `json:"appUrl"`
|
AppURL string `json:"appUrl"`
|
||||||
CookieDomain string `json:"cookieDomain"`
|
CookieDomain string `json:"cookieDomain"`
|
||||||
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
|
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
|
||||||
BackgroundImage string `json:"backgroundImage"`
|
BackgroundImage string `json:"backgroundImage"`
|
||||||
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
|
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
|
||||||
WarningsEnabled bool `json:"warningsEnabled"`
|
WarningsEnabled bool `json:"warningsEnabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Provider struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
OAuth bool `json:"oauth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContextControllerConfig struct {
|
||||||
|
Providers []Provider
|
||||||
|
Title string
|
||||||
|
AppURL string
|
||||||
|
CookieDomain string
|
||||||
|
ForgotPasswordMessage string
|
||||||
|
BackgroundImage string
|
||||||
|
OAuthAutoRedirect string
|
||||||
|
WarningsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextController struct {
|
type ContextController struct {
|
||||||
log *logger.Logger
|
config ContextControllerConfig
|
||||||
config model.Config
|
router *gin.RouterGroup
|
||||||
runtime model.RuntimeConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextController(
|
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
|
||||||
log *logger.Logger,
|
if !config.WarningsEnabled {
|
||||||
config model.Config,
|
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
|
||||||
runtimeConfig model.RuntimeConfig,
|
|
||||||
router *gin.RouterGroup,
|
|
||||||
) *ContextController {
|
|
||||||
controller := &ContextController{
|
|
||||||
log: log,
|
|
||||||
config: config,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !config.UI.WarningsEnabled {
|
return &ContextController{
|
||||||
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
config: config,
|
||||||
|
router: router,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contextGroup := router.Group("/context")
|
func (controller *ContextController) SetupRoutes() {
|
||||||
|
contextGroup := controller.router.Group("/context")
|
||||||
contextGroup.GET("/user", controller.userContextHandler)
|
contextGroup.GET("/user", controller.userContextHandler)
|
||||||
contextGroup.GET("/app", controller.appContextHandler)
|
contextGroup.GET("/app", controller.appContextHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ContextController) userContextHandler(c *gin.Context) {
|
func (controller *ContextController) userContextHandler(c *gin.Context) {
|
||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
tlog.App.Debug().Err(err).Msg("No user context found in request")
|
||||||
c.JSON(200, UserContextResponse{
|
c.JSON(200, UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
@@ -95,10 +105,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
||||||
appUrl, err := url.Parse(controller.runtime.AppURL)
|
appUrl, err := url.Parse(controller.config.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
|
tlog.App.Error().Err(err).Msg("Failed to parse app URL")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -109,13 +118,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
|
|||||||
c.JSON(200, AppContextResponse{
|
c.JSON(200, AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Providers: controller.runtime.ConfiguredProviders,
|
Providers: controller.config.Providers,
|
||||||
Title: controller.config.UI.Title,
|
Title: controller.config.Title,
|
||||||
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
|
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
|
||||||
CookieDomain: controller.runtime.CookieDomain,
|
CookieDomain: controller.config.CookieDomain,
|
||||||
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
|
||||||
BackgroundImage: controller.config.UI.BackgroundImage,
|
BackgroundImage: controller.config.BackgroundImage,
|
||||||
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
|
OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
|
||||||
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
WarningsEnabled: controller.config.WarningsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,19 +8,30 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestContextController(t *testing.T) {
|
func TestContextController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
controllerConfig := controller.ContextControllerConfig{
|
||||||
|
Providers: []controller.Provider{
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
{
|
||||||
|
Name: "Local",
|
||||||
|
ID: "local",
|
||||||
|
OAuth: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Title: "Tinyauth",
|
||||||
|
AppURL: "https://tinyauth.example.com",
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
ForgotPasswordMessage: "foo",
|
||||||
|
BackgroundImage: "/background.jpg",
|
||||||
|
OAuthAutoRedirect: "none",
|
||||||
|
WarningsEnabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
@@ -36,17 +47,17 @@ func TestContextController(t *testing.T) {
|
|||||||
expectedAppContextResponse := controller.AppContextResponse{
|
expectedAppContextResponse := controller.AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: controllerConfig.Providers,
|
||||||
Title: cfg.UI.Title,
|
Title: controllerConfig.Title,
|
||||||
AppURL: runtime.AppURL,
|
AppURL: controllerConfig.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: controllerConfig.CookieDomain,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: controllerConfig.BackgroundImage,
|
||||||
OAuthAutoRedirect: cfg.OAuth.AutoRedirect,
|
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: controllerConfig.WarningsEnabled,
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedAppContextResponse)
|
bytes, err := json.Marshal(expectedAppContextResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -60,7 +71,7 @@ func TestContextController(t *testing.T) {
|
|||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedUserContextResponse)
|
bytes, err := json.Marshal(expectedUserContextResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -75,7 +86,7 @@ func TestContextController(t *testing.T) {
|
|||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
|
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -89,11 +100,11 @@ func TestContextController(t *testing.T) {
|
|||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
|
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedUserContextResponse)
|
bytes, err := json.Marshal(expectedUserContextResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -110,12 +121,13 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewContextController(log, cfg, runtime, group)
|
contextController := controller.NewContextController(controllerConfig, group)
|
||||||
|
contextController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
request, err := http.NewRequest("GET", test.path, nil)
|
request, err := http.NewRequest("GET", test.path, nil)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,18 @@ package controller
|
|||||||
import "github.com/gin-gonic/gin"
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
type HealthController struct {
|
type HealthController struct {
|
||||||
|
router *gin.RouterGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
||||||
controller := &HealthController{}
|
return &HealthController{
|
||||||
|
router: router,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router.GET("/healthz", controller.healthHandler)
|
func (controller *HealthController) SetupRoutes() {
|
||||||
router.HEAD("/healthz", controller.healthHandler)
|
controller.router.GET("/healthz", controller.healthHandler)
|
||||||
|
controller.router.HEAD("/healthz", controller.healthHandler)
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *HealthController) healthHandler(c *gin.Context) {
|
func (controller *HealthController) healthHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
|
tlog.NewTestLogger().Init()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
path string
|
path string
|
||||||
@@ -29,7 +30,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
"message": "Healthy",
|
"message": "Healthy",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedHealthResponse)
|
bytes, err := json.Marshal(expectedHealthResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -43,7 +44,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
"message": "Healthy",
|
"message": "Healthy",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedHealthResponse)
|
bytes, err := json.Marshal(expectedHealthResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -55,12 +56,13 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewHealthController(group)
|
healthController := controller.NewHealthController(group)
|
||||||
|
healthController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
request, err := http.NewRequest(test.method, test.path, nil)
|
request, err := http.NewRequest(test.method, test.path, nil)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -20,32 +19,33 @@ type OAuthRequest struct {
|
|||||||
Provider string `uri:"provider" binding:"required"`
|
Provider string `uri:"provider" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthController struct {
|
type OAuthControllerConfig struct {
|
||||||
log *logger.Logger
|
CSRFCookieName string
|
||||||
config model.Config
|
OAuthSessionCookieName string
|
||||||
runtime model.RuntimeConfig
|
RedirectCookieName string
|
||||||
auth *service.AuthService
|
SecureCookie bool
|
||||||
|
AppURL string
|
||||||
|
CookieDomain string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthController(
|
type OAuthController struct {
|
||||||
log *logger.Logger,
|
config OAuthControllerConfig
|
||||||
config model.Config,
|
router *gin.RouterGroup
|
||||||
runtimeConfig model.RuntimeConfig,
|
auth *service.AuthService
|
||||||
router *gin.RouterGroup,
|
}
|
||||||
auth *service.AuthService,
|
|
||||||
) *OAuthController {
|
|
||||||
controller := &OAuthController{
|
|
||||||
log: log,
|
|
||||||
config: config,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
auth: auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
oauthGroup := router.Group("/oauth")
|
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
|
||||||
|
return &OAuthController{
|
||||||
|
config: config,
|
||||||
|
router: router,
|
||||||
|
auth: auth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *OAuthController) SetupRoutes() {
|
||||||
|
oauthGroup := controller.router.Group("/oauth")
|
||||||
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
||||||
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||||
@@ -53,7 +53,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -66,7 +66,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
err = c.BindQuery(&reqParams)
|
err = c.BindQuery(&reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
|
tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -75,10 +75,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
|
||||||
|
|
||||||
if !isRedirectSafe {
|
if !isRedirectSafe {
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -86,7 +86,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
|
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -97,7 +97,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -105,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -119,7 +119,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -127,21 +127,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
|
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,8 +149,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state != oauthPendingSession.State {
|
if state != oauthPendingSession.State {
|
||||||
controller.log.App.Warn().Msg("OAuth state mismatch")
|
tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,80 +158,68 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if user == nil {
|
|
||||||
controller.log.App.Warn().Msg("OAuth provider did not return user info")
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
tlog.App.Error().Msg("OAuth provider did not return an email")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
||||||
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
|
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
|
||||||
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
|
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
|
|
||||||
if strings.TrimSpace(user.Name) != "" {
|
if strings.TrimSpace(user.Name) != "" {
|
||||||
controller.log.App.Debug().Msg("Using name from OAuth provider")
|
tlog.App.Debug().Msg("Using name from OAuth provider")
|
||||||
name = user.Name
|
name = user.Name
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
|
tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
|
||||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
var username string
|
var username string
|
||||||
|
|
||||||
if strings.TrimSpace(user.PreferredUsername) != "" {
|
if strings.TrimSpace(user.PreferredUsername) != "" {
|
||||||
controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
|
tlog.App.Debug().Msg("Using preferred username from OAuth provider")
|
||||||
username = user.PreferredUsername
|
username = user.PreferredUsername
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
|
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
|
||||||
username = strings.Replace(user.Email, "@", "_", 1)
|
username = strings.Replace(user.Email, "@", "_", 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if svc.ID() != req.Provider {
|
if svc.ID() != req.Provider {
|
||||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,29 +233,29 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
OAuthSub: user.Sub,
|
OAuthSub: user.Sub,
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Msg("Creating session cookie for user")
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
|
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
||||||
|
|
||||||
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
||||||
controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
|
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
|
||||||
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,16 +265,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
||||||
@@ -295,10 +283,3 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
|
|||||||
params.ClientID != "" &&
|
params.ClientID != "" &&
|
||||||
params.RedirectURI != ""
|
params.RedirectURI != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) getCookieDomain() string {
|
|
||||||
if controller.config.Auth.SubdomainsEnabled {
|
|
||||||
return "." + controller.runtime.CookieDomain
|
|
||||||
}
|
|
||||||
return controller.runtime.CookieDomain
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OIDCControllerConfig struct{}
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
log *logger.Logger
|
config OIDCControllerConfig
|
||||||
oidc *service.OIDCService
|
router *gin.RouterGroup
|
||||||
runtime model.RuntimeConfig
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCallback struct {
|
type AuthorizeCallback struct {
|
||||||
@@ -56,42 +58,29 @@ type ClientCredentials struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCController(
|
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
|
||||||
log *logger.Logger,
|
return &OIDCController{
|
||||||
oidcService *service.OIDCService,
|
config: config,
|
||||||
runtimeConfig model.RuntimeConfig,
|
oidc: oidcService,
|
||||||
router *gin.RouterGroup) *OIDCController {
|
router: router,
|
||||||
controller := &OIDCController{
|
|
||||||
log: log,
|
|
||||||
oidc: oidcService,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
oidcGroup := router.Group("/oidc")
|
func (controller *OIDCController) SetupRoutes() {
|
||||||
|
oidcGroup := controller.router.Group("/oidc")
|
||||||
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
||||||
oidcGroup.POST("/authorize", controller.Authorize)
|
oidcGroup.POST("/authorize", controller.Authorize)
|
||||||
oidcGroup.POST("/token", controller.Token)
|
oidcGroup.POST("/token", controller.Token)
|
||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
oidcGroup.POST("/userinfo", controller.Userinfo)
|
oidcGroup.POST("/userinfo", controller.Userinfo)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
|
||||||
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req ClientRequest
|
var req ClientRequest
|
||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -102,7 +91,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
|
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Client not found",
|
"message": "Client not found",
|
||||||
@@ -118,7 +107,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Authorize(c *gin.Context) {
|
func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if !controller.oidc.IsConfigured() {
|
||||||
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
|
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -153,7 +142,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err = controller.oidc.ValidateAuthorizeParams(req)
|
err = controller.oidc.ValidateAuthorizeParams(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
|
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
|
||||||
if err.Error() != "invalid_request_uri" {
|
if err.Error() != "invalid_request_uri" {
|
||||||
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
||||||
return
|
return
|
||||||
@@ -185,7 +174,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to store user info")
|
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
||||||
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -208,10 +197,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if !controller.oidc.IsConfigured() {
|
||||||
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
tlog.App.Warn().Msg("OIDC not configured")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"error": "server_error",
|
"error": "not_found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -220,7 +209,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.Bind(&req)
|
err := c.Bind(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
|
tlog.App.Error().Err(err).Msg("Failed to bind token request")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -229,7 +218,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
err = controller.oidc.ValidateGrantType(req.GrantType)
|
err = controller.oidc.ValidateGrantType(req.GrantType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Invalid grant type")
|
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
@@ -244,12 +233,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
// If it fails, we try basic auth
|
// If it fails, we try basic auth
|
||||||
if creds.ClientID == "" || creds.ClientSecret == "" {
|
if creds.ClientID == "" || creds.ClientSecret == "" {
|
||||||
controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
|
tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
|
||||||
|
|
||||||
clientId, clientSecret, ok := c.Request.BasicAuth()
|
clientId, clientSecret, ok := c.Request.BasicAuth()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("Client credentials not found in basic auth")
|
tlog.App.Error().Msg("Missing authorization header")
|
||||||
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
@@ -266,7 +255,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(creds.ClientID)
|
client, ok := controller.oidc.GetClient(creds.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
|
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
@@ -274,7 +263,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if client.ClientSecret != creds.ClientSecret {
|
if client.ClientSecret != creds.ClientSecret {
|
||||||
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
|
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
@@ -288,30 +277,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete code")
|
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
controller.log.App.Warn().Msg("Code not found")
|
tlog.App.Warn().Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeExpired) {
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
controller.log.App.Warn().Msg("Code expired")
|
tlog.App.Warn().Msg("Code expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrInvalidClient) {
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
controller.log.App.Warn().Msg("Code does not belong to client")
|
tlog.App.Warn().Msg("Invalid client ID")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -319,7 +308,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
controller.log.App.Warn().Msg("Redirect URI does not match")
|
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -329,7 +318,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("PKCE validation failed")
|
tlog.App.Warn().Msg("PKCE validation failed")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -339,7 +328,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -352,7 +341,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenExpired) {
|
if errors.Is(err, service.ErrTokenExpired) {
|
||||||
controller.log.App.Warn().Msg("Refresh token expired")
|
tlog.App.Error().Err(err).Msg("Refresh token expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -360,14 +349,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, service.ErrInvalidClient) {
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
controller.log.App.Warn().Msg("Refresh token does not belong to client")
|
tlog.App.Error().Err(err).Msg("Invalid client")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
|
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -384,10 +373,10 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if !controller.oidc.IsConfigured() {
|
||||||
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
tlog.App.Warn().Msg("OIDC not configured")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"error": "server_error",
|
"error": "not_found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -398,7 +387,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
if authorization != "" {
|
if authorization != "" {
|
||||||
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
|
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -406,7 +395,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower(tokenType) != "bearer" {
|
if strings.ToLower(tokenType) != "bearer" {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -416,7 +405,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
token = bearerToken
|
token = bearerToken
|
||||||
} else if c.Request.Method == http.MethodPost {
|
} else if c.Request.Method == http.MethodPost {
|
||||||
if c.ContentType() != "application/x-www-form-urlencoded" {
|
if c.ContentType() != "application/x-www-form-urlencoded" {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
|
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -424,14 +413,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
token = c.PostForm("access_token")
|
token = c.PostForm("access_token")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
|
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -442,14 +431,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenNotFound) {
|
if errors.Is(err, service.ErrTokenNotFound) {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get access token")
|
tlog.App.Err(err).Msg("Failed to get token entry")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -458,7 +447,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
|
|
||||||
// If we don't have the openid scope, return an error
|
// If we don't have the openid scope, return an error
|
||||||
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_scope",
|
"error": "invalid_scope",
|
||||||
})
|
})
|
||||||
@@ -468,7 +457,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get user info")
|
tlog.App.Err(err).Msg("Failed to get user entry")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -479,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
||||||
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
|
tlog.App.Error().Err(err).Msg(reason)
|
||||||
|
|
||||||
if callback != "" {
|
if callback != "" {
|
||||||
errorQueries := CallbackError{
|
errorQueries := CallbackError{
|
||||||
@@ -519,16 +508,8 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectUrl := ""
|
|
||||||
|
|
||||||
if controller.oidc != nil {
|
|
||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
|
|
||||||
} else {
|
|
||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": redirectUrl,
|
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
package controller_test
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -20,15 +19,29 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOIDCController(t *testing.T) {
|
func TestOIDCController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
oidcServiceCfg := service.OIDCServiceConfig{
|
||||||
|
Clients: map[string]model.OIDCClientConfig{
|
||||||
|
"test": {
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
ClientSecret: "some-client-secret",
|
||||||
|
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||||
|
Name: "Test Client",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
||||||
|
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
||||||
|
Issuer: "https://tinyauth.example.com",
|
||||||
|
SessionExpiry: 500,
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerCfg := controller.OIDCControllerConfig{}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
@@ -90,7 +103,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
|
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
|
||||||
},
|
},
|
||||||
@@ -110,7 +123,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
Nonce: "some-nonce",
|
Nonce: "some-nonce",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -118,7 +131,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
|
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
|
||||||
},
|
},
|
||||||
@@ -138,7 +151,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
Nonce: "some-nonce",
|
Nonce: "some-nonce",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -147,11 +160,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -170,7 +183,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -178,7 +191,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, res["error"], "unsupported_grant_type")
|
assert.Equal(t, res["error"], "unsupported_grant_type")
|
||||||
},
|
},
|
||||||
@@ -193,7 +206,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -231,7 +244,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -254,11 +267,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var authorizeRes map[string]any
|
var authorizeRes map[string]any
|
||||||
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := authorizeRes["redirect_uri"].(string)
|
redirectURI := authorizeRes["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
code := queryParams.Get("code")
|
code := queryParams.Get("code")
|
||||||
@@ -270,7 +283,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -293,7 +306,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var tokenRes map[string]any
|
var tokenRes map[string]any
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok := tokenRes["refresh_token"]
|
_, ok := tokenRes["refresh_token"]
|
||||||
assert.True(t, ok, "Expected refresh token in response")
|
assert.True(t, ok, "Expected refresh token in response")
|
||||||
@@ -307,7 +320,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
ClientSecret: "some-client-secret",
|
ClientSecret: "some-client-secret",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -319,7 +332,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
var refreshRes map[string]any
|
var refreshRes map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
|
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok = refreshRes["access_token"]
|
_, ok = refreshRes["access_token"]
|
||||||
assert.True(t, ok, "Expected access token in refresh response")
|
assert.True(t, ok, "Expected access token in refresh response")
|
||||||
@@ -340,11 +353,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var authorizeRes map[string]any
|
var authorizeRes map[string]any
|
||||||
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := authorizeRes["redirect_uri"].(string)
|
redirectURI := authorizeRes["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
code := queryParams.Get("code")
|
code := queryParams.Get("code")
|
||||||
@@ -356,7 +369,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -376,7 +389,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var secondRes map[string]any
|
var secondRes map[string]any
|
||||||
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
|
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "invalid_grant", secondRes["error"])
|
assert.Equal(t, "invalid_grant", secondRes["error"])
|
||||||
},
|
},
|
||||||
@@ -404,7 +417,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var tokenRes map[string]any
|
var tokenRes map[string]any
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken := tokenRes["access_token"].(string)
|
accessToken := tokenRes["access_token"].(string)
|
||||||
assert.NotEmpty(t, accessToken)
|
assert.NotEmpty(t, accessToken)
|
||||||
@@ -416,7 +429,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var userInfoRes map[string]any
|
var userInfoRes map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok := userInfoRes["sub"]
|
_, ok := userInfoRes["sub"]
|
||||||
assert.True(t, ok, "Expected sub claim in userinfo response")
|
assert.True(t, ok, "Expected sub claim in userinfo response")
|
||||||
@@ -436,7 +449,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -451,7 +464,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -466,7 +479,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -481,7 +494,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_grant", res["error"])
|
assert.Equal(t, "invalid_grant", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -496,7 +509,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -511,7 +524,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -528,7 +541,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var tokenRes map[string]any
|
var tokenRes map[string]any
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken := tokenRes["access_token"].(string)
|
accessToken := tokenRes["access_token"].(string)
|
||||||
assert.NotEmpty(t, accessToken)
|
assert.NotEmpty(t, accessToken)
|
||||||
@@ -542,7 +555,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var userInfoRes map[string]any
|
var userInfoRes map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok := userInfoRes["sub"]
|
_, ok := userInfoRes["sub"]
|
||||||
assert.True(t, ok, "Expected sub claim in userinfo response")
|
assert.True(t, ok, "Expected sub claim in userinfo response")
|
||||||
@@ -566,7 +579,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "",
|
CodeChallengeMethod: "",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -575,11 +588,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -596,7 +609,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeVerifier: "some-challenge",
|
CodeVerifier: "some-challenge",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(tokenReqBody)
|
reqBodyEncoded, err := query.Values(tokenReqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -627,7 +640,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "S256",
|
CodeChallengeMethod: "S256",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -636,11 +649,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -657,7 +670,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeVerifier: "some-challenge",
|
CodeVerifier: "some-challenge",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(tokenReqBody)
|
reqBodyEncoded, err := query.Values(tokenReqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -688,7 +701,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "S256",
|
CodeChallengeMethod: "S256",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -697,11 +710,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -718,7 +731,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeVerifier: "some-challenge-1",
|
CodeVerifier: "some-challenge-1",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(tokenReqBody)
|
reqBodyEncoded, err := query.Values(tokenReqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -749,7 +762,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "foo",
|
CodeChallengeMethod: "foo",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -758,11 +771,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
error := queryParams.Get("error")
|
error := queryParams.Get("error")
|
||||||
@@ -781,11 +794,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
code := queryParams.Get("code")
|
code := queryParams.Get("code")
|
||||||
@@ -797,7 +810,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -808,7 +821,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken := res["access_token"].(string)
|
accessToken := res["access_token"].(string)
|
||||||
assert.NotEmpty(t, accessToken)
|
assert.NotEmpty(t, accessToken)
|
||||||
@@ -833,22 +846,20 @@ func TestOIDCController(t *testing.T) {
|
|||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_grant", res["error"])
|
assert.Equal(t, "invalid_grant", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||||
wg := &sync.WaitGroup{}
|
err = oidcService.Init()
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -862,7 +873,8 @@ func TestOIDCController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewOIDCController(log, oidcService, runtime, group)
|
oidcController := controller.NewOIDCController(controllerCfg, oidcService, group)
|
||||||
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -871,6 +883,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -50,31 +50,29 @@ type ProxyContext struct {
|
|||||||
ProxyType ProxyType
|
ProxyType ProxyType
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyController struct {
|
type ProxyControllerConfig struct {
|
||||||
log *logger.Logger
|
AppURL string
|
||||||
runtime model.RuntimeConfig
|
|
||||||
acls *service.AccessControlsService
|
|
||||||
auth *service.AuthService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyController(
|
type ProxyController struct {
|
||||||
log *logger.Logger,
|
config ProxyControllerConfig
|
||||||
runtime model.RuntimeConfig,
|
router *gin.RouterGroup
|
||||||
router *gin.RouterGroup,
|
acls *service.AccessControlsService
|
||||||
acls *service.AccessControlsService,
|
auth *service.AuthService
|
||||||
auth *service.AuthService,
|
}
|
||||||
) *ProxyController {
|
|
||||||
controller := &ProxyController{
|
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
|
||||||
log: log,
|
return &ProxyController{
|
||||||
runtime: runtime,
|
config: config,
|
||||||
acls: acls,
|
router: router,
|
||||||
auth: auth,
|
acls: acls,
|
||||||
|
auth: auth,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyGroup := router.Group("/auth")
|
func (controller *ProxyController) SetupRoutes() {
|
||||||
|
proxyGroup := controller.router.Group("/auth")
|
||||||
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||||
@@ -82,7 +80,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
proxyCtx, err := controller.getProxyContext(c)
|
proxyCtx, err := controller.getProxyContext(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
|
tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad request",
|
"message": "Bad request",
|
||||||
@@ -90,15 +88,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
|
||||||
|
|
||||||
// Get acls
|
// Get acls
|
||||||
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
|
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
|
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
|
||||||
|
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
|
|
||||||
if controller.auth.IsBypassedIP(clientIP, acls) {
|
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||||
@@ -113,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
|
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !authEnabled {
|
if !authEnabled {
|
||||||
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
|
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -135,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -158,24 +160,26 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
|
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
|
||||||
userContext = &model.UserContext{
|
userContext = &model.UserContext{
|
||||||
Authenticated: false,
|
Authenticated: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
||||||
|
|
||||||
if userContext.Authenticated {
|
if userContext.Authenticated {
|
||||||
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
||||||
|
|
||||||
if !userAllowed {
|
if !userAllowed {
|
||||||
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
|
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -186,7 +190,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
queries.Set("username", userContext.GetUsername())
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -211,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !groupOK {
|
if !groupOK {
|
||||||
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
|
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
@@ -219,7 +223,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -230,7 +234,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
queries.Set("username", userContext.GetUsername())
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -273,12 +277,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -302,19 +306,20 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
|||||||
headers := utils.ParseHeaders(acls.Response.Headers)
|
headers := utils.ParseHeaders(acls.Response.Headers)
|
||||||
|
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
|
tlog.App.Debug().Str("header", key).Msg("Setting header")
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
|
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
|
||||||
|
|
||||||
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
||||||
controller.log.App.Debug().Msg("Setting basic auth header for response")
|
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
|
||||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
|
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
|
||||||
redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
|
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -515,7 +520,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
return ProxyContext{}, err
|
return ProxyContext{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
|
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
|
||||||
|
|
||||||
authModules := controller.determineAuthModules(proxy)
|
authModules := controller.determineAuthModules(proxy)
|
||||||
|
|
||||||
@@ -526,13 +531,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
var ctx ProxyContext
|
var ctx ProxyContext
|
||||||
|
|
||||||
for _, module := range authModules {
|
for _, module := range authModules {
|
||||||
controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
|
tlog.App.Debug().Msgf("Trying auth module: %v", module)
|
||||||
ctx, err = controller.getContextFromAuthModule(c, module)
|
ctx, err = controller.getContextFromAuthModule(c, module)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
|
tlog.App.Debug().Msgf("Auth module %v succeeded", module)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
|
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -544,9 +549,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
|
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
|
||||||
|
|
||||||
if isBrowser {
|
if isBrowser {
|
||||||
controller.log.App.Debug().Msg("Request identified as coming from a browser client")
|
tlog.App.Debug().Msg("Request identified as coming from a browser")
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
|
tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.IsBrowser = isBrowser
|
ctx.IsBrowser = isBrowser
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package controller_test
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -14,15 +13,35 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyController(t *testing.T) {
|
func TestProxyController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerCfg := controller.ProxyControllerConfig{
|
||||||
|
AppURL: "https://tinyauth.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
acls := map[string]model.App{
|
acls := map[string]model.App{
|
||||||
"app_path_allow": {
|
"app_path_allow": {
|
||||||
@@ -379,19 +398,32 @@ func TestProxyController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
docker := service.NewDockerService()
|
||||||
ctx := context.TODO()
|
err = docker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
err = ldap.Init()
|
||||||
aclsService := service.NewAccessControlsService(log, nil, acls)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
aclsService := service.NewAccessControlsService(docker, acls)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
@@ -406,13 +438,15 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(log, runtime, group, aclsService, authService)
|
proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService)
|
||||||
|
proxyController.SetupRoutes()
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,39 +4,42 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ResourcesControllerConfig struct {
|
||||||
|
Path string
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
type ResourcesController struct {
|
type ResourcesController struct {
|
||||||
config model.Config
|
config ResourcesControllerConfig
|
||||||
|
router *gin.RouterGroup
|
||||||
fileServer http.Handler
|
fileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourcesController(
|
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
|
||||||
config model.Config,
|
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
|
||||||
router *gin.RouterGroup,
|
|
||||||
) *ResourcesController {
|
|
||||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
|
|
||||||
|
|
||||||
controller := &ResourcesController{
|
return &ResourcesController{
|
||||||
config: config,
|
config: config,
|
||||||
|
router: router,
|
||||||
fileServer: fileServer,
|
fileServer: fileServer,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router.GET("/resources/*resource", controller.resourcesHandler)
|
func (controller *ResourcesController) SetupRoutes() {
|
||||||
|
controller.router.GET("/resources/*resource", controller.resourcesHandler)
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||||
if controller.config.Resources.Path == "" {
|
if controller.config.Path == "" {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Resources not found",
|
"message": "Resources not found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !controller.config.Resources.Enabled {
|
if !controller.config.Enabled {
|
||||||
c.JSON(403, gin.H{
|
c.JSON(403, gin.H{
|
||||||
"status": 403,
|
"status": 403,
|
||||||
"message": "Resources are disabled",
|
"message": "Resources are disabled",
|
||||||
|
|||||||
@@ -3,20 +3,26 @@ package controller_test
|
|||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResourcesController(t *testing.T) {
|
func TestResourcesController(t *testing.T) {
|
||||||
cfg, _ := test.CreateTestConfigs(t)
|
tlog.NewTestLogger().Init()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
resourcesControllerCfg := controller.ResourcesControllerConfig{
|
||||||
|
Path: path.Join(tempDir, "resources"),
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := os.Mkdir(resourcesControllerCfg.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -55,11 +61,11 @@ func TestResourcesController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := resourcesControllerCfg.Path + "/testfile.txt"
|
||||||
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
|
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt"
|
testFilePathParent := tempDir + "/somefile.txt"
|
||||||
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
|
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -69,7 +75,8 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewResourcesController(cfg, group)
|
resourcesController := controller.NewResourcesController(resourcesControllerCfg, group)
|
||||||
|
resourcesController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -25,30 +25,30 @@ type TotpRequest struct {
|
|||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserController struct {
|
type UserControllerConfig struct {
|
||||||
log *logger.Logger
|
CookieDomain string
|
||||||
runtime model.RuntimeConfig
|
SessionCookieName string
|
||||||
auth *service.AuthService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserController(
|
type UserController struct {
|
||||||
log *logger.Logger,
|
config UserControllerConfig
|
||||||
runtimeConfig model.RuntimeConfig,
|
router *gin.RouterGroup
|
||||||
router *gin.RouterGroup,
|
auth *service.AuthService
|
||||||
auth *service.AuthService,
|
}
|
||||||
) *UserController {
|
|
||||||
controller := &UserController{
|
|
||||||
log: log,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
auth: auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
userGroup := router.Group("/user")
|
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
|
||||||
|
return &UserController{
|
||||||
|
config: config,
|
||||||
|
router: router,
|
||||||
|
auth: auth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *UserController) SetupRoutes() {
|
||||||
|
userGroup := controller.router.Group("/user")
|
||||||
userGroup.POST("/login", controller.loginHandler)
|
userGroup.POST("/login", controller.loginHandler)
|
||||||
userGroup.POST("/logout", controller.logoutHandler)
|
userGroup.POST("/logout", controller.logoutHandler)
|
||||||
userGroup.POST("/totp", controller.totpHandler)
|
userGroup.POST("/totp", controller.totpHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *UserController) loginHandler(c *gin.Context) {
|
func (controller *UserController) loginHandler(c *gin.Context) {
|
||||||
@@ -56,7 +56,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
|
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -64,13 +64,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
|
tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
|
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
|
tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
|
||||||
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
|
tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -84,16 +84,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrUserNotFound) {
|
if errors.Is(err, service.ErrUserNotFound) {
|
||||||
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
|
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found")
|
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
|
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -102,13 +102,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
||||||
controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
|
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
if search.Type == model.UserLocal {
|
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
|
||||||
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
|
|
||||||
} else {
|
|
||||||
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
|
|
||||||
}
|
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -122,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
localUser = controller.auth.GetLocalUser(req.Username)
|
localUser = controller.auth.GetLocalUser(req.Username)
|
||||||
|
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
|
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -131,7 +127,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if localUser.TOTPSecret != "" {
|
if localUser.TOTPSecret != "" {
|
||||||
controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
|
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||||
|
|
||||||
name := localUser.Attributes.Name
|
name := localUser.Attributes.Name
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@@ -140,7 +136,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
email := localUser.Attributes.Email
|
email := localUser.Attributes.Email
|
||||||
if email == "" {
|
if email == "" {
|
||||||
email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
|
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
||||||
@@ -152,7 +148,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -174,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,10 +187,12 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
sessionCookie.Provider = "ldap"
|
sessionCookie.Provider = "ldap"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -204,13 +202,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
|
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
|
||||||
|
tlog.AuditLoginSuccess(c, req.Username, "username")
|
||||||
if search.Type == model.UserLocal {
|
|
||||||
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
|
|
||||||
} else {
|
|
||||||
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
|
|
||||||
}
|
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(req.Username, true)
|
controller.auth.RecordLoginAttempt(req.Username, true)
|
||||||
|
|
||||||
@@ -221,20 +214,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *UserController) logoutHandler(c *gin.Context) {
|
func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||||
controller.log.App.Debug().Msg("Logout attempt")
|
tlog.App.Debug().Msg("Logout request received")
|
||||||
|
|
||||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, http.ErrNoCookie) {
|
if errors.Is(err, http.ErrNoCookie) {
|
||||||
controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
|
tlog.App.Warn().Msg("No session cookie found on logout request")
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Logout successful",
|
"message": "Logout successful",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -245,7 +238,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -256,10 +249,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
|
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
|
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||||
controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
|
tlog.AuditLogout(c, "unknown", "unknown")
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
@@ -275,7 +268,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
|
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -286,7 +279,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -295,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.TOTPPending() {
|
if !context.TOTPPending() {
|
||||||
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
|
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -303,13 +296,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
||||||
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
|
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -322,7 +314,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
|
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -333,9 +325,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
|
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
|
||||||
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
||||||
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
|
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -343,15 +335,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = controller.auth.DeleteSession(c, uuid)
|
_, err = controller.auth.DeleteSession(c, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
|
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
||||||
@@ -359,7 +351,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,10 +362,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie.Email = user.Attributes.Email
|
sessionCookie.Email = user.Attributes.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -383,8 +377,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
|
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
|
||||||
controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
|
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,15 +19,53 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUserController(t *testing.T) {
|
func TestUserController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "attruser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
Attributes: model.UserAttributes{
|
||||||
|
Name: "Alice Smith",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "attrtotpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
Attributes: model.UserAttributes{
|
||||||
|
Name: "Bob Jones",
|
||||||
|
Email: "bob@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
userControllerCfg := controller.UserControllerConfig{
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
totpCtx := func(c *gin.Context) {
|
totpCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
@@ -73,12 +111,14 @@ func TestUserController(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
@@ -96,7 +136,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -104,7 +144,7 @@ func TestUserController(t *testing.T) {
|
|||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
require.Len(t, recorder.Result().Cookies(), 1)
|
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||||
|
|
||||||
cookie := recorder.Result().Cookies()[0]
|
cookie := recorder.Result().Cookies()[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
@@ -124,7 +164,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -145,7 +185,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for range 3 {
|
for range 3 {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -180,7 +220,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -191,12 +231,12 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, decodedBody["totpPending"], true)
|
assert.Equal(t, decodedBody["totpPending"], true)
|
||||||
|
|
||||||
// should set the session cookie
|
// should set the session cookie
|
||||||
require.Len(t, recorder.Result().Cookies(), 1)
|
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||||
cookie := recorder.Result().Cookies()[0]
|
cookie := recorder.Result().Cookies()[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.True(t, cookie.HttpOnly)
|
assert.True(t, cookie.HttpOnly)
|
||||||
@@ -217,7 +257,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -226,7 +266,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
cookies := recorder.Result().Cookies()
|
cookies := recorder.Result().Cookies()
|
||||||
require.Len(t, cookies, 1)
|
assert.Len(t, cookies, 1)
|
||||||
|
|
||||||
cookie := cookies[0]
|
cookie := cookies[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
@@ -240,7 +280,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
cookies = recorder.Result().Cookies()
|
cookies = recorder.Result().Cookies()
|
||||||
require.Len(t, cookies, 1)
|
assert.Len(t, cookies, 1)
|
||||||
|
|
||||||
cookie = cookies[0]
|
cookie = cookies[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
@@ -267,14 +307,14 @@ func TestUserController(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := controller.TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
recorder = httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
@@ -289,7 +329,7 @@ func TestUserController(t *testing.T) {
|
|||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
require.Len(t, recorder.Result().Cookies(), 1)
|
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||||
|
|
||||||
// should set a new session cookie with totp pending removed
|
// should set a new session cookie with totp pending removed
|
||||||
totpCookie := recorder.Result().Cookies()[0]
|
totpCookie := recorder.Result().Cookies()[0]
|
||||||
@@ -312,7 +352,7 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
recorder = httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
@@ -416,11 +456,21 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
docker := service.NewDockerService()
|
||||||
wg := &sync.WaitGroup{}
|
err = docker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
err = ldap.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
@@ -439,7 +489,8 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewUserController(log, runtime, group, authService)
|
userController := controller.NewUserController(userControllerCfg, group, authService)
|
||||||
|
userController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -448,6 +499,7 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,30 +26,28 @@ type OpenIDConnectConfiguration struct {
|
|||||||
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
|
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WellKnownControllerConfig struct{}
|
||||||
|
|
||||||
type WellKnownController struct {
|
type WellKnownController struct {
|
||||||
oidc *service.OIDCService
|
config WellKnownControllerConfig
|
||||||
|
engine *gin.Engine
|
||||||
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
|
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
|
||||||
controller := &WellKnownController{
|
return &WellKnownController{
|
||||||
oidc: oidc,
|
config: config,
|
||||||
|
oidc: oidc,
|
||||||
|
engine: engine,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
func (controller *WellKnownController) SetupRoutes() {
|
||||||
router.GET("/.well-known/jwks.json", controller.JWKS)
|
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||||
|
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC service not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
issuer := controller.oidc.GetIssuer()
|
issuer := controller.oidc.GetIssuer()
|
||||||
c.JSON(200, OpenIDConnectConfiguration{
|
c.JSON(200, OpenIDConnectConfiguration{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
@@ -71,19 +69,11 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC service not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jwks, err := controller.oidc.GetJWK()
|
jwks, err := controller.oidc.GetJWK()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": "500",
|
||||||
"message": "failed to get JWK",
|
"message": "failed to get JWK",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
package controller_test
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,17 +12,30 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWellKnownController(t *testing.T) {
|
func TestWellKnownController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
oidcServiceCfg := service.OIDCServiceConfig{
|
||||||
|
Clients: map[string]model.OIDCClientConfig{
|
||||||
|
"test": {
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
ClientSecret: "some-client-secret",
|
||||||
|
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||||
|
Name: "Test Client",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
||||||
|
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
||||||
|
Issuer: "https://tinyauth.example.com",
|
||||||
|
SessionExpiry: 500,
|
||||||
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
@@ -44,11 +56,11 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected := controller.OpenIDConnectConfiguration{
|
expected := controller.OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: oidcServiceCfg.Issuer,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer),
|
||||||
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL),
|
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer),
|
||||||
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL),
|
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer),
|
||||||
ScopesSupported: service.SupportedScopes,
|
ScopesSupported: service.SupportedScopes,
|
||||||
ResponseTypesSupported: service.SupportedResponseTypes,
|
ResponseTypesSupported: service.SupportedResponseTypes,
|
||||||
GrantTypesSupported: service.SupportedGrantTypes,
|
GrantTypesSupported: service.SupportedGrantTypes,
|
||||||
@@ -89,17 +101,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||||
|
err = oidcService.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -109,13 +119,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewWellKnownController(oidcService, &router.RouterGroup)
|
wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router)
|
||||||
|
wellKnownController.SetupRoutes()
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -35,27 +35,29 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddlewareConfig struct {
|
||||||
log *logger.Logger
|
CookieDomain string
|
||||||
runtime model.RuntimeConfig
|
SessionCookieName string
|
||||||
auth *service.AuthService
|
|
||||||
broker *service.OAuthBrokerService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextMiddleware(
|
type ContextMiddleware struct {
|
||||||
log *logger.Logger,
|
config ContextMiddlewareConfig
|
||||||
runtime model.RuntimeConfig,
|
auth *service.AuthService
|
||||||
auth *service.AuthService,
|
broker *service.OAuthBrokerService
|
||||||
broker *service.OAuthBrokerService,
|
}
|
||||||
) *ContextMiddleware {
|
|
||||||
|
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
|
||||||
return &ContextMiddleware{
|
return &ContextMiddleware{
|
||||||
log: log,
|
config: config,
|
||||||
runtime: runtime,
|
auth: auth,
|
||||||
auth: auth,
|
broker: broker,
|
||||||
broker: broker,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ContextMiddleware) Init() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
|
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
|
||||||
@@ -63,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
uuid, err := c.Cookie(m.runtime.SessionCookieName)
|
uuid, err := c.Cookie(m.config.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||||
@@ -73,12 +75,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername())
|
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||||
c.Set("context", userContext)
|
c.Set("context", userContext)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
|
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +90,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
userContext, headers, err := m.basicAuth(username, password)
|
userContext, headers, err := m.basicAuth(username, password)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -139,7 +141,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userContext.Local.Attributes.Email == "" {
|
if userContext.Local.Attributes.Email == "" {
|
||||||
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain)
|
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
||||||
}
|
}
|
||||||
case model.ProviderLDAP:
|
case model.ProviderLDAP:
|
||||||
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
||||||
@@ -160,7 +162,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
|||||||
|
|
||||||
userContext.LDAP.Groups = user.Groups
|
userContext.LDAP.Groups = user.Groups
|
||||||
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
||||||
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain)
|
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
|
||||||
case model.ProviderOAuth:
|
case model.ProviderOAuth:
|
||||||
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
||||||
|
|
||||||
@@ -189,7 +191,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
|||||||
locked, remaining := m.auth.IsAccountLocked(username)
|
locked, remaining := m.auth.IsAccountLocked(username)
|
||||||
|
|
||||||
if locked {
|
if locked {
|
||||||
m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
||||||
headers["x-tinyauth-lock-locked"] = "true"
|
headers["x-tinyauth-lock-locked"] = "true"
|
||||||
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||||
return nil, headers, nil
|
return nil, headers, nil
|
||||||
@@ -222,7 +224,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
|||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
||||||
},
|
},
|
||||||
Attributes: user.Attributes,
|
Attributes: user.Attributes,
|
||||||
}
|
}
|
||||||
@@ -238,7 +240,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
|||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: utils.Capitalize(username),
|
Name: utils.Capitalize(username),
|
||||||
Email: utils.CompileUserEmail(username, m.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
|
||||||
},
|
},
|
||||||
Groups: user.Groups,
|
Groups: user.Groups,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,15 +17,36 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestContextMiddleware(t *testing.T) {
|
func TestContextMiddleware(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewareCfg := middleware.ContextMiddlewareConfig{
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
basicAuthHeader := func(username, password string) string {
|
basicAuthHeader := func(username, password string) string {
|
||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||||
@@ -249,20 +270,30 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
err = ldap.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
||||||
|
err = contextMiddleware.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
authService.ClearRateLimitsTestingOnly()
|
authService.ClearRateLimitsTestingOnly()
|
||||||
@@ -291,6 +322,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -18,25 +19,29 @@ type UIMiddleware struct {
|
|||||||
uiFileServer http.Handler
|
uiFileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUIMiddleware() (*UIMiddleware, error) {
|
func NewUIMiddleware() *UIMiddleware {
|
||||||
m := &UIMiddleware{}
|
return &UIMiddleware{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UIMiddleware) Init() error {
|
||||||
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load ui assets: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.uiFs = ui
|
m.uiFs = ui
|
||||||
m.uiFileServer = http.FileServerFS(ui)
|
m.uiFileServer = http.FileServerFS(ui)
|
||||||
|
|
||||||
return m, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
|
|
||||||
|
tlog.App.Debug().Str("path", path).Msg("path")
|
||||||
|
|
||||||
switch strings.SplitN(path, "/", 2)[0] {
|
switch strings.SplitN(path, "/", 2)[0] {
|
||||||
case "api", "resources", ".well-known":
|
case "api", "resources", ".well-known":
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// See context middleware for explanation of why we have to do this
|
// See context middleware for explanation of why we have to do this
|
||||||
@@ -17,14 +17,14 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZerologMiddleware struct {
|
type ZerologMiddleware struct{}
|
||||||
log *logger.Logger
|
|
||||||
|
func NewZerologMiddleware() *ZerologMiddleware {
|
||||||
|
return &ZerologMiddleware{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
|
func (m *ZerologMiddleware) Init() error {
|
||||||
return &ZerologMiddleware{
|
return nil
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ZerologMiddleware) logPath(path string) bool {
|
func (m *ZerologMiddleware) logPath(path string) bool {
|
||||||
@@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
latency := time.Since(tStart).String()
|
latency := time.Since(tStart).String()
|
||||||
|
|
||||||
subLogger := m.log.HTTP.With().Str("method", method).
|
subLogger := tlog.HTTP.With().Str("method", method).
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Str("address", address).
|
Str("address", address).
|
||||||
Str("client_ip", clientIP).
|
Str("client_ip", clientIP).
|
||||||
|
|||||||
@@ -14,12 +14,10 @@ func NewDefaultConfiguration() *Config {
|
|||||||
Path: "./resources",
|
Path: "./resources",
|
||||||
},
|
},
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Port: 3000,
|
Port: 3000,
|
||||||
Address: "0.0.0.0",
|
Address: "0.0.0.0",
|
||||||
ConcurrentListenersEnabled: false,
|
|
||||||
},
|
},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
SubdomainsEnabled: true,
|
|
||||||
SessionExpiry: 86400, // 1 day
|
SessionExpiry: 86400, // 1 day
|
||||||
SessionMaxLifetime: 0, // disabled
|
SessionMaxLifetime: 0, // disabled
|
||||||
LoginTimeout: 300, // 5 minutes
|
LoginTimeout: 300, // 5 minutes
|
||||||
@@ -96,16 +94,14 @@ type ResourcesConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||||
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"`
|
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"`
|
||||||
Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
|
Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
|
||||||
SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"`
|
|
||||||
UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"`
|
UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"`
|
||||||
UsersFile string `description:"Path to the users file." yaml:"usersFile"`
|
UsersFile string `description:"Path to the users file." yaml:"usersFile"`
|
||||||
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
|
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
|
||||||
@@ -149,10 +145,9 @@ type IPConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthConfig struct {
|
type OAuthConfig struct {
|
||||||
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
|
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
|
||||||
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
|
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
|
||||||
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
|
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
|
||||||
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConfig struct {
|
type OIDCConfig struct {
|
||||||
|
|||||||
@@ -8,10 +8,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrUserContextNotFound = errors.New("user context not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProviderType int
|
type ProviderType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -78,7 +74,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
userContextValue, exists := ginctx.Get("context")
|
userContextValue, exists := ginctx.Get("context")
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, ErrUserContextNotFound
|
return nil, errors.New("failed to get user context")
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext, ok := userContextValue.(*UserContext)
|
userContext, ok := userContextValue.(*UserContext)
|
||||||
@@ -121,7 +117,7 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
|||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// By default we assume an unknown name which is oauth
|
// By default we assume an unkown name which is oauth
|
||||||
default:
|
default:
|
||||||
c.Provider = ProviderOAuth
|
c.Provider = ProviderOAuth
|
||||||
c.OAuth = &OAuthContext{
|
c.OAuth = &OAuthContext{
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ func TestContext(t *testing.T) {
|
|||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: model.ErrUserContextNotFound.Error(),
|
expected: "failed to get user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
type RuntimeConfig struct {
|
|
||||||
AppURL string
|
|
||||||
UUID string
|
|
||||||
CookieDomain string
|
|
||||||
SessionCookieName string
|
|
||||||
CSRFCookieName string
|
|
||||||
RedirectCookieName string
|
|
||||||
OAuthSessionCookieName string
|
|
||||||
LocalUsers []LocalUser
|
|
||||||
OAuthProviders map[string]OAuthServiceConfig
|
|
||||||
OAuthWhitelist []string
|
|
||||||
ConfiguredProviders []Provider
|
|
||||||
OIDCClients []OIDCClientConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type Provider struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
OAuth bool `json:"oauth"`
|
|
||||||
}
|
|
||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider interface {
|
type LabelProvider interface {
|
||||||
@@ -12,33 +12,32 @@ type LabelProvider interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
log *logger.Logger
|
labelProvider LabelProvider
|
||||||
labelProvider *LabelProvider
|
|
||||||
static map[string]model.App
|
static map[string]model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(
|
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
|
||||||
log *logger.Logger,
|
|
||||||
labelProvider *LabelProvider,
|
|
||||||
static map[string]model.App) *AccessControlsService {
|
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
log: log,
|
|
||||||
labelProvider: labelProvider,
|
labelProvider: labelProvider,
|
||||||
static: static,
|
static: static,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (acls *AccessControlsService) Init() error {
|
||||||
|
return nil // No initialization needed
|
||||||
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||||
var appAcls *model.App
|
var appAcls *model.App
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
appAcls = &config
|
appAcls = &config
|
||||||
break // If we find a match by domain, we can stop searching
|
break // If we find a match by domain, we can stop searching
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
appAcls = &config
|
appAcls = &config
|
||||||
break // If we find a match by app name, we can stop searching
|
break // If we find a match by app name, we can stop searching
|
||||||
}
|
}
|
||||||
@@ -51,15 +50,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
|
|||||||
app := acls.lookupStaticACLs(domain)
|
app := acls.lookupStaticACLs(domain)
|
||||||
|
|
||||||
if app != nil {
|
if app != nil {
|
||||||
acls.log.App.Debug().Msg("Using static ACLs for app")
|
tlog.App.Debug().Msg("Using ACls from static configuration")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have a label provider configured, try to get ACLs from it
|
// Fallback to label provider
|
||||||
if acls.labelProvider != nil {
|
tlog.App.Debug().Msg("Falling back to label provider for ACLs")
|
||||||
return (*acls.labelProvider).GetLabels(domain)
|
return acls.labelProvider.GetLabels(domain)
|
||||||
}
|
|
||||||
|
|
||||||
// no labels
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -72,41 +72,38 @@ type Lockdown struct {
|
|||||||
ActiveUntil time.Time
|
ActiveUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthServiceConfig struct {
|
||||||
|
LocalUsers *[]model.LocalUser
|
||||||
|
OauthWhitelist []string
|
||||||
|
SessionExpiry int
|
||||||
|
SessionMaxLifetime int
|
||||||
|
SecureCookie bool
|
||||||
|
CookieDomain string
|
||||||
|
LoginTimeout int
|
||||||
|
LoginMaxRetries int
|
||||||
|
SessionCookieName string
|
||||||
|
IP model.IPConfig
|
||||||
|
LDAPGroupsCacheTTL int
|
||||||
|
}
|
||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
log *logger.Logger
|
config AuthServiceConfig
|
||||||
config model.Config
|
|
||||||
runtime model.RuntimeConfig
|
|
||||||
context context.Context
|
|
||||||
|
|
||||||
ldap *LdapService
|
|
||||||
queries *repository.Queries
|
|
||||||
oauthBroker *OAuthBrokerService
|
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
loginAttempts map[string]*LoginAttempt
|
||||||
ldapGroupsCache map[string]*LdapGroupsCache
|
ldapGroupsCache map[string]*LdapGroupsCache
|
||||||
oauthPendingSessions map[string]*OAuthPendingSession
|
oauthPendingSessions map[string]*OAuthPendingSession
|
||||||
oauthMutex sync.RWMutex
|
oauthMutex sync.RWMutex
|
||||||
loginMutex sync.RWMutex
|
loginMutex sync.RWMutex
|
||||||
ldapGroupsMutex sync.RWMutex
|
ldapGroupsMutex sync.RWMutex
|
||||||
|
ldap *LdapService
|
||||||
|
queries *repository.Queries
|
||||||
|
oauthBroker *OAuthBrokerService
|
||||||
lockdown *Lockdown
|
lockdown *Lockdown
|
||||||
lockdownCtx context.Context
|
lockdownCtx context.Context
|
||||||
lockdownCancelFunc context.CancelFunc
|
lockdownCancelFunc context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(
|
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||||
log *logger.Logger,
|
return &AuthService{
|
||||||
config model.Config,
|
|
||||||
runtime model.RuntimeConfig,
|
|
||||||
ctx context.Context,
|
|
||||||
wg *sync.WaitGroup,
|
|
||||||
ldap *LdapService,
|
|
||||||
queries *repository.Queries,
|
|
||||||
oauthBroker *OAuthBrokerService,
|
|
||||||
) *AuthService {
|
|
||||||
service := &AuthService{
|
|
||||||
log: log,
|
|
||||||
runtime: runtime,
|
|
||||||
context: ctx,
|
|
||||||
config: config,
|
config: config,
|
||||||
loginAttempts: make(map[string]*LoginAttempt),
|
loginAttempts: make(map[string]*LoginAttempt),
|
||||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||||
@@ -115,10 +112,11 @@ func NewAuthService(
|
|||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
wg.Go(service.CleanupOAuthSessionsRoutine)
|
func (auth *AuthService) Init() error {
|
||||||
|
go auth.CleanupOAuthSessionsRoutine()
|
||||||
return service
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||||
@@ -129,7 +127,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.ldap != nil {
|
if auth.ldap.IsConfigured() {
|
||||||
userDN, err := auth.ldap.GetUserDN(username)
|
userDN, err := auth.ldap.GetUserDN(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -154,7 +152,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
}
|
}
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
case model.UserLDAP:
|
case model.UserLDAP:
|
||||||
if auth.ldap != nil {
|
if auth.ldap.IsConfigured() {
|
||||||
err := auth.ldap.Bind(search.Username, password)
|
err := auth.ldap.Bind(search.Username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
||||||
@@ -174,10 +172,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||||
if auth.runtime.LocalUsers == nil {
|
if auth.config.LocalUsers == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, user := range auth.runtime.LocalUsers {
|
for _, user := range *auth.config.LocalUsers {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
@@ -186,7 +184,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
||||||
if auth.ldap == nil {
|
if !auth.ldap.IsConfigured() {
|
||||||
return nil, errors.New("ldap service not configured")
|
return nil, errors.New("ldap service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,7 +208,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
auth.ldapGroupsMutex.Lock()
|
auth.ldapGroupsMutex.Lock()
|
||||||
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
|
||||||
}
|
}
|
||||||
auth.ldapGroupsMutex.Unlock()
|
auth.ldapGroupsMutex.Unlock()
|
||||||
|
|
||||||
@@ -229,7 +227,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
return true, remaining
|
return true, remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,7 +245,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,14 +276,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
|
|
||||||
attempt.FailedAttempts++
|
attempt.FailedAttempts++
|
||||||
|
|
||||||
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
|
||||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
|
||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
@@ -300,7 +298,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
if data.TotpPending {
|
if data.TotpPending {
|
||||||
expiry = 3600
|
expiry = 3600
|
||||||
} else {
|
} else {
|
||||||
expiry = auth.config.Auth.SessionExpiry
|
expiry = auth.config.SessionExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||||
@@ -326,13 +324,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.config.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: expiresAt,
|
Expires: expiresAt,
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -349,8 +347,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
|
|
||||||
var refreshThreshold int64
|
var refreshThreshold int64
|
||||||
|
|
||||||
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
|
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||||
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
|
refreshThreshold = int64(auth.config.SessionExpiry / 2)
|
||||||
} else {
|
} else {
|
||||||
refreshThreshold = int64(time.Hour.Seconds())
|
refreshThreshold = int64(time.Hour.Seconds())
|
||||||
}
|
}
|
||||||
@@ -379,13 +377,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.config.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
MaxAge: int(newExpiry - currentTime),
|
MaxAge: int(newExpiry - currentTime),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -396,17 +394,17 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
|||||||
err := auth.queries.DeleteSession(ctx, uuid)
|
err := auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.config.SessionCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: time.Now(),
|
Expires: time.Now(),
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -424,8 +422,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||||
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
|
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
||||||
err = auth.queries.DeleteSession(ctx, uuid)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
@@ -446,11 +444,11 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||||
return len(auth.runtime.LocalUsers) > 0
|
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||||
return auth.ldap != nil
|
return auth.ldap.IsConfigured()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
@@ -459,18 +457,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
if context.Provider == model.ProviderOAuth {
|
if context.Provider == model.ProviderOAuth {
|
||||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
||||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if acls.Users.Block != "" {
|
if acls.Users.Block != "" {
|
||||||
auth.log.App.Debug().Msg("Checking users block list")
|
tlog.App.Debug().Msg("Checking blocked users")
|
||||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("Checking users allow list")
|
tlog.App.Debug().Msg("Checking users")
|
||||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -480,23 +478,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsOAuth() {
|
if !context.IsOAuth() {
|
||||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.OAuth.Groups {
|
for _, userGroup := range context.OAuth.Groups {
|
||||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("No groups matched")
|
tlog.App.Debug().Msg("No groups matched")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,18 +504,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsLDAP() {
|
if !context.IsLDAP() {
|
||||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.LDAP.Groups {
|
for _, userGroup := range context.LDAP.Groups {
|
||||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("No groups matched")
|
tlog.App.Debug().Msg("No groups matched")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -561,17 +559,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Merge the global and app IP filter
|
// Merge the global and app IP filter
|
||||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
||||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
||||||
|
|
||||||
for _, blocked := range blockedIps {
|
for _, blocked := range blockedIps {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -579,21 +577,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|||||||
for _, allowed := range allowedIPs {
|
for _, allowed := range allowedIPs {
|
||||||
res, err := utils.FilterIP(allowed, ip)
|
res, err := utils.FilterIP(allowed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(allowedIPs) > 0 {
|
if len(allowedIPs) > 0 {
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -605,16 +603,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
|||||||
for _, bypassed := range acls.IP.Bypass {
|
for _, bypassed := range acls.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -718,32 +716,21 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||||
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
|
||||||
|
|
||||||
ticker := time.NewTicker(30 * time.Minute)
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
auth.oauthMutex.Lock()
|
||||||
case <-ticker.C:
|
|
||||||
auth.log.App.Debug().Msg("Running OAuth session cleanup")
|
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
now := time.Now()
|
||||||
|
|
||||||
now := time.Now()
|
for sessionId, session := range auth.oauthPendingSessions {
|
||||||
|
if now.After(session.ExpiresAt) {
|
||||||
for sessionId, session := range auth.oauthPendingSessions {
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
if now.After(session.ExpiresAt) {
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
|
||||||
case <-auth.context.Done():
|
|
||||||
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -812,11 +799,11 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
||||||
|
|
||||||
auth.lockdown = &Lockdown{
|
auth.lockdown = &Lockdown{
|
||||||
Active: true,
|
Active: true,
|
||||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point all login attemps will also expire so,
|
// At this point all login attemps will also expire so,
|
||||||
@@ -833,14 +820,11 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.context.Done():
|
|
||||||
// Service is shutting down, end lockdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
||||||
|
|
||||||
auth.lockdown = nil
|
auth.lockdown = nil
|
||||||
auth.loginMutex.Unlock()
|
auth.loginMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,56 +3,51 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
container "github.com/docker/docker/api/types/container"
|
container "github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DockerService struct {
|
type DockerService struct {
|
||||||
log *logger.Logger
|
client *client.Client
|
||||||
client *client.Client
|
context context.Context
|
||||||
context context.Context
|
|
||||||
|
|
||||||
isConnected bool
|
isConnected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDockerService(
|
func NewDockerService() *DockerService {
|
||||||
log *logger.Logger,
|
return &DockerService{}
|
||||||
ctx context.Context,
|
}
|
||||||
wg *sync.WaitGroup,
|
|
||||||
) (*DockerService, error) {
|
|
||||||
|
|
||||||
|
func (docker *DockerService) Init() error {
|
||||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
client.NegotiateAPIVersion(ctx)
|
client.NegotiateAPIVersion(ctx)
|
||||||
|
|
||||||
_, err = client.Ping(ctx)
|
docker.client = client
|
||||||
|
docker.context = ctx
|
||||||
|
|
||||||
|
_, err = docker.client.Ping(docker.context)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.App.Debug().Err(err).Msg("Docker not connected")
|
tlog.App.Debug().Err(err).Msg("Docker not connected")
|
||||||
return nil, nil
|
docker.isConnected = false
|
||||||
|
docker.client = nil
|
||||||
|
docker.context = nil
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
service := &DockerService{
|
docker.isConnected = true
|
||||||
log: log,
|
tlog.App.Debug().Msg("Docker connected")
|
||||||
client: client,
|
|
||||||
context: ctx,
|
|
||||||
}
|
|
||||||
|
|
||||||
service.isConnected = true
|
return nil
|
||||||
service.log.App.Debug().Msg("Docker connected successfully")
|
|
||||||
|
|
||||||
wg.Go(service.watchAndClose)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
||||||
@@ -65,7 +60,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
|
|||||||
|
|
||||||
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !docker.isConnected {
|
if !docker.isConnected {
|
||||||
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
|
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,28 +82,17 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
|||||||
|
|
||||||
for appName, appLabels := range labels.Apps {
|
for appName, appLabels := range labels.Apps {
|
||||||
if appLabels.Config.Domain == appDomain {
|
if appLabels.Config.Domain == appDomain {
|
||||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
|
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) watchAndClose() {
|
|
||||||
<-docker.context.Done()
|
|
||||||
docker.log.App.Debug().Msg("Closing Docker client")
|
|
||||||
if docker.client != nil {
|
|
||||||
err := docker.client.Close()
|
|
||||||
if err != nil {
|
|
||||||
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||||
@@ -36,10 +36,9 @@ type ingressApp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
log *logger.Logger
|
|
||||||
ctx context.Context
|
|
||||||
|
|
||||||
client dynamic.Interface
|
client dynamic.Interface
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
started bool
|
started bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ingressApps map[ingressKey][]ingressApp
|
ingressApps map[ingressKey][]ingressApp
|
||||||
@@ -47,55 +46,12 @@ type KubernetesService struct {
|
|||||||
appNameIndex map[string]ingressAppKey
|
appNameIndex map[string]ingressAppKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKubernetesService(
|
func NewKubernetesService() *KubernetesService {
|
||||||
log *logger.Logger,
|
return &KubernetesService{
|
||||||
ctx context.Context,
|
|
||||||
wg *sync.WaitGroup,
|
|
||||||
) (*KubernetesService, error) {
|
|
||||||
cfg, err := rest.InClusterConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := dynamic.NewForConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
gvr := schema.GroupVersionResource{
|
|
||||||
Group: "networking.k8s.io",
|
|
||||||
Version: "v1",
|
|
||||||
Resource: "ingresses",
|
|
||||||
}
|
|
||||||
|
|
||||||
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
||||||
defer accessCancel()
|
|
||||||
|
|
||||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
|
||||||
if err != nil {
|
|
||||||
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
|
||||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
|
||||||
|
|
||||||
service := &KubernetesService{
|
|
||||||
log: log,
|
|
||||||
ctx: ctx,
|
|
||||||
client: client,
|
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Go(func() {
|
|
||||||
service.watchGVR(gvr)
|
|
||||||
})
|
|
||||||
|
|
||||||
service.started = true
|
|
||||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
|
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
|
||||||
@@ -177,7 +133,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
|||||||
}
|
}
|
||||||
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
|
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -205,13 +161,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
|
|||||||
|
|
||||||
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
|
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range list.Items {
|
for i := range list.Items {
|
||||||
k.updateFromItem(&list.Items[i])
|
k.updateFromItem(&list.Items[i])
|
||||||
}
|
}
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,14 +181,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
return false
|
return false
|
||||||
case event, ok := <-w.ResultChan():
|
case event, ok := <-w.ResultChan():
|
||||||
if !ok {
|
if !ok {
|
||||||
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
|
||||||
w.Stop()
|
w.Stop()
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
item, ok := event.Object.(*unstructured.Unstructured)
|
item, ok := event.Object.(*unstructured.Unstructured)
|
||||||
if !ok {
|
if !ok {
|
||||||
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
|
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
@@ -243,7 +199,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
}
|
}
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
|
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -254,29 +210,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
|||||||
defer resyncTicker.Stop()
|
defer resyncTicker.Stop()
|
||||||
|
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
|
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
|
||||||
time.Sleep(30 * time.Second)
|
time.Sleep(30 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-k.ctx.Done():
|
case <-k.ctx.Done():
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
|
||||||
return
|
return
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
|
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
ctx, cancel := context.WithCancel(k.ctx)
|
ctx, cancel := context.WithCancel(k.ctx)
|
||||||
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
|
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
|
||||||
cancel()
|
cancel()
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
|
||||||
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
||||||
cancel()
|
cancel()
|
||||||
return
|
return
|
||||||
@@ -286,25 +242,65 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *KubernetesService) Init() error {
|
||||||
|
var cfg *rest.Config
|
||||||
|
var err error
|
||||||
|
|
||||||
|
cfg, err = rest.InClusterConfig()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dynamic.NewForConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create Kubernetes client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
k.client = client
|
||||||
|
k.ctx, k.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
gvr := schema.GroupVersionResource{
|
||||||
|
Group: "networking.k8s.io",
|
||||||
|
Version: "v1",
|
||||||
|
Resource: "ingresses",
|
||||||
|
}
|
||||||
|
|
||||||
|
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
|
||||||
|
defer accessCancel()
|
||||||
|
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
|
||||||
|
k.started = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
|
||||||
|
go k.watchGVR(gvr)
|
||||||
|
|
||||||
|
k.started = true
|
||||||
|
tlog.App.Info().Msg("Kubernetes label provider initialized")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !k.started {
|
if !k.started {
|
||||||
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
|
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// First check cache
|
// First check cache
|
||||||
app := k.getByDomain(appDomain)
|
app := k.getByDomain(appDomain)
|
||||||
if app != nil {
|
if app != nil {
|
||||||
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
appName := strings.SplitN(appDomain, ".", 2)[0]
|
appName := strings.SplitN(appDomain, ".", 2)[0]
|
||||||
app = k.getByAppName(appName)
|
app = k.getByAppName(appName)
|
||||||
if app != nil {
|
if app != nil {
|
||||||
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,9 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestKubernetesService(t *testing.T) {
|
func TestKubernetesService(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
|
||||||
log.Init()
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
run func(t *testing.T, svc *KubernetesService)
|
run func(t *testing.T, svc *KubernetesService)
|
||||||
@@ -183,7 +179,6 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
log: log,
|
|
||||||
}
|
}
|
||||||
test.run(t, svc)
|
test.run(t, svc)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,47 +9,69 @@ import (
|
|||||||
|
|
||||||
"github.com/cenkalti/backoff/v5"
|
"github.com/cenkalti/backoff/v5"
|
||||||
ldapgo "github.com/go-ldap/ldap/v3"
|
ldapgo "github.com/go-ldap/ldap/v3"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LdapService struct {
|
type LdapServiceConfig struct {
|
||||||
log *logger.Logger
|
Address string
|
||||||
config model.Config
|
BindDN string
|
||||||
context context.Context
|
BindPassword string
|
||||||
|
BaseDN string
|
||||||
conn *ldapgo.Conn
|
Insecure bool
|
||||||
mutex sync.RWMutex
|
SearchFilter string
|
||||||
cert *tls.Certificate
|
AuthCert string
|
||||||
|
AuthKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLdapService(
|
type LdapService struct {
|
||||||
log *logger.Logger,
|
config LdapServiceConfig
|
||||||
config model.Config,
|
conn *ldapgo.Conn
|
||||||
ctx context.Context,
|
mutex sync.RWMutex
|
||||||
wg *sync.WaitGroup,
|
cert *tls.Certificate
|
||||||
) (*LdapService, error) {
|
isConfigured bool
|
||||||
if config.LDAP.Address == "" {
|
}
|
||||||
return nil, nil
|
|
||||||
|
func NewLdapService(config LdapServiceConfig) *LdapService {
|
||||||
|
return &LdapService{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) IsConfigured() bool {
|
||||||
|
return ldap.isConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) Unconfigure() error {
|
||||||
|
if !ldap.isConfigured {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ldap := &LdapService{
|
if ldap.conn != nil {
|
||||||
log: log,
|
if err := ldap.conn.Close(); err != nil {
|
||||||
config: config,
|
return fmt.Errorf("failed to close LDAP connection: %w", err)
|
||||||
context: ctx,
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ldap.isConfigured = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) Init() error {
|
||||||
|
if ldap.config.Address == "" {
|
||||||
|
ldap.isConfigured = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ldap.isConfigured = true
|
||||||
|
|
||||||
// Check whether authentication with client certificate is possible
|
// Check whether authentication with client certificate is possible
|
||||||
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
|
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
|
||||||
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
|
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
|
||||||
|
|
||||||
ldap.cert = &cert
|
ldap.cert = &cert
|
||||||
|
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
|
||||||
|
|
||||||
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
|
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
|
||||||
/*
|
/*
|
||||||
@@ -62,39 +84,26 @@ func NewLdapService(
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := ldap.connect()
|
_, err := ldap.connect()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Go(func() {
|
go func() {
|
||||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
for range time.Tick(time.Duration(5) * time.Minute) {
|
||||||
|
err := ldap.heartbeat()
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
if err != nil {
|
||||||
defer ticker.Stop()
|
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
|
||||||
|
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
||||||
for {
|
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
||||||
select {
|
continue
|
||||||
case <-ticker.C:
|
|
||||||
err := ldap.heartbeat()
|
|
||||||
if err != nil {
|
|
||||||
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
|
|
||||||
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
|
||||||
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
|
|
||||||
}
|
}
|
||||||
case <-ldap.context.Done():
|
tlog.App.Info().Msg("Successfully reconnected to LDAP server")
|
||||||
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
}()
|
||||||
|
|
||||||
return ldap, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
||||||
@@ -111,13 +120,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
// 2. conn.StartTLS(tlsConfig)
|
// 2. conn.StartTLS(tlsConfig)
|
||||||
// 3. conn.externalBind()
|
// 3. conn.externalBind()
|
||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
Certificates: []tls.Certificate{*ldap.cert},
|
Certificates: []tls.Certificate{*ldap.cert},
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
InsecureSkipVerify: ldap.config.LDAP.Insecure,
|
InsecureSkipVerify: ldap.config.Insecure,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -137,10 +146,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
||||||
// Escape the username to prevent LDAP injection
|
// Escape the username to prevent LDAP injection
|
||||||
escapedUsername := ldapgo.EscapeFilter(username)
|
escapedUsername := ldapgo.EscapeFilter(username)
|
||||||
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
|
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
ldap.config.LDAP.BaseDN,
|
ldap.config.BaseDN,
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
filter,
|
filter,
|
||||||
[]string{"dn"},
|
[]string{"dn"},
|
||||||
@@ -167,7 +176,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
|||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
ldap.config.LDAP.BaseDN,
|
ldap.config.BaseDN,
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
|
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
|
||||||
[]string{"dn"},
|
[]string{"dn"},
|
||||||
@@ -215,7 +224,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
|||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
return ldap.conn.ExternalBind()
|
return ldap.conn.ExternalBind()
|
||||||
}
|
}
|
||||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||||
@@ -229,7 +238,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) heartbeat() error {
|
func (ldap *LdapService) heartbeat() error {
|
||||||
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
|
tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
"",
|
"",
|
||||||
@@ -251,7 +260,7 @@ func (ldap *LdapService) heartbeat() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) reconnect() error {
|
func (ldap *LdapService) reconnect() error {
|
||||||
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
|
tlog.App.Info().Msg("Reconnecting to LDAP server")
|
||||||
|
|
||||||
exp := backoff.NewExponentialBackOff()
|
exp := backoff.NewExponentialBackOff()
|
||||||
exp.InitialInterval = 500 * time.Millisecond
|
exp.InitialInterval = 500 * time.Millisecond
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -21,39 +19,33 @@ type OAuthServiceImpl interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerService struct {
|
type OAuthBrokerService struct {
|
||||||
log *logger.Logger
|
|
||||||
|
|
||||||
services map[string]OAuthServiceImpl
|
services map[string]OAuthServiceImpl
|
||||||
configs map[string]model.OAuthServiceConfig
|
configs map[string]model.OAuthServiceConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
|
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
|
||||||
"github": newGitHubOAuthService,
|
"github": newGitHubOAuthService,
|
||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthBrokerService(
|
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
|
||||||
log *logger.Logger,
|
return &OAuthBrokerService{
|
||||||
configs map[string]model.OAuthServiceConfig,
|
|
||||||
ctx context.Context,
|
|
||||||
) *OAuthBrokerService {
|
|
||||||
service := &OAuthBrokerService{
|
|
||||||
log: log,
|
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]OAuthServiceImpl),
|
||||||
configs: configs,
|
configs: configs,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for name, cfg := range configs {
|
func (broker *OAuthBrokerService) Init() error {
|
||||||
|
for name, cfg := range broker.configs {
|
||||||
if presetFunc, exists := presets[name]; exists {
|
if presetFunc, exists := presets[name]; exists {
|
||||||
service.services[name] = presetFunc(cfg, ctx)
|
broker.services[name] = presetFunc(cfg)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||||
} else {
|
} else {
|
||||||
service.services[name] = NewOAuthService(cfg, name, ctx)
|
broker.services[name] = NewOAuthService(cfg, name)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return service
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
||||||
|
|||||||
@@ -1,25 +1,23 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"golang.org/x/oauth2/endpoints"
|
"golang.org/x/oauth2/endpoints"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
|
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"openid", "email", "profile"}
|
scopes := []string{"openid", "email", "profile"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.Google.AuthURL
|
config.AuthURL = endpoints.Google.AuthURL
|
||||||
config.TokenURL = endpoints.Google.TokenURL
|
config.TokenURL = endpoints.Google.TokenURL
|
||||||
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
|
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
|
||||||
return NewOAuthService(config, "google", ctx)
|
return NewOAuthService(config, "google")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
|
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"read:user", "user:email"}
|
scopes := []string{"read:user", "user:email"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.GitHub.AuthURL
|
config.AuthURL = endpoints.GitHub.AuthURL
|
||||||
config.TokenURL = endpoints.GitHub.TokenURL
|
config.TokenURL = endpoints.GitHub.TokenURL
|
||||||
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
|
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type OAuthService struct {
|
|||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
|
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
@@ -29,7 +29,8 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
ctx := context.Background()
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
return &OAuthService{
|
return &OAuthService{
|
||||||
serviceCfg: config,
|
serviceCfg: config,
|
||||||
@@ -43,7 +44,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
TokenURL: config.TokenURL,
|
TokenURL: config.TokenURL,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: vctx,
|
ctx: ctx,
|
||||||
userinfoExtractor: defaultExtractor,
|
userinfoExtractor: defaultExtractor,
|
||||||
id: id,
|
id: id,
|
||||||
}
|
}
|
||||||
|
|||||||
+125
-133
@@ -16,7 +16,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
@@ -26,7 +25,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -112,173 +111,172 @@ type AuthorizeRequest struct {
|
|||||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCServiceConfig struct {
|
||||||
log *logger.Logger
|
Clients map[string]model.OIDCClientConfig
|
||||||
config model.Config
|
PrivateKeyPath string
|
||||||
runtime model.RuntimeConfig
|
PublicKeyPath string
|
||||||
queries *repository.Queries
|
Issuer string
|
||||||
context context.Context
|
SessionExpiry int
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
|
||||||
privateKey *rsa.PrivateKey
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
issuer string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCService(
|
type OIDCService struct {
|
||||||
log *logger.Logger,
|
config OIDCServiceConfig
|
||||||
config model.Config,
|
queries *repository.Queries
|
||||||
runtime model.RuntimeConfig,
|
clients map[string]model.OIDCClientConfig
|
||||||
queries *repository.Queries,
|
privateKey *rsa.PrivateKey
|
||||||
ctx context.Context,
|
publicKey crypto.PublicKey
|
||||||
wg *sync.WaitGroup) (*OIDCService, error) {
|
issuer string
|
||||||
|
isConfigured bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
|
||||||
|
return &OIDCService{
|
||||||
|
config: config,
|
||||||
|
queries: queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) IsConfigured() bool {
|
||||||
|
return service.isConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) Init() error {
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
if len(runtime.OIDCClients) == 0 {
|
if len(service.config.Clients) == 0 {
|
||||||
return nil, nil
|
service.isConfigured = false
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
service.isConfigured = true
|
||||||
|
|
||||||
// Ensure issuer is https
|
// Ensure issuer is https
|
||||||
uissuer, err := url.Parse(runtime.AppURL)
|
uissuer, err := url.Parse(service.config.Issuer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if uissuer.Scheme != "https" {
|
if uissuer.Scheme != "https" {
|
||||||
return nil, errors.New("issuer must be https")
|
return errors.New("issuer must be https")
|
||||||
}
|
}
|
||||||
|
|
||||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||||
|
|
||||||
// Create/load private and public keys
|
// Create/load private and public keys
|
||||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
|
||||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
strings.TrimSpace(service.config.PublicKeyPath) == "" {
|
||||||
return nil, errors.New("private key path and public key path are required")
|
return errors.New("private key path and public key path are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var privateKey *rsa.PrivateKey
|
var privateKey *rsa.PrivateKey
|
||||||
|
|
||||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||||
if der == nil {
|
if der == nil {
|
||||||
return nil, errors.New("failed to marshal private key")
|
return errors.New("failed to marshal private key")
|
||||||
}
|
}
|
||||||
encoded := pem.EncodeToMemory(&pem.Block{
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.privateKey = privateKey
|
||||||
} else {
|
} else {
|
||||||
block, _ := pem.Decode(fprivateKey)
|
block, _ := pem.Decode(fprivateKey)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode private key")
|
return errors.New("failed to decode private key")
|
||||||
}
|
}
|
||||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.privateKey = privateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
var publicKey crypto.PublicKey
|
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
|
||||||
|
|
||||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
publicKey = privateKey.Public()
|
publicKey := service.privateKey.Public()
|
||||||
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
||||||
if der == nil {
|
if der == nil {
|
||||||
return nil, errors.New("failed to marshal public key")
|
return errors.New("failed to marshal public key")
|
||||||
}
|
}
|
||||||
encoded := pem.EncodeToMemory(&pem.Block{
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
service.publicKey = publicKey
|
||||||
} else {
|
} else {
|
||||||
block, _ := pem.Decode(fpublicKey)
|
block, _ := pem.Decode(fpublicKey)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode public key")
|
return errors.New("failed to decode public key")
|
||||||
}
|
}
|
||||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||||
switch block.Type {
|
switch block.Type {
|
||||||
case "RSA PUBLIC KEY":
|
case "RSA PUBLIC KEY":
|
||||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.publicKey = publicKey
|
||||||
case "PUBLIC KEY":
|
case "PUBLIC KEY":
|
||||||
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
|
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.publicKey = publicKey.(crypto.PublicKey)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
|
return fmt.Errorf("unsupported public key type: %s", block.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
clients := make(map[string]model.OIDCClientConfig)
|
service.clients = make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range config.OIDC.Clients {
|
for id, client := range service.config.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
if client.Name == "" {
|
if client.Name == "" {
|
||||||
client.Name = utils.Capitalize(client.ID)
|
client.Name = utils.Capitalize(client.ID)
|
||||||
}
|
}
|
||||||
clients[client.ClientID] = client
|
service.clients[client.ClientID] = client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the client secrets from files if they exist
|
// Load the client secrets from files if they exist
|
||||||
for id, client := range clients {
|
for id, client := range service.clients {
|
||||||
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
||||||
if secret != "" {
|
if secret != "" {
|
||||||
client.ClientSecret = secret
|
client.ClientSecret = secret
|
||||||
}
|
}
|
||||||
client.ClientSecretFile = ""
|
client.ClientSecretFile = ""
|
||||||
clients[id] = client
|
service.clients[id] = client
|
||||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the service
|
return nil
|
||||||
service := &OIDCService{
|
|
||||||
log: log,
|
|
||||||
config: config,
|
|
||||||
runtime: runtime,
|
|
||||||
queries: queries,
|
|
||||||
context: ctx,
|
|
||||||
|
|
||||||
clients: clients,
|
|
||||||
privateKey: privateKey,
|
|
||||||
publicKey: publicKey,
|
|
||||||
issuer: issuer,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start cleanup routine
|
|
||||||
wg.Go(service.cleanupRoutine)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetIssuer() string {
|
func (service *OIDCService) GetIssuer() string {
|
||||||
@@ -309,7 +307,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
|||||||
return errors.New("invalid_scope")
|
return errors.New("invalid_scope")
|
||||||
}
|
}
|
||||||
if !slices.Contains(SupportedScopes, scope) {
|
if !slices.Contains(SupportedScopes, scope) {
|
||||||
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
|
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,7 +357,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
entry.CodeChallenge = req.CodeChallenge
|
entry.CodeChallenge = req.CodeChallenge
|
||||||
} else {
|
} else {
|
||||||
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
|
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
|
||||||
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
|
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,7 +449,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
|||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
|
|
||||||
@@ -531,16 +529,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
|
|||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
refreshToken := utils.GenerateString(32)
|
refreshToken := utils.GenerateString(32)
|
||||||
|
|
||||||
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
||||||
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
|
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(service.config.Auth.SessionExpiry),
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -600,14 +598,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
|||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
newRefreshToken := utils.GenerateString(32)
|
newRefreshToken := utils.GenerateString(32)
|
||||||
|
|
||||||
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
|
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: newRefreshToken,
|
RefreshToken: newRefreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(service.config.Auth.SessionExpiry),
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -750,62 +748,56 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup routine - Resource heavy due to the linked tables
|
// Cleanup routine - Resource heavy due to the linked tables
|
||||||
func (service *OIDCService) cleanupRoutine() {
|
func (service *OIDCService) Cleanup() {
|
||||||
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
// We need a context for the routine
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
currentTime := time.Now().Unix()
|
||||||
case <-ticker.C:
|
|
||||||
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||||
|
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
|
TokenExpiresAt: currentTime,
|
||||||
|
RefreshTokenExpiresAt: currentTime,
|
||||||
|
})
|
||||||
|
|
||||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
if err != nil {
|
||||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
|
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||||
TokenExpiresAt: currentTime,
|
}
|
||||||
RefreshTokenExpiresAt: currentTime,
|
|
||||||
})
|
for _, expiredToken := range expiredTokens {
|
||||||
|
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||||
|
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredCode := range expiredCodes {
|
||||||
|
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
}
|
|
||||||
|
|
||||||
for _, expiredToken := range expiredTokens {
|
|
||||||
err := service.DeleteOldSession(service.context, expiredToken.Sub)
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
|
||||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, expiredCode := range expiredCodes {
|
|
||||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
||||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
|
||||||
err := service.DeleteOldSession(service.context, expiredCode.Sub)
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
case <-service.context.Done():
|
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||||
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
if err != nil {
|
||||||
return
|
tlog.App.Warn().Err(err).Msg("Failed to delete session")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package service_test
|
package service_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -12,7 +10,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() repository.OidcUserinfo {
|
func newTestUser() repository.OidcUserinfo {
|
||||||
@@ -51,29 +48,13 @@ func newTestUser() repository.OidcUserinfo {
|
|||||||
|
|
||||||
func TestCompileUserinfo(t *testing.T) {
|
func TestCompileUserinfo(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
svc := service.NewOIDCService(service.OIDCServiceConfig{
|
||||||
cfg := model.Config{
|
PrivateKeyPath: dir + "/key.pem",
|
||||||
OIDC: model.OIDCConfig{
|
PublicKeyPath: dir + "/key.pub",
|
||||||
PrivateKeyPath: dir + "/key.pem",
|
Issuer: "https://tinyauth.example.com",
|
||||||
PublicKeyPath: dir + "/key.pub",
|
SessionExpiry: 3600,
|
||||||
},
|
}, nil)
|
||||||
Auth: model.AuthConfig{
|
require.NoError(t, svc.Init())
|
||||||
SessionExpiry: 3600,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
runtime := model.RuntimeConfig{
|
|
||||||
AppURL: "https://tinyauth.example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
log := logger.NewLogger().WithTestConfig()
|
|
||||||
log.Init()
|
|
||||||
|
|
||||||
ctx := context.TODO()
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
|
||||||
|
|
||||||
func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
config := model.Config{
|
|
||||||
UI: model.UIConfig{
|
|
||||||
Title: "Tinyauth Test",
|
|
||||||
ForgotPasswordMessage: "foo",
|
|
||||||
BackgroundImage: "/background.jpg",
|
|
||||||
WarningsEnabled: true,
|
|
||||||
},
|
|
||||||
OAuth: model.OAuthConfig{
|
|
||||||
AutoRedirect: "none",
|
|
||||||
},
|
|
||||||
OIDC: model.OIDCConfig{
|
|
||||||
Clients: map[string]model.OIDCClientConfig{
|
|
||||||
"test": {
|
|
||||||
ClientID: "some-client-id",
|
|
||||||
ClientSecret: "some-client-secret",
|
|
||||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
|
||||||
Name: "Test Client",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PrivateKeyPath: filepath.Join(tempDir, "key.pem"),
|
|
||||||
PublicKeyPath: filepath.Join(tempDir, "key.pub"),
|
|
||||||
},
|
|
||||||
Auth: model.AuthConfig{
|
|
||||||
SessionExpiry: 10,
|
|
||||||
LoginTimeout: 10,
|
|
||||||
LoginMaxRetries: 3,
|
|
||||||
},
|
|
||||||
Database: model.DatabaseConfig{
|
|
||||||
Path: filepath.Join(tempDir, "test.db"),
|
|
||||||
},
|
|
||||||
Resources: model.ResourcesConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Path: filepath.Join(tempDir, "resources"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
runtime := model.RuntimeConfig{
|
|
||||||
ConfiguredProviders: []model.Provider{
|
|
||||||
{
|
|
||||||
Name: "Local",
|
|
||||||
ID: "local",
|
|
||||||
OAuth: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LocalUsers: []model.LocalUser{
|
|
||||||
{
|
|
||||||
Username: "testuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "totpuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
TOTPSecret: TestingTOTPSecret,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "attruser",
|
|
||||||
Password: string(passwd),
|
|
||||||
Attributes: model.UserAttributes{
|
|
||||||
Name: "Alice Smith",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "attrtotpuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
TOTPSecret: TestingTOTPSecret,
|
|
||||||
Attributes: model.UserAttributes{
|
|
||||||
Name: "Bob Jones",
|
|
||||||
Email: "bob@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CookieDomain: "example.com",
|
|
||||||
AppURL: "https://tinyauth.example.com",
|
|
||||||
SessionCookieName: "tinyauth-session",
|
|
||||||
OIDCClients: func() []model.OIDCClientConfig {
|
|
||||||
var clients []model.OIDCClientConfig
|
|
||||||
for id, client := range config.OIDC.Clients {
|
|
||||||
client.ID = id
|
|
||||||
clients = append(clients, client)
|
|
||||||
}
|
|
||||||
return clients
|
|
||||||
}(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, runtime
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,12 +22,13 @@ func GetCookieDomain(u string) (string, error) {
|
|||||||
host := parsed.Hostname()
|
host := parsed.Hostname()
|
||||||
|
|
||||||
if netIP := net.ParseIP(host); netIP != nil {
|
if netIP := net.ParseIP(host); netIP != nil {
|
||||||
return "", errors.New("ip addresses not allowed")
|
return "", errors.New("IP addresses not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
parts := strings.Split(host, ".")
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
|
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
|
||||||
return host, nil
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,27 +47,6 @@ func GetCookieDomain(u string) (string, error) {
|
|||||||
return domain, nil
|
return domain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetStandaloneCookieDomain(u string) (string, error) {
|
|
||||||
parsed, err := url.Parse(u)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
host := parsed.Hostname()
|
|
||||||
|
|
||||||
if netIP := net.ParseIP(host); netIP != nil {
|
|
||||||
return "", errors.New("ip addresses not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
|
||||||
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return "", errors.New("invalid app url")
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseFileToLine(content string) string {
|
func ParseFileToLine(content string) string {
|
||||||
lines := strings.Split(content, "\n")
|
lines := strings.Split(content, "\n")
|
||||||
users := make([]string, 0)
|
users := make([]string, 0)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func TestGetRootDomain(t *testing.T) {
|
|||||||
// IP address
|
// IP address
|
||||||
domain = "http://10.10.10.10"
|
domain = "http://10.10.10.10"
|
||||||
_, err = utils.GetCookieDomain(domain)
|
_, err = utils.GetCookieDomain(domain)
|
||||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
assert.ErrorContains(t, err, "IP addresses not allowed")
|
||||||
|
|
||||||
// Invalid URL
|
// Invalid URL
|
||||||
domain = "http://[::1]:namedport"
|
domain = "http://[::1]:namedport"
|
||||||
@@ -180,48 +180,3 @@ func TestIsRedirectSafe(t *testing.T) {
|
|||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
|
||||||
// Normal case
|
|
||||||
domain := "http://tinyauth.app"
|
|
||||||
expected := "tinyauth.app"
|
|
||||||
result, err := utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with subdomain (full hostname is returned, no subdomain stripping)
|
|
||||||
domain = "http://sub.tinyauth.app"
|
|
||||||
expected = "sub.tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with port (port should be stripped)
|
|
||||||
domain = "http://tinyauth.app:8080"
|
|
||||||
expected = "tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with path
|
|
||||||
domain = "https://tinyauth.app/some/path"
|
|
||||||
expected = "tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// IP address
|
|
||||||
domain = "http://10.10.10.10"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
|
||||||
|
|
||||||
// Invalid domain (only TLD)
|
|
||||||
domain = "com"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "invalid app url")
|
|
||||||
|
|
||||||
// Invalid URL
|
|
||||||
domain = "http://[::1]:namedport"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
HTTP zerolog.Logger
|
|
||||||
App zerolog.Logger
|
|
||||||
config model.LogConfig
|
|
||||||
base zerolog.Logger
|
|
||||||
audit zerolog.Logger
|
|
||||||
writer io.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger() *Logger {
|
|
||||||
return &Logger{
|
|
||||||
writer: os.Stderr,
|
|
||||||
config: model.LogConfig{
|
|
||||||
Level: "error",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
App: model.LogStreamConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
// No reason to enable audit by default since it will be suppressed by the log level
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
|
|
||||||
l.config = cfg
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithSimpleConfig() *Logger {
|
|
||||||
l.config = model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithTestConfig() *Logger {
|
|
||||||
l.config = model.LogConfig{
|
|
||||||
Level: "trace",
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithWriter(writer io.Writer) *Logger {
|
|
||||||
l.writer = writer
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Init() {
|
|
||||||
base := log.With().
|
|
||||||
Timestamp().
|
|
||||||
Logger().
|
|
||||||
Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
|
|
||||||
|
|
||||||
if !l.config.Json {
|
|
||||||
base = base.Output(zerolog.ConsoleWriter{
|
|
||||||
Out: l.writer,
|
|
||||||
TimeFormat: time.RFC3339,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
|
|
||||||
base = base.With().Caller().Logger()
|
|
||||||
}
|
|
||||||
|
|
||||||
l.base = base
|
|
||||||
l.audit = l.createLogger("audit", l.config.Streams.Audit)
|
|
||||||
l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
|
|
||||||
l.App = l.createLogger("app", l.config.Streams.App)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) parseLogLevel(level string) zerolog.Level {
|
|
||||||
if level == "" {
|
|
||||||
return zerolog.InfoLevel
|
|
||||||
}
|
|
||||||
parsed, err := zerolog.ParseLevel(strings.ToLower(level))
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
|
|
||||||
parsed = zerolog.ErrorLevel
|
|
||||||
}
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
|
|
||||||
if !cfg.Enabled {
|
|
||||||
return zerolog.Nop()
|
|
||||||
}
|
|
||||||
sub := l.base.With().Str("stream", component).Logger()
|
|
||||||
if cfg.Level != "" {
|
|
||||||
sub = sub.Level(l.parseLogLevel(cfg.Level))
|
|
||||||
}
|
|
||||||
return sub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
|
|
||||||
l.audit.Info().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "login").
|
|
||||||
Str("result", "success").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", ip).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
|
|
||||||
l.audit.Warn().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "login").
|
|
||||||
Str("result", "failure").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", ip).
|
|
||||||
Str("reason", reason).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) AuditLogout(username, provider, ip string) {
|
|
||||||
l.audit.Info().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "logout").
|
|
||||||
Str("result", "success").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", ip).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used for testing
|
|
||||||
func (l *Logger) GetConfig() model.LogConfig {
|
|
||||||
return l.config
|
|
||||||
}
|
|
||||||
@@ -1,173 +0,0 @@
|
|||||||
package logger_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLogger(t *testing.T) {
|
|
||||||
type testCase struct {
|
|
||||||
description string
|
|
||||||
run func(t *testing.T)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []testCase{
|
|
||||||
{
|
|
||||||
description: "Should create a simple logger with the expected config",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
l := logger.NewLogger().WithSimpleConfig()
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should create a test logger with the expected config",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
l := logger.NewLogger().WithTestConfig()
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, model.LogConfig{
|
|
||||||
Level: "trace",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: true},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should create a logger with a custom config",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
customCfg := model.LogConfig{
|
|
||||||
Level: "debug",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: false},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithConfig(customCfg)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, customCfg)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Default logger should use error type and log json",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithWriter(&buf)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, model.LogConfig{
|
|
||||||
Level: "error",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
l.App.Error().Msg("test")
|
|
||||||
|
|
||||||
var entry map[string]any
|
|
||||||
err := json.Unmarshal(buf.Bytes(), &entry)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "test", entry["message"])
|
|
||||||
assert.Equal(t, "app", entry["stream"])
|
|
||||||
assert.Equal(t, "error", entry["level"])
|
|
||||||
assert.NotEmpty(t, entry["time"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should default to error level if an invalid level is provided",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
|
|
||||||
customCfg := model.LogConfig{
|
|
||||||
Level: "invalid",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
|
|
||||||
|
|
||||||
// should not get logged
|
|
||||||
l.AuditLoginFailure("test", "test", "test", "test")
|
|
||||||
|
|
||||||
assert.Empty(t, buf.String())
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should use nop logger for disabled streams",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
|
|
||||||
customCfg := model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: false},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
|
|
||||||
|
|
||||||
l.App.Info().Msg("test")
|
|
||||||
|
|
||||||
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
|
|
||||||
|
|
||||||
assert.NotEmpty(t, buf.String())
|
|
||||||
assert.NotContains(t, buf.String(), "test_nop")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, test.run)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -28,41 +28,3 @@ func CoalesceToString(value any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseNonEmptyLines(contents string) []string {
|
|
||||||
lines := make([]string, 0)
|
|
||||||
|
|
||||||
for line := range strings.SplitSeq(contents, "\n") {
|
|
||||||
lineTrimmed := strings.TrimSpace(line)
|
|
||||||
if lineTrimmed == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
lines = append(lines, lineTrimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
return lines
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) {
|
|
||||||
values := make([]string, 0, len(valuesCfg))
|
|
||||||
|
|
||||||
for _, value := range valuesCfg {
|
|
||||||
valueTrimmed := strings.TrimSpace(value)
|
|
||||||
if valueTrimmed == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
values = append(values, valueTrimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
if valuesPath == "" {
|
|
||||||
return values, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
contents, err := ReadFile(valuesPath)
|
|
||||||
if err != nil {
|
|
||||||
return []string{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
values = append(values, ParseNonEmptyLines(contents)...)
|
|
||||||
return values, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package utils_test
|
package utils_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -57,33 +56,3 @@ func TestCompileUserEmail(t *testing.T) {
|
|||||||
// Test with invalid email
|
// Test with invalid email
|
||||||
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
|
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseNonEmptyLines(t *testing.T) {
|
|
||||||
lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n")
|
|
||||||
|
|
||||||
assert.Equal(t, []string{"first@example.com", "second@example.com"}, lines)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStringList(t *testing.T) {
|
|
||||||
file, err := os.Create("/tmp/tinyauth_list_test_file")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = file.WriteString(" third@example.com \n\n fourth@example.com \n")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = file.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer os.Remove("/tmp/tinyauth_list_test_file")
|
|
||||||
|
|
||||||
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values)
|
|
||||||
|
|
||||||
values, err = utils.GetStringList(nil, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{}, values)
|
|
||||||
|
|
||||||
values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file")
|
|
||||||
assert.ErrorContains(t, err, "no such file or directory")
|
|
||||||
assert.Equal(t, []string{}, values)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package tlog
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
// functions here use CallerSkipFrame to ensure correct caller info is logged
|
||||||
|
|
||||||
|
func AuditLoginSuccess(c *gin.Context, username, provider string) {
|
||||||
|
Audit.Info().
|
||||||
|
CallerSkipFrame(1).
|
||||||
|
Str("event", "login").
|
||||||
|
Str("result", "success").
|
||||||
|
Str("username", username).
|
||||||
|
Str("provider", provider).
|
||||||
|
Str("ip", c.ClientIP()).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
|
||||||
|
Audit.Warn().
|
||||||
|
CallerSkipFrame(1).
|
||||||
|
Str("event", "login").
|
||||||
|
Str("result", "failure").
|
||||||
|
Str("username", username).
|
||||||
|
Str("provider", provider).
|
||||||
|
Str("ip", c.ClientIP()).
|
||||||
|
Str("reason", reason).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuditLogout(c *gin.Context, username, provider string) {
|
||||||
|
Audit.Info().
|
||||||
|
CallerSkipFrame(1).
|
||||||
|
Str("event", "logout").
|
||||||
|
Str("result", "success").
|
||||||
|
Str("username", username).
|
||||||
|
Str("provider", provider).
|
||||||
|
Str("ip", c.ClientIP()).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package tlog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
Audit zerolog.Logger
|
||||||
|
HTTP zerolog.Logger
|
||||||
|
App zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
Audit zerolog.Logger
|
||||||
|
HTTP zerolog.Logger
|
||||||
|
App zerolog.Logger
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewLogger(cfg model.LogConfig) *Logger {
|
||||||
|
baseLogger := log.With().
|
||||||
|
Timestamp().
|
||||||
|
Caller().
|
||||||
|
Logger().
|
||||||
|
Level(parseLogLevel(cfg.Level))
|
||||||
|
|
||||||
|
if !cfg.Json {
|
||||||
|
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
|
||||||
|
Out: os.Stderr,
|
||||||
|
TimeFormat: time.RFC3339,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Logger{
|
||||||
|
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
|
||||||
|
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
|
||||||
|
App: createLogger("app", cfg.Streams.App, baseLogger),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSimpleLogger() *Logger {
|
||||||
|
return NewLogger(model.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Json: false,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestLogger() *Logger {
|
||||||
|
return NewLogger(model.LogConfig{
|
||||||
|
Level: "trace",
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Init() {
|
||||||
|
Audit = l.Audit
|
||||||
|
HTTP = l.HTTP
|
||||||
|
App = l.App
|
||||||
|
}
|
||||||
|
|
||||||
|
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
|
||||||
|
if !streamCfg.Enabled {
|
||||||
|
return zerolog.Nop()
|
||||||
|
}
|
||||||
|
subLogger := baseLogger.With().Str("log_stream", component).Logger()
|
||||||
|
// override level if specified, otherwise use base level
|
||||||
|
if streamCfg.Level != "" {
|
||||||
|
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
|
||||||
|
}
|
||||||
|
return subLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogLevel(level string) zerolog.Level {
|
||||||
|
if level == "" {
|
||||||
|
return zerolog.InfoLevel
|
||||||
|
}
|
||||||
|
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
|
||||||
|
parsedLevel = zerolog.InfoLevel
|
||||||
|
}
|
||||||
|
return parsedLevel
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package tlog_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewLogger(t *testing.T) {
|
||||||
|
cfg := model.LogConfig{
|
||||||
|
Level: "debug",
|
||||||
|
Json: true,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
|
||||||
|
App: model.LogStreamConfig{Enabled: true, Level: ""},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSimpleLogger(t *testing.T) {
|
||||||
|
logger := tlog.NewSimpleLogger()
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerInit(t *testing.T) {
|
||||||
|
logger := tlog.NewSimpleLogger()
|
||||||
|
logger.Init()
|
||||||
|
|
||||||
|
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerWithDisabledStreams(t *testing.T) {
|
||||||
|
cfg := model.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Json: false,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: false},
|
||||||
|
App: model.LogStreamConfig{Enabled: false},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogStreamField(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
cfg := model.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Json: true,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
|
// Override output for HTTP logger to capture output
|
||||||
|
logger.HTTP = logger.HTTP.Output(&buf)
|
||||||
|
|
||||||
|
logger.HTTP.Info().Msg("test message")
|
||||||
|
|
||||||
|
var logEntry map[string]interface{}
|
||||||
|
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "http", logEntry["log_stream"])
|
||||||
|
assert.Equal(t, "test message", logEntry["message"])
|
||||||
|
}
|
||||||
@@ -13,7 +13,7 @@ func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttribute
|
|||||||
var users []model.LocalUser
|
var users []model.LocalUser
|
||||||
|
|
||||||
if len(usersStr) == 0 {
|
if len(usersStr) == 0 {
|
||||||
return nil, nil
|
return &users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range usersStr {
|
for _, user := range usersStr {
|
||||||
@@ -34,9 +34,32 @@ func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttribute
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
|
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
|
||||||
usersStr, err := GetStringList(usersCfg, usersPath)
|
var usersStr []string
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
if len(usersCfg) == 0 && usersPath == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(usersCfg) > 0 {
|
||||||
|
usersStr = append(usersStr, usersCfg...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if usersPath != "" {
|
||||||
|
contents, err := ReadFile(usersPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.SplitSeq(contents, "\n")
|
||||||
|
|
||||||
|
for line := range lines {
|
||||||
|
lineTrimmed := strings.TrimSpace(line)
|
||||||
|
if lineTrimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
usersStr = append(usersStr, lineTrimmed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ParseUsers(usersStr, userAttributes)
|
return ParseUsers(usersStr, userAttributes)
|
||||||
|
|||||||
Reference in New Issue
Block a user