mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-14 16:20:14 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5aeb886523 | |||
| b06b60150f | |||
| 4077bacfdf | |||
| 4c0181c5e2 | |||
| 44a7cbf41b | |||
| d90e3d652d |
@@ -91,8 +91,6 @@ TINYAUTH_APPS_name_LDAP_GROUPS=
|
|||||||
|
|
||||||
# Comma-separated list of allowed OAuth domains.
|
# Comma-separated list of allowed OAuth domains.
|
||||||
TINYAUTH_OAUTH_WHITELIST=
|
TINYAUTH_OAUTH_WHITELIST=
|
||||||
# Path to the OAuth whitelist file.
|
|
||||||
TINYAUTH_OAUTH_WHITELISTFILE=
|
|
||||||
# The OAuth provider to use for automatic redirection.
|
# The OAuth provider to use for automatic redirection.
|
||||||
TINYAUTH_OAUTH_AUTOREDIRECT=
|
TINYAUTH_OAUTH_AUTOREDIRECT=
|
||||||
# OAuth client ID.
|
# OAuth client ID.
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ jobs:
|
|||||||
retention-days: 5
|
retention-days: 5
|
||||||
|
|
||||||
- name: Upload to code-scanning
|
- name: Upload to code-scanning
|
||||||
uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
|
uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4
|
||||||
with:
|
with:
|
||||||
sarif_file: results.sarif
|
sarif_file: results.sarif
|
||||||
|
|||||||
@@ -48,6 +48,3 @@ __debug_*
|
|||||||
|
|
||||||
# testing config
|
# testing config
|
||||||
config.certify.yml
|
config.certify.yml
|
||||||
|
|
||||||
# deepsec
|
|
||||||
/.deepsec
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma
|
|||||||
|
|
||||||
A big thank you to the following people for providing me with more coffee:
|
A big thank you to the following people for providing me with more coffee:
|
||||||
|
|
||||||
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <a href="https://github.com/apearson"><img src="https://github.com/apearson.png" width="64px" alt="User avatar: apearson" /></a> <!-- sponsors -->
|
<!-- sponsors --><a href="https://github.com/erwinkramer"><img src="https://github.com/erwinkramer.png" width="64px" alt="User avatar: erwinkramer" /></a> <a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a> <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a> <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a> <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a> <a href="https://github.com/eliasbenb"><img src="https://github.com/eliasbenb.png" width="64px" alt="User avatar: eliasbenb" /></a> <a href="https://github.com/afunworm"><img src="https://github.com/afunworm.png" width="64px" alt="User avatar: afunworm" /></a> <a href="https://github.com/chip-well"><img src="https://github.com/chip-well.png" width="64px" alt="User avatar: chip-well" /></a> <a href="https://github.com/Lancelot-Enguerrand"><img src="https://github.com/Lancelot-Enguerrand.png" width="64px" alt="User avatar: Lancelot-Enguerrand" /></a> <a href="https://github.com/allgoewer"><img src="https://github.com/allgoewer.png" width="64px" alt="User avatar: allgoewer" /></a> <a href="https://github.com/NEANC"><img src="https://github.com/NEANC.png" width="64px" alt="User avatar: NEANC" /></a> <a href="https://github.com/ax-mad"><img src="https://github.com/ax-mad.png" width="64px" alt="User avatar: ax-mad" /></a> <a href="https://github.com/stegratech"><img src="https://github.com/stegratech.png" width="64px" alt="User avatar: stegratech" /></a> <!-- sponsors -->
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
|
|
||||||
|
|||||||
+2
-50
@@ -2,56 +2,8 @@
|
|||||||
|
|
||||||
## Supported Versions
|
## Supported Versions
|
||||||
|
|
||||||
It is recommended to use the [latest](https://github.com/tinyauthapp/tinyauth/releases/latest) available version of Tinyauth. This is because it includes security fixes, new features and dependency updates. Older versions, especially major ones, are not supported and won't receive security or patch updates.
|
It is recommended to use the [latest](https://github.com/tinyauthapp/tinyauth/releases/latest) available version of tinyauth. This is because it includes security fixes, new features and dependency updates. Older versions, especially major ones, are not supported and won't receive security or patch updates.
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
Please **do not** report security vulnerabilities through public GitHub issues, discussions, or pull requests as I won't be able to patch them in time and they may get exploited by malicious actors.
|
Due to the nature of this app, it needs to be secure. If you discover any security issues or vulnerabilities in the app please contact me as soon as possible at <security@tinyauth.app>. Please do not use the issues section to report security issues as I won't be able to patch them in time and they may get exploited by malicious actors.
|
||||||
|
|
||||||
Instead, report them privately using [GitHub's Private Vulnerability Reporting](https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability) via the **Security** tab of this repository.
|
|
||||||
|
|
||||||
Or send us an email at <security@tinyauth.app>.
|
|
||||||
|
|
||||||
### A note on AI-assisted reports
|
|
||||||
|
|
||||||
If AI tooling (LLMs, automated scanners, agentic assistants, etc.) helped you discover, analyse, or write up this issue, please say so in your report. This isn't a judgement - AI-assisted findings are welcome - but disclosing it up front helps maintainers calibrate how much additional verification a report needs, and tends to make the report itself clearer.
|
|
||||||
|
|
||||||
When submitting a report, please use the structure below so it can be triaged quickly.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1. Summary
|
|
||||||
|
|
||||||
A short, one-paragraph description of the vulnerability and its impact (e.g. what an attacker can achieve, who is affected, and under what conditions).
|
|
||||||
|
|
||||||
### 2. Steps to Reproduce / Proof of Concept
|
|
||||||
|
|
||||||
Provide a minimal, reliable reproduction:
|
|
||||||
|
|
||||||
1. Step one
|
|
||||||
2. Step two
|
|
||||||
3. Step three
|
|
||||||
|
|
||||||
Include any required input, payloads, configuration, or code snippets. Attach a PoC script or screenshots where helpful.
|
|
||||||
|
|
||||||
### 3. Expected vs. Actual Behaviour
|
|
||||||
|
|
||||||
- **Expected:** what *should* happen
|
|
||||||
- **Actual:** what *does* happen, and why it's a security issue
|
|
||||||
|
|
||||||
### 4. Suggested Fix or Mitigation *(optional)*
|
|
||||||
|
|
||||||
If you have an idea for how to address the issue, describe it here. A private gist link is welcome but not required.
|
|
||||||
|
|
||||||
- **Have you tested this fix?** Yes / No
|
|
||||||
- **If yes,** briefly describe how it was tested and what was verified.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What to Expect
|
|
||||||
|
|
||||||
- **Acknowledgement** within a reasonable timeframe after receiving your report
|
|
||||||
- **Updates** as the issue is investigated and addressed
|
|
||||||
- **Public credit** in the resulting advisory, along with any **CVE assigned**, unless you'd prefer to stay anonymous
|
|
||||||
|
|
||||||
We follow a **90-day coordinated disclosure** window: please allow up to 90 days from the date of your report for the issue to be investigated and patched before publicly disclosing it. The publication date - whether earlier if a fix lands sooner, or later if more time is genuinely needed - will be agreed with you in advance.
|
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,8 +40,7 @@ func createUserCmd() *cli.Command {
|
|||||||
Configuration: tCfg,
|
Configuration: tCfg,
|
||||||
Resources: loaders,
|
Resources: loaders,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
if tCfg.Interactive {
|
if tCfg.Interactive {
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
@@ -74,7 +73,7 @@ func createUserCmd() *cli.Command {
|
|||||||
return errors.New("username and password cannot be empty")
|
return errors.New("username and password cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("username", tCfg.Username).Msg("Creating user")
|
tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user")
|
||||||
|
|
||||||
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
|
passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -87,7 +86,7 @@ func createUserCmd() *cli.Command {
|
|||||||
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
|
passwdStr = strings.ReplaceAll(passwdStr, "$", "$$")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
|
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
"github.com/mdp/qrterminal/v3"
|
"github.com/mdp/qrterminal/v3"
|
||||||
@@ -40,8 +40,7 @@ func generateTotpCmd() *cli.Command {
|
|||||||
Configuration: tCfg,
|
Configuration: tCfg,
|
||||||
Resources: loaders,
|
Resources: loaders,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
if tCfg.Interactive {
|
if tCfg.Interactive {
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
@@ -89,9 +88,9 @@ func generateTotpCmd() *cli.Command {
|
|||||||
|
|
||||||
secret := key.Secret()
|
secret := key.Secret()
|
||||||
|
|
||||||
log.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
|
tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret")
|
||||||
|
|
||||||
log.App.Info().Msg("Generated QR code")
|
tlog.App.Info().Msg("Generated QR code")
|
||||||
|
|
||||||
config := qrterminal.Config{
|
config := qrterminal.Config{
|
||||||
Level: qrterminal.L,
|
Level: qrterminal.L,
|
||||||
@@ -110,7 +109,7 @@ func generateTotpCmd() *cli.Command {
|
|||||||
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
|
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
|
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type healthzResponse struct {
|
type healthzResponse struct {
|
||||||
@@ -26,8 +26,7 @@ func healthcheckCmd() *cli.Command {
|
|||||||
Resources: nil,
|
Resources: nil,
|
||||||
AllowArg: true,
|
AllowArg: true,
|
||||||
Run: func(args []string) error {
|
Run: func(args []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
|
srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS")
|
||||||
if srvAddr == "" {
|
if srvAddr == "" {
|
||||||
@@ -49,7 +48,7 @@ func healthcheckCmd() *cli.Command {
|
|||||||
return errors.New("Could not determine app URL")
|
return errors.New("Could not determine app URL")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Str("app_url", appUrl).Msg("Performing health check")
|
tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check")
|
||||||
|
|
||||||
client := http.Client{
|
client := http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
@@ -87,7 +86,7 @@ func healthcheckCmd() *cli.Command {
|
|||||||
return fmt.Errorf("failed to decode response: %w", err)
|
return fmt.Errorf("failed to decode response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
|
tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
@@ -108,6 +109,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runCmd(cfg model.Config) error {
|
func runCmd(cfg model.Config) error {
|
||||||
|
logger := tlog.NewLogger(cfg.Log)
|
||||||
|
logger.Init()
|
||||||
|
|
||||||
|
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(cfg)
|
||||||
|
|
||||||
err := app.Setup()
|
err := app.Setup()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -44,8 +44,7 @@ func verifyUserCmd() *cli.Command {
|
|||||||
Configuration: tCfg,
|
Configuration: tCfg,
|
||||||
Resources: loaders,
|
Resources: loaders,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
log := logger.NewLogger().WithSimpleConfig()
|
tlog.NewSimpleLogger().Init()
|
||||||
log.Init()
|
|
||||||
|
|
||||||
if tCfg.Interactive {
|
if tCfg.Interactive {
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
@@ -98,9 +97,9 @@ func verifyUserCmd() *cli.Command {
|
|||||||
|
|
||||||
if user.TOTPSecret == "" {
|
if user.TOTPSecret == "" {
|
||||||
if tCfg.Totp != "" {
|
if tCfg.Totp != "" {
|
||||||
log.App.Warn().Msg("User does not have TOTP secret")
|
tlog.App.Warn().Msg("User does not have TOTP secret")
|
||||||
}
|
}
|
||||||
log.App.Info().Msg("User verified")
|
tlog.App.Info().Msg("User verified")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,7 +109,7 @@ func verifyUserCmd() *cli.Command {
|
|||||||
return fmt.Errorf("TOTP code incorrect")
|
return fmt.Errorf("TOTP code incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Msg("User verified")
|
tlog.App.Info().Msg("User verified")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ require (
|
|||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.51.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
k8s.io/apimachinery v0.36.1
|
k8s.io/apimachinery v0.32.2
|
||||||
k8s.io/client-go v0.36.1
|
k8s.io/client-go v0.32.2
|
||||||
modernc.org/sqlite v1.50.1
|
modernc.org/sqlite v1.49.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -63,7 +63,7 @@ require (
|
|||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
|
github.com/fxamacker/cbor/v2 v2.7.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||||
@@ -74,6 +74,9 @@ require (
|
|||||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||||
github.com/goccy/go-json v0.10.5 // indirect
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||||
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
|
github.com/google/gofuzz v1.2.0 // indirect
|
||||||
github.com/huandu/xstrings v1.5.0 // indirect
|
github.com/huandu/xstrings v1.5.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
@@ -90,7 +93,7 @@ require (
|
|||||||
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
||||||
github.com/moby/term v0.5.2 // indirect
|
github.com/moby/term v0.5.2 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
@@ -118,28 +121,25 @@ require (
|
|||||||
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
|
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
|
go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
|
||||||
golang.org/x/arch v0.22.0 // indirect
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/net v0.53.0 // indirect
|
golang.org/x/net v0.52.0 // indirect
|
||||||
golang.org/x/sync v0.20.0 // indirect
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
golang.org/x/sys v0.44.0 // indirect
|
golang.org/x/sys v0.43.0 // indirect
|
||||||
golang.org/x/term v0.43.0 // indirect
|
golang.org/x/term v0.42.0 // indirect
|
||||||
golang.org/x/text v0.37.0 // indirect
|
golang.org/x/text v0.36.0 // indirect
|
||||||
golang.org/x/time v0.14.0 // indirect
|
golang.org/x/time v0.12.0 // indirect
|
||||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
gotest.tools/v3 v3.5.2 // indirect
|
gotest.tools/v3 v3.5.2 // indirect
|
||||||
k8s.io/klog/v2 v2.140.0 // indirect
|
k8s.io/klog/v2 v2.130.1 // indirect
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a // indirect
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 // indirect
|
modernc.org/libc v1.72.0 // indirect
|
||||||
modernc.org/libc v1.72.3 // indirect
|
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
rsc.io/qr v0.2.0 // indirect
|
rsc.io/qr v0.2.0 // indirect
|
||||||
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect
|
||||||
sigs.k8s.io/randfill v1.0.0 // indirect
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect
|
||||||
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 // indirect
|
sigs.k8s.io/yaml v1.4.0 // indirect
|
||||||
sigs.k8s.io/yaml v1.6.0 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,14 +97,14 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4
|
|||||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes=
|
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
|
||||||
github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
||||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
@@ -140,16 +140,23 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
|||||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
||||||
github.com/google/gnostic-models v0.7.0 h1:qwTtogB15McXDaNqTZdzPJRHvaVJlAl+HVQnLmJEJxo=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/google/gnostic-models v0.7.0/go.mod h1:whL5G0m6dmc5cPxKc5bdKdEN3UjI7OUGxBlw57miDrQ=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
|
github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I=
|
||||||
|
github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U=
|
||||||
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0=
|
github.com/google/go-querystring v1.2.0 h1:yhqkPbu2/OH+V9BfpCVPZkNmUXhb2gBxJArfhIxNtP0=
|
||||||
github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU=
|
github.com/google/go-querystring v1.2.0/go.mod h1:8IFJqpSRITyJ8QhQ13bmbeMBDfmeEJZD5A0egEOmkqU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||||
|
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
@@ -178,6 +185,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
|
|||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
@@ -219,9 +228,8 @@ github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFL
|
|||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee h1:W5t00kpgFdJifH4BDsTlE89Zl93FEloxaWZfGcifgq8=
|
|
||||||
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
|
||||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
@@ -261,12 +269,11 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
|
|||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -287,6 +294,8 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
|||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
@@ -311,35 +320,56 @@ go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpu
|
|||||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
|
||||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
|
||||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
|
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||||
|
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
|
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||||
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
|
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||||
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
|
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
|
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||||
|
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
|
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M=
|
||||||
@@ -347,13 +377,13 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:
|
|||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af h1:+5/Sw3GsDNlEmu7TfklWKPdQ0Ykja5VEmq2i817+jbI=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.12-0.20260120151049-f2248ac996af/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
gopkg.in/evanphx/json-patch.v4 v4.13.0 h1:czT3CmqEaQ1aanPc5SdlgQrrEIb8w/wwCvWWnfEbYzo=
|
gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4=
|
||||||
gopkg.in/evanphx/json-patch.v4 v4.13.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M=
|
gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M=
|
||||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
@@ -361,22 +391,22 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
k8s.io/api v0.36.1 h1:XbL/EMj8K2aJpJtePmqUyQMsM0D4QI2pvl7YKJ20FTY=
|
k8s.io/api v0.32.2 h1:bZrMLEkgizC24G9eViHGOPbW+aRo9duEISRIJKfdJuw=
|
||||||
k8s.io/api v0.36.1/go.mod h1:KOWo4ey3TINlXjeHVuwB3i+tXXnu+UcwFBHlI/9dvEo=
|
k8s.io/api v0.32.2/go.mod h1:hKlhk4x1sJyYnHENsrdCWw31FEmCijNGPJO5WzHiJ6Y=
|
||||||
k8s.io/apimachinery v0.36.1 h1:G63Gjx2W+q0YD+72Vo8oY0nDnePVwnuzTmmy5ENrVSA=
|
k8s.io/apimachinery v0.32.2 h1:yoQBR9ZGkA6Rgmhbp/yuT9/g+4lxtsGYwW6dR6BDPLQ=
|
||||||
k8s.io/apimachinery v0.36.1/go.mod h1:ibYOR00vW/I1kzvi5SF0dRuJ52BvKtfvRdOn35GPQ+8=
|
k8s.io/apimachinery v0.32.2/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE=
|
||||||
k8s.io/client-go v0.36.1 h1:FN/K8QIT2CEDt+2WB2HnWrUANZ50AP5GII43/SP2JR0=
|
k8s.io/client-go v0.32.2 h1:4dYCD4Nz+9RApM2b/3BtVvBHw54QjMFUl1OLcJG5yOA=
|
||||||
k8s.io/client-go v0.36.1/go.mod h1:s6rAnCtTGYDQnpNjEhSaISV+2O8jwruZ6m3QOYBFbtU=
|
k8s.io/client-go v0.32.2/go.mod h1:fpZ4oJXclZ3r2nDOv+Ux3XcJutfrwjKTCHz2H3sww94=
|
||||||
k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc=
|
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
|
||||||
k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0=
|
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a h1:xCeOEAOoGYl2jnJoHkC3hkbPJgdATINPMAxaynU2Ovg=
|
k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f h1:GA7//TjRY9yWGy1poLzYYJJ4JRdzg3+O6e8I+e+8T5Y=
|
||||||
k8s.io/kube-openapi v0.0.0-20260317180543-43fb72c5454a/go.mod h1:uGBT7iTA6c6MvqUvSXIaYZo9ukscABYi2btjhvgKGZ0=
|
k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f/go.mod h1:R/HEjbvWI0qdfb8viZUeVZm0X6IZnxAydC7YU42CMw4=
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2 h1:AZYQSJemyQB5eRxqcPky+/7EdBj0xi3g0ZcxxJ7vbWU=
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 h1:M3sRQVHv7vB20Xc2ybTt7ODCeFj6JSWYFzOFnYeS6Ro=
|
||||||
k8s.io/utils v0.0.0-20260210185600-b8788abfbbc2/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk=
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
|
||||||
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
|
modernc.org/cc/v4 v4.27.3 h1:uNCgn37E5U09mTv1XgskEVUJ8ADKpmFMPxzGJ0TSo+U=
|
||||||
modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI=
|
modernc.org/cc/v4 v4.27.3/go.mod h1:3YjcbCqhoTTHPycJDRl2WZKKFj0nwcOIPBfEZK0Hdk8=
|
||||||
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
|
modernc.org/ccgo/v4 v4.32.4 h1:L5OB8rpEX4ZsXEQwGozRfJyJSFHbbNVOoQ59DU9/KuU=
|
||||||
modernc.org/ccgo/v4 v4.34.0/go.mod h1:AS5WYMyBakQ+fhsHhtP8mWB82KTGPkNNJDGfGQCe0/A=
|
modernc.org/ccgo/v4 v4.32.4/go.mod h1:lY7f+fiTDHfcv6YlRgSkxYfhs+UvOEEzj49jAn2TOx0=
|
||||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
@@ -385,29 +415,27 @@ modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
|||||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
|
modernc.org/libc v1.72.0 h1:IEu559v9a0XWjw0DPoVKtXpO2qt5NVLAnFaBbjq+n8c=
|
||||||
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
|
modernc.org/libc v1.72.0/go.mod h1:tTU8DL8A+XLVkEY3x5E/tO7s2Q/q42EtnNWda/L5QhQ=
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w=
|
modernc.org/sqlite v1.49.1 h1:dYGHTKcX1sJ+EQDnUzvz4TJ5GbuvhNJa8Fg6ElGx73U=
|
||||||
modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
modernc.org/sqlite v1.49.1/go.mod h1:m0w8xhwYUVY3H6pSDwc3gkJ/irZT/0YEXwBlhaxQEew=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||||
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8=
|
||||||
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
|
sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo=
|
||||||
sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU=
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2 h1:MdmvkGuXi/8io6ixD5wud3vOLwc1rj0aNqRlpuvjmwA=
|
||||||
sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY=
|
sigs.k8s.io/structured-merge-diff/v4 v4.4.2/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4=
|
||||||
sigs.k8s.io/structured-merge-diff/v6 v6.3.2 h1:kwVWMx5yS1CrnFWA/2QHyRVJ8jM6dBA80uLmm0wJkk8=
|
sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E=
|
||||||
sigs.k8s.io/structured-merge-diff/v6 v6.3.2/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE=
|
sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY=
|
||||||
sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
|
|
||||||
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
|
|
||||||
|
|||||||
+122
-273
@@ -3,50 +3,38 @@ package bootstrap
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Services struct {
|
|
||||||
accessControlService *service.AccessControlsService
|
|
||||||
authService *service.AuthService
|
|
||||||
dockerService *service.DockerService
|
|
||||||
kubernetesService *service.KubernetesService
|
|
||||||
ldapService *service.LdapService
|
|
||||||
oauthBrokerService *service.OAuthBrokerService
|
|
||||||
oidcService *service.OIDCService
|
|
||||||
}
|
|
||||||
|
|
||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
context struct {
|
||||||
|
appUrl string
|
||||||
|
uuid string
|
||||||
|
cookieDomain string
|
||||||
|
sessionCookieName string
|
||||||
|
csrfCookieName string
|
||||||
|
redirectCookieName string
|
||||||
|
oauthSessionCookieName string
|
||||||
|
localUsers *[]model.LocalUser
|
||||||
|
oauthProviders map[string]model.OAuthServiceConfig
|
||||||
|
configuredProviders []controller.Provider
|
||||||
|
oidcClients []model.OIDCClientConfig
|
||||||
|
}
|
||||||
services Services
|
services Services
|
||||||
log *logger.Logger
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
queries *repository.Queries
|
|
||||||
router *gin.Engine
|
|
||||||
db *sql.DB
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
@@ -56,69 +44,49 @@ func NewBootstrapApp(config model.Config) *BootstrapApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) Setup() error {
|
func (app *BootstrapApp) Setup() error {
|
||||||
// create context
|
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
||||||
app.ctx = ctx
|
|
||||||
app.cancel = cancel
|
|
||||||
|
|
||||||
// setup logger
|
|
||||||
log := logger.NewLogger().WithConfig(app.config.Log)
|
|
||||||
log.Init()
|
|
||||||
app.log = log
|
|
||||||
|
|
||||||
// get app url
|
// get app url
|
||||||
if app.config.AppURL == "" {
|
if app.config.AppURL == "" {
|
||||||
return errors.New("app url cannot be empty, perhaps config loading failed")
|
return fmt.Errorf("app URL cannot be empty, perhaps config loading failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
appUrl, err := url.Parse(app.config.AppURL)
|
appUrl, err := url.Parse(app.config.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse app url: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
|
app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host
|
||||||
|
|
||||||
// validate session config
|
// validate session config
|
||||||
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
|
||||||
return errors.New("session max lifetime cannot be less than session expiry")
|
return fmt.Errorf("session max lifetime cannot be less than session expiry")
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse users
|
// Parse users
|
||||||
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
|
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load users: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.LocalUsers = *users
|
app.context.localUsers = users
|
||||||
|
|
||||||
// load oauth whitelist
|
// Setup OAuth providers
|
||||||
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
|
app.context.oauthProviders = app.config.OAuth.Providers
|
||||||
|
|
||||||
if err != nil {
|
for name, provider := range app.context.oauthProviders {
|
||||||
return fmt.Errorf("failed to load oauth whitelist: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.runtime.OAuthWhitelist = oauthWhitelist
|
|
||||||
|
|
||||||
// setup oauth providers
|
|
||||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
|
||||||
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||||
provider.ClientSecret = secret
|
provider.ClientSecret = secret
|
||||||
provider.ClientSecretFile = ""
|
provider.ClientSecretFile = ""
|
||||||
|
|
||||||
if provider.RedirectURL == "" {
|
if provider.RedirectURL == "" {
|
||||||
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
|
provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.OAuthProviders[id] = provider
|
app.context.oauthProviders[name] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// set presets for built-in providers
|
for id, provider := range app.context.oauthProviders {
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
|
||||||
if provider.Name == "" {
|
if provider.Name == "" {
|
||||||
if name, ok := model.OverrideProviders[id]; ok {
|
if name, ok := model.OverrideProviders[id]; ok {
|
||||||
provider.Name = name
|
provider.Name = name
|
||||||
@@ -126,72 +94,71 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
provider.Name = utils.Capitalize(id)
|
provider.Name = utils.Capitalize(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
app.runtime.OAuthProviders[id] = provider
|
app.context.oauthProviders[id] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup oidc clients
|
// Setup OIDC clients
|
||||||
for id, client := range app.config.OIDC.Clients {
|
for id, client := range app.config.OIDC.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
app.runtime.OIDCClients = append(app.runtime.OIDCClients, client)
|
app.context.oidcClients = append(app.context.oidcClients, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cookie domain
|
// Get cookie domain
|
||||||
cookieDomainResolver := utils.GetCookieDomain
|
cookieDomainResolver := utils.GetCookieDomain
|
||||||
|
|
||||||
if !app.config.Auth.SubdomainsEnabled {
|
if !app.config.Auth.SubdomainsEnabled {
|
||||||
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
|
tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work")
|
||||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
|
cookieDomain, err := cookieDomainResolver(app.context.appUrl)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get cookie domain: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.runtime.CookieDomain = cookieDomain
|
app.context.cookieDomain = cookieDomain
|
||||||
|
|
||||||
// cookie names
|
// Cookie names
|
||||||
app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname())
|
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
|
||||||
|
cookieId := strings.Split(app.context.uuid, "-")[0]
|
||||||
|
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||||
|
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
||||||
|
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
||||||
|
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||||
|
|
||||||
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
|
// Dumps
|
||||||
|
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
|
||||||
|
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
|
||||||
|
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
|
||||||
|
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
|
||||||
|
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
|
||||||
|
tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
|
||||||
|
tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
|
||||||
|
|
||||||
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
// Database
|
||||||
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
db, err := app.SetupDatabase(app.config.Database.Path)
|
||||||
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
|
||||||
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
|
||||||
|
|
||||||
// database
|
|
||||||
err = app.SetupDatabase()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup database: %w", err)
|
return fmt.Errorf("failed to setup database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// after this point, we start initializing dependencies so it's a good time to setup a defer
|
// Queries
|
||||||
// to ensure that resources are cleaned up properly in case of an error during initialization
|
queries := repository.New(db)
|
||||||
defer func() {
|
|
||||||
app.cancel()
|
|
||||||
app.wg.Wait()
|
|
||||||
app.db.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// queries
|
// Services
|
||||||
queries := repository.New(app.db)
|
services, err := app.initServices(queries)
|
||||||
app.queries = queries
|
|
||||||
|
|
||||||
// services
|
|
||||||
err = app.setupServices()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize services: %w", err)
|
return fmt.Errorf("failed to initialize services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// configured providers
|
app.services = services
|
||||||
configuredProviders := make([]model.Provider, 0)
|
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
// Configured providers
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders := make([]controller.Provider, 0)
|
||||||
|
|
||||||
|
for id, provider := range app.context.oauthProviders {
|
||||||
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: provider.Name,
|
Name: provider.Name,
|
||||||
ID: id,
|
ID: id,
|
||||||
OAuth: true,
|
OAuth: true,
|
||||||
@@ -202,171 +169,70 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return configuredProviders[i].Name < configuredProviders[j].Name
|
return configuredProviders[i].Name < configuredProviders[j].Name
|
||||||
})
|
})
|
||||||
|
|
||||||
if app.services.authService.LocalAuthConfigured() {
|
if services.authService.LocalAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: "Local",
|
Name: "Local",
|
||||||
ID: "local",
|
ID: "local",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if app.services.authService.LDAPAuthConfigured() {
|
if services.authService.LDAPAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, model.Provider{
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: "LDAP",
|
Name: "LDAP",
|
||||||
ID: "ldap",
|
ID: "ldap",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers")
|
||||||
|
|
||||||
if len(configuredProviders) == 0 {
|
if len(configuredProviders) == 0 {
|
||||||
return errors.New("no authentication providers configured")
|
return fmt.Errorf("no authentication providers configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, provider := range configuredProviders {
|
app.context.configuredProviders = configuredProviders
|
||||||
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
|
|
||||||
}
|
|
||||||
|
|
||||||
app.runtime.ConfiguredProviders = configuredProviders
|
// Setup router
|
||||||
|
router, err := app.setupRouter()
|
||||||
// setup router
|
|
||||||
err = app.setupRouter()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup routes: %w", err)
|
return fmt.Errorf("failed to setup routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// start db cleanup routine
|
// Start db cleanup routine
|
||||||
app.log.App.Debug().Msg("Starting database cleanup routine")
|
tlog.App.Debug().Msg("Starting database cleanup routine")
|
||||||
app.wg.Go(app.dbCleanupRoutine)
|
go app.dbCleanupRoutine(queries)
|
||||||
|
|
||||||
// if analytics are not disabled, start heartbeat
|
// If analytics are not disabled, start heartbeat
|
||||||
if app.config.Analytics.Enabled {
|
if app.config.Analytics.Enabled {
|
||||||
app.log.App.Debug().Msg("Starting heartbeat routine")
|
tlog.App.Debug().Msg("Starting heartbeat routine")
|
||||||
app.wg.Go(app.heartbeatRoutine)
|
go app.heartbeatRoutine()
|
||||||
}
|
}
|
||||||
|
|
||||||
// create err channel to listen for server errors
|
// If we have an socket path, bind to it
|
||||||
errChanLen := 0
|
if app.config.Server.SocketPath != "" {
|
||||||
|
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
|
||||||
runUnix := app.config.Server.SocketPath != ""
|
tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
|
||||||
runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
|
err := os.Remove(app.config.Server.SocketPath)
|
||||||
|
|
||||||
if runUnix {
|
|
||||||
errChanLen++
|
|
||||||
}
|
|
||||||
|
|
||||||
if runHTTP {
|
|
||||||
errChanLen++
|
|
||||||
}
|
|
||||||
|
|
||||||
errChan := make(chan error, errChanLen)
|
|
||||||
|
|
||||||
if app.config.Server.ConcurrentListenersEnabled {
|
|
||||||
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
|
|
||||||
}
|
|
||||||
|
|
||||||
// serve unix
|
|
||||||
if runUnix {
|
|
||||||
app.wg.Go(func() {
|
|
||||||
if err := app.serveUnix(); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// serve to http
|
|
||||||
if runHTTP {
|
|
||||||
app.wg.Go(func() {
|
|
||||||
if err := app.serveHTTP(); err != nil {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// monitor cancellation and server errors
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-app.ctx.Done():
|
|
||||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
|
||||||
return nil
|
|
||||||
case err := <-errChan:
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server error: %w", err)
|
return fmt.Errorf("failed to remove existing socket file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) serveHTTP() error {
|
tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
|
||||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
if err := router.RunUnix(app.config.Server.SocketPath); err != nil {
|
||||||
|
tlog.App.Fatal().Err(err).Msg("Failed to start server")
|
||||||
|
}
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on %s", address)
|
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Addr: address,
|
|
||||||
Handler: app.router.Handler(),
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-app.ctx.Done()
|
|
||||||
app.log.App.Debug().Msg("Shutting down http listener")
|
|
||||||
server.Shutdown(app.ctx)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := server.ListenAndServe()
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
return fmt.Errorf("failed to start http listener: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) serveUnix() error {
|
|
||||||
if app.config.Server.SocketPath == "" {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := os.Stat(app.config.Server.SocketPath)
|
// Start server
|
||||||
|
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||||
if err == nil {
|
tlog.App.Info().Msgf("Starting server on %s", address)
|
||||||
app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
|
if err := router.Run(address); err != nil {
|
||||||
err := os.Remove(app.config.Server.SocketPath)
|
tlog.App.Fatal().Err(err).Msg("Failed to start server")
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove existing socket file: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
|
|
||||||
|
|
||||||
listener, err := net.Listen("unix", app.config.Server.SocketPath)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create unix socket listener: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
server := &http.Server{
|
|
||||||
Handler: app.router.Handler(),
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdown := func() {
|
|
||||||
server.Shutdown(app.ctx)
|
|
||||||
listener.Close()
|
|
||||||
os.Remove(app.config.Server.SocketPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-app.ctx.Done()
|
|
||||||
app.log.App.Debug().Msg("Shutting down unix socket listener")
|
|
||||||
shutdown()
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = server.Serve(listener)
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
shutdown()
|
|
||||||
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -376,20 +242,20 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
|||||||
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
type Heartbeat struct {
|
type heartbeat struct {
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var body Heartbeat
|
var body heartbeat
|
||||||
|
|
||||||
body.UUID = app.runtime.UUID
|
body.UUID = app.context.uuid
|
||||||
body.Version = model.Version
|
body.Version = model.Version
|
||||||
|
|
||||||
bodyJson, err := json.Marshal(body)
|
bodyJson, err := json.Marshal(body)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start")
|
tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,60 +265,43 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
|||||||
|
|
||||||
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
|
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
tlog.App.Debug().Msg("Sending heartbeat")
|
||||||
case <-ticker.C:
|
|
||||||
app.log.App.Debug().Msg("Sending heartbeat")
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to create heartbeat request")
|
tlog.App.Error().Err(err).Msg("Failed to create heartbeat request")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
req.Header.Add("Content-Type", "application/json")
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Error().Err(err).Msg("Failed to send heartbeat")
|
tlog.App.Error().Err(err).Msg("Failed to send heartbeat")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode != 200 && res.StatusCode != 201 {
|
if res.StatusCode != 200 && res.StatusCode != 201 {
|
||||||
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
|
||||||
}
|
|
||||||
case <-app.ctx.Done():
|
|
||||||
app.log.App.Debug().Msg("Stopping heartbeat routine")
|
|
||||||
ticker.Stop()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) dbCleanupRoutine() {
|
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
tlog.App.Debug().Msg("Cleaning up old database sessions")
|
||||||
case <-ticker.C:
|
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
||||||
app.log.App.Debug().Msg("Running database cleanup")
|
if err != nil {
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions")
|
||||||
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
|
|
||||||
}
|
|
||||||
|
|
||||||
app.log.App.Debug().Msg("Database cleanup completed")
|
|
||||||
case <-app.ctx.Done():
|
|
||||||
app.log.App.Debug().Msg("Stopping database cleanup routine")
|
|
||||||
ticker.Stop()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,26 +14,19 @@ import (
|
|||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) SetupDatabase() error {
|
func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) {
|
||||||
dir := filepath.Dir(app.config.Database.Path)
|
dir := filepath.Dir(databasePath)
|
||||||
|
|
||||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||||
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", app.config.Database.Path)
|
db, err := sql.Open("sqlite", databasePath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the database if there is an error during migration
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
db.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
||||||
// if the sqlite connection starts being a bottleneck
|
// if the sqlite connection starts being a bottleneck
|
||||||
db.SetMaxOpenConns(1)
|
db.SetMaxOpenConns(1)
|
||||||
@@ -41,29 +34,24 @@ func (app *BootstrapApp) SetupDatabase() error {
|
|||||||
migrations, err := iofs.New(assets.Migrations, "migrations")
|
migrations, err := iofs.New(assets.Migrations, "migrations")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create migrations: %w", err)
|
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create migrator: %w", err)
|
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
return fmt.Errorf("failed to migrate database: %w", err)
|
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.db = db
|
return db, nil
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (app *BootstrapApp) GetDB() *sql.DB {
|
|
||||||
return app.db
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,16 +2,21 @@ package bootstrap
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupRouter() error {
|
var DEV_MODES = []string{"main", "test", "development"}
|
||||||
// we don't want gin debug mode
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||||
|
if !slices.Contains(DEV_MODES, model.Version) {
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
}
|
||||||
|
|
||||||
engine := gin.New()
|
engine := gin.New()
|
||||||
engine.Use(gin.Recovery())
|
engine.Use(gin.Recovery())
|
||||||
@@ -20,36 +25,101 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
|
err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to set trusted proxies: %w", err)
|
return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
|
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
||||||
engine.Use(contextMiddleware.Middleware())
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
|
}, app.services.authService, app.services.oauthBrokerService)
|
||||||
|
|
||||||
uiMiddleware, err := middleware.NewUIMiddleware()
|
err := contextMiddleware.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize UI middleware: %w", err)
|
return nil, fmt.Errorf("failed to initialize context middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.Use(contextMiddleware.Middleware())
|
||||||
|
|
||||||
|
uiMiddleware := middleware.NewUIMiddleware()
|
||||||
|
|
||||||
|
err = uiMiddleware.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize UI middleware: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.Use(uiMiddleware.Middleware())
|
engine.Use(uiMiddleware.Middleware())
|
||||||
|
|
||||||
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
|
zerologMiddleware := middleware.NewZerologMiddleware()
|
||||||
|
|
||||||
|
err = zerologMiddleware.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
engine.Use(zerologMiddleware.Middleware())
|
engine.Use(zerologMiddleware.Middleware())
|
||||||
|
|
||||||
apiRouter := engine.Group("/api")
|
apiRouter := engine.Group("/api")
|
||||||
|
|
||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
contextController := controller.NewContextController(controller.ContextControllerConfig{
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
Providers: app.context.configuredProviders,
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
Title: app.config.UI.Title,
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
AppURL: app.config.AppURL,
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
CookieDomain: app.context.cookieDomain,
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
|
||||||
controller.NewHealthController(apiRouter)
|
BackgroundImage: app.config.UI.BackgroundImage,
|
||||||
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
|
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
|
||||||
|
WarningsEnabled: app.config.UI.WarningsEnabled,
|
||||||
|
}, apiRouter)
|
||||||
|
|
||||||
app.router = engine
|
contextController.SetupRoutes()
|
||||||
return nil
|
|
||||||
|
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
|
||||||
|
AppURL: app.config.AppURL,
|
||||||
|
SecureCookie: app.config.Auth.SecureCookie,
|
||||||
|
CSRFCookieName: app.context.csrfCookieName,
|
||||||
|
RedirectCookieName: app.context.redirectCookieName,
|
||||||
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
OAuthSessionCookieName: app.context.oauthSessionCookieName,
|
||||||
|
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
||||||
|
}, apiRouter, app.services.authService)
|
||||||
|
|
||||||
|
oauthController.SetupRoutes()
|
||||||
|
|
||||||
|
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
|
||||||
|
|
||||||
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
|
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
|
||||||
|
AppURL: app.config.AppURL,
|
||||||
|
}, apiRouter, app.services.accessControlService, app.services.authService)
|
||||||
|
|
||||||
|
proxyController.SetupRoutes()
|
||||||
|
|
||||||
|
userController := controller.NewUserController(controller.UserControllerConfig{
|
||||||
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
|
}, apiRouter, app.services.authService)
|
||||||
|
|
||||||
|
userController.SetupRoutes()
|
||||||
|
|
||||||
|
resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{
|
||||||
|
Path: app.config.Resources.Path,
|
||||||
|
Enabled: app.config.Resources.Enabled,
|
||||||
|
}, &engine.RouterGroup)
|
||||||
|
|
||||||
|
resourcesController.SetupRoutes()
|
||||||
|
|
||||||
|
healthController := controller.NewHealthController(apiRouter)
|
||||||
|
|
||||||
|
healthController.SetupRoutes()
|
||||||
|
|
||||||
|
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
|
||||||
|
|
||||||
|
wellknownController.SetupRoutes()
|
||||||
|
|
||||||
|
return engine, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,66 +1,131 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (app *BootstrapApp) setupServices() error {
|
type Services struct {
|
||||||
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
|
accessControlService *service.AccessControlsService
|
||||||
|
authService *service.AuthService
|
||||||
|
dockerService *service.DockerService
|
||||||
|
kubernetesService *service.KubernetesService
|
||||||
|
ldapService *service.LdapService
|
||||||
|
oauthBrokerService *service.OAuthBrokerService
|
||||||
|
oidcService *service.OIDCService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
|
||||||
|
services := Services{}
|
||||||
|
|
||||||
|
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||||
|
Address: app.config.LDAP.Address,
|
||||||
|
BindDN: app.config.LDAP.BindDN,
|
||||||
|
BindPassword: app.config.LDAP.BindPassword,
|
||||||
|
BaseDN: app.config.LDAP.BaseDN,
|
||||||
|
Insecure: app.config.LDAP.Insecure,
|
||||||
|
SearchFilter: app.config.LDAP.SearchFilter,
|
||||||
|
AuthCert: app.config.LDAP.AuthCert,
|
||||||
|
AuthKey: app.config.LDAP.AuthKey,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := ldapService.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
|
tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it")
|
||||||
|
ldapService.Unconfigure()
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.ldapService = ldapService
|
services.ldapService = ldapService
|
||||||
|
|
||||||
|
var labelProvider service.LabelProvider
|
||||||
|
var dockerService *service.DockerService
|
||||||
|
var kubernetesService *service.KubernetesService
|
||||||
|
|
||||||
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
useKubernetes := app.config.LabelProvider == "kubernetes" ||
|
||||||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
|
||||||
|
|
||||||
var labelProvider service.LabelProvider
|
|
||||||
|
|
||||||
if useKubernetes {
|
if useKubernetes {
|
||||||
app.log.App.Debug().Msg("Using Kubernetes label provider")
|
tlog.App.Debug().Msg("Using Kubernetes label provider")
|
||||||
|
kubernetesService = service.NewKubernetesService()
|
||||||
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
|
err = kubernetesService.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize kubernetes service: %w", err)
|
return Services{}, err
|
||||||
}
|
}
|
||||||
|
services.kubernetesService = kubernetesService
|
||||||
app.services.kubernetesService = kubernetesService
|
|
||||||
labelProvider = kubernetesService
|
labelProvider = kubernetesService
|
||||||
} else {
|
} else {
|
||||||
app.log.App.Debug().Msg("Using Docker label provider")
|
tlog.App.Debug().Msg("Using Docker label provider")
|
||||||
|
dockerService = service.NewDockerService()
|
||||||
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
|
err = dockerService.Init()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize docker service: %w", err)
|
return Services{}, err
|
||||||
}
|
}
|
||||||
|
services.dockerService = dockerService
|
||||||
app.services.dockerService = dockerService
|
|
||||||
labelProvider = dockerService
|
labelProvider = dockerService
|
||||||
}
|
}
|
||||||
|
|
||||||
accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
|
accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps)
|
||||||
app.services.accessControlService = accessControlsService
|
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
err = accessControlsService.Init()
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
|
||||||
|
|
||||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
|
|
||||||
app.services.authService = authService
|
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize oidc service: %w", err)
|
return Services{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.services.oidcService = oidcService
|
services.accessControlService = accessControlsService
|
||||||
|
|
||||||
return nil
|
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
|
||||||
|
|
||||||
|
err = oauthBrokerService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
|
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||||
|
LocalUsers: app.context.localUsers,
|
||||||
|
OauthWhitelist: app.config.OAuth.Whitelist,
|
||||||
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
|
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
||||||
|
SecureCookie: app.config.Auth.SecureCookie,
|
||||||
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
LoginTimeout: app.config.Auth.LoginTimeout,
|
||||||
|
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
|
IP: app.config.Auth.IP,
|
||||||
|
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
||||||
|
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
||||||
|
}, services.ldapService, queries, services.oauthBrokerService)
|
||||||
|
|
||||||
|
err = authService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.authService = authService
|
||||||
|
|
||||||
|
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
|
||||||
|
Clients: app.config.OIDC.Clients,
|
||||||
|
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
||||||
|
PublicKeyPath: app.config.OIDC.PublicKeyPath,
|
||||||
|
Issuer: app.config.AppURL,
|
||||||
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
|
}, queries)
|
||||||
|
|
||||||
|
err = oidcService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.oidcService = oidcService
|
||||||
|
|
||||||
|
return services, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -24,52 +24,62 @@ type UserContextResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AppContextResponse struct {
|
type AppContextResponse struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Providers []model.Provider `json:"providers"`
|
Providers []Provider `json:"providers"`
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
AppURL string `json:"appUrl"`
|
AppURL string `json:"appUrl"`
|
||||||
CookieDomain string `json:"cookieDomain"`
|
CookieDomain string `json:"cookieDomain"`
|
||||||
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
|
ForgotPasswordMessage string `json:"forgotPasswordMessage"`
|
||||||
BackgroundImage string `json:"backgroundImage"`
|
BackgroundImage string `json:"backgroundImage"`
|
||||||
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
|
OAuthAutoRedirect string `json:"oauthAutoRedirect"`
|
||||||
WarningsEnabled bool `json:"warningsEnabled"`
|
WarningsEnabled bool `json:"warningsEnabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Provider struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
OAuth bool `json:"oauth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContextControllerConfig struct {
|
||||||
|
Providers []Provider
|
||||||
|
Title string
|
||||||
|
AppURL string
|
||||||
|
CookieDomain string
|
||||||
|
ForgotPasswordMessage string
|
||||||
|
BackgroundImage string
|
||||||
|
OAuthAutoRedirect string
|
||||||
|
WarningsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextController struct {
|
type ContextController struct {
|
||||||
log *logger.Logger
|
config ContextControllerConfig
|
||||||
config model.Config
|
router *gin.RouterGroup
|
||||||
runtime model.RuntimeConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextController(
|
func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController {
|
||||||
log *logger.Logger,
|
if !config.WarningsEnabled {
|
||||||
config model.Config,
|
tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.")
|
||||||
runtimeConfig model.RuntimeConfig,
|
|
||||||
router *gin.RouterGroup,
|
|
||||||
) *ContextController {
|
|
||||||
controller := &ContextController{
|
|
||||||
log: log,
|
|
||||||
config: config,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !config.UI.WarningsEnabled {
|
return &ContextController{
|
||||||
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
|
config: config,
|
||||||
|
router: router,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contextGroup := router.Group("/context")
|
func (controller *ContextController) SetupRoutes() {
|
||||||
|
contextGroup := controller.router.Group("/context")
|
||||||
contextGroup.GET("/user", controller.userContextHandler)
|
contextGroup.GET("/user", controller.userContextHandler)
|
||||||
contextGroup.GET("/app", controller.appContextHandler)
|
contextGroup.GET("/app", controller.appContextHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ContextController) userContextHandler(c *gin.Context) {
|
func (controller *ContextController) userContextHandler(c *gin.Context) {
|
||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request")
|
tlog.App.Debug().Err(err).Msg("No user context found in request")
|
||||||
c.JSON(200, UserContextResponse{
|
c.JSON(200, UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
@@ -95,10 +105,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
func (controller *ContextController) appContextHandler(c *gin.Context) {
|
||||||
appUrl, err := url.Parse(controller.runtime.AppURL)
|
appUrl, err := url.Parse(controller.config.AppURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
|
tlog.App.Error().Err(err).Msg("Failed to parse app URL")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -109,13 +118,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) {
|
|||||||
c.JSON(200, AppContextResponse{
|
c.JSON(200, AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Providers: controller.runtime.ConfiguredProviders,
|
Providers: controller.config.Providers,
|
||||||
Title: controller.config.UI.Title,
|
Title: controller.config.Title,
|
||||||
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
|
AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host),
|
||||||
CookieDomain: controller.runtime.CookieDomain,
|
CookieDomain: controller.config.CookieDomain,
|
||||||
ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: controller.config.ForgotPasswordMessage,
|
||||||
BackgroundImage: controller.config.UI.BackgroundImage,
|
BackgroundImage: controller.config.BackgroundImage,
|
||||||
OAuthAutoRedirect: controller.config.OAuth.AutoRedirect,
|
OAuthAutoRedirect: controller.config.OAuthAutoRedirect,
|
||||||
WarningsEnabled: controller.config.UI.WarningsEnabled,
|
WarningsEnabled: controller.config.WarningsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,19 +8,30 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestContextController(t *testing.T) {
|
func TestContextController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
controllerConfig := controller.ContextControllerConfig{
|
||||||
|
Providers: []controller.Provider{
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
{
|
||||||
|
Name: "Local",
|
||||||
|
ID: "local",
|
||||||
|
OAuth: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Title: "Tinyauth",
|
||||||
|
AppURL: "https://tinyauth.example.com",
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
ForgotPasswordMessage: "foo",
|
||||||
|
BackgroundImage: "/background.jpg",
|
||||||
|
OAuthAutoRedirect: "none",
|
||||||
|
WarningsEnabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
@@ -36,17 +47,17 @@ func TestContextController(t *testing.T) {
|
|||||||
expectedAppContextResponse := controller.AppContextResponse{
|
expectedAppContextResponse := controller.AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: controllerConfig.Providers,
|
||||||
Title: cfg.UI.Title,
|
Title: controllerConfig.Title,
|
||||||
AppURL: runtime.AppURL,
|
AppURL: controllerConfig.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: controllerConfig.CookieDomain,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: controllerConfig.BackgroundImage,
|
||||||
OAuthAutoRedirect: cfg.OAuth.AutoRedirect,
|
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: controllerConfig.WarningsEnabled,
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedAppContextResponse)
|
bytes, err := json.Marshal(expectedAppContextResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -60,7 +71,7 @@ func TestContextController(t *testing.T) {
|
|||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedUserContextResponse)
|
bytes, err := json.Marshal(expectedUserContextResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -75,7 +86,7 @@ func TestContextController(t *testing.T) {
|
|||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
|
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -89,11 +100,11 @@ func TestContextController(t *testing.T) {
|
|||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
|
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedUserContextResponse)
|
bytes, err := json.Marshal(expectedUserContextResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -110,12 +121,13 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewContextController(log, cfg, runtime, group)
|
contextController := controller.NewContextController(controllerConfig, group)
|
||||||
|
contextController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
request, err := http.NewRequest("GET", test.path, nil)
|
request, err := http.NewRequest("GET", test.path, nil)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,18 @@ package controller
|
|||||||
import "github.com/gin-gonic/gin"
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
type HealthController struct {
|
type HealthController struct {
|
||||||
|
router *gin.RouterGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
func NewHealthController(router *gin.RouterGroup) *HealthController {
|
||||||
controller := &HealthController{}
|
return &HealthController{
|
||||||
|
router: router,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router.GET("/healthz", controller.healthHandler)
|
func (controller *HealthController) SetupRoutes() {
|
||||||
router.HEAD("/healthz", controller.healthHandler)
|
controller.router.GET("/healthz", controller.healthHandler)
|
||||||
|
controller.router.HEAD("/healthz", controller.healthHandler)
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *HealthController) healthHandler(c *gin.Context) {
|
func (controller *HealthController) healthHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
|
tlog.NewTestLogger().Init()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
path string
|
path string
|
||||||
@@ -29,7 +30,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
"message": "Healthy",
|
"message": "Healthy",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedHealthResponse)
|
bytes, err := json.Marshal(expectedHealthResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -43,7 +44,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
"message": "Healthy",
|
"message": "Healthy",
|
||||||
}
|
}
|
||||||
bytes, err := json.Marshal(expectedHealthResponse)
|
bytes, err := json.Marshal(expectedHealthResponse)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
return string(bytes)
|
return string(bytes)
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
@@ -55,12 +56,13 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewHealthController(group)
|
healthController := controller.NewHealthController(group)
|
||||||
|
healthController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
request, err := http.NewRequest(test.method, test.path, nil)
|
request, err := http.NewRequest(test.method, test.path, nil)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
|
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -20,32 +19,34 @@ type OAuthRequest struct {
|
|||||||
Provider string `uri:"provider" binding:"required"`
|
Provider string `uri:"provider" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthController struct {
|
type OAuthControllerConfig struct {
|
||||||
log *logger.Logger
|
CSRFCookieName string
|
||||||
config model.Config
|
OAuthSessionCookieName string
|
||||||
runtime model.RuntimeConfig
|
RedirectCookieName string
|
||||||
auth *service.AuthService
|
SecureCookie bool
|
||||||
|
AppURL string
|
||||||
|
CookieDomain string
|
||||||
|
SubdomainsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthController(
|
type OAuthController struct {
|
||||||
log *logger.Logger,
|
config OAuthControllerConfig
|
||||||
config model.Config,
|
router *gin.RouterGroup
|
||||||
runtimeConfig model.RuntimeConfig,
|
auth *service.AuthService
|
||||||
router *gin.RouterGroup,
|
}
|
||||||
auth *service.AuthService,
|
|
||||||
) *OAuthController {
|
|
||||||
controller := &OAuthController{
|
|
||||||
log: log,
|
|
||||||
config: config,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
auth: auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
oauthGroup := router.Group("/oauth")
|
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
|
||||||
|
return &OAuthController{
|
||||||
|
config: config,
|
||||||
|
router: router,
|
||||||
|
auth: auth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *OAuthController) SetupRoutes() {
|
||||||
|
oauthGroup := controller.router.Group("/oauth")
|
||||||
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
|
||||||
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||||
@@ -53,7 +54,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -66,7 +67,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
err = c.BindQuery(&reqParams)
|
err = c.BindQuery(&reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind query parameters")
|
tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -75,10 +76,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
|
||||||
|
|
||||||
if !isRedirectSafe {
|
if !isRedirectSafe {
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -86,7 +87,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
|
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -97,7 +98,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session")
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -105,7 +106,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -119,7 +120,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -127,21 +128,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName)
|
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
|
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,8 +150,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state != oauthPendingSession.State {
|
if state != oauthPendingSession.State {
|
||||||
controller.log.App.Warn().Msg("OAuth state mismatch")
|
tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,80 +159,68 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if user == nil {
|
|
||||||
controller.log.App.Warn().Msg("OAuth provider did not return user info")
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
tlog.App.Error().Msg("OAuth provider did not return an email")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
if !controller.auth.IsEmailWhitelisted(user.Email) {
|
||||||
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
|
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
|
||||||
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
|
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
|
|
||||||
if strings.TrimSpace(user.Name) != "" {
|
if strings.TrimSpace(user.Name) != "" {
|
||||||
controller.log.App.Debug().Msg("Using name from OAuth provider")
|
tlog.App.Debug().Msg("Using name from OAuth provider")
|
||||||
name = user.Name
|
name = user.Name
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
|
tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name")
|
||||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
var username string
|
var username string
|
||||||
|
|
||||||
if strings.TrimSpace(user.PreferredUsername) != "" {
|
if strings.TrimSpace(user.PreferredUsername) != "" {
|
||||||
controller.log.App.Debug().Msg("Using preferred username from OAuth provider")
|
tlog.App.Debug().Msg("Using preferred username from OAuth provider")
|
||||||
username = user.PreferredUsername
|
username = user.PreferredUsername
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email")
|
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
|
||||||
username = strings.Replace(user.Email, "@", "_", 1)
|
username = strings.Replace(user.Email, "@", "_", 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if svc.ID() != req.Provider {
|
if svc.ID() != req.Provider {
|
||||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,29 +234,29 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
OAuthSub: user.Sub,
|
OAuthSub: user.Sub,
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Msg("Creating session cookie for user")
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP())
|
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
||||||
|
|
||||||
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
||||||
controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params")
|
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
|
||||||
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,16 +266,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
||||||
@@ -297,8 +286,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) getCookieDomain() string {
|
func (controller *OAuthController) getCookieDomain() string {
|
||||||
if controller.config.Auth.SubdomainsEnabled {
|
if controller.config.SubdomainsEnabled {
|
||||||
return "." + controller.runtime.CookieDomain
|
return "." + controller.config.CookieDomain
|
||||||
}
|
}
|
||||||
return controller.runtime.CookieDomain
|
return controller.config.CookieDomain
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,13 +13,15 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OIDCControllerConfig struct{}
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
log *logger.Logger
|
config OIDCControllerConfig
|
||||||
oidc *service.OIDCService
|
router *gin.RouterGroup
|
||||||
runtime model.RuntimeConfig
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCallback struct {
|
type AuthorizeCallback struct {
|
||||||
@@ -56,42 +58,29 @@ type ClientCredentials struct {
|
|||||||
ClientSecret string
|
ClientSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCController(
|
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
|
||||||
log *logger.Logger,
|
return &OIDCController{
|
||||||
oidcService *service.OIDCService,
|
config: config,
|
||||||
runtimeConfig model.RuntimeConfig,
|
oidc: oidcService,
|
||||||
router *gin.RouterGroup) *OIDCController {
|
router: router,
|
||||||
controller := &OIDCController{
|
|
||||||
log: log,
|
|
||||||
oidc: oidcService,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
oidcGroup := router.Group("/oidc")
|
func (controller *OIDCController) SetupRoutes() {
|
||||||
|
oidcGroup := controller.router.Group("/oidc")
|
||||||
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
|
||||||
oidcGroup.POST("/authorize", controller.Authorize)
|
oidcGroup.POST("/authorize", controller.Authorize)
|
||||||
oidcGroup.POST("/token", controller.Token)
|
oidcGroup.POST("/token", controller.Token)
|
||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
oidcGroup.POST("/userinfo", controller.Userinfo)
|
oidcGroup.POST("/userinfo", controller.Userinfo)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
|
||||||
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req ClientRequest
|
var req ClientRequest
|
||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind URI")
|
tlog.App.Error().Err(err).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -102,7 +91,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found")
|
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Client not found",
|
"message": "Client not found",
|
||||||
@@ -118,7 +107,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Authorize(c *gin.Context) {
|
func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if !controller.oidc.IsConfigured() {
|
||||||
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
|
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -153,7 +142,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err = controller.oidc.ValidateAuthorizeParams(req)
|
err = controller.oidc.ValidateAuthorizeParams(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
|
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
|
||||||
if err.Error() != "invalid_request_uri" {
|
if err.Error() != "invalid_request_uri" {
|
||||||
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
||||||
return
|
return
|
||||||
@@ -185,7 +174,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to store user info")
|
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
||||||
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -208,10 +197,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if !controller.oidc.IsConfigured() {
|
||||||
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
tlog.App.Warn().Msg("OIDC not configured")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"error": "server_error",
|
"error": "not_found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -220,7 +209,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.Bind(&req)
|
err := c.Bind(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to bind token request")
|
tlog.App.Error().Err(err).Msg("Failed to bind token request")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -229,7 +218,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
err = controller.oidc.ValidateGrantType(req.GrantType)
|
err = controller.oidc.ValidateGrantType(req.GrantType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Warn().Err(err).Msg("Invalid grant type")
|
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
@@ -244,12 +233,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
// If it fails, we try basic auth
|
// If it fails, we try basic auth
|
||||||
if creds.ClientID == "" || creds.ClientSecret == "" {
|
if creds.ClientID == "" || creds.ClientSecret == "" {
|
||||||
controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth")
|
tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth")
|
||||||
|
|
||||||
clientId, clientSecret, ok := c.Request.BasicAuth()
|
clientId, clientSecret, ok := c.Request.BasicAuth()
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("Client credentials not found in basic auth")
|
tlog.App.Error().Msg("Missing authorization header")
|
||||||
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
@@ -266,7 +255,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
client, ok := controller.oidc.GetClient(creds.ClientID)
|
client, ok := controller.oidc.GetClient(creds.ClientID)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found")
|
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
@@ -274,7 +263,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if client.ClientSecret != creds.ClientSecret {
|
if client.ClientSecret != creds.ClientSecret {
|
||||||
controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret")
|
tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
@@ -288,30 +277,30 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete code")
|
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeNotFound) {
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
controller.log.App.Warn().Msg("Code not found")
|
tlog.App.Warn().Msg("Code not found")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrCodeExpired) {
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
controller.log.App.Warn().Msg("Code expired")
|
tlog.App.Warn().Msg("Code expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errors.Is(err, service.ErrInvalidClient) {
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
controller.log.App.Warn().Msg("Code does not belong to client")
|
tlog.App.Warn().Msg("Invalid client ID")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -319,7 +308,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
controller.log.App.Warn().Msg("Redirect URI does not match")
|
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -329,7 +318,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("PKCE validation failed")
|
tlog.App.Warn().Msg("PKCE validation failed")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -339,7 +328,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -352,7 +341,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenExpired) {
|
if errors.Is(err, service.ErrTokenExpired) {
|
||||||
controller.log.App.Warn().Msg("Refresh token expired")
|
tlog.App.Error().Err(err).Msg("Refresh token expired")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
@@ -360,14 +349,14 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, service.ErrInvalidClient) {
|
if errors.Is(err, service.ErrInvalidClient) {
|
||||||
controller.log.App.Warn().Msg("Refresh token does not belong to client")
|
tlog.App.Error().Err(err).Msg("Invalid client")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to refresh access token")
|
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -384,10 +373,10 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if !controller.oidc.IsConfigured() {
|
||||||
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
tlog.App.Warn().Msg("OIDC not configured")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"error": "server_error",
|
"error": "not_found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -398,7 +387,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
if authorization != "" {
|
if authorization != "" {
|
||||||
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
|
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -406,7 +395,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower(tokenType) != "bearer" {
|
if strings.ToLower(tokenType) != "bearer" {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -416,7 +405,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
token = bearerToken
|
token = bearerToken
|
||||||
} else if c.Request.Method == http.MethodPost {
|
} else if c.Request.Method == http.MethodPost {
|
||||||
if c.ContentType() != "application/x-www-form-urlencoded" {
|
if c.ContentType() != "application/x-www-form-urlencoded" {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
|
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -424,14 +413,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
token = c.PostForm("access_token")
|
token = c.PostForm("access_token")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token")
|
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
@@ -442,14 +431,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrTokenNotFound) {
|
if errors.Is(err, service.ErrTokenNotFound) {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_grant",
|
"error": "invalid_grant",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get access token")
|
tlog.App.Err(err).Msg("Failed to get token entry")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -458,7 +447,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
|
|
||||||
// If we don't have the openid scope, return an error
|
// If we don't have the openid scope, return an error
|
||||||
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_scope",
|
"error": "invalid_scope",
|
||||||
})
|
})
|
||||||
@@ -468,7 +457,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get user info")
|
tlog.App.Err(err).Msg("Failed to get user entry")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "server_error",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
@@ -479,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
||||||
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
|
tlog.App.Error().Err(err).Msg(reason)
|
||||||
|
|
||||||
if callback != "" {
|
if callback != "" {
|
||||||
errorQueries := CallbackError{
|
errorQueries := CallbackError{
|
||||||
@@ -519,16 +508,8 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectUrl := ""
|
|
||||||
|
|
||||||
if controller.oidc != nil {
|
|
||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
|
|
||||||
} else {
|
|
||||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": redirectUrl,
|
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
package controller_test
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -20,15 +19,29 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOIDCController(t *testing.T) {
|
func TestOIDCController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
oidcServiceCfg := service.OIDCServiceConfig{
|
||||||
|
Clients: map[string]model.OIDCClientConfig{
|
||||||
|
"test": {
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
ClientSecret: "some-client-secret",
|
||||||
|
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||||
|
Name: "Test Client",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
||||||
|
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
||||||
|
Issuer: "https://tinyauth.example.com",
|
||||||
|
SessionExpiry: 500,
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerCfg := controller.OIDCControllerConfig{}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
@@ -90,7 +103,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
|
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
|
||||||
},
|
},
|
||||||
@@ -110,7 +123,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
Nonce: "some-nonce",
|
Nonce: "some-nonce",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -118,7 +131,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
|
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
|
||||||
},
|
},
|
||||||
@@ -138,7 +151,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
Nonce: "some-nonce",
|
Nonce: "some-nonce",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -147,11 +160,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -170,7 +183,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -178,7 +191,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, res["error"], "unsupported_grant_type")
|
assert.Equal(t, res["error"], "unsupported_grant_type")
|
||||||
},
|
},
|
||||||
@@ -193,7 +206,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -231,7 +244,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -254,11 +267,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var authorizeRes map[string]any
|
var authorizeRes map[string]any
|
||||||
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := authorizeRes["redirect_uri"].(string)
|
redirectURI := authorizeRes["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
code := queryParams.Get("code")
|
code := queryParams.Get("code")
|
||||||
@@ -270,7 +283,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -293,7 +306,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var tokenRes map[string]any
|
var tokenRes map[string]any
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok := tokenRes["refresh_token"]
|
_, ok := tokenRes["refresh_token"]
|
||||||
assert.True(t, ok, "Expected refresh token in response")
|
assert.True(t, ok, "Expected refresh token in response")
|
||||||
@@ -307,7 +320,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
ClientSecret: "some-client-secret",
|
ClientSecret: "some-client-secret",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -319,7 +332,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
var refreshRes map[string]any
|
var refreshRes map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
|
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok = refreshRes["access_token"]
|
_, ok = refreshRes["access_token"]
|
||||||
assert.True(t, ok, "Expected access token in refresh response")
|
assert.True(t, ok, "Expected access token in refresh response")
|
||||||
@@ -340,11 +353,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var authorizeRes map[string]any
|
var authorizeRes map[string]any
|
||||||
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := authorizeRes["redirect_uri"].(string)
|
redirectURI := authorizeRes["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
code := queryParams.Get("code")
|
code := queryParams.Get("code")
|
||||||
@@ -356,7 +369,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -376,7 +389,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var secondRes map[string]any
|
var secondRes map[string]any
|
||||||
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
|
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "invalid_grant", secondRes["error"])
|
assert.Equal(t, "invalid_grant", secondRes["error"])
|
||||||
},
|
},
|
||||||
@@ -404,7 +417,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var tokenRes map[string]any
|
var tokenRes map[string]any
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken := tokenRes["access_token"].(string)
|
accessToken := tokenRes["access_token"].(string)
|
||||||
assert.NotEmpty(t, accessToken)
|
assert.NotEmpty(t, accessToken)
|
||||||
@@ -416,7 +429,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var userInfoRes map[string]any
|
var userInfoRes map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok := userInfoRes["sub"]
|
_, ok := userInfoRes["sub"]
|
||||||
assert.True(t, ok, "Expected sub claim in userinfo response")
|
assert.True(t, ok, "Expected sub claim in userinfo response")
|
||||||
@@ -436,7 +449,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -451,7 +464,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -466,7 +479,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -481,7 +494,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_grant", res["error"])
|
assert.Equal(t, "invalid_grant", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -496,7 +509,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -511,7 +524,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_request", res["error"])
|
assert.Equal(t, "invalid_request", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -528,7 +541,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var tokenRes map[string]any
|
var tokenRes map[string]any
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken := tokenRes["access_token"].(string)
|
accessToken := tokenRes["access_token"].(string)
|
||||||
assert.NotEmpty(t, accessToken)
|
assert.NotEmpty(t, accessToken)
|
||||||
@@ -542,7 +555,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var userInfoRes map[string]any
|
var userInfoRes map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, ok := userInfoRes["sub"]
|
_, ok := userInfoRes["sub"]
|
||||||
assert.True(t, ok, "Expected sub claim in userinfo response")
|
assert.True(t, ok, "Expected sub claim in userinfo response")
|
||||||
@@ -566,7 +579,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "",
|
CodeChallengeMethod: "",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -575,11 +588,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -596,7 +609,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeVerifier: "some-challenge",
|
CodeVerifier: "some-challenge",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(tokenReqBody)
|
reqBodyEncoded, err := query.Values(tokenReqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -627,7 +640,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "S256",
|
CodeChallengeMethod: "S256",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -636,11 +649,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -657,7 +670,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeVerifier: "some-challenge",
|
CodeVerifier: "some-challenge",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(tokenReqBody)
|
reqBodyEncoded, err := query.Values(tokenReqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -688,7 +701,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "S256",
|
CodeChallengeMethod: "S256",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -697,11 +710,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
assert.Equal(t, queryParams.Get("state"), "some-state")
|
||||||
@@ -718,7 +731,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeVerifier: "some-challenge-1",
|
CodeVerifier: "some-challenge-1",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(tokenReqBody)
|
reqBodyEncoded, err := query.Values(tokenReqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -749,7 +762,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
CodeChallengeMethod: "foo",
|
CodeChallengeMethod: "foo",
|
||||||
}
|
}
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
reqBodyBytes, err := json.Marshal(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -758,11 +771,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
error := queryParams.Get("error")
|
error := queryParams.Get("error")
|
||||||
@@ -781,11 +794,11 @@ func TestOIDCController(t *testing.T) {
|
|||||||
|
|
||||||
var res map[string]any
|
var res map[string]any
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
redirectURI := res["redirect_uri"].(string)
|
||||||
url, err := url.Parse(redirectURI)
|
url, err := url.Parse(redirectURI)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
queryParams := url.Query()
|
queryParams := url.Query()
|
||||||
code := queryParams.Get("code")
|
code := queryParams.Get("code")
|
||||||
@@ -797,7 +810,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
RedirectURI: "https://test.example.com/callback",
|
RedirectURI: "https://test.example.com/callback",
|
||||||
}
|
}
|
||||||
reqBodyEncoded, err := query.Values(reqBody)
|
reqBodyEncoded, err := query.Values(reqBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -808,7 +821,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accessToken := res["access_token"].(string)
|
accessToken := res["access_token"].(string)
|
||||||
assert.NotEmpty(t, accessToken)
|
assert.NotEmpty(t, accessToken)
|
||||||
@@ -833,22 +846,20 @@ func TestOIDCController(t *testing.T) {
|
|||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "invalid_grant", res["error"])
|
assert.Equal(t, "invalid_grant", res["error"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||||
wg := &sync.WaitGroup{}
|
err = oidcService.Init()
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -862,7 +873,8 @@ func TestOIDCController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewOIDCController(log, oidcService, runtime, group)
|
oidcController := controller.NewOIDCController(controllerCfg, oidcService, group)
|
||||||
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -871,6 +883,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
@@ -50,31 +50,29 @@ type ProxyContext struct {
|
|||||||
ProxyType ProxyType
|
ProxyType ProxyType
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyController struct {
|
type ProxyControllerConfig struct {
|
||||||
log *logger.Logger
|
AppURL string
|
||||||
runtime model.RuntimeConfig
|
|
||||||
acls *service.AccessControlsService
|
|
||||||
auth *service.AuthService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyController(
|
type ProxyController struct {
|
||||||
log *logger.Logger,
|
config ProxyControllerConfig
|
||||||
runtime model.RuntimeConfig,
|
router *gin.RouterGroup
|
||||||
router *gin.RouterGroup,
|
acls *service.AccessControlsService
|
||||||
acls *service.AccessControlsService,
|
auth *service.AuthService
|
||||||
auth *service.AuthService,
|
}
|
||||||
) *ProxyController {
|
|
||||||
controller := &ProxyController{
|
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController {
|
||||||
log: log,
|
return &ProxyController{
|
||||||
runtime: runtime,
|
config: config,
|
||||||
acls: acls,
|
router: router,
|
||||||
auth: auth,
|
acls: acls,
|
||||||
|
auth: auth,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyGroup := router.Group("/auth")
|
func (controller *ProxyController) SetupRoutes() {
|
||||||
|
proxyGroup := controller.router.Group("/auth")
|
||||||
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||||
@@ -82,7 +80,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
proxyCtx, err := controller.getProxyContext(c)
|
proxyCtx, err := controller.getProxyContext(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request")
|
tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad request",
|
"message": "Bad request",
|
||||||
@@ -90,15 +88,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
|
||||||
|
|
||||||
// Get acls
|
// Get acls
|
||||||
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
|
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource")
|
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
|
||||||
|
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
|
|
||||||
if controller.auth.IsBypassedIP(clientIP, acls) {
|
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||||
@@ -113,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource")
|
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !authEnabled {
|
if !authEnabled {
|
||||||
controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication")
|
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -135,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -158,24 +160,26 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
|
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
|
||||||
userContext = &model.UserContext{
|
userContext = &model.UserContext{
|
||||||
Authenticated: false,
|
Authenticated: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
||||||
|
|
||||||
if userContext.Authenticated {
|
if userContext.Authenticated {
|
||||||
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
||||||
|
|
||||||
if !userAllowed {
|
if !userAllowed {
|
||||||
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource")
|
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -186,7 +190,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
queries.Set("username", userContext.GetUsername())
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -211,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !groupOK {
|
if !groupOK {
|
||||||
controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource")
|
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
|
||||||
|
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
@@ -219,7 +223,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -230,7 +234,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
queries.Set("username", userContext.GetUsername())
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -273,12 +277,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
|
||||||
controller.handleError(c, proxyCtx)
|
controller.handleError(c, proxyCtx)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -302,19 +306,20 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
|||||||
headers := utils.ParseHeaders(acls.Response.Headers)
|
headers := utils.ParseHeaders(acls.Response.Headers)
|
||||||
|
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
|
tlog.App.Debug().Str("header", key).Msg("Setting header")
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
|
basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile)
|
||||||
|
|
||||||
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
||||||
controller.log.App.Debug().Msg("Setting basic auth header for response")
|
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
|
||||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
|
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
|
||||||
redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL)
|
redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL)
|
||||||
|
|
||||||
if !controller.useBrowserResponse(proxyCtx) {
|
if !controller.useBrowserResponse(proxyCtx) {
|
||||||
c.Header("x-tinyauth-location", redirectURL)
|
c.Header("x-tinyauth-location", redirectURL)
|
||||||
@@ -515,7 +520,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
return ProxyContext{}, err
|
return ProxyContext{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy)
|
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
|
||||||
|
|
||||||
authModules := controller.determineAuthModules(proxy)
|
authModules := controller.determineAuthModules(proxy)
|
||||||
|
|
||||||
@@ -526,13 +531,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
var ctx ProxyContext
|
var ctx ProxyContext
|
||||||
|
|
||||||
for _, module := range authModules {
|
for _, module := range authModules {
|
||||||
controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module)
|
tlog.App.Debug().Msgf("Trying auth module: %v", module)
|
||||||
ctx, err = controller.getContextFromAuthModule(c, module)
|
ctx, err = controller.getContextFromAuthModule(c, module)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module)
|
tlog.App.Debug().Msgf("Auth module %v succeeded", module)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err)
|
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -544,9 +549,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext
|
|||||||
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
|
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
|
||||||
|
|
||||||
if isBrowser {
|
if isBrowser {
|
||||||
controller.log.App.Debug().Msg("Request identified as coming from a browser client")
|
tlog.App.Debug().Msg("Request identified as coming from a browser")
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Debug().Msg("Request identified as coming from a non-browser client")
|
tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.IsBrowser = isBrowser
|
ctx.IsBrowser = isBrowser
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
package controller_test
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -14,15 +13,35 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyController(t *testing.T) {
|
func TestProxyController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerCfg := controller.ProxyControllerConfig{
|
||||||
|
AppURL: "https://tinyauth.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
acls := map[string]model.App{
|
acls := map[string]model.App{
|
||||||
"app_path_allow": {
|
"app_path_allow": {
|
||||||
@@ -379,19 +398,32 @@ func TestProxyController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
docker := service.NewDockerService()
|
||||||
ctx := context.TODO()
|
err = docker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
err = ldap.Init()
|
||||||
aclsService := service.NewAccessControlsService(log, nil, acls)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
aclsService := service.NewAccessControlsService(docker, acls)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
@@ -406,13 +438,15 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(log, runtime, group, aclsService, authService)
|
proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService)
|
||||||
|
proxyController.SetupRoutes()
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,39 +4,42 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ResourcesControllerConfig struct {
|
||||||
|
Path string
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
type ResourcesController struct {
|
type ResourcesController struct {
|
||||||
config model.Config
|
config ResourcesControllerConfig
|
||||||
|
router *gin.RouterGroup
|
||||||
fileServer http.Handler
|
fileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourcesController(
|
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
|
||||||
config model.Config,
|
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path)))
|
||||||
router *gin.RouterGroup,
|
|
||||||
) *ResourcesController {
|
|
||||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
|
|
||||||
|
|
||||||
controller := &ResourcesController{
|
return &ResourcesController{
|
||||||
config: config,
|
config: config,
|
||||||
|
router: router,
|
||||||
fileServer: fileServer,
|
fileServer: fileServer,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router.GET("/resources/*resource", controller.resourcesHandler)
|
func (controller *ResourcesController) SetupRoutes() {
|
||||||
|
controller.router.GET("/resources/*resource", controller.resourcesHandler)
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||||
if controller.config.Resources.Path == "" {
|
if controller.config.Path == "" {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Resources not found",
|
"message": "Resources not found",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !controller.config.Resources.Enabled {
|
if !controller.config.Enabled {
|
||||||
c.JSON(403, gin.H{
|
c.JSON(403, gin.H{
|
||||||
"status": 403,
|
"status": 403,
|
||||||
"message": "Resources are disabled",
|
"message": "Resources are disabled",
|
||||||
|
|||||||
@@ -3,20 +3,26 @@ package controller_test
|
|||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResourcesController(t *testing.T) {
|
func TestResourcesController(t *testing.T) {
|
||||||
cfg, _ := test.CreateTestConfigs(t)
|
tlog.NewTestLogger().Init()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
resourcesControllerCfg := controller.ResourcesControllerConfig{
|
||||||
|
Path: path.Join(tempDir, "resources"),
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := os.Mkdir(resourcesControllerCfg.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -55,11 +61,11 @@ func TestResourcesController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := resourcesControllerCfg.Path + "/testfile.txt"
|
||||||
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
|
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt"
|
testFilePathParent := tempDir + "/somefile.txt"
|
||||||
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
|
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -69,7 +75,8 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewResourcesController(cfg, group)
|
resourcesController := controller.NewResourcesController(resourcesControllerCfg, group)
|
||||||
|
resourcesController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
@@ -25,30 +25,30 @@ type TotpRequest struct {
|
|||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserController struct {
|
type UserControllerConfig struct {
|
||||||
log *logger.Logger
|
CookieDomain string
|
||||||
runtime model.RuntimeConfig
|
SessionCookieName string
|
||||||
auth *service.AuthService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserController(
|
type UserController struct {
|
||||||
log *logger.Logger,
|
config UserControllerConfig
|
||||||
runtimeConfig model.RuntimeConfig,
|
router *gin.RouterGroup
|
||||||
router *gin.RouterGroup,
|
auth *service.AuthService
|
||||||
auth *service.AuthService,
|
}
|
||||||
) *UserController {
|
|
||||||
controller := &UserController{
|
|
||||||
log: log,
|
|
||||||
runtime: runtimeConfig,
|
|
||||||
auth: auth,
|
|
||||||
}
|
|
||||||
|
|
||||||
userGroup := router.Group("/user")
|
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
|
||||||
|
return &UserController{
|
||||||
|
config: config,
|
||||||
|
router: router,
|
||||||
|
auth: auth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (controller *UserController) SetupRoutes() {
|
||||||
|
userGroup := controller.router.Group("/user")
|
||||||
userGroup.POST("/login", controller.loginHandler)
|
userGroup.POST("/login", controller.loginHandler)
|
||||||
userGroup.POST("/logout", controller.logoutHandler)
|
userGroup.POST("/logout", controller.logoutHandler)
|
||||||
userGroup.POST("/totp", controller.totpHandler)
|
userGroup.POST("/totp", controller.totpHandler)
|
||||||
|
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *UserController) loginHandler(c *gin.Context) {
|
func (controller *UserController) loginHandler(c *gin.Context) {
|
||||||
@@ -56,7 +56,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind JSON")
|
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -64,13 +64,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt")
|
tlog.App.Debug().Str("username", req.Username).Msg("Login attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
|
isLocked, remaining := controller.auth.IsAccountLocked(req.Username)
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
|
tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts")
|
||||||
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked")
|
tlog.AuditLoginFailure(c, req.Username, "username", "account locked")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -84,16 +84,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrUserNotFound) {
|
if errors.Is(err, service.ErrUserNotFound) {
|
||||||
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
|
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found")
|
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt")
|
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -102,13 +102,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
||||||
controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt")
|
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
if search.Type == model.UserLocal {
|
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
|
||||||
controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password")
|
|
||||||
} else {
|
|
||||||
controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password")
|
|
||||||
}
|
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -122,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
localUser = controller.auth.GetLocalUser(req.Username)
|
localUser = controller.auth.GetLocalUser(req.Username)
|
||||||
|
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification")
|
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -131,7 +127,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if localUser.TOTPSecret != "" {
|
if localUser.TOTPSecret != "" {
|
||||||
controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session")
|
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||||
|
|
||||||
name := localUser.Attributes.Name
|
name := localUser.Attributes.Name
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@@ -140,7 +136,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
email := localUser.Attributes.Email
|
email := localUser.Attributes.Email
|
||||||
if email == "" {
|
if email == "" {
|
||||||
email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain)
|
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
||||||
@@ -152,7 +148,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -174,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,15 +185,14 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if search.Type == model.UserLDAP {
|
if search.Type == model.UserLDAP {
|
||||||
sessionCookie.Provider = "ldap"
|
sessionCookie.Provider = "ldap"
|
||||||
if search.Email != "" {
|
|
||||||
sessionCookie.Email = search.Email
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -207,13 +202,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
controller.log.App.Info().Str("username", req.Username).Msg("Login successful")
|
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
|
||||||
|
tlog.AuditLoginSuccess(c, req.Username, "username")
|
||||||
if search.Type == model.UserLocal {
|
|
||||||
controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP())
|
|
||||||
} else {
|
|
||||||
controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP())
|
|
||||||
}
|
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(req.Username, true)
|
controller.auth.RecordLoginAttempt(req.Username, true)
|
||||||
|
|
||||||
@@ -224,20 +214,20 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *UserController) logoutHandler(c *gin.Context) {
|
func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||||
controller.log.App.Debug().Msg("Logout attempt")
|
tlog.App.Debug().Msg("Logout request received")
|
||||||
|
|
||||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, http.ErrNoCookie) {
|
if errors.Is(err, http.ErrNoCookie) {
|
||||||
controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout")
|
tlog.App.Warn().Msg("No session cookie found on logout request")
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Logout successful",
|
"message": "Logout successful",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -248,7 +238,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
|
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -259,10 +249,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP())
|
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user")
|
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||||
controller.log.AuditLogout("unknown", "unknown", c.ClientIP())
|
tlog.AuditLogout(c, "unknown", "unknown")
|
||||||
}
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
@@ -278,7 +268,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification")
|
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"status": 400,
|
"status": 400,
|
||||||
"message": "Bad Request",
|
"message": "Bad Request",
|
||||||
@@ -289,7 +279,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
context, err := new(model.UserContext).NewFromGin(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification")
|
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -298,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.TOTPPending() {
|
if !context.TOTPPending() {
|
||||||
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session")
|
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -306,13 +296,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
||||||
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked")
|
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -325,7 +314,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification")
|
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -336,9 +325,9 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt")
|
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
|
||||||
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
||||||
controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code")
|
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -346,15 +335,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
uuid, err := c.Cookie(controller.runtime.SessionCookieName)
|
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = controller.auth.DeleteSession(c, uuid)
|
_, err = controller.auth.DeleteSession(c, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
|
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it")
|
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
||||||
@@ -362,7 +351,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain),
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -373,10 +362,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
sessionCookie.Email = user.Attributes.Email
|
sessionCookie.Email = user.Attributes.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -386,8 +377,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete")
|
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
|
||||||
controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP())
|
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,15 +19,53 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUserController(t *testing.T) {
|
func TestUserController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "attruser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
Attributes: model.UserAttributes{
|
||||||
|
Name: "Alice Smith",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "attrtotpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
Attributes: model.UserAttributes{
|
||||||
|
Name: "Bob Jones",
|
||||||
|
Email: "bob@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
userControllerCfg := controller.UserControllerConfig{
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
totpCtx := func(c *gin.Context) {
|
totpCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
@@ -73,12 +111,14 @@ func TestUserController(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
@@ -96,7 +136,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -104,7 +144,7 @@ func TestUserController(t *testing.T) {
|
|||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
require.Len(t, recorder.Result().Cookies(), 1)
|
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||||
|
|
||||||
cookie := recorder.Result().Cookies()[0]
|
cookie := recorder.Result().Cookies()[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
@@ -124,7 +164,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -145,7 +185,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for range 3 {
|
for range 3 {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -180,7 +220,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -191,12 +231,12 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, decodedBody["totpPending"], true)
|
assert.Equal(t, decodedBody["totpPending"], true)
|
||||||
|
|
||||||
// should set the session cookie
|
// should set the session cookie
|
||||||
require.Len(t, recorder.Result().Cookies(), 1)
|
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||||
cookie := recorder.Result().Cookies()[0]
|
cookie := recorder.Result().Cookies()[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.True(t, cookie.HttpOnly)
|
assert.True(t, cookie.HttpOnly)
|
||||||
@@ -217,7 +257,7 @@ func TestUserController(t *testing.T) {
|
|||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -226,7 +266,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
cookies := recorder.Result().Cookies()
|
cookies := recorder.Result().Cookies()
|
||||||
require.Len(t, cookies, 1)
|
assert.Len(t, cookies, 1)
|
||||||
|
|
||||||
cookie := cookies[0]
|
cookie := cookies[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
@@ -240,7 +280,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
cookies = recorder.Result().Cookies()
|
cookies = recorder.Result().Cookies()
|
||||||
require.Len(t, cookies, 1)
|
assert.Len(t, cookies, 1)
|
||||||
|
|
||||||
cookie = cookies[0]
|
cookie = cookies[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
@@ -267,14 +307,14 @@ func TestUserController(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := controller.TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
recorder = httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
@@ -289,7 +329,7 @@ func TestUserController(t *testing.T) {
|
|||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
require.Len(t, recorder.Result().Cookies(), 1)
|
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||||
|
|
||||||
// should set a new session cookie with totp pending removed
|
// should set a new session cookie with totp pending removed
|
||||||
totpCookie := recorder.Result().Cookies()[0]
|
totpCookie := recorder.Result().Cookies()[0]
|
||||||
@@ -312,7 +352,7 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
recorder = httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
@@ -416,11 +456,21 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
docker := service.NewDockerService()
|
||||||
wg := &sync.WaitGroup{}
|
err = docker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
err = ldap.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
@@ -439,7 +489,8 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewUserController(log, runtime, group, authService)
|
userController := controller.NewUserController(userControllerCfg, group, authService)
|
||||||
|
userController.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -448,6 +499,7 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,30 +26,28 @@ type OpenIDConnectConfiguration struct {
|
|||||||
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
|
RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WellKnownControllerConfig struct{}
|
||||||
|
|
||||||
type WellKnownController struct {
|
type WellKnownController struct {
|
||||||
oidc *service.OIDCService
|
config WellKnownControllerConfig
|
||||||
|
engine *gin.Engine
|
||||||
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
|
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
|
||||||
controller := &WellKnownController{
|
return &WellKnownController{
|
||||||
oidc: oidc,
|
config: config,
|
||||||
|
oidc: oidc,
|
||||||
|
engine: engine,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
func (controller *WellKnownController) SetupRoutes() {
|
||||||
router.GET("/.well-known/jwks.json", controller.JWKS)
|
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
|
||||||
|
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
|
||||||
return controller
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC service not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
issuer := controller.oidc.GetIssuer()
|
issuer := controller.oidc.GetIssuer()
|
||||||
c.JSON(200, OpenIDConnectConfiguration{
|
c.JSON(200, OpenIDConnectConfiguration{
|
||||||
Issuer: issuer,
|
Issuer: issuer,
|
||||||
@@ -71,19 +69,11 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
func (controller *WellKnownController) JWKS(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "OIDC service not configured",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
jwks, err := controller.oidc.GetJWK()
|
jwks, err := controller.oidc.GetJWK()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(500, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 500,
|
"status": "500",
|
||||||
"message": "failed to get JWK",
|
"message": "failed to get JWK",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
package controller_test
|
package controller_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -13,17 +12,30 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWellKnownController(t *testing.T) {
|
func TestWellKnownController(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
oidcServiceCfg := service.OIDCServiceConfig{
|
||||||
|
Clients: map[string]model.OIDCClientConfig{
|
||||||
|
"test": {
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
ClientSecret: "some-client-secret",
|
||||||
|
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||||
|
Name: "Test Client",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
||||||
|
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
||||||
|
Issuer: "https://tinyauth.example.com",
|
||||||
|
SessionExpiry: 500,
|
||||||
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
@@ -44,11 +56,11 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected := controller.OpenIDConnectConfiguration{
|
expected := controller.OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: oidcServiceCfg.Issuer,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer),
|
||||||
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL),
|
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer),
|
||||||
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL),
|
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer),
|
||||||
ScopesSupported: service.SupportedScopes,
|
ScopesSupported: service.SupportedScopes,
|
||||||
ResponseTypesSupported: service.SupportedResponseTypes,
|
ResponseTypesSupported: service.SupportedResponseTypes,
|
||||||
GrantTypesSupported: service.SupportedGrantTypes,
|
GrantTypesSupported: service.SupportedGrantTypes,
|
||||||
@@ -89,17 +101,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||||
|
err = oidcService.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -109,13 +119,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewWellKnownController(oidcService, &router.RouterGroup)
|
wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router)
|
||||||
|
wellKnownController.SetupRoutes()
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -35,27 +35,29 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddlewareConfig struct {
|
||||||
log *logger.Logger
|
CookieDomain string
|
||||||
runtime model.RuntimeConfig
|
SessionCookieName string
|
||||||
auth *service.AuthService
|
|
||||||
broker *service.OAuthBrokerService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextMiddleware(
|
type ContextMiddleware struct {
|
||||||
log *logger.Logger,
|
config ContextMiddlewareConfig
|
||||||
runtime model.RuntimeConfig,
|
auth *service.AuthService
|
||||||
auth *service.AuthService,
|
broker *service.OAuthBrokerService
|
||||||
broker *service.OAuthBrokerService,
|
}
|
||||||
) *ContextMiddleware {
|
|
||||||
|
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
|
||||||
return &ContextMiddleware{
|
return &ContextMiddleware{
|
||||||
log: log,
|
config: config,
|
||||||
runtime: runtime,
|
auth: auth,
|
||||||
auth: auth,
|
broker: broker,
|
||||||
broker: broker,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ContextMiddleware) Init() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
|
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
|
||||||
@@ -63,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
uuid, err := c.Cookie(m.runtime.SessionCookieName)
|
uuid, err := c.Cookie(m.config.SessionCookieName)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||||
@@ -73,12 +75,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername())
|
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||||
c.Set("context", userContext)
|
c.Set("context", userContext)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
|
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +90,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
userContext, headers, err := m.basicAuth(username, password)
|
userContext, headers, err := m.basicAuth(username, password)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -139,7 +141,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userContext.Local.Attributes.Email == "" {
|
if userContext.Local.Attributes.Email == "" {
|
||||||
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain)
|
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
||||||
}
|
}
|
||||||
case model.ProviderLDAP:
|
case model.ProviderLDAP:
|
||||||
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
||||||
@@ -160,12 +162,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
|||||||
|
|
||||||
userContext.LDAP.Groups = user.Groups
|
userContext.LDAP.Groups = user.Groups
|
||||||
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
||||||
|
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
|
||||||
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain)
|
|
||||||
if search.Email != "" {
|
|
||||||
userContext.LDAP.Email = search.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
case model.ProviderOAuth:
|
case model.ProviderOAuth:
|
||||||
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
||||||
|
|
||||||
@@ -194,7 +191,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
|||||||
locked, remaining := m.auth.IsAccountLocked(username)
|
locked, remaining := m.auth.IsAccountLocked(username)
|
||||||
|
|
||||||
if locked {
|
if locked {
|
||||||
m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
||||||
headers["x-tinyauth-lock-locked"] = "true"
|
headers["x-tinyauth-lock-locked"] = "true"
|
||||||
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||||
return nil, headers, nil
|
return nil, headers, nil
|
||||||
@@ -227,7 +224,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
|||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain),
|
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
||||||
},
|
},
|
||||||
Attributes: user.Attributes,
|
Attributes: user.Attributes,
|
||||||
}
|
}
|
||||||
@@ -243,15 +240,11 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
|||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: utils.Capitalize(username),
|
Name: utils.Capitalize(username),
|
||||||
|
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
|
||||||
},
|
},
|
||||||
Groups: user.Groups,
|
Groups: user.Groups,
|
||||||
}
|
}
|
||||||
userContext.Provider = model.ProviderLDAP
|
userContext.Provider = model.ProviderLDAP
|
||||||
|
|
||||||
userContext.LDAP.Email = utils.CompileUserEmail(username, m.runtime.CookieDomain)
|
|
||||||
if search.Email != "" {
|
|
||||||
userContext.LDAP.Email = search.Email
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext.Authenticated = true
|
userContext.Authenticated = true
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"path"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,15 +17,36 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestContextMiddleware(t *testing.T) {
|
func TestContextMiddleware(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
tlog.NewTestLogger().Init()
|
||||||
log.Init()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
cfg, runtime := test.CreateTestConfigs(t)
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewareCfg := middleware.ContextMiddlewareConfig{
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
basicAuthHeader := func(username, password string) string {
|
basicAuthHeader := func(username, password string) string {
|
||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||||
@@ -249,20 +270,30 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
err := app.SetupDatabase()
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(db)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
|
err = ldap.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
||||||
|
err = contextMiddleware.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
authService.ClearRateLimitsTestingOnly()
|
authService.ClearRateLimitsTestingOnly()
|
||||||
@@ -291,6 +322,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
app.GetDB().Close()
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -18,25 +19,29 @@ type UIMiddleware struct {
|
|||||||
uiFileServer http.Handler
|
uiFileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUIMiddleware() (*UIMiddleware, error) {
|
func NewUIMiddleware() *UIMiddleware {
|
||||||
m := &UIMiddleware{}
|
return &UIMiddleware{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UIMiddleware) Init() error {
|
||||||
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load ui assets: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.uiFs = ui
|
m.uiFs = ui
|
||||||
m.uiFileServer = http.FileServerFS(ui)
|
m.uiFileServer = http.FileServerFS(ui)
|
||||||
|
|
||||||
return m, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
func (m *UIMiddleware) Middleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
|
|
||||||
|
tlog.App.Debug().Str("path", path).Msg("path")
|
||||||
|
|
||||||
switch strings.SplitN(path, "/", 2)[0] {
|
switch strings.SplitN(path, "/", 2)[0] {
|
||||||
case "api", "resources", ".well-known":
|
case "api", "resources", ".well-known":
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// See context middleware for explanation of why we have to do this
|
// See context middleware for explanation of why we have to do this
|
||||||
@@ -17,14 +17,14 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
type ZerologMiddleware struct {
|
type ZerologMiddleware struct{}
|
||||||
log *logger.Logger
|
|
||||||
|
func NewZerologMiddleware() *ZerologMiddleware {
|
||||||
|
return &ZerologMiddleware{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
|
func (m *ZerologMiddleware) Init() error {
|
||||||
return &ZerologMiddleware{
|
return nil
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ZerologMiddleware) logPath(path string) bool {
|
func (m *ZerologMiddleware) logPath(path string) bool {
|
||||||
@@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
latency := time.Since(tStart).String()
|
latency := time.Since(tStart).String()
|
||||||
|
|
||||||
subLogger := m.log.HTTP.With().Str("method", method).
|
subLogger := tlog.HTTP.With().Str("method", method).
|
||||||
Str("path", path).
|
Str("path", path).
|
||||||
Str("address", address).
|
Str("address", address).
|
||||||
Str("client_ip", clientIP).
|
Str("client_ip", clientIP).
|
||||||
|
|||||||
@@ -14,9 +14,8 @@ func NewDefaultConfiguration() *Config {
|
|||||||
Path: "./resources",
|
Path: "./resources",
|
||||||
},
|
},
|
||||||
Server: ServerConfig{
|
Server: ServerConfig{
|
||||||
Port: 3000,
|
Port: 3000,
|
||||||
Address: "0.0.0.0",
|
Address: "0.0.0.0",
|
||||||
ConcurrentListenersEnabled: false,
|
|
||||||
},
|
},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
SubdomainsEnabled: true,
|
SubdomainsEnabled: true,
|
||||||
@@ -96,10 +95,9 @@ type ResourcesConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||||
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
@@ -149,10 +147,9 @@ type IPConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthConfig struct {
|
type OAuthConfig struct {
|
||||||
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
|
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
|
||||||
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
|
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
|
||||||
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
|
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
|
||||||
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConfig struct {
|
type OIDCConfig struct {
|
||||||
|
|||||||
@@ -8,10 +8,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrUserContextNotFound = errors.New("user context not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type ProviderType int
|
type ProviderType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -78,7 +74,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
userContextValue, exists := ginctx.Get("context")
|
userContextValue, exists := ginctx.Get("context")
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, ErrUserContextNotFound
|
return nil, errors.New("failed to get user context")
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext, ok := userContextValue.(*UserContext)
|
userContext, ok := userContextValue.(*UserContext)
|
||||||
@@ -121,7 +117,7 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
|||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// By default we assume an unknown name which is oauth
|
// By default we assume an unkown name which is oauth
|
||||||
default:
|
default:
|
||||||
c.Provider = ProviderOAuth
|
c.Provider = ProviderOAuth
|
||||||
c.OAuth = &OAuthContext{
|
c.OAuth = &OAuthContext{
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ func TestContext(t *testing.T) {
|
|||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: model.ErrUserContextNotFound.Error(),
|
expected: "failed to get user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
type RuntimeConfig struct {
|
|
||||||
AppURL string
|
|
||||||
UUID string
|
|
||||||
CookieDomain string
|
|
||||||
SessionCookieName string
|
|
||||||
CSRFCookieName string
|
|
||||||
RedirectCookieName string
|
|
||||||
OAuthSessionCookieName string
|
|
||||||
LocalUsers []LocalUser
|
|
||||||
OAuthProviders map[string]OAuthServiceConfig
|
|
||||||
OAuthWhitelist []string
|
|
||||||
ConfiguredProviders []Provider
|
|
||||||
OIDCClients []OIDCClientConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type Provider struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
OAuth bool `json:"oauth"`
|
|
||||||
}
|
|
||||||
@@ -21,6 +21,5 @@ type LocalUser struct {
|
|||||||
|
|
||||||
type UserSearch struct {
|
type UserSearch struct {
|
||||||
Username string
|
Username string
|
||||||
Email string // used for LDAP, we can't throw it to LDAPUser because it would need another cache or an LDAP lookup every time
|
|
||||||
Type UserSearchType
|
Type UserSearchType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider interface {
|
type LabelProvider interface {
|
||||||
@@ -12,33 +12,32 @@ type LabelProvider interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
log *logger.Logger
|
labelProvider LabelProvider
|
||||||
labelProvider *LabelProvider
|
|
||||||
static map[string]model.App
|
static map[string]model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(
|
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
|
||||||
log *logger.Logger,
|
|
||||||
labelProvider *LabelProvider,
|
|
||||||
static map[string]model.App) *AccessControlsService {
|
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
log: log,
|
|
||||||
labelProvider: labelProvider,
|
labelProvider: labelProvider,
|
||||||
static: static,
|
static: static,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (acls *AccessControlsService) Init() error {
|
||||||
|
return nil // No initialization needed
|
||||||
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||||
var appAcls *model.App
|
var appAcls *model.App
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
appAcls = &config
|
appAcls = &config
|
||||||
break // If we find a match by domain, we can stop searching
|
break // If we find a match by domain, we can stop searching
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
appAcls = &config
|
appAcls = &config
|
||||||
break // If we find a match by app name, we can stop searching
|
break // If we find a match by app name, we can stop searching
|
||||||
}
|
}
|
||||||
@@ -51,15 +50,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
|
|||||||
app := acls.lookupStaticACLs(domain)
|
app := acls.lookupStaticACLs(domain)
|
||||||
|
|
||||||
if app != nil {
|
if app != nil {
|
||||||
acls.log.App.Debug().Msg("Using static ACLs for app")
|
tlog.App.Debug().Msg("Using ACls from static configuration")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have a label provider configured, try to get ACLs from it
|
// Fallback to label provider
|
||||||
if acls.labelProvider != nil {
|
tlog.App.Debug().Msg("Falling back to label provider for ACLs")
|
||||||
return (*acls.labelProvider).GetLabels(domain)
|
return acls.labelProvider.GetLabels(domain)
|
||||||
}
|
|
||||||
|
|
||||||
// no labels
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
+101
-104
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -72,41 +72,39 @@ type Lockdown struct {
|
|||||||
ActiveUntil time.Time
|
ActiveUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthServiceConfig struct {
|
||||||
|
LocalUsers *[]model.LocalUser
|
||||||
|
OauthWhitelist []string
|
||||||
|
SessionExpiry int
|
||||||
|
SessionMaxLifetime int
|
||||||
|
SecureCookie bool
|
||||||
|
CookieDomain string
|
||||||
|
LoginTimeout int
|
||||||
|
LoginMaxRetries int
|
||||||
|
SessionCookieName string
|
||||||
|
IP model.IPConfig
|
||||||
|
LDAPGroupsCacheTTL int
|
||||||
|
SubdomainsEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
log *logger.Logger
|
config AuthServiceConfig
|
||||||
config model.Config
|
|
||||||
runtime model.RuntimeConfig
|
|
||||||
context context.Context
|
|
||||||
|
|
||||||
ldap *LdapService
|
|
||||||
queries *repository.Queries
|
|
||||||
oauthBroker *OAuthBrokerService
|
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
loginAttempts map[string]*LoginAttempt
|
||||||
ldapGroupsCache map[string]*LdapGroupsCache
|
ldapGroupsCache map[string]*LdapGroupsCache
|
||||||
oauthPendingSessions map[string]*OAuthPendingSession
|
oauthPendingSessions map[string]*OAuthPendingSession
|
||||||
oauthMutex sync.RWMutex
|
oauthMutex sync.RWMutex
|
||||||
loginMutex sync.RWMutex
|
loginMutex sync.RWMutex
|
||||||
ldapGroupsMutex sync.RWMutex
|
ldapGroupsMutex sync.RWMutex
|
||||||
|
ldap *LdapService
|
||||||
|
queries *repository.Queries
|
||||||
|
oauthBroker *OAuthBrokerService
|
||||||
lockdown *Lockdown
|
lockdown *Lockdown
|
||||||
lockdownCtx context.Context
|
lockdownCtx context.Context
|
||||||
lockdownCancelFunc context.CancelFunc
|
lockdownCancelFunc context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(
|
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||||
log *logger.Logger,
|
return &AuthService{
|
||||||
config model.Config,
|
|
||||||
runtime model.RuntimeConfig,
|
|
||||||
ctx context.Context,
|
|
||||||
wg *sync.WaitGroup,
|
|
||||||
ldap *LdapService,
|
|
||||||
queries *repository.Queries,
|
|
||||||
oauthBroker *OAuthBrokerService,
|
|
||||||
) *AuthService {
|
|
||||||
service := &AuthService{
|
|
||||||
log: log,
|
|
||||||
runtime: runtime,
|
|
||||||
context: ctx,
|
|
||||||
config: config,
|
config: config,
|
||||||
loginAttempts: make(map[string]*LoginAttempt),
|
loginAttempts: make(map[string]*LoginAttempt),
|
||||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||||
@@ -115,10 +113,11 @@ func NewAuthService(
|
|||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
wg.Go(service.CleanupOAuthSessionsRoutine)
|
func (auth *AuthService) Init() error {
|
||||||
|
go auth.CleanupOAuthSessionsRoutine()
|
||||||
return service
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||||
@@ -129,8 +128,8 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.ldap != nil {
|
if auth.ldap.IsConfigured() {
|
||||||
userDN, email, err := auth.ldap.GetUserInfo(username)
|
userDN, err := auth.ldap.GetUserDN(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get ldap user: %w", err)
|
return nil, fmt.Errorf("failed to get ldap user: %w", err)
|
||||||
@@ -138,7 +137,6 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
|
|||||||
|
|
||||||
return &model.UserSearch{
|
return &model.UserSearch{
|
||||||
Username: userDN,
|
Username: userDN,
|
||||||
Email: email,
|
|
||||||
Type: model.UserLDAP,
|
Type: model.UserLDAP,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -155,7 +153,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
}
|
}
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
case model.UserLDAP:
|
case model.UserLDAP:
|
||||||
if auth.ldap != nil {
|
if auth.ldap.IsConfigured() {
|
||||||
err := auth.ldap.Bind(search.Username, password)
|
err := auth.ldap.Bind(search.Username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
||||||
@@ -175,10 +173,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||||
if auth.runtime.LocalUsers == nil {
|
if auth.config.LocalUsers == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, user := range auth.runtime.LocalUsers {
|
for _, user := range *auth.config.LocalUsers {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
@@ -187,7 +185,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
||||||
if auth.ldap == nil {
|
if !auth.ldap.IsConfigured() {
|
||||||
return nil, errors.New("ldap service not configured")
|
return nil, errors.New("ldap service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +209,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
auth.ldapGroupsMutex.Lock()
|
auth.ldapGroupsMutex.Lock()
|
||||||
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
|
||||||
}
|
}
|
||||||
auth.ldapGroupsMutex.Unlock()
|
auth.ldapGroupsMutex.Unlock()
|
||||||
|
|
||||||
@@ -230,7 +228,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
return true, remaining
|
return true, remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,7 +246,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,14 +277,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
|
|
||||||
attempt.FailedAttempts++
|
attempt.FailedAttempts++
|
||||||
|
|
||||||
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
|
||||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
|
||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
@@ -301,7 +299,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
if data.TotpPending {
|
if data.TotpPending {
|
||||||
expiry = 3600
|
expiry = 3600
|
||||||
} else {
|
} else {
|
||||||
expiry = auth.config.Auth.SessionExpiry
|
expiry = auth.config.SessionExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||||
@@ -327,13 +325,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.config.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: expiresAt,
|
Expires: expiresAt,
|
||||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -350,8 +348,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
|
|
||||||
var refreshThreshold int64
|
var refreshThreshold int64
|
||||||
|
|
||||||
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
|
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||||
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
|
refreshThreshold = int64(auth.config.SessionExpiry / 2)
|
||||||
} else {
|
} else {
|
||||||
refreshThreshold = int64(time.Hour.Seconds())
|
refreshThreshold = int64(time.Hour.Seconds())
|
||||||
}
|
}
|
||||||
@@ -380,13 +378,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.config.SessionCookieName,
|
||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
MaxAge: int(newExpiry - currentTime),
|
MaxAge: int(newExpiry - currentTime),
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -397,17 +395,23 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
|||||||
err := auth.queries.DeleteSession(ctx, uuid)
|
err := auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: auth.runtime.SessionCookieName,
|
Name: auth.config.SessionCookieName,
|
||||||
Value: "",
|
Value: "",
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: time.Now(),
|
Expires: time.Now(),
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Secure: auth.config.Auth.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -425,8 +429,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||||
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
|
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
||||||
err = auth.queries.DeleteSession(ctx, uuid)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
@@ -447,11 +451,11 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||||
return len(auth.runtime.LocalUsers) > 0
|
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||||
return auth.ldap != nil
|
return auth.ldap.IsConfigured()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
@@ -460,18 +464,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
if context.Provider == model.ProviderOAuth {
|
if context.Provider == model.ProviderOAuth {
|
||||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
||||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if acls.Users.Block != "" {
|
if acls.Users.Block != "" {
|
||||||
auth.log.App.Debug().Msg("Checking users block list")
|
tlog.App.Debug().Msg("Checking blocked users")
|
||||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("Checking users allow list")
|
tlog.App.Debug().Msg("Checking users")
|
||||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -481,23 +485,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsOAuth() {
|
if !context.IsOAuth() {
|
||||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.OAuth.Groups {
|
for _, userGroup := range context.OAuth.Groups {
|
||||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("No groups matched")
|
tlog.App.Debug().Msg("No groups matched")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -507,18 +511,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsLDAP() {
|
if !context.IsLDAP() {
|
||||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.LDAP.Groups {
|
for _, userGroup := range context.LDAP.Groups {
|
||||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Msg("No groups matched")
|
tlog.App.Debug().Msg("No groups matched")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -562,17 +566,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Merge the global and app IP filter
|
// Merge the global and app IP filter
|
||||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
||||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
||||||
|
|
||||||
for _, blocked := range blockedIps {
|
for _, blocked := range blockedIps {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -580,21 +584,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
|||||||
for _, allowed := range allowedIPs {
|
for _, allowed := range allowedIPs {
|
||||||
res, err := utils.FilterIP(allowed, ip)
|
res, err := utils.FilterIP(allowed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(allowedIPs) > 0 {
|
if len(allowedIPs) > 0 {
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -606,16 +610,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
|||||||
for _, bypassed := range acls.IP.Bypass {
|
for _, bypassed := range acls.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -719,32 +723,21 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||||
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
|
||||||
|
|
||||||
ticker := time.NewTicker(30 * time.Minute)
|
ticker := time.NewTicker(30 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
auth.oauthMutex.Lock()
|
||||||
case <-ticker.C:
|
|
||||||
auth.log.App.Debug().Msg("Running OAuth session cleanup")
|
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
now := time.Now()
|
||||||
|
|
||||||
now := time.Now()
|
for sessionId, session := range auth.oauthPendingSessions {
|
||||||
|
if now.After(session.ExpiresAt) {
|
||||||
for sessionId, session := range auth.oauthPendingSessions {
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
if now.After(session.ExpiresAt) {
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
|
||||||
case <-auth.context.Done():
|
|
||||||
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -813,11 +806,11 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
||||||
|
|
||||||
auth.lockdown = &Lockdown{
|
auth.lockdown = &Lockdown{
|
||||||
Active: true,
|
Active: true,
|
||||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point all login attemps will also expire so,
|
// At this point all login attemps will also expire so,
|
||||||
@@ -834,14 +827,11 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.context.Done():
|
|
||||||
// Service is shutting down, end lockdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
||||||
|
|
||||||
auth.lockdown = nil
|
auth.lockdown = nil
|
||||||
auth.loginMutex.Unlock()
|
auth.loginMutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -855,3 +845,10 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
|||||||
}
|
}
|
||||||
auth.loginMutex.Unlock()
|
auth.loginMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) getCookieDomain() string {
|
||||||
|
if auth.config.SubdomainsEnabled {
|
||||||
|
return "." + auth.config.CookieDomain
|
||||||
|
}
|
||||||
|
return auth.config.CookieDomain
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,56 +3,51 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
container "github.com/docker/docker/api/types/container"
|
container "github.com/docker/docker/api/types/container"
|
||||||
"github.com/docker/docker/client"
|
"github.com/docker/docker/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DockerService struct {
|
type DockerService struct {
|
||||||
log *logger.Logger
|
client *client.Client
|
||||||
client *client.Client
|
context context.Context
|
||||||
context context.Context
|
|
||||||
|
|
||||||
isConnected bool
|
isConnected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDockerService(
|
func NewDockerService() *DockerService {
|
||||||
log *logger.Logger,
|
return &DockerService{}
|
||||||
ctx context.Context,
|
}
|
||||||
wg *sync.WaitGroup,
|
|
||||||
) (*DockerService, error) {
|
|
||||||
|
|
||||||
|
func (docker *DockerService) Init() error {
|
||||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
client.NegotiateAPIVersion(ctx)
|
client.NegotiateAPIVersion(ctx)
|
||||||
|
|
||||||
_, err = client.Ping(ctx)
|
docker.client = client
|
||||||
|
docker.context = ctx
|
||||||
|
|
||||||
|
_, err = docker.client.Ping(docker.context)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.App.Debug().Err(err).Msg("Docker not connected")
|
tlog.App.Debug().Err(err).Msg("Docker not connected")
|
||||||
return nil, nil
|
docker.isConnected = false
|
||||||
|
docker.client = nil
|
||||||
|
docker.context = nil
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
service := &DockerService{
|
docker.isConnected = true
|
||||||
log: log,
|
tlog.App.Debug().Msg("Docker connected")
|
||||||
client: client,
|
|
||||||
context: ctx,
|
|
||||||
}
|
|
||||||
|
|
||||||
service.isConnected = true
|
return nil
|
||||||
service.log.App.Debug().Msg("Docker connected successfully")
|
|
||||||
|
|
||||||
wg.Go(service.watchAndClose)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
||||||
@@ -65,7 +60,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
|
|||||||
|
|
||||||
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !docker.isConnected {
|
if !docker.isConnected {
|
||||||
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
|
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,28 +82,17 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
|||||||
|
|
||||||
for appName, appLabels := range labels.Apps {
|
for appName, appLabels := range labels.Apps {
|
||||||
if appLabels.Config.Domain == appDomain {
|
if appLabels.Config.Domain == appDomain {
|
||||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||||
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||||
return &appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
|
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) watchAndClose() {
|
|
||||||
<-docker.context.Done()
|
|
||||||
docker.log.App.Debug().Msg("Closing Docker client")
|
|
||||||
if docker.client != nil {
|
|
||||||
err := docker.client.Close()
|
|
||||||
if err != nil {
|
|
||||||
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||||
@@ -36,10 +36,9 @@ type ingressApp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
log *logger.Logger
|
|
||||||
ctx context.Context
|
|
||||||
|
|
||||||
client dynamic.Interface
|
client dynamic.Interface
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
started bool
|
started bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ingressApps map[ingressKey][]ingressApp
|
ingressApps map[ingressKey][]ingressApp
|
||||||
@@ -47,55 +46,12 @@ type KubernetesService struct {
|
|||||||
appNameIndex map[string]ingressAppKey
|
appNameIndex map[string]ingressAppKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKubernetesService(
|
func NewKubernetesService() *KubernetesService {
|
||||||
log *logger.Logger,
|
return &KubernetesService{
|
||||||
ctx context.Context,
|
|
||||||
wg *sync.WaitGroup,
|
|
||||||
) (*KubernetesService, error) {
|
|
||||||
cfg, err := rest.InClusterConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := dynamic.NewForConfig(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
gvr := schema.GroupVersionResource{
|
|
||||||
Group: "networking.k8s.io",
|
|
||||||
Version: "v1",
|
|
||||||
Resource: "ingresses",
|
|
||||||
}
|
|
||||||
|
|
||||||
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
||||||
defer accessCancel()
|
|
||||||
|
|
||||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
|
||||||
if err != nil {
|
|
||||||
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
|
||||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
|
||||||
|
|
||||||
service := &KubernetesService{
|
|
||||||
log: log,
|
|
||||||
ctx: ctx,
|
|
||||||
client: client,
|
|
||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Go(func() {
|
|
||||||
service.watchGVR(gvr)
|
|
||||||
})
|
|
||||||
|
|
||||||
service.started = true
|
|
||||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
|
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
|
||||||
@@ -177,7 +133,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
|||||||
}
|
}
|
||||||
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
|
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -205,13 +161,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
|
|||||||
|
|
||||||
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
|
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range list.Items {
|
for i := range list.Items {
|
||||||
k.updateFromItem(&list.Items[i])
|
k.updateFromItem(&list.Items[i])
|
||||||
}
|
}
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,14 +181,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
return false
|
return false
|
||||||
case event, ok := <-w.ResultChan():
|
case event, ok := <-w.ResultChan():
|
||||||
if !ok {
|
if !ok {
|
||||||
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds")
|
||||||
w.Stop()
|
w.Stop()
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
item, ok := event.Object.(*unstructured.Unstructured)
|
item, ok := event.Object.(*unstructured.Unstructured)
|
||||||
if !ok {
|
if !ok {
|
||||||
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
|
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
@@ -243,7 +199,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
|
|||||||
}
|
}
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
|
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -254,29 +210,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
|||||||
defer resyncTicker.Stop()
|
defer resyncTicker.Stop()
|
||||||
|
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
|
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds")
|
||||||
time.Sleep(30 * time.Second)
|
time.Sleep(30 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-k.ctx.Done():
|
case <-k.ctx.Done():
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher")
|
||||||
return
|
return
|
||||||
case <-resyncTicker.C:
|
case <-resyncTicker.C:
|
||||||
if err := k.resyncGVR(gvr); err != nil {
|
if err := k.resyncGVR(gvr); err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
|
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
ctx, cancel := context.WithCancel(k.ctx)
|
ctx, cancel := context.WithCancel(k.ctx)
|
||||||
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
|
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher")
|
||||||
cancel()
|
cancel()
|
||||||
time.Sleep(10 * time.Second)
|
time.Sleep(10 * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
|
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started")
|
||||||
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
if !k.runWatcher(gvr, watcher, resyncTicker) {
|
||||||
cancel()
|
cancel()
|
||||||
return
|
return
|
||||||
@@ -286,25 +242,65 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (k *KubernetesService) Init() error {
|
||||||
|
var cfg *rest.Config
|
||||||
|
var err error
|
||||||
|
|
||||||
|
cfg, err = rest.InClusterConfig()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := dynamic.NewForConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create Kubernetes client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
k.client = client
|
||||||
|
k.ctx, k.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
gvr := schema.GroupVersionResource{
|
||||||
|
Group: "networking.k8s.io",
|
||||||
|
Version: "v1",
|
||||||
|
Resource: "ingresses",
|
||||||
|
}
|
||||||
|
|
||||||
|
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
|
||||||
|
defer accessCancel()
|
||||||
|
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work")
|
||||||
|
k.started = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible")
|
||||||
|
go k.watchGVR(gvr)
|
||||||
|
|
||||||
|
k.started = true
|
||||||
|
tlog.App.Info().Msg("Kubernetes label provider initialized")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !k.started {
|
if !k.started {
|
||||||
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
|
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// First check cache
|
// First check cache
|
||||||
app := k.getByDomain(appDomain)
|
app := k.getByDomain(appDomain)
|
||||||
if app != nil {
|
if app != nil {
|
||||||
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
appName := strings.SplitN(appDomain, ".", 2)[0]
|
appName := strings.SplitN(appDomain, ".", 2)[0]
|
||||||
app = k.getByAppName(appName)
|
app = k.getByAppName(appName)
|
||||||
if app != nil {
|
if app != nil {
|
||||||
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,9 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestKubernetesService(t *testing.T) {
|
func TestKubernetesService(t *testing.T) {
|
||||||
log := logger.NewLogger().WithTestConfig()
|
|
||||||
log.Init()
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
run func(t *testing.T, svc *KubernetesService)
|
run func(t *testing.T, svc *KubernetesService)
|
||||||
@@ -183,7 +179,6 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
ingressApps: make(map[ingressKey][]ingressApp),
|
ingressApps: make(map[ingressKey][]ingressApp),
|
||||||
domainIndex: make(map[string]ingressAppKey),
|
domainIndex: make(map[string]ingressAppKey),
|
||||||
appNameIndex: make(map[string]ingressAppKey),
|
appNameIndex: make(map[string]ingressAppKey),
|
||||||
log: log,
|
|
||||||
}
|
}
|
||||||
test.run(t, svc)
|
test.run(t, svc)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -9,47 +9,69 @@ import (
|
|||||||
|
|
||||||
"github.com/cenkalti/backoff/v5"
|
"github.com/cenkalti/backoff/v5"
|
||||||
ldapgo "github.com/go-ldap/ldap/v3"
|
ldapgo "github.com/go-ldap/ldap/v3"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LdapService struct {
|
type LdapServiceConfig struct {
|
||||||
log *logger.Logger
|
Address string
|
||||||
config model.Config
|
BindDN string
|
||||||
context context.Context
|
BindPassword string
|
||||||
|
BaseDN string
|
||||||
conn *ldapgo.Conn
|
Insecure bool
|
||||||
mutex sync.RWMutex
|
SearchFilter string
|
||||||
cert *tls.Certificate
|
AuthCert string
|
||||||
|
AuthKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLdapService(
|
type LdapService struct {
|
||||||
log *logger.Logger,
|
config LdapServiceConfig
|
||||||
config model.Config,
|
conn *ldapgo.Conn
|
||||||
ctx context.Context,
|
mutex sync.RWMutex
|
||||||
wg *sync.WaitGroup,
|
cert *tls.Certificate
|
||||||
) (*LdapService, error) {
|
isConfigured bool
|
||||||
if config.LDAP.Address == "" {
|
}
|
||||||
return nil, nil
|
|
||||||
|
func NewLdapService(config LdapServiceConfig) *LdapService {
|
||||||
|
return &LdapService{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) IsConfigured() bool {
|
||||||
|
return ldap.isConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) Unconfigure() error {
|
||||||
|
if !ldap.isConfigured {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ldap := &LdapService{
|
if ldap.conn != nil {
|
||||||
log: log,
|
if err := ldap.conn.Close(); err != nil {
|
||||||
config: config,
|
return fmt.Errorf("failed to close LDAP connection: %w", err)
|
||||||
context: ctx,
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ldap.isConfigured = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) Init() error {
|
||||||
|
if ldap.config.Address == "" {
|
||||||
|
ldap.isConfigured = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ldap.isConfigured = true
|
||||||
|
|
||||||
// Check whether authentication with client certificate is possible
|
// Check whether authentication with client certificate is possible
|
||||||
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
|
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
|
||||||
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
|
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
|
||||||
|
|
||||||
ldap.cert = &cert
|
ldap.cert = &cert
|
||||||
|
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
|
||||||
|
|
||||||
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
|
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
|
||||||
/*
|
/*
|
||||||
@@ -62,39 +84,26 @@ func NewLdapService(
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := ldap.connect()
|
_, err := ldap.connect()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Go(func() {
|
go func() {
|
||||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
for range time.Tick(time.Duration(5) * time.Minute) {
|
||||||
|
err := ldap.heartbeat()
|
||||||
ticker := time.NewTicker(5 * time.Minute)
|
if err != nil {
|
||||||
defer ticker.Stop()
|
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
|
||||||
|
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
||||||
for {
|
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
||||||
select {
|
continue
|
||||||
case <-ticker.C:
|
|
||||||
err := ldap.heartbeat()
|
|
||||||
if err != nil {
|
|
||||||
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
|
|
||||||
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
|
||||||
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
|
|
||||||
}
|
}
|
||||||
case <-ldap.context.Done():
|
tlog.App.Info().Msg("Successfully reconnected to LDAP server")
|
||||||
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
}()
|
||||||
|
|
||||||
return ldap, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
||||||
@@ -111,13 +120,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
// 2. conn.StartTLS(tlsConfig)
|
// 2. conn.StartTLS(tlsConfig)
|
||||||
// 3. conn.externalBind()
|
// 3. conn.externalBind()
|
||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
Certificates: []tls.Certificate{*ldap.cert},
|
Certificates: []tls.Certificate{*ldap.cert},
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
InsecureSkipVerify: ldap.config.LDAP.Insecure,
|
InsecureSkipVerify: ldap.config.Insecure,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@@ -134,15 +143,16 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
return ldap.conn, nil
|
return ldap.conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserInfo(username string) (dn string, email string, err error) {
|
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
||||||
|
// Escape the username to prevent LDAP injection
|
||||||
escapedUsername := ldapgo.EscapeFilter(username)
|
escapedUsername := ldapgo.EscapeFilter(username)
|
||||||
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
|
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
ldap.config.LDAP.BaseDN,
|
ldap.config.BaseDN,
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
filter,
|
filter,
|
||||||
[]string{"dn", "mail"},
|
[]string{"dn"},
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,22 +161,22 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
|||||||
|
|
||||||
searchResult, err := ldap.conn.Search(searchRequest)
|
searchResult, err := ldap.conn.Search(searchRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(searchResult.Entries) != 1 {
|
if len(searchResult.Entries) != 1 {
|
||||||
return "", "", fmt.Errorf("multiple or no entries found for user %s", username)
|
return "", fmt.Errorf("multiple or no entries found for user %s", username)
|
||||||
}
|
}
|
||||||
|
|
||||||
entry := searchResult.Entries[0]
|
userDN := searchResult.Entries[0].DN
|
||||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
return userDN, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
ldap.config.LDAP.BaseDN,
|
ldap.config.BaseDN,
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
|
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
|
||||||
[]string{"dn"},
|
[]string{"dn"},
|
||||||
@@ -214,7 +224,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
|||||||
if ldap.cert != nil {
|
if ldap.cert != nil {
|
||||||
return ldap.conn.ExternalBind()
|
return ldap.conn.ExternalBind()
|
||||||
}
|
}
|
||||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||||
@@ -228,7 +238,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) heartbeat() error {
|
func (ldap *LdapService) heartbeat() error {
|
||||||
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
|
tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
|
||||||
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
"",
|
"",
|
||||||
@@ -250,7 +260,7 @@ func (ldap *LdapService) heartbeat() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) reconnect() error {
|
func (ldap *LdapService) reconnect() error {
|
||||||
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
|
tlog.App.Info().Msg("Reconnecting to LDAP server")
|
||||||
|
|
||||||
exp := backoff.NewExponentialBackOff()
|
exp := backoff.NewExponentialBackOff()
|
||||||
exp.InitialInterval = 500 * time.Millisecond
|
exp.InitialInterval = 500 * time.Millisecond
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -21,39 +19,33 @@ type OAuthServiceImpl interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerService struct {
|
type OAuthBrokerService struct {
|
||||||
log *logger.Logger
|
|
||||||
|
|
||||||
services map[string]OAuthServiceImpl
|
services map[string]OAuthServiceImpl
|
||||||
configs map[string]model.OAuthServiceConfig
|
configs map[string]model.OAuthServiceConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
|
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
|
||||||
"github": newGitHubOAuthService,
|
"github": newGitHubOAuthService,
|
||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthBrokerService(
|
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
|
||||||
log *logger.Logger,
|
return &OAuthBrokerService{
|
||||||
configs map[string]model.OAuthServiceConfig,
|
|
||||||
ctx context.Context,
|
|
||||||
) *OAuthBrokerService {
|
|
||||||
service := &OAuthBrokerService{
|
|
||||||
log: log,
|
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]OAuthServiceImpl),
|
||||||
configs: configs,
|
configs: configs,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for name, cfg := range configs {
|
func (broker *OAuthBrokerService) Init() error {
|
||||||
|
for name, cfg := range broker.configs {
|
||||||
if presetFunc, exists := presets[name]; exists {
|
if presetFunc, exists := presets[name]; exists {
|
||||||
service.services[name] = presetFunc(cfg, ctx)
|
broker.services[name] = presetFunc(cfg)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||||
} else {
|
} else {
|
||||||
service.services[name] = NewOAuthService(cfg, name, ctx)
|
broker.services[name] = NewOAuthService(cfg, name)
|
||||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return service
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
||||||
|
|||||||
@@ -1,25 +1,23 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"golang.org/x/oauth2/endpoints"
|
"golang.org/x/oauth2/endpoints"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
|
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"openid", "email", "profile"}
|
scopes := []string{"openid", "email", "profile"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.Google.AuthURL
|
config.AuthURL = endpoints.Google.AuthURL
|
||||||
config.TokenURL = endpoints.Google.TokenURL
|
config.TokenURL = endpoints.Google.TokenURL
|
||||||
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
|
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
|
||||||
return NewOAuthService(config, "google", ctx)
|
return NewOAuthService(config, "google")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
|
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"read:user", "user:email"}
|
scopes := []string{"read:user", "user:email"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.GitHub.AuthURL
|
config.AuthURL = endpoints.GitHub.AuthURL
|
||||||
config.TokenURL = endpoints.GitHub.TokenURL
|
config.TokenURL = endpoints.GitHub.TokenURL
|
||||||
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
|
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type OAuthService struct {
|
|||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
|
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
@@ -29,7 +29,8 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
ctx := context.Background()
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
return &OAuthService{
|
return &OAuthService{
|
||||||
serviceCfg: config,
|
serviceCfg: config,
|
||||||
@@ -43,7 +44,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
|
|||||||
TokenURL: config.TokenURL,
|
TokenURL: config.TokenURL,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: vctx,
|
ctx: ctx,
|
||||||
userinfoExtractor: defaultExtractor,
|
userinfoExtractor: defaultExtractor,
|
||||||
id: id,
|
id: id,
|
||||||
}
|
}
|
||||||
|
|||||||
+130
-138
@@ -16,7 +16,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
@@ -26,7 +25,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -112,173 +111,172 @@ type AuthorizeRequest struct {
|
|||||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OIDCService struct {
|
type OIDCServiceConfig struct {
|
||||||
log *logger.Logger
|
Clients map[string]model.OIDCClientConfig
|
||||||
config model.Config
|
PrivateKeyPath string
|
||||||
runtime model.RuntimeConfig
|
PublicKeyPath string
|
||||||
queries *repository.Queries
|
Issuer string
|
||||||
context context.Context
|
SessionExpiry int
|
||||||
|
|
||||||
clients map[string]model.OIDCClientConfig
|
|
||||||
privateKey *rsa.PrivateKey
|
|
||||||
publicKey crypto.PublicKey
|
|
||||||
issuer string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCService(
|
type OIDCService struct {
|
||||||
log *logger.Logger,
|
config OIDCServiceConfig
|
||||||
config model.Config,
|
queries *repository.Queries
|
||||||
runtime model.RuntimeConfig,
|
clients map[string]model.OIDCClientConfig
|
||||||
queries *repository.Queries,
|
privateKey *rsa.PrivateKey
|
||||||
ctx context.Context,
|
publicKey crypto.PublicKey
|
||||||
wg *sync.WaitGroup) (*OIDCService, error) {
|
issuer string
|
||||||
|
isConfigured bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
|
||||||
|
return &OIDCService{
|
||||||
|
config: config,
|
||||||
|
queries: queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) IsConfigured() bool {
|
||||||
|
return service.isConfigured
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) Init() error {
|
||||||
// If not configured, skip init
|
// If not configured, skip init
|
||||||
if len(runtime.OIDCClients) == 0 {
|
if len(service.config.Clients) == 0 {
|
||||||
return nil, nil
|
service.isConfigured = false
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
service.isConfigured = true
|
||||||
|
|
||||||
// Ensure issuer is https
|
// Ensure issuer is https
|
||||||
uissuer, err := url.Parse(runtime.AppURL)
|
uissuer, err := url.Parse(service.config.Issuer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if uissuer.Scheme != "https" {
|
if uissuer.Scheme != "https" {
|
||||||
return nil, errors.New("issuer must be https")
|
return errors.New("issuer must be https")
|
||||||
}
|
}
|
||||||
|
|
||||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||||
|
|
||||||
// Create/load private and public keys
|
// Create/load private and public keys
|
||||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
|
||||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
strings.TrimSpace(service.config.PublicKeyPath) == "" {
|
||||||
return nil, errors.New("private key path and public key path are required")
|
return errors.New("private key path and public key path are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var privateKey *rsa.PrivateKey
|
var privateKey *rsa.PrivateKey
|
||||||
|
|
||||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||||
if der == nil {
|
if der == nil {
|
||||||
return nil, errors.New("failed to marshal private key")
|
return errors.New("failed to marshal private key")
|
||||||
}
|
}
|
||||||
encoded := pem.EncodeToMemory(&pem.Block{
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.privateKey = privateKey
|
||||||
} else {
|
} else {
|
||||||
block, _ := pem.Decode(fprivateKey)
|
block, _ := pem.Decode(fprivateKey)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode private key")
|
return errors.New("failed to decode private key")
|
||||||
}
|
}
|
||||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.privateKey = privateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
var publicKey crypto.PublicKey
|
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
|
||||||
|
|
||||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
publicKey = privateKey.Public()
|
publicKey := service.privateKey.Public()
|
||||||
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
||||||
if der == nil {
|
if der == nil {
|
||||||
return nil, errors.New("failed to marshal public key")
|
return errors.New("failed to marshal public key")
|
||||||
}
|
}
|
||||||
encoded := pem.EncodeToMemory(&pem.Block{
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "RSA PUBLIC KEY",
|
Type: "RSA PUBLIC KEY",
|
||||||
Bytes: der,
|
Bytes: der,
|
||||||
})
|
})
|
||||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
service.publicKey = publicKey
|
||||||
} else {
|
} else {
|
||||||
block, _ := pem.Decode(fpublicKey)
|
block, _ := pem.Decode(fpublicKey)
|
||||||
if block == nil {
|
if block == nil {
|
||||||
return nil, errors.New("failed to decode public key")
|
return errors.New("failed to decode public key")
|
||||||
}
|
}
|
||||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||||
switch block.Type {
|
switch block.Type {
|
||||||
case "RSA PUBLIC KEY":
|
case "RSA PUBLIC KEY":
|
||||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.publicKey = publicKey
|
||||||
case "PUBLIC KEY":
|
case "PUBLIC KEY":
|
||||||
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
|
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
service.publicKey = publicKey.(crypto.PublicKey)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
|
return fmt.Errorf("unsupported public key type: %s", block.Type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
clients := make(map[string]model.OIDCClientConfig)
|
service.clients = make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range config.OIDC.Clients {
|
for id, client := range service.config.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
if client.Name == "" {
|
if client.Name == "" {
|
||||||
client.Name = utils.Capitalize(client.ID)
|
client.Name = utils.Capitalize(client.ID)
|
||||||
}
|
}
|
||||||
clients[client.ClientID] = client
|
service.clients[client.ClientID] = client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the client secrets from files if they exist
|
// Load the client secrets from files if they exist
|
||||||
for id, client := range clients {
|
for id, client := range service.clients {
|
||||||
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
||||||
if secret != "" {
|
if secret != "" {
|
||||||
client.ClientSecret = secret
|
client.ClientSecret = secret
|
||||||
}
|
}
|
||||||
client.ClientSecretFile = ""
|
client.ClientSecretFile = ""
|
||||||
clients[id] = client
|
service.clients[id] = client
|
||||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the service
|
return nil
|
||||||
service := &OIDCService{
|
|
||||||
log: log,
|
|
||||||
config: config,
|
|
||||||
runtime: runtime,
|
|
||||||
queries: queries,
|
|
||||||
context: ctx,
|
|
||||||
|
|
||||||
clients: clients,
|
|
||||||
privateKey: privateKey,
|
|
||||||
publicKey: publicKey,
|
|
||||||
issuer: issuer,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start cleanup routine
|
|
||||||
wg.Go(service.cleanupRoutine)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetIssuer() string {
|
func (service *OIDCService) GetIssuer() string {
|
||||||
@@ -296,11 +294,6 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
|||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("access_denied")
|
return errors.New("access_denied")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect URI to verify that it's trusted
|
|
||||||
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
|
|
||||||
return errors.New("invalid_request_uri")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scopes
|
// Scopes
|
||||||
scopes := strings.Split(req.Scope, " ")
|
scopes := strings.Split(req.Scope, " ")
|
||||||
@@ -314,7 +307,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
|||||||
return errors.New("invalid_scope")
|
return errors.New("invalid_scope")
|
||||||
}
|
}
|
||||||
if !slices.Contains(SupportedScopes, scope) {
|
if !slices.Contains(SupportedScopes, scope) {
|
||||||
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
|
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +316,11 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
|
|||||||
return errors.New("unsupported_response_type")
|
return errors.New("unsupported_response_type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect URI
|
||||||
|
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
|
||||||
|
return errors.New("invalid_request_uri")
|
||||||
|
}
|
||||||
|
|
||||||
// PKCE code challenge method if set
|
// PKCE code challenge method if set
|
||||||
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
|
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
|
||||||
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
|
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
|
||||||
@@ -359,7 +357,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
entry.CodeChallenge = req.CodeChallenge
|
entry.CodeChallenge = req.CodeChallenge
|
||||||
} else {
|
} else {
|
||||||
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
|
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
|
||||||
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
|
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,7 +449,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
|||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
|
|
||||||
@@ -531,16 +529,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
|
|||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
refreshToken := utils.GenerateString(32)
|
refreshToken := utils.GenerateString(32)
|
||||||
|
|
||||||
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
// Refresh token lives double the time of an access token but can't be used to access userinfo
|
||||||
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
|
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(service.config.Auth.SessionExpiry),
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -600,14 +598,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
|||||||
accessToken := utils.GenerateString(32)
|
accessToken := utils.GenerateString(32)
|
||||||
newRefreshToken := utils.GenerateString(32)
|
newRefreshToken := utils.GenerateString(32)
|
||||||
|
|
||||||
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
|
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
|
||||||
|
|
||||||
tokenResponse := TokenResponse{
|
tokenResponse := TokenResponse{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: newRefreshToken,
|
RefreshToken: newRefreshToken,
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
ExpiresIn: int64(service.config.Auth.SessionExpiry),
|
ExpiresIn: int64(service.config.SessionExpiry),
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||||
}
|
}
|
||||||
@@ -750,62 +748,56 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup routine - Resource heavy due to the linked tables
|
// Cleanup routine - Resource heavy due to the linked tables
|
||||||
func (service *OIDCService) cleanupRoutine() {
|
func (service *OIDCService) Cleanup() {
|
||||||
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
// We need a context for the routine
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for range ticker.C {
|
||||||
select {
|
currentTime := time.Now().Unix()
|
||||||
case <-ticker.C:
|
|
||||||
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||||
|
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||||
|
TokenExpiresAt: currentTime,
|
||||||
|
RefreshTokenExpiresAt: currentTime,
|
||||||
|
})
|
||||||
|
|
||||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
if err != nil {
|
||||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
|
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||||
TokenExpiresAt: currentTime,
|
}
|
||||||
RefreshTokenExpiresAt: currentTime,
|
|
||||||
})
|
for _, expiredToken := range expiredTokens {
|
||||||
|
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||||
|
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expiredCode := range expiredCodes {
|
||||||
|
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
}
|
|
||||||
|
|
||||||
for _, expiredToken := range expiredTokens {
|
|
||||||
err := service.DeleteOldSession(service.context, expiredToken.Sub)
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
|
||||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, expiredCode := range expiredCodes {
|
|
||||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
||||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
|
||||||
err := service.DeleteOldSession(service.context, expiredCode.Sub)
|
|
||||||
if err != nil {
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
case <-service.context.Done():
|
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||||
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
if err != nil {
|
||||||
return
|
tlog.App.Warn().Err(err).Msg("Failed to delete session")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package service_test
|
package service_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -12,7 +10,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() repository.OidcUserinfo {
|
func newTestUser() repository.OidcUserinfo {
|
||||||
@@ -51,29 +48,13 @@ func newTestUser() repository.OidcUserinfo {
|
|||||||
|
|
||||||
func TestCompileUserinfo(t *testing.T) {
|
func TestCompileUserinfo(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
|
svc := service.NewOIDCService(service.OIDCServiceConfig{
|
||||||
cfg := model.Config{
|
PrivateKeyPath: dir + "/key.pem",
|
||||||
OIDC: model.OIDCConfig{
|
PublicKeyPath: dir + "/key.pub",
|
||||||
PrivateKeyPath: dir + "/key.pem",
|
Issuer: "https://tinyauth.example.com",
|
||||||
PublicKeyPath: dir + "/key.pub",
|
SessionExpiry: 3600,
|
||||||
},
|
}, nil)
|
||||||
Auth: model.AuthConfig{
|
require.NoError(t, svc.Init())
|
||||||
SessionExpiry: 3600,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
runtime := model.RuntimeConfig{
|
|
||||||
AppURL: "https://tinyauth.example.com",
|
|
||||||
}
|
|
||||||
|
|
||||||
log := logger.NewLogger().WithTestConfig()
|
|
||||||
log.Init()
|
|
||||||
|
|
||||||
ctx := context.TODO()
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
|
||||||
|
|
||||||
func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
config := model.Config{
|
|
||||||
UI: model.UIConfig{
|
|
||||||
Title: "Tinyauth Test",
|
|
||||||
ForgotPasswordMessage: "foo",
|
|
||||||
BackgroundImage: "/background.jpg",
|
|
||||||
WarningsEnabled: true,
|
|
||||||
},
|
|
||||||
OAuth: model.OAuthConfig{
|
|
||||||
AutoRedirect: "none",
|
|
||||||
},
|
|
||||||
OIDC: model.OIDCConfig{
|
|
||||||
Clients: map[string]model.OIDCClientConfig{
|
|
||||||
"test": {
|
|
||||||
ClientID: "some-client-id",
|
|
||||||
ClientSecret: "some-client-secret",
|
|
||||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
|
||||||
Name: "Test Client",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PrivateKeyPath: filepath.Join(tempDir, "key.pem"),
|
|
||||||
PublicKeyPath: filepath.Join(tempDir, "key.pub"),
|
|
||||||
},
|
|
||||||
Auth: model.AuthConfig{
|
|
||||||
SessionExpiry: 10,
|
|
||||||
LoginTimeout: 10,
|
|
||||||
LoginMaxRetries: 3,
|
|
||||||
},
|
|
||||||
Database: model.DatabaseConfig{
|
|
||||||
Path: filepath.Join(tempDir, "test.db"),
|
|
||||||
},
|
|
||||||
Resources: model.ResourcesConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Path: filepath.Join(tempDir, "resources"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
runtime := model.RuntimeConfig{
|
|
||||||
ConfiguredProviders: []model.Provider{
|
|
||||||
{
|
|
||||||
Name: "Local",
|
|
||||||
ID: "local",
|
|
||||||
OAuth: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LocalUsers: []model.LocalUser{
|
|
||||||
{
|
|
||||||
Username: "testuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "totpuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
TOTPSecret: TestingTOTPSecret,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "attruser",
|
|
||||||
Password: string(passwd),
|
|
||||||
Attributes: model.UserAttributes{
|
|
||||||
Name: "Alice Smith",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "attrtotpuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
TOTPSecret: TestingTOTPSecret,
|
|
||||||
Attributes: model.UserAttributes{
|
|
||||||
Name: "Bob Jones",
|
|
||||||
Email: "bob@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CookieDomain: "example.com",
|
|
||||||
AppURL: "https://tinyauth.example.com",
|
|
||||||
SessionCookieName: "tinyauth-session",
|
|
||||||
OIDCClients: func() []model.OIDCClientConfig {
|
|
||||||
var clients []model.OIDCClientConfig
|
|
||||||
for id, client := range config.OIDC.Clients {
|
|
||||||
client.ID = id
|
|
||||||
clients = append(clients, client)
|
|
||||||
}
|
|
||||||
return clients
|
|
||||||
}(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, runtime
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,12 +22,13 @@ func GetCookieDomain(u string) (string, error) {
|
|||||||
host := parsed.Hostname()
|
host := parsed.Hostname()
|
||||||
|
|
||||||
if netIP := net.ParseIP(host); netIP != nil {
|
if netIP := net.ParseIP(host); netIP != nil {
|
||||||
return "", errors.New("ip addresses not allowed")
|
return "", errors.New("IP addresses not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
parts := strings.Split(host, ".")
|
||||||
|
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
|
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
|
||||||
return host, nil
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,19 +53,7 @@ func GetStandaloneCookieDomain(u string) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
host := parsed.Hostname()
|
return parsed.Hostname(), nil
|
||||||
|
|
||||||
if netIP := net.ParseIP(host); netIP != nil {
|
|
||||||
return "", errors.New("ip addresses not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
|
||||||
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return "", errors.New("invalid app url")
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseFileToLine(content string) string {
|
func ParseFileToLine(content string) string {
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func TestGetRootDomain(t *testing.T) {
|
|||||||
// IP address
|
// IP address
|
||||||
domain = "http://10.10.10.10"
|
domain = "http://10.10.10.10"
|
||||||
_, err = utils.GetCookieDomain(domain)
|
_, err = utils.GetCookieDomain(domain)
|
||||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
assert.ErrorContains(t, err, "IP addresses not allowed")
|
||||||
|
|
||||||
// Invalid URL
|
// Invalid URL
|
||||||
domain = "http://[::1]:namedport"
|
domain = "http://[::1]:namedport"
|
||||||
@@ -180,48 +180,3 @@ func TestIsRedirectSafe(t *testing.T) {
|
|||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
|
||||||
// Normal case
|
|
||||||
domain := "http://tinyauth.app"
|
|
||||||
expected := "tinyauth.app"
|
|
||||||
result, err := utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with subdomain (full hostname is returned, no subdomain stripping)
|
|
||||||
domain = "http://sub.tinyauth.app"
|
|
||||||
expected = "sub.tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with port (port should be stripped)
|
|
||||||
domain = "http://tinyauth.app:8080"
|
|
||||||
expected = "tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// URL with path
|
|
||||||
domain = "https://tinyauth.app/some/path"
|
|
||||||
expected = "tinyauth.app"
|
|
||||||
result, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, result)
|
|
||||||
|
|
||||||
// IP address
|
|
||||||
domain = "http://10.10.10.10"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "ip addresses not allowed")
|
|
||||||
|
|
||||||
// Invalid domain (only TLD)
|
|
||||||
domain = "com"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "invalid app url")
|
|
||||||
|
|
||||||
// Invalid URL
|
|
||||||
domain = "http://[::1]:namedport"
|
|
||||||
_, err = utils.GetStandaloneCookieDomain(domain)
|
|
||||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
HTTP zerolog.Logger
|
|
||||||
App zerolog.Logger
|
|
||||||
config model.LogConfig
|
|
||||||
base zerolog.Logger
|
|
||||||
audit zerolog.Logger
|
|
||||||
writer io.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger() *Logger {
|
|
||||||
return &Logger{
|
|
||||||
writer: os.Stderr,
|
|
||||||
config: model.LogConfig{
|
|
||||||
Level: "error",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
App: model.LogStreamConfig{
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
// No reason to enable audit by default since it will be suppressed by the log level
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithConfig(cfg model.LogConfig) *Logger {
|
|
||||||
l.config = cfg
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithSimpleConfig() *Logger {
|
|
||||||
l.config = model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithTestConfig() *Logger {
|
|
||||||
l.config = model.LogConfig{
|
|
||||||
Level: "trace",
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: true},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) WithWriter(writer io.Writer) *Logger {
|
|
||||||
l.writer = writer
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Init() {
|
|
||||||
base := log.With().
|
|
||||||
Timestamp().
|
|
||||||
Logger().
|
|
||||||
Level(l.parseLogLevel(l.config.Level)).Output(l.writer)
|
|
||||||
|
|
||||||
if !l.config.Json {
|
|
||||||
base = base.Output(zerolog.ConsoleWriter{
|
|
||||||
Out: l.writer,
|
|
||||||
TimeFormat: time.RFC3339,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel {
|
|
||||||
base = base.With().Caller().Logger()
|
|
||||||
}
|
|
||||||
|
|
||||||
l.base = base
|
|
||||||
l.audit = l.createLogger("audit", l.config.Streams.Audit)
|
|
||||||
l.HTTP = l.createLogger("http", l.config.Streams.HTTP)
|
|
||||||
l.App = l.createLogger("app", l.config.Streams.App)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) parseLogLevel(level string) zerolog.Level {
|
|
||||||
if level == "" {
|
|
||||||
return zerolog.InfoLevel
|
|
||||||
}
|
|
||||||
parsed, err := zerolog.ParseLevel(strings.ToLower(level))
|
|
||||||
if err != nil {
|
|
||||||
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error")
|
|
||||||
parsed = zerolog.ErrorLevel
|
|
||||||
}
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger {
|
|
||||||
if !cfg.Enabled {
|
|
||||||
return zerolog.Nop()
|
|
||||||
}
|
|
||||||
sub := l.base.With().Str("stream", component).Logger()
|
|
||||||
if cfg.Level != "" {
|
|
||||||
sub = sub.Level(l.parseLogLevel(cfg.Level))
|
|
||||||
}
|
|
||||||
return sub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) AuditLoginSuccess(username, provider, ip string) {
|
|
||||||
l.audit.Info().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "login").
|
|
||||||
Str("result", "success").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", ip).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) {
|
|
||||||
l.audit.Warn().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "login").
|
|
||||||
Str("result", "failure").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", ip).
|
|
||||||
Str("reason", reason).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) AuditLogout(username, provider, ip string) {
|
|
||||||
l.audit.Info().
|
|
||||||
CallerSkipFrame(1).
|
|
||||||
Str("event", "logout").
|
|
||||||
Str("result", "success").
|
|
||||||
Str("username", username).
|
|
||||||
Str("provider", provider).
|
|
||||||
Str("ip", ip).
|
|
||||||
Send()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used for testing
|
|
||||||
func (l *Logger) GetConfig() model.LogConfig {
|
|
||||||
return l.config
|
|
||||||
}
|
|
||||||
@@ -1,173 +0,0 @@
|
|||||||
package logger_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLogger(t *testing.T) {
|
|
||||||
type testCase struct {
|
|
||||||
description string
|
|
||||||
run func(t *testing.T)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []testCase{
|
|
||||||
{
|
|
||||||
description: "Should create a simple logger with the expected config",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
l := logger.NewLogger().WithSimpleConfig()
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should create a test logger with the expected config",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
l := logger.NewLogger().WithTestConfig()
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, model.LogConfig{
|
|
||||||
Level: "trace",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: true},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should create a logger with a custom config",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
customCfg := model.LogConfig{
|
|
||||||
Level: "debug",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: false},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithConfig(customCfg)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, customCfg)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Default logger should use error type and log json",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithWriter(&buf)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
cfg := l.GetConfig()
|
|
||||||
|
|
||||||
assert.Equal(t, cfg, model.LogConfig{
|
|
||||||
Level: "error",
|
|
||||||
Json: true,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
l.App.Error().Msg("test")
|
|
||||||
|
|
||||||
var entry map[string]any
|
|
||||||
err := json.Unmarshal(buf.Bytes(), &entry)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "test", entry["message"])
|
|
||||||
assert.Equal(t, "app", entry["stream"])
|
|
||||||
assert.Equal(t, "error", entry["level"])
|
|
||||||
assert.NotEmpty(t, entry["time"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should default to error level if an invalid level is provided",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
|
|
||||||
customCfg := model.LogConfig{
|
|
||||||
Level: "invalid",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: true},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel())
|
|
||||||
assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel())
|
|
||||||
|
|
||||||
// should not get logged
|
|
||||||
l.AuditLoginFailure("test", "test", "test", "test")
|
|
||||||
|
|
||||||
assert.Empty(t, buf.String())
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should use nop logger for disabled streams",
|
|
||||||
run: func(t *testing.T) {
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
|
|
||||||
customCfg := model.LogConfig{
|
|
||||||
Level: "info",
|
|
||||||
Json: false,
|
|
||||||
Streams: model.LogStreams{
|
|
||||||
HTTP: model.LogStreamConfig{Enabled: false},
|
|
||||||
App: model.LogStreamConfig{Enabled: true},
|
|
||||||
Audit: model.LogStreamConfig{Enabled: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf)
|
|
||||||
l.Init()
|
|
||||||
|
|
||||||
assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel())
|
|
||||||
|
|
||||||
l.App.Info().Msg("test")
|
|
||||||
|
|
||||||
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
|
|
||||||
|
|
||||||
assert.NotEmpty(t, buf.String())
|
|
||||||
assert.NotContains(t, buf.String(), "test_nop")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, test.run)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -28,41 +28,3 @@ func CoalesceToString(value any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseNonEmptyLines(contents string) []string {
|
|
||||||
lines := make([]string, 0)
|
|
||||||
|
|
||||||
for line := range strings.SplitSeq(contents, "\n") {
|
|
||||||
lineTrimmed := strings.TrimSpace(line)
|
|
||||||
if lineTrimmed == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
lines = append(lines, lineTrimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
return lines
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) {
|
|
||||||
values := make([]string, 0, len(valuesCfg))
|
|
||||||
|
|
||||||
for _, value := range valuesCfg {
|
|
||||||
valueTrimmed := strings.TrimSpace(value)
|
|
||||||
if valueTrimmed == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
values = append(values, valueTrimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
if valuesPath == "" {
|
|
||||||
return values, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
contents, err := ReadFile(valuesPath)
|
|
||||||
if err != nil {
|
|
||||||
return []string{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
values = append(values, ParseNonEmptyLines(contents)...)
|
|
||||||
return values, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package utils_test
|
package utils_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -57,33 +56,3 @@ func TestCompileUserEmail(t *testing.T) {
|
|||||||
// Test with invalid email
|
// Test with invalid email
|
||||||
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
|
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseNonEmptyLines(t *testing.T) {
|
|
||||||
lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n")
|
|
||||||
|
|
||||||
assert.Equal(t, []string{"first@example.com", "second@example.com"}, lines)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStringList(t *testing.T) {
|
|
||||||
file, err := os.Create("/tmp/tinyauth_list_test_file")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = file.WriteString(" third@example.com \n\n fourth@example.com \n")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = file.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer os.Remove("/tmp/tinyauth_list_test_file")
|
|
||||||
|
|
||||||
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values)
|
|
||||||
|
|
||||||
values, err = utils.GetStringList(nil, "")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{}, values)
|
|
||||||
|
|
||||||
values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file")
|
|
||||||
assert.ErrorContains(t, err, "no such file or directory")
|
|
||||||
assert.Equal(t, []string{}, values)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package tlog
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
// functions here use CallerSkipFrame to ensure correct caller info is logged
|
||||||
|
|
||||||
|
func AuditLoginSuccess(c *gin.Context, username, provider string) {
|
||||||
|
Audit.Info().
|
||||||
|
CallerSkipFrame(1).
|
||||||
|
Str("event", "login").
|
||||||
|
Str("result", "success").
|
||||||
|
Str("username", username).
|
||||||
|
Str("provider", provider).
|
||||||
|
Str("ip", c.ClientIP()).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuditLoginFailure(c *gin.Context, username, provider string, reason string) {
|
||||||
|
Audit.Warn().
|
||||||
|
CallerSkipFrame(1).
|
||||||
|
Str("event", "login").
|
||||||
|
Str("result", "failure").
|
||||||
|
Str("username", username).
|
||||||
|
Str("provider", provider).
|
||||||
|
Str("ip", c.ClientIP()).
|
||||||
|
Str("reason", reason).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuditLogout(c *gin.Context, username, provider string) {
|
||||||
|
Audit.Info().
|
||||||
|
CallerSkipFrame(1).
|
||||||
|
Str("event", "logout").
|
||||||
|
Str("result", "success").
|
||||||
|
Str("username", username).
|
||||||
|
Str("provider", provider).
|
||||||
|
Str("ip", c.ClientIP()).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package tlog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
Audit zerolog.Logger
|
||||||
|
HTTP zerolog.Logger
|
||||||
|
App zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
Audit zerolog.Logger
|
||||||
|
HTTP zerolog.Logger
|
||||||
|
App zerolog.Logger
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewLogger(cfg model.LogConfig) *Logger {
|
||||||
|
baseLogger := log.With().
|
||||||
|
Timestamp().
|
||||||
|
Caller().
|
||||||
|
Logger().
|
||||||
|
Level(parseLogLevel(cfg.Level))
|
||||||
|
|
||||||
|
if !cfg.Json {
|
||||||
|
baseLogger = baseLogger.Output(zerolog.ConsoleWriter{
|
||||||
|
Out: os.Stderr,
|
||||||
|
TimeFormat: time.RFC3339,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Logger{
|
||||||
|
Audit: createLogger("audit", cfg.Streams.Audit, baseLogger),
|
||||||
|
HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger),
|
||||||
|
App: createLogger("app", cfg.Streams.App, baseLogger),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSimpleLogger() *Logger {
|
||||||
|
return NewLogger(model.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Json: false,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTestLogger() *Logger {
|
||||||
|
return NewLogger(model.LogConfig{
|
||||||
|
Level: "trace",
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Init() {
|
||||||
|
Audit = l.Audit
|
||||||
|
HTTP = l.HTTP
|
||||||
|
App = l.App
|
||||||
|
}
|
||||||
|
|
||||||
|
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
|
||||||
|
if !streamCfg.Enabled {
|
||||||
|
return zerolog.Nop()
|
||||||
|
}
|
||||||
|
subLogger := baseLogger.With().Str("log_stream", component).Logger()
|
||||||
|
// override level if specified, otherwise use base level
|
||||||
|
if streamCfg.Level != "" {
|
||||||
|
subLogger = subLogger.Level(parseLogLevel(streamCfg.Level))
|
||||||
|
}
|
||||||
|
return subLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogLevel(level string) zerolog.Level {
|
||||||
|
if level == "" {
|
||||||
|
return zerolog.InfoLevel
|
||||||
|
}
|
||||||
|
parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level))
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info")
|
||||||
|
parsedLevel = zerolog.InfoLevel
|
||||||
|
}
|
||||||
|
return parsedLevel
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package tlog_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewLogger(t *testing.T) {
|
||||||
|
cfg := model.LogConfig{
|
||||||
|
Level: "debug",
|
||||||
|
Json: true,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
|
||||||
|
App: model.LogStreamConfig{Enabled: true, Level: ""},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSimpleLogger(t *testing.T) {
|
||||||
|
logger := tlog.NewSimpleLogger()
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerInit(t *testing.T) {
|
||||||
|
logger := tlog.NewSimpleLogger()
|
||||||
|
logger.Init()
|
||||||
|
|
||||||
|
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoggerWithDisabledStreams(t *testing.T) {
|
||||||
|
cfg := model.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Json: false,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: false},
|
||||||
|
App: model.LogStreamConfig{Enabled: false},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
|
||||||
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogStreamField(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
cfg := model.LogConfig{
|
||||||
|
Level: "info",
|
||||||
|
Json: true,
|
||||||
|
Streams: model.LogStreams{
|
||||||
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
|
Audit: model.LogStreamConfig{Enabled: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
|
// Override output for HTTP logger to capture output
|
||||||
|
logger.HTTP = logger.HTTP.Output(&buf)
|
||||||
|
|
||||||
|
logger.HTTP.Info().Msg("test message")
|
||||||
|
|
||||||
|
var logEntry map[string]interface{}
|
||||||
|
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "http", logEntry["log_stream"])
|
||||||
|
assert.Equal(t, "test message", logEntry["message"])
|
||||||
|
}
|
||||||
@@ -13,7 +13,7 @@ func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttribute
|
|||||||
var users []model.LocalUser
|
var users []model.LocalUser
|
||||||
|
|
||||||
if len(usersStr) == 0 {
|
if len(usersStr) == 0 {
|
||||||
return nil, nil
|
return &users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range usersStr {
|
for _, user := range usersStr {
|
||||||
@@ -34,9 +34,32 @@ func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttribute
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
|
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
|
||||||
usersStr, err := GetStringList(usersCfg, usersPath)
|
var usersStr []string
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
if len(usersCfg) == 0 && usersPath == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(usersCfg) > 0 {
|
||||||
|
usersStr = append(usersStr, usersCfg...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if usersPath != "" {
|
||||||
|
contents, err := ReadFile(usersPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.SplitSeq(contents, "\n")
|
||||||
|
|
||||||
|
for line := range lines {
|
||||||
|
lineTrimmed := strings.TrimSpace(line)
|
||||||
|
if lineTrimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
usersStr = append(usersStr, lineTrimmed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ParseUsers(usersStr, userAttributes)
|
return ParseUsers(usersStr, userAttributes)
|
||||||
|
|||||||
Reference in New Issue
Block a user