mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-06 04:18:10 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 04b2290d73 | |||
| e04980468f | |||
| d47e4d3d79 | |||
| f3965a7470 | |||
| 36d4e3ec52 | |||
| eab9f71110 | |||
| e13598bf3c | |||
| 4d3860f860 | |||
| 3b5da06862 | |||
| 8f337aaff8 | |||
| ff3c25c09d | |||
| 26daef7d4e | |||
| c932817757 | |||
| 004df2f852 | |||
| df56708b9a | |||
| 62ffd2fd11 | |||
| a3ec07230c | |||
| b4eb7090bd | |||
| 2f24f823eb | |||
| 9a219046ac | |||
| 97d58b376d | |||
| b426a1529e | |||
| c7efb71a5a | |||
| eec75a6f49 |
@@ -84,7 +84,7 @@ jobs:
|
|||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
cp -r frontend/dist internal/assets/dist
|
cp -r frontend/dist internal/assets/dist
|
||||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ jobs:
|
|||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
cp -r frontend/dist internal/assets/dist
|
cp -r frontend/dist internal/assets/dist
|
||||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ jobs:
|
|||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
cp -r frontend/dist internal/assets/dist
|
cp -r frontend/dist internal/assets/dist
|
||||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ jobs:
|
|||||||
- name: Build
|
- name: Build
|
||||||
run: |
|
run: |
|
||||||
cp -r frontend/dist internal/assets/dist
|
cp -r frontend/dist internal/assets/dist
|
||||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
||||||
env:
|
env:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
|
|
||||||
|
|||||||
+3
-3
@@ -38,9 +38,9 @@ COPY ./internal ./internal
|
|||||||
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
|
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
|
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
|
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
|
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||||
|
|
||||||
# Runner
|
# Runner
|
||||||
FROM alpine:3.23 AS runner
|
FROM alpine:3.23 AS runner
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
|
|||||||
RUN mkdir -p data
|
RUN mkdir -p data
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
|
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
|
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
|
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||||
|
|
||||||
# Runner
|
# Runner
|
||||||
FROM gcr.io/distroless/static-debian12:latest AS runner
|
FROM gcr.io/distroless/static-debian12:latest AS runner
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ webui: clean-webui
|
|||||||
# Build the binary
|
# Build the binary
|
||||||
binary: webui
|
binary: webui
|
||||||
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
|
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \
|
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
|
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
|
||||||
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
|
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \
|
||||||
-o ${BIN_NAME} ./cmd/tinyauth
|
-o ${BIN_NAME} ./cmd/tinyauth
|
||||||
|
|
||||||
# Build for amd64
|
# Build for amd64
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ func generateTotpCmd() *cli.Command {
|
|||||||
docker = true
|
docker = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.TotpSecret != "" {
|
if user.TOTPSecret != "" {
|
||||||
return fmt.Errorf("user already has a TOTP secret")
|
return fmt.Errorf("user already has a TOTP secret")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,14 +102,14 @@ func generateTotpCmd() *cli.Command {
|
|||||||
|
|
||||||
qrterminal.GenerateWithConfig(key.URL(), config)
|
qrterminal.GenerateWithConfig(key.URL(), config)
|
||||||
|
|
||||||
user.TotpSecret = secret
|
user.TOTPSecret = secret
|
||||||
|
|
||||||
// If using docker escape re-escape it
|
// If using docker escape re-escape it
|
||||||
if docker {
|
if docker {
|
||||||
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
|
user.Password = strings.ReplaceAll(user.Password, "$", "$$")
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
|
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"charm.land/huh/v2"
|
"charm.land/huh/v2"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/loaders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
tConfig := config.NewDefaultConfiguration()
|
tConfig := model.NewDefaultConfiguration()
|
||||||
|
|
||||||
loaders := []cli.ResourceLoader{
|
loaders := []cli.ResourceLoader{
|
||||||
&loaders.FileLoader{},
|
&loaders.FileLoader{},
|
||||||
@@ -108,11 +108,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runCmd(cfg config.Config) error {
|
func runCmd(cfg model.Config) error {
|
||||||
logger := tlog.NewLogger(cfg.Log)
|
logger := tlog.NewLogger(cfg.Log)
|
||||||
logger.Init()
|
logger.Init()
|
||||||
|
|
||||||
tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth")
|
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth")
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(cfg)
|
app := bootstrap.NewBootstrapApp(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func verifyUserCmd() *cli.Command {
|
|||||||
return fmt.Errorf("password is incorrect: %w", err)
|
return fmt.Errorf("password is incorrect: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.TotpSecret == "" {
|
if user.TOTPSecret == "" {
|
||||||
if tCfg.Totp != "" {
|
if tCfg.Totp != "" {
|
||||||
tlog.App.Warn().Msg("User does not have TOTP secret")
|
tlog.App.Warn().Msg("User does not have TOTP secret")
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func verifyUserCmd() *cli.Command {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ok := totp.Validate(tCfg.Totp, user.TotpSecret)
|
ok := totp.Validate(tCfg.Totp, user.TOTPSecret)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("TOTP code incorrect")
|
return fmt.Errorf("TOTP code incorrect")
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func versionCmd() *cli.Command {
|
func versionCmd() *cli.Command {
|
||||||
@@ -15,9 +14,9 @@ func versionCmd() *cli.Command {
|
|||||||
Configuration: nil,
|
Configuration: nil,
|
||||||
Resources: nil,
|
Resources: nil,
|
||||||
Run: func(_ []string) error {
|
Run: func(_ []string) error {
|
||||||
fmt.Printf("Version: %s\n", config.Version)
|
fmt.Printf("Version: %s\n", model.Version)
|
||||||
fmt.Printf("Commit Hash: %s\n", config.CommitHash)
|
fmt.Printf("Commit Hash: %s\n", model.CommitHash)
|
||||||
fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp)
|
fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-2
@@ -10,7 +10,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EnvEntry struct {
|
type EnvEntry struct {
|
||||||
@@ -20,7 +20,7 @@ type EnvEntry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateExampleEnv() {
|
func generateExampleEnv() {
|
||||||
cfg := config.NewDefaultConfiguration()
|
cfg := model.NewDefaultConfiguration()
|
||||||
entries := make([]EnvEntry, 0)
|
entries := make([]EnvEntry, 0)
|
||||||
|
|
||||||
root := reflect.TypeOf(cfg).Elem()
|
root := reflect.TypeOf(cfg).Elem()
|
||||||
|
|||||||
+2
-2
@@ -10,7 +10,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MarkdownEntry struct {
|
type MarkdownEntry struct {
|
||||||
@@ -21,7 +21,7 @@ type MarkdownEntry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateMarkdown() {
|
func generateMarkdown() {
|
||||||
cfg := config.NewDefaultConfiguration()
|
cfg := model.NewDefaultConfiguration()
|
||||||
entries := make([]MarkdownEntry, 0)
|
entries := make([]MarkdownEntry, 0)
|
||||||
|
|
||||||
root := reflect.TypeOf(cfg).Elem()
|
root := reflect.TypeOf(cfg).Elem()
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ require (
|
|||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.50.0
|
golang.org/x/crypto v0.50.0
|
||||||
golang.org/x/oauth2 v0.36.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
gotest.tools/v3 v3.5.2
|
|
||||||
k8s.io/apimachinery v0.32.2
|
k8s.io/apimachinery v0.32.2
|
||||||
k8s.io/client-go v0.32.2
|
k8s.io/client-go v0.32.2
|
||||||
modernc.org/sqlite v1.49.1
|
modernc.org/sqlite v1.49.1
|
||||||
@@ -133,6 +132,7 @@ require (
|
|||||||
google.golang.org/protobuf v1.36.11 // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
gotest.tools/v3 v3.5.2 // indirect
|
||||||
k8s.io/klog/v2 v2.130.1 // indirect
|
k8s.io/klog/v2 v2.130.1 // indirect
|
||||||
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
|
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
|
||||||
modernc.org/libc v1.72.0 // indirect
|
modernc.org/libc v1.72.0 // indirect
|
||||||
|
|||||||
@@ -12,15 +12,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config config.Config
|
config model.Config
|
||||||
context struct {
|
context struct {
|
||||||
appUrl string
|
appUrl string
|
||||||
uuid string
|
uuid string
|
||||||
@@ -29,15 +29,15 @@ type BootstrapApp struct {
|
|||||||
csrfCookieName string
|
csrfCookieName string
|
||||||
redirectCookieName string
|
redirectCookieName string
|
||||||
oauthSessionCookieName string
|
oauthSessionCookieName string
|
||||||
users []config.User
|
localUsers *[]model.LocalUser
|
||||||
oauthProviders map[string]config.OAuthServiceConfig
|
oauthProviders map[string]model.OAuthServiceConfig
|
||||||
configuredProviders []controller.Provider
|
configuredProviders []controller.Provider
|
||||||
oidcClients []config.OIDCClientConfig
|
oidcClients []model.OIDCClientConfig
|
||||||
}
|
}
|
||||||
services Services
|
services Services
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBootstrapApp(config config.Config) *BootstrapApp {
|
func NewBootstrapApp(config model.Config) *BootstrapApp {
|
||||||
return &BootstrapApp{
|
return &BootstrapApp{
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.context.users = users
|
app.context.localUsers = users
|
||||||
|
|
||||||
// Setup OAuth providers
|
// Setup OAuth providers
|
||||||
app.context.oauthProviders = app.config.OAuth.Providers
|
app.context.oauthProviders = app.config.OAuth.Providers
|
||||||
@@ -88,7 +88,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
|
|
||||||
for id, provider := range app.context.oauthProviders {
|
for id, provider := range app.context.oauthProviders {
|
||||||
if provider.Name == "" {
|
if provider.Name == "" {
|
||||||
if name, ok := config.OverrideProviders[id]; ok {
|
if name, ok := model.OverrideProviders[id]; ok {
|
||||||
provider.Name = name
|
provider.Name = name
|
||||||
} else {
|
} else {
|
||||||
provider.Name = utils.Capitalize(id)
|
provider.Name = utils.Capitalize(id)
|
||||||
@@ -115,14 +115,14 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
// Cookie names
|
// Cookie names
|
||||||
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
|
app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
|
||||||
cookieId := strings.Split(app.context.uuid, "-")[0]
|
cookieId := strings.Split(app.context.uuid, "-")[0]
|
||||||
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
|
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
|
||||||
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
|
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
|
||||||
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
|
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
|
||||||
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
|
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
|
||||||
|
|
||||||
// Dumps
|
// Dumps
|
||||||
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
|
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
|
||||||
tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump")
|
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump")
|
||||||
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
|
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
|
||||||
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
|
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
|
||||||
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
|
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
|
||||||
@@ -171,7 +171,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if services.authService.LdapAuthConfigured() {
|
if services.authService.LDAPAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, controller.Provider{
|
configuredProviders = append(configuredProviders, controller.Provider{
|
||||||
Name: "LDAP",
|
Name: "LDAP",
|
||||||
ID: "ldap",
|
ID: "ldap",
|
||||||
@@ -244,7 +244,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
|||||||
var body heartbeat
|
var body heartbeat
|
||||||
|
|
||||||
body.UUID = app.context.uuid
|
body.UUID = app.context.uuid
|
||||||
body.Version = config.Version
|
body.Version = model.Version
|
||||||
|
|
||||||
bodyJson, err := json.Marshal(body)
|
bodyJson, err := json.Marshal(body)
|
||||||
|
|
||||||
@@ -257,7 +257,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
|
|||||||
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
|
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
|
||||||
}
|
}
|
||||||
|
|
||||||
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
|
heartbeatURL := model.APIServer + "/v1/instances/heartbeat"
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
tlog.App.Debug().Msg("Sending heartbeat")
|
tlog.App.Debug().Msg("Sending heartbeat")
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
var DEV_MODES = []string{"main", "test", "development"}
|
var DEV_MODES = []string{"main", "test", "development"}
|
||||||
|
|
||||||
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||||
if !slices.Contains(DEV_MODES, config.Version) {
|
if !slices.Contains(DEV_MODES, model.Version) {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,6 +31,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
|||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
|
||||||
CookieDomain: app.context.cookieDomain,
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
}, app.services.authService, app.services.oauthBrokerService)
|
}, app.services.authService, app.services.oauthBrokerService)
|
||||||
|
|
||||||
err := contextMiddleware.Init()
|
err := contextMiddleware.Init()
|
||||||
@@ -99,6 +100,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
|||||||
|
|
||||||
userController := controller.NewUserController(controller.UserControllerConfig{
|
userController := controller.NewUserController(controller.UserControllerConfig{
|
||||||
CookieDomain: app.context.cookieDomain,
|
CookieDomain: app.context.cookieDomain,
|
||||||
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
}, apiRouter, app.services.authService)
|
}, apiRouter, app.services.authService)
|
||||||
|
|
||||||
userController.SetupRoutes()
|
userController.SetupRoutes()
|
||||||
|
|||||||
@@ -22,14 +22,14 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
services := Services{}
|
services := Services{}
|
||||||
|
|
||||||
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||||
Address: app.config.Ldap.Address,
|
Address: app.config.LDAP.Address,
|
||||||
BindDN: app.config.Ldap.BindDN,
|
BindDN: app.config.LDAP.BindDN,
|
||||||
BindPassword: app.config.Ldap.BindPassword,
|
BindPassword: app.config.LDAP.BindPassword,
|
||||||
BaseDN: app.config.Ldap.BaseDN,
|
BaseDN: app.config.LDAP.BaseDN,
|
||||||
Insecure: app.config.Ldap.Insecure,
|
Insecure: app.config.LDAP.Insecure,
|
||||||
SearchFilter: app.config.Ldap.SearchFilter,
|
SearchFilter: app.config.LDAP.SearchFilter,
|
||||||
AuthCert: app.config.Ldap.AuthCert,
|
AuthCert: app.config.LDAP.AuthCert,
|
||||||
AuthKey: app.config.Ldap.AuthKey,
|
AuthKey: app.config.LDAP.AuthKey,
|
||||||
})
|
})
|
||||||
|
|
||||||
err := ldapService.Init()
|
err := ldapService.Init()
|
||||||
@@ -89,7 +89,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
services.oauthBrokerService = oauthBrokerService
|
services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||||
Users: app.context.users,
|
LocalUsers: app.context.localUsers,
|
||||||
OauthWhitelist: app.config.OAuth.Whitelist,
|
OauthWhitelist: app.config.OAuth.Whitelist,
|
||||||
SessionExpiry: app.config.Auth.SessionExpiry,
|
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||||
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
|
||||||
@@ -99,7 +99,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
||||||
SessionCookieName: app.context.sessionCookieName,
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
IP: app.config.Auth.IP,
|
IP: app.config.Auth.IP,
|
||||||
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
|
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
||||||
}, services.ldapService, queries, services.oauthBrokerService)
|
}, services.ldapService, queries, services.oauthBrokerService)
|
||||||
|
|
||||||
err = authService.Init()
|
err = authService.Init()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -19,7 +19,7 @@ type UserContextResponse struct {
|
|||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
OAuth bool `json:"oauth"`
|
OAuth bool `json:"oauth"`
|
||||||
TotpPending bool `json:"totpPending"`
|
TOTPPending bool `json:"totpPending"`
|
||||||
OAuthName string `json:"oauthName"`
|
OAuthName string `json:"oauthName"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,28 +76,29 @@ func (controller *ContextController) SetupRoutes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ContextController) userContextHandler(c *gin.Context) {
|
func (controller *ContextController) userContextHandler(c *gin.Context) {
|
||||||
context, err := utils.GetContext(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Debug().Err(err).Msg("No user context found in request")
|
||||||
|
c.JSON(200, UserContextResponse{
|
||||||
|
Status: 401,
|
||||||
|
Message: "Unauthorized",
|
||||||
|
IsLoggedIn: false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
userContext := UserContextResponse{
|
userContext := UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
IsLoggedIn: context.IsLoggedIn,
|
IsLoggedIn: context.Authenticated,
|
||||||
Username: context.Username,
|
Username: context.GetUsername(),
|
||||||
Name: context.Name,
|
Name: context.GetName(),
|
||||||
Email: context.Email,
|
Email: context.GetEmail(),
|
||||||
Provider: context.Provider,
|
Provider: context.ProviderName(),
|
||||||
OAuth: context.OAuth,
|
OAuth: context.IsOAuth(),
|
||||||
TotpPending: context.TotpPending,
|
TOTPPending: context.TOTPPending(),
|
||||||
OAuthName: context.OAuthName,
|
OAuthName: context.OAuthName(),
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Debug().Err(err).Msg("No user context found in request")
|
|
||||||
userContext.Status = 401
|
|
||||||
userContext.Message = "Unauthorized"
|
|
||||||
userContext.IsLoggedIn = false
|
|
||||||
c.JSON(200, userContext)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, userContext)
|
c.JSON(200, userContext)
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestContextController(t *testing.T) {
|
func TestContextController(t *testing.T) {
|
||||||
@@ -79,12 +79,16 @@ func TestContextController(t *testing.T) {
|
|||||||
description: "Ensure user context returns when authorized",
|
description: "Ensure user context returns when authorized",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
func(c *gin.Context) {
|
func(c *gin.Context) {
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||||
Provider: "local",
|
},
|
||||||
IsLoggedIn: true,
|
},
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
type UnauthorizedQuery struct {
|
||||||
|
Username string `url:"username"`
|
||||||
|
Resource string `url:"resource"`
|
||||||
|
GroupErr bool `url:"groupErr"`
|
||||||
|
IP string `url:"ip"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RedirectQuery struct {
|
||||||
|
RedirectURI string `url:"redirect_uri"`
|
||||||
|
}
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -176,7 +175,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
|
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
|
||||||
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
|
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
|
||||||
|
|
||||||
queries, err := query.Values(config.UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -236,7 +235,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
@@ -244,6 +243,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
||||||
|
|
||||||
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
|
||||||
@@ -259,7 +260,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oauthPendingSession.CallbackParams.RedirectURI != "" {
|
if oauthPendingSession.CallbackParams.RedirectURI != "" {
|
||||||
queries, err := query.Values(config.RedirectQuery{
|
queries, err := query.Values(RedirectQuery{
|
||||||
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
|
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
|
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -111,14 +112,14 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext, err := utils.GetContext(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
|
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !userContext.IsLoggedIn {
|
if !userContext.Authenticated {
|
||||||
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
|
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -151,7 +152,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
|
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
|
||||||
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
|
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
|
||||||
code := utils.GenerateString(32)
|
code := utils.GenerateString(32)
|
||||||
|
|
||||||
// Before storing the code, delete old session
|
// Before storing the code, delete old session
|
||||||
@@ -170,7 +171,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
|
|
||||||
// We also need a snapshot of the user that authorized this (skip if no openid scope)
|
// We also need a snapshot of the user that authorized this (skip if no openid scope)
|
||||||
if slices.Contains(strings.Fields(req.Scope), "openid") {
|
if slices.Contains(strings.Fields(req.Scope), "openid") {
|
||||||
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
|
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOIDCController(t *testing.T) {
|
func TestOIDCController(t *testing.T) {
|
||||||
@@ -27,7 +27,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
oidcServiceCfg := service.OIDCServiceConfig{
|
oidcServiceCfg := service.OIDCServiceConfig{
|
||||||
Clients: map[string]config.OIDCClientConfig{
|
Clients: map[string]model.OIDCClientConfig{
|
||||||
"test": {
|
"test": {
|
||||||
ClientID: "some-client-id",
|
ClientID: "some-client-id",
|
||||||
ClientSecret: "some-client-secret",
|
ClientSecret: "some-client-secret",
|
||||||
@@ -44,12 +44,16 @@ func TestOIDCController(t *testing.T) {
|
|||||||
controllerCfg := controller.OIDCControllerConfig{}
|
controllerCfg := controller.OIDCControllerConfig{}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
Username: "test",
|
Username: "test",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
IsLoggedIn: true,
|
},
|
||||||
Provider: "local",
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
@@ -848,7 +852,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -103,7 +103,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
|
|
||||||
if controller.auth.IsBypassedIP(acls.IP, clientIP) {
|
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -112,7 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
|
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
||||||
@@ -130,8 +130,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.CheckIP(acls.IP, clientIP) {
|
if !controller.auth.CheckIP(clientIP, acls) {
|
||||||
queries, err := query.Values(config.UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
IP: clientIP,
|
IP: clientIP,
|
||||||
})
|
})
|
||||||
@@ -157,28 +157,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var userContext config.UserContext
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
context, err := utils.GetContext(c)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
|
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
|
||||||
userContext = config.UserContext{
|
userContext = &model.UserContext{
|
||||||
IsLoggedIn: false,
|
Authenticated: false,
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
userContext = context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
|
||||||
|
|
||||||
if userContext.IsLoggedIn {
|
if userContext.Authenticated {
|
||||||
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
|
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
|
||||||
|
|
||||||
if !userAllowed {
|
if !userAllowed {
|
||||||
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
|
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
|
||||||
|
|
||||||
queries, err := query.Values(config.UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -188,10 +184,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.OAuth {
|
if userContext.IsOAuth() {
|
||||||
queries.Set("username", userContext.Email)
|
queries.Set("username", userContext.GetEmail())
|
||||||
} else {
|
} else {
|
||||||
queries.Set("username", userContext.Username)
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
@@ -209,19 +205,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.OAuth || userContext.Provider == "ldap" {
|
if userContext.IsOAuth() || userContext.IsLDAP() {
|
||||||
var groupOK bool
|
var groupOK bool
|
||||||
|
|
||||||
if userContext.OAuth {
|
if userContext.IsOAuth() {
|
||||||
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
|
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
|
||||||
} else {
|
} else {
|
||||||
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
|
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !groupOK {
|
if !groupOK {
|
||||||
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
|
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
|
||||||
|
|
||||||
queries, err := query.Values(config.UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
GroupErr: true,
|
GroupErr: true,
|
||||||
})
|
})
|
||||||
@@ -232,10 +228,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.OAuth {
|
if userContext.IsOAuth() {
|
||||||
queries.Set("username", userContext.Email)
|
queries.Set("username", userContext.GetEmail())
|
||||||
} else {
|
} else {
|
||||||
queries.Set("username", userContext.Username)
|
queries.Set("username", userContext.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
|
||||||
@@ -254,17 +250,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
|
c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername()))
|
||||||
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
|
c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName()))
|
||||||
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
|
c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail()))
|
||||||
|
|
||||||
if userContext.Provider == "ldap" {
|
if userContext.IsLDAP() {
|
||||||
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
|
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ",")))
|
||||||
} else if userContext.Provider != "local" {
|
|
||||||
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
|
if userContext.IsOAuth() {
|
||||||
|
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
|
||||||
|
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
|
||||||
|
}
|
||||||
|
|
||||||
controller.setHeaders(c, acls)
|
controller.setHeaders(c, acls)
|
||||||
|
|
||||||
@@ -275,7 +272,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(config.RedirectQuery{
|
queries, err := query.Values(RedirectQuery{
|
||||||
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
|
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -299,9 +296,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
|
||||||
|
if acls == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
headers := utils.ParseHeaders(acls.Response.Headers)
|
headers := utils.ParseHeaders(acls.Response.Headers)
|
||||||
|
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
@@ -313,7 +314,7 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
|
|||||||
|
|
||||||
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
||||||
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
|
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
|
||||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,14 +6,14 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyController(t *testing.T) {
|
func TestProxyController(t *testing.T) {
|
||||||
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
authServiceCfg := service.AuthServiceConfig{
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
Users: []config.User{
|
LocalUsers: &[]model.LocalUser{
|
||||||
{
|
{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
@@ -29,7 +29,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SessionExpiry: 10, // 10 seconds, useful for testing
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
@@ -43,28 +43,28 @@ func TestProxyController(t *testing.T) {
|
|||||||
AppURL: "https://tinyauth.example.com",
|
AppURL: "https://tinyauth.example.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
acls := map[string]config.App{
|
acls := map[string]model.App{
|
||||||
"app_path_allow": {
|
"app_path_allow": {
|
||||||
Config: config.AppConfig{
|
Config: model.AppConfig{
|
||||||
Domain: "path-allow.example.com",
|
Domain: "path-allow.example.com",
|
||||||
},
|
},
|
||||||
Path: config.AppPath{
|
Path: model.AppPath{
|
||||||
Allow: "/allowed",
|
Allow: "/allowed",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"app_user_allow": {
|
"app_user_allow": {
|
||||||
Config: config.AppConfig{
|
Config: model.AppConfig{
|
||||||
Domain: "user-allow.example.com",
|
Domain: "user-allow.example.com",
|
||||||
},
|
},
|
||||||
Users: config.AppUsers{
|
Users: model.AppUsers{
|
||||||
Allow: "testuser",
|
Allow: "testuser",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ip_bypass": {
|
"ip_bypass": {
|
||||||
Config: config.AppConfig{
|
Config: model.AppConfig{
|
||||||
Domain: "ip-bypass.example.com",
|
Domain: "ip-bypass.example.com",
|
||||||
},
|
},
|
||||||
IP: config.AppIP{
|
IP: model.AppIP{
|
||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -74,24 +74,32 @@ func TestProxyController(t *testing.T) {
|
|||||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Name: "Testuser",
|
Name: "Testuser",
|
||||||
Email: "testuser@example.com",
|
Email: "testuser@example.com",
|
||||||
IsLoggedIn: true,
|
},
|
||||||
Provider: "local",
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtxTotp := func(c *gin.Context) {
|
simpleCtxTotp := func(c *gin.Context) {
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Name: "Totpuser",
|
Name: "Totpuser",
|
||||||
Email: "totpuser@example.com",
|
Email: "totpuser@example.com",
|
||||||
IsLoggedIn: true,
|
},
|
||||||
Provider: "local",
|
TOTPEnabled: true,
|
||||||
TotpEnabled: true,
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
@@ -391,9 +399,9 @@ func TestProxyController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -25,6 +27,7 @@ type TotpRequest struct {
|
|||||||
|
|
||||||
type UserControllerConfig struct {
|
type UserControllerConfig struct {
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
|
SessionCookieName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserController struct {
|
type UserController struct {
|
||||||
@@ -77,9 +80,10 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userSearch := controller.auth.SearchUser(req.Username)
|
search, err := controller.auth.SearchUser(req.Username)
|
||||||
|
|
||||||
if userSearch.Type == "unknown" {
|
if err != nil {
|
||||||
|
if errors.Is(err, service.ErrUserNotFound) {
|
||||||
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
|
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
|
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
|
||||||
@@ -89,8 +93,15 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !controller.auth.VerifyUser(userSearch, req.Password) {
|
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
||||||
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
|
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
|
||||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||||
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
|
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
|
||||||
@@ -106,30 +117,26 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
controller.auth.RecordLoginAttempt(req.Username, true)
|
controller.auth.RecordLoginAttempt(req.Username, true)
|
||||||
|
|
||||||
var localUser *config.User
|
var localUser *model.LocalUser
|
||||||
if userSearch.Type == "local" {
|
|
||||||
user := controller.auth.GetLocalUser(userSearch.Username)
|
|
||||||
localUser = &user
|
|
||||||
}
|
|
||||||
|
|
||||||
if userSearch.Type == "local" && localUser != nil {
|
if search.Type == model.UserLocal {
|
||||||
user := *localUser
|
localUser = controller.auth.GetLocalUser(req.Username)
|
||||||
|
|
||||||
if user.TotpSecret != "" {
|
if localUser.TOTPSecret != "" {
|
||||||
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||||
|
|
||||||
name := user.Attributes.Name
|
name := localUser.Attributes.Name
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = utils.Capitalize(user.Username)
|
name = utils.Capitalize(localUser.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
email := user.Attributes.Email
|
email := localUser.Attributes.Email
|
||||||
if email == "" {
|
if email == "" {
|
||||||
email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain)
|
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := controller.auth.CreateSessionCookie(c, &repository.Session{
|
cookie, err := controller.auth.CreateSession(c, repository.Session{
|
||||||
Username: user.Username,
|
Username: localUser.Username,
|
||||||
Name: name,
|
Name: name,
|
||||||
Email: email,
|
Email: email,
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
@@ -145,6 +152,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "TOTP required",
|
"message": "TOTP required",
|
||||||
@@ -161,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
if userSearch.Type == "local" && localUser != nil {
|
if search.Type == model.UserLocal {
|
||||||
if localUser.Attributes.Name != "" {
|
if localUser.Attributes.Name != "" {
|
||||||
sessionCookie.Name = localUser.Attributes.Name
|
sessionCookie.Name = localUser.Attributes.Name
|
||||||
}
|
}
|
||||||
@@ -170,13 +179,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if userSearch.Type == "ldap" {
|
if search.Type == model.UserLDAP {
|
||||||
sessionCookie.Provider = "ldap"
|
sessionCookie.Provider = "ldap"
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
@@ -187,6 +196,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
@@ -196,12 +207,46 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
func (controller *UserController) logoutHandler(c *gin.Context) {
|
func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||||
tlog.App.Debug().Msg("Logout request received")
|
tlog.App.Debug().Msg("Logout request received")
|
||||||
|
|
||||||
controller.auth.DeleteSessionCookie(c)
|
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
||||||
|
|
||||||
context, err := utils.GetContext(c)
|
if err != nil {
|
||||||
if err == nil && context.IsLoggedIn {
|
if errors.Is(err, http.ErrNoCookie) {
|
||||||
tlog.AuditLogout(c, context.Username, context.Provider)
|
tlog.App.Warn().Msg("No session cookie found on logout request")
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"message": "Logout successful",
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||||
|
} else {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||||
|
tlog.AuditLogout(c, "unknown", "unknown")
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -222,7 +267,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
context, err := utils.GetContext(c)
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
||||||
@@ -233,7 +278,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !context.TotpPending {
|
if !context.TOTPPending() {
|
||||||
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
|
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
@@ -242,12 +287,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt")
|
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
|
||||||
|
|
||||||
isLocked, remaining := controller.auth.IsAccountLocked(context.Username)
|
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
|
||||||
|
|
||||||
if isLocked {
|
if isLocked {
|
||||||
tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts")
|
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
@@ -257,14 +302,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user := controller.auth.GetLocalUser(context.Username)
|
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||||
|
|
||||||
ok := totp.Validate(req.Code, user.TotpSecret)
|
if user == nil {
|
||||||
|
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||||
if !ok {
|
|
||||||
tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code")
|
|
||||||
controller.auth.RecordLoginAttempt(context.Username, false)
|
|
||||||
tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code")
|
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -272,10 +313,23 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful")
|
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||||
tlog.AuditLoginSuccess(c, context.Username, "totp")
|
|
||||||
|
|
||||||
controller.auth.RecordLoginAttempt(context.Username, true)
|
if !ok {
|
||||||
|
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
|
||||||
|
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
|
||||||
|
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"status": 401,
|
||||||
|
"message": "Unauthorized",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
|
||||||
|
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
|
||||||
|
|
||||||
|
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
||||||
|
|
||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
@@ -293,7 +347,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
|
cookie, err := controller.auth.CreateSession(c, sessionCookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
@@ -304,6 +358,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
|
|||||||
@@ -10,14 +10,14 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pquerna/otp/totp"
|
"github.com/pquerna/otp/totp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUserController(t *testing.T) {
|
func TestUserController(t *testing.T) {
|
||||||
@@ -25,7 +25,7 @@ func TestUserController(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
authServiceCfg := service.AuthServiceConfig{
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
Users: []config.User{
|
LocalUsers: &[]model.LocalUser{
|
||||||
{
|
{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
@@ -33,12 +33,12 @@ func TestUserController(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "attruser",
|
Username: "attruser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
Attributes: config.UserAttributes{
|
Attributes: model.UserAttributes{
|
||||||
Name: "Alice Smith",
|
Name: "Alice Smith",
|
||||||
Email: "alice@example.com",
|
Email: "alice@example.com",
|
||||||
},
|
},
|
||||||
@@ -46,8 +46,8 @@ func TestUserController(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Username: "attrtotpuser",
|
Username: "attrtotpuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
Attributes: config.UserAttributes{
|
Attributes: model.UserAttributes{
|
||||||
Name: "Bob Jones",
|
Name: "Bob Jones",
|
||||||
Email: "bob@example.com",
|
Email: "bob@example.com",
|
||||||
},
|
},
|
||||||
@@ -62,6 +62,53 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
userControllerCfg := controller.UserControllerConfig{
|
userControllerCfg := controller.UserControllerConfig{
|
||||||
CookieDomain: "example.com",
|
CookieDomain: "example.com",
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpCtx := func(c *gin.Context) {
|
||||||
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: false,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "totpuser",
|
||||||
|
Name: "Totpuser",
|
||||||
|
Email: "totpuser@example.com",
|
||||||
|
},
|
||||||
|
TOTPPending: true,
|
||||||
|
TOTPEnabled: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: false,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "attrtotpuser",
|
||||||
|
Name: "Bob Jones",
|
||||||
|
Email: "bob@example.com",
|
||||||
|
},
|
||||||
|
TOTPPending: true,
|
||||||
|
TOTPEnabled: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
simpleCtx := func(c *gin.Context) {
|
||||||
|
c.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Test User",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -94,7 +141,7 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.True(t, cookie.HttpOnly)
|
assert.True(t, cookie.HttpOnly)
|
||||||
assert.Equal(t, "example.com", cookie.Domain)
|
assert.Equal(t, "example.com", cookie.Domain)
|
||||||
assert.Equal(t, 10, cookie.MaxAge)
|
assert.Equal(t, 9, cookie.MaxAge)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -183,12 +230,14 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.True(t, cookie.HttpOnly)
|
assert.True(t, cookie.HttpOnly)
|
||||||
assert.Equal(t, "example.com", cookie.Domain)
|
assert.Equal(t, "example.com", cookie.Domain)
|
||||||
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to logout",
|
description: "Should be able to logout",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
// First login to get a session cookie
|
// First login to get a session cookie
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := controller.LoginRequest{
|
||||||
@@ -204,9 +253,10 @@ func TestUserController(t *testing.T) {
|
|||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
cookies := recorder.Result().Cookies()
|
||||||
|
assert.Len(t, cookies, 1)
|
||||||
|
|
||||||
cookie := recorder.Result().Cookies()[0]
|
cookie := cookies[0]
|
||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
|
|
||||||
// Now logout using the session cookie
|
// Now logout using the session cookie
|
||||||
@@ -217,17 +267,20 @@ func TestUserController(t *testing.T) {
|
|||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
cookies = recorder.Result().Cookies()
|
||||||
|
assert.Len(t, cookies, 1)
|
||||||
|
|
||||||
logoutCookie := recorder.Result().Cookies()[0]
|
cookie = cookies[0]
|
||||||
assert.Equal(t, "tinyauth-session", logoutCookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.Equal(t, "", logoutCookie.Value)
|
assert.Equal(t, "", cookie.Value)
|
||||||
assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{
|
||||||
|
totpCtx,
|
||||||
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -253,12 +306,14 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
||||||
assert.True(t, totpCookie.HttpOnly)
|
assert.True(t, totpCookie.HttpOnly)
|
||||||
assert.Equal(t, "example.com", totpCookie.Domain)
|
assert.Equal(t, "example.com", totpCookie.Domain)
|
||||||
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
|
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Totp should rate limit on multiple invalid attempts",
|
description: "Totp should rate limit on multiple invalid attempts",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{
|
||||||
|
totpCtx,
|
||||||
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
for range 3 {
|
for range 3 {
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := controller.TotpRequest{
|
||||||
@@ -328,7 +383,9 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTP completion uses name and email from user attributes",
|
description: "TOTP completion uses name and email from user attributes",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{
|
||||||
|
totpAttrCtx,
|
||||||
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -349,9 +406,9 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -379,33 +436,6 @@ func TestUserController(t *testing.T) {
|
|||||||
authService.ClearRateLimitsTestingOnly()
|
authService.ClearRateLimitsTestingOnly()
|
||||||
}
|
}
|
||||||
|
|
||||||
setTotpMiddlewareOverrides := map[string]config.UserContext{
|
|
||||||
"Should be able to login with totp": {
|
|
||||||
Username: "totpuser",
|
|
||||||
Name: "Totpuser",
|
|
||||||
Email: "totpuser@example.com",
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: true,
|
|
||||||
TotpEnabled: true,
|
|
||||||
},
|
|
||||||
"Totp should rate limit on multiple invalid attempts": {
|
|
||||||
Username: "totpuser",
|
|
||||||
Name: "Totpuser",
|
|
||||||
Email: "totpuser@example.com",
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: true,
|
|
||||||
TotpEnabled: true,
|
|
||||||
},
|
|
||||||
"TOTP completion uses name and email from user attributes": {
|
|
||||||
Username: "attrtotpuser",
|
|
||||||
Name: "Bob Jones",
|
|
||||||
Email: "bob@example.com",
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: true,
|
|
||||||
TotpEnabled: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
beforeEach()
|
beforeEach()
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
@@ -415,15 +445,6 @@ func TestUserController(t *testing.T) {
|
|||||||
router.Use(middleware)
|
router.Use(middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gin is stupid and doesn't allow setting a middleware after the groups
|
|
||||||
// so we need to do some stupid overrides here
|
|
||||||
if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok {
|
|
||||||
ctx := ctx
|
|
||||||
router.Use(func(c *gin.Context) {
|
|
||||||
c.Set("context", &ctx)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,14 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWellKnownController(t *testing.T) {
|
func TestWellKnownController(t *testing.T) {
|
||||||
@@ -23,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
oidcServiceCfg := service.OIDCServiceConfig{
|
oidcServiceCfg := service.OIDCServiceConfig{
|
||||||
Clients: map[string]config.OIDCClientConfig{
|
Clients: map[string]model.OIDCClientConfig{
|
||||||
"test": {
|
"test": {
|
||||||
ClientID: "some-client-id",
|
ClientID: "some-client-id",
|
||||||
ClientSecret: "some-client-secret",
|
ClientSecret: "some-client-secret",
|
||||||
@@ -101,7 +101,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -34,6 +37,7 @@ var (
|
|||||||
|
|
||||||
type ContextMiddlewareConfig struct {
|
type ContextMiddlewareConfig struct {
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
|
SessionCookieName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddleware struct {
|
||||||
@@ -61,200 +65,191 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := m.auth.GetSessionCookie(c)
|
uuid, err := c.Cookie(m.config.SessionCookieName)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if cookie != nil {
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||||
|
c.Set("context", userContext)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
username, password, ok := c.Request.BasicAuth()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
userContext, headers, err := m.basicAuth(username, password)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("No valid session cookie found")
|
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||||
goto basic
|
|
||||||
}
|
|
||||||
|
|
||||||
if cookie.TotpPending {
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: true,
|
|
||||||
TotpEnabled: true,
|
|
||||||
})
|
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch cookie.Provider {
|
for k, v := range headers {
|
||||||
case "local", "ldap":
|
c.Header(k, v)
|
||||||
userSearch := m.auth.SearchUser(cookie.Username)
|
|
||||||
|
|
||||||
if userSearch.Type == "unknown" {
|
|
||||||
tlog.App.Debug().Msg("User from session cookie not found")
|
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
goto basic
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userSearch.Type != cookie.Provider {
|
c.Set("context", userContext)
|
||||||
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
|
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ldapGroups []string
|
c.Next()
|
||||||
var localAttributes config.UserAttributes
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if cookie.Provider == "ldap" {
|
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
|
||||||
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
|
session, err := m.auth.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
|
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ldapGroups = ldapUser.Groups
|
userContext, err := new(model.UserContext).NewFromSession(session)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cookie.Provider == "local" {
|
if userContext.Provider == model.ProviderLocal &&
|
||||||
localUser := m.auth.GetLocalUser(cookie.Username)
|
userContext.Local.TOTPPending {
|
||||||
localAttributes = localUser.Attributes
|
userContext.Local.TOTPEnabled = true
|
||||||
|
return userContext, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.auth.RefreshSessionCookie(c)
|
switch userContext.Provider {
|
||||||
c.Set("context", &config.UserContext{
|
case model.ProviderLocal:
|
||||||
Username: cookie.Username,
|
user := m.auth.GetLocalUser(userContext.Local.Username)
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
if user == nil {
|
||||||
Provider: cookie.Provider,
|
return nil, nil, fmt.Errorf("local user not found")
|
||||||
IsLoggedIn: true,
|
}
|
||||||
LdapGroups: strings.Join(ldapGroups, ","),
|
|
||||||
Attributes: localAttributes,
|
userContext.Local.Attributes = user.Attributes
|
||||||
})
|
|
||||||
c.Next()
|
if userContext.Local.Attributes.Name == "" {
|
||||||
return
|
userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
|
||||||
default:
|
}
|
||||||
_, exists := m.broker.GetService(cookie.Provider)
|
|
||||||
|
if userContext.Local.Attributes.Email == "" {
|
||||||
|
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
||||||
|
}
|
||||||
|
case model.ProviderLDAP:
|
||||||
|
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if search.Type != model.UserLDAP {
|
||||||
|
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := m.auth.GetLDAPUser(search.Username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.LDAP.Groups = user.Groups
|
||||||
|
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
||||||
|
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
|
||||||
|
case model.ProviderOAuth:
|
||||||
|
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
tlog.App.Debug().Msg("OAuth provider from session cookie not found")
|
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
goto basic
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.auth.IsEmailWhitelisted(cookie.Email) {
|
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
|
||||||
tlog.App.Debug().Msg("Email from session cookie not whitelisted")
|
m.auth.DeleteSession(ctx, uuid)
|
||||||
m.auth.DeleteSessionCookie(c)
|
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||||
goto basic
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.auth.RefreshSessionCookie(c)
|
cookie, err := m.auth.RefreshSession(ctx, uuid)
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
OAuthGroups: cookie.OAuthGroups,
|
|
||||||
OAuthName: cookie.OAuthName,
|
|
||||||
OAuthSub: cookie.OAuthSub,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
OAuth: true,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
basic:
|
|
||||||
basic := m.auth.GetBasicAuth(c)
|
|
||||||
|
|
||||||
if basic == nil {
|
|
||||||
tlog.App.Debug().Msg("No basic auth provided")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
locked, remaining := m.auth.IsAccountLocked(basic.Username)
|
|
||||||
|
|
||||||
if locked {
|
|
||||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
|
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userSearch := m.auth.SearchUser(basic.Username)
|
|
||||||
|
|
||||||
if userSearch.Type == "unknown" || userSearch.Type == "error" {
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
|
||||||
tlog.App.Debug().Msg("User from basic auth not found")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !m.auth.VerifyUser(userSearch, basic.Password) {
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
|
||||||
tlog.App.Debug().Msg("Invalid password for basic auth user")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, true)
|
|
||||||
|
|
||||||
switch userSearch.Type {
|
|
||||||
case "local":
|
|
||||||
tlog.App.Debug().Msg("Basic auth user is local")
|
|
||||||
|
|
||||||
user := m.auth.GetLocalUser(basic.Username)
|
|
||||||
|
|
||||||
if user.TotpSecret != "" {
|
|
||||||
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
name := utils.Capitalize(user.Username)
|
|
||||||
if user.Attributes.Name != "" {
|
|
||||||
name = user.Attributes.Name
|
|
||||||
}
|
|
||||||
email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
|
||||||
if user.Attributes.Email != "" {
|
|
||||||
email = user.Attributes.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: user.Username,
|
|
||||||
Name: name,
|
|
||||||
Email: email,
|
|
||||||
Provider: "local",
|
|
||||||
IsLoggedIn: true,
|
|
||||||
IsBasicAuth: true,
|
|
||||||
Attributes: user.Attributes,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
case "ldap":
|
|
||||||
tlog.App.Debug().Msg("Basic auth user is LDAP")
|
|
||||||
|
|
||||||
ldapUser, err := m.auth.GetLdapUser(basic.Username)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
|
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set("context", &config.UserContext{
|
return userContext, cookie, nil
|
||||||
Username: basic.Username,
|
}
|
||||||
Name: utils.Capitalize(basic.Username),
|
|
||||||
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
|
func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) {
|
||||||
Provider: "ldap",
|
headers := make(map[string]string)
|
||||||
IsLoggedIn: true,
|
userContext := new(model.UserContext)
|
||||||
LdapGroups: strings.Join(ldapUser.Groups, ","),
|
locked, remaining := m.auth.IsAccountLocked(username)
|
||||||
IsBasicAuth: true,
|
|
||||||
})
|
if locked {
|
||||||
c.Next()
|
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
||||||
return
|
headers["x-tinyauth-lock-locked"] = "true"
|
||||||
|
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||||
|
return nil, headers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Next()
|
search, err := m.auth.SearchUser(username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error searching for user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = m.auth.CheckUserPassword(*search, password)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
m.auth.RecordLoginAttempt(username, false)
|
||||||
|
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.auth.RecordLoginAttempt(username, true)
|
||||||
|
|
||||||
|
switch search.Type {
|
||||||
|
case model.UserLocal:
|
||||||
|
user := m.auth.GetLocalUser(username)
|
||||||
|
|
||||||
|
if user.TOTPSecret != "" {
|
||||||
|
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.Local = &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: user.Username,
|
||||||
|
Name: utils.Capitalize(user.Username),
|
||||||
|
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
||||||
|
},
|
||||||
|
Attributes: user.Attributes,
|
||||||
|
}
|
||||||
|
userContext.Provider = model.ProviderLocal
|
||||||
|
case model.UserLDAP:
|
||||||
|
user, err := m.auth.GetLDAPUser(username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.LDAP = &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: username,
|
||||||
|
Name: utils.Capitalize(username),
|
||||||
|
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
|
||||||
|
},
|
||||||
|
Groups: user.Groups,
|
||||||
|
}
|
||||||
|
userContext.Provider = model.ProviderLDAP
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.Authenticated = true
|
||||||
|
return userContext, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ContextMiddleware) isIgnorePath(path string) bool {
|
func (m *ContextMiddleware) isIgnorePath(path string) bool {
|
||||||
|
|||||||
@@ -0,0 +1,330 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextMiddleware(t *testing.T) {
|
||||||
|
tlog.NewTestLogger().Init()
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
|
LocalUsers: &[]model.LocalUser{
|
||||||
|
{
|
||||||
|
Username: "testuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Username: "totpuser",
|
||||||
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||||
|
LoginMaxRetries: 3,
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewareCfg := middleware.ContextMiddlewareConfig{
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
SessionCookieName: "tinyauth-session",
|
||||||
|
}
|
||||||
|
|
||||||
|
basicAuthHeader := func(username, password string) string {
|
||||||
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||||
|
}
|
||||||
|
|
||||||
|
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
||||||
|
t.Helper()
|
||||||
|
_, err := queries.CreateSession(context.Background(), params)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type runArgs struct {
|
||||||
|
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||||
|
queries *repository.Queries
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
description string
|
||||||
|
run func(t *testing.T, args runArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Skip path bypasses auth processing",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/healthz", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "No credentials yields no context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Valid session cookie sets authenticated local context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
uuid := "session-valid-local"
|
||||||
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||||
|
UUID: uuid,
|
||||||
|
Username: "testuser",
|
||||||
|
Provider: "local",
|
||||||
|
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
require.NotNil(t, userCtx)
|
||||||
|
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||||
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||||
|
assert.True(t, userCtx.Authenticated)
|
||||||
|
require.NotNil(t, userCtx.Local)
|
||||||
|
assert.False(t, userCtx.Local.TOTPEnabled)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
uuid := "session-totp-pending"
|
||||||
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||||
|
UUID: uuid,
|
||||||
|
Username: "totpuser",
|
||||||
|
Provider: "local",
|
||||||
|
TotpPending: true,
|
||||||
|
Expiry: time.Now().Add(60 * time.Second).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
require.NotNil(t, userCtx)
|
||||||
|
assert.Equal(t, "totpuser", userCtx.GetUsername())
|
||||||
|
assert.False(t, userCtx.Authenticated)
|
||||||
|
require.NotNil(t, userCtx.Local)
|
||||||
|
assert.True(t, userCtx.Local.TOTPPending)
|
||||||
|
assert.True(t, userCtx.Local.TOTPEnabled)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Unknown session cookie yields no context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Session for missing local user yields no context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
uuid := "session-deleted-user"
|
||||||
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||||
|
UUID: uuid,
|
||||||
|
Username: "ghostuser",
|
||||||
|
Provider: "local",
|
||||||
|
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Expired session cookie yields no context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
uuid := "session-expired"
|
||||||
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||||
|
UUID: uuid,
|
||||||
|
Username: "testuser",
|
||||||
|
Provider: "local",
|
||||||
|
Expiry: time.Now().Add(-1 * time.Second).Unix(),
|
||||||
|
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Valid basic auth sets authenticated local context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
require.NotNil(t, userCtx)
|
||||||
|
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||||
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||||
|
assert.True(t, userCtx.Authenticated)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Invalid basic auth password yields no context",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Basic auth is rejected for users with totp",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Locked account on basic auth sets lock headers",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
for range 3 {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||||
|
args.do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||||
|
userCtx, recorder := args.do(req)
|
||||||
|
|
||||||
|
assert.Nil(t, userCtx)
|
||||||
|
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
|
||||||
|
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Cookie auth takes precedence over basic auth",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
uuid := "session-precedence"
|
||||||
|
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||||
|
UUID: uuid,
|
||||||
|
Username: "testuser",
|
||||||
|
Provider: "local",
|
||||||
|
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
require.NotNil(t, userCtx)
|
||||||
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||||
|
assert.True(t, userCtx.Authenticated)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure fallback to basic auth when cookie is missing",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
require.NotNil(t, userCtx)
|
||||||
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||||
|
assert.True(t, userCtx.Authenticated)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||||
|
|
||||||
|
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||||
|
|
||||||
|
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
queries := repository.New(db)
|
||||||
|
|
||||||
|
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||||
|
err = ldap.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||||
|
err = broker.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||||
|
err = authService.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
||||||
|
err = contextMiddleware.Init()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
authService.ClearRateLimitsTestingOnly()
|
||||||
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
|
||||||
|
var captured *model.UserContext
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(contextMiddleware.Middleware())
|
||||||
|
handler := func(c *gin.Context) {
|
||||||
|
if val, exists := c.Get("context"); exists {
|
||||||
|
captured, _ = val.(*model.UserContext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
router.GET("/api/test", handler)
|
||||||
|
router.GET("/api/healthz", handler)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
return captured, recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
test.run(t, runArgs{do: do, queries: queries})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = db.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package model
|
||||||
|
|
||||||
// Default configuration
|
// Default configuration
|
||||||
func NewDefaultConfiguration() *Config {
|
func NewDefaultConfiguration() *Config {
|
||||||
@@ -29,7 +29,7 @@ func NewDefaultConfiguration() *Config {
|
|||||||
BackgroundImage: "/background.jpg",
|
BackgroundImage: "/background.jpg",
|
||||||
WarningsEnabled: true,
|
WarningsEnabled: true,
|
||||||
},
|
},
|
||||||
Ldap: LdapConfig{
|
LDAP: LDAPConfig{
|
||||||
Insecure: false,
|
Insecure: false,
|
||||||
SearchFilter: "(uid=%s)",
|
SearchFilter: "(uid=%s)",
|
||||||
GroupCacheTTL: 900, // 15 minutes
|
GroupCacheTTL: 900, // 15 minutes
|
||||||
@@ -63,20 +63,6 @@ func NewDefaultConfiguration() *Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version information, set at build time
|
|
||||||
|
|
||||||
var Version = "development"
|
|
||||||
var CommitHash = "development"
|
|
||||||
var BuildTimestamp = "0000-00-00T00:00:00Z"
|
|
||||||
|
|
||||||
// Cookie name templates
|
|
||||||
|
|
||||||
var SessionCookieName = "tinyauth-session"
|
|
||||||
var CSRFCookieName = "tinyauth-csrf"
|
|
||||||
var RedirectCookieName = "tinyauth-redirect"
|
|
||||||
var OAuthSessionCookieName = "tinyauth-oauth"
|
|
||||||
|
|
||||||
// Main app config
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
|
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
|
||||||
Database DatabaseConfig `description:"Database configuration." yaml:"database"`
|
Database DatabaseConfig `description:"Database configuration." yaml:"database"`
|
||||||
@@ -88,7 +74,7 @@ type Config struct {
|
|||||||
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
|
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
|
||||||
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
|
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
|
||||||
UI UIConfig `description:"UI customization." yaml:"ui"`
|
UI UIConfig `description:"UI customization." yaml:"ui"`
|
||||||
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
|
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"`
|
||||||
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
||||||
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
|
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
|
||||||
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
||||||
@@ -177,7 +163,7 @@ type UIConfig struct {
|
|||||||
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
|
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LdapConfig struct {
|
type LDAPConfig struct {
|
||||||
Address string `description:"LDAP server address." yaml:"address"`
|
Address string `description:"LDAP server address." yaml:"address"`
|
||||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||||
@@ -210,20 +196,6 @@ type ExperimentalConfig struct {
|
|||||||
ConfigFile string `description:"Path to config file." yaml:"-"`
|
ConfigFile string `description:"Path to config file." yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config loader options
|
|
||||||
|
|
||||||
const DefaultNamePrefix = "TINYAUTH_"
|
|
||||||
|
|
||||||
// OAuth/OIDC config
|
|
||||||
|
|
||||||
type Claims struct {
|
|
||||||
Sub string `json:"sub"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
PreferredUsername string `json:"preferred_username"`
|
|
||||||
Groups any `json:"groups"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OAuthServiceConfig struct {
|
type OAuthServiceConfig struct {
|
||||||
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
||||||
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
||||||
@@ -246,60 +218,6 @@ type OIDCClientConfig struct {
|
|||||||
Name string `description:"Client name in UI." yaml:"name"`
|
Name string `description:"Client name in UI." yaml:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var OverrideProviders = map[string]string{
|
|
||||||
"google": "Google",
|
|
||||||
"github": "GitHub",
|
|
||||||
}
|
|
||||||
|
|
||||||
// User/session related stuff
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
TotpSecret string
|
|
||||||
Attributes UserAttributes
|
|
||||||
}
|
|
||||||
|
|
||||||
type LdapUser struct {
|
|
||||||
DN string
|
|
||||||
Groups []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserSearch struct {
|
|
||||||
Username string
|
|
||||||
Type string // local, ldap or unknown
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserContext struct {
|
|
||||||
Username string
|
|
||||||
Name string
|
|
||||||
Email string
|
|
||||||
IsLoggedIn bool
|
|
||||||
IsBasicAuth bool
|
|
||||||
OAuth bool
|
|
||||||
Provider string
|
|
||||||
TotpPending bool
|
|
||||||
OAuthGroups string
|
|
||||||
TotpEnabled bool
|
|
||||||
OAuthName string
|
|
||||||
OAuthSub string
|
|
||||||
LdapGroups string
|
|
||||||
Attributes UserAttributes
|
|
||||||
}
|
|
||||||
|
|
||||||
// API responses and queries
|
|
||||||
|
|
||||||
type UnauthorizedQuery struct {
|
|
||||||
Username string `url:"username"`
|
|
||||||
Resource string `url:"resource"`
|
|
||||||
GroupErr bool `url:"groupErr"`
|
|
||||||
IP string `url:"ip"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RedirectQuery struct {
|
|
||||||
RedirectURI string `url:"redirect_uri"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ACLs
|
// ACLs
|
||||||
|
|
||||||
type Apps struct {
|
type Apps struct {
|
||||||
@@ -355,7 +273,3 @@ type AppPath struct {
|
|||||||
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
|
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
|
||||||
Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
|
Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// API server
|
|
||||||
|
|
||||||
var ApiServer = "https://api.tinyauth.app"
|
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
const DefaultNamePrefix = "TINYAUTH_"
|
||||||
|
|
||||||
|
const APIServer = "https://api.tinyauth.app"
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
PreferredUsername string `json:"preferred_username"`
|
||||||
|
Groups any `json:"groups"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var OverrideProviders = map[string]string{
|
||||||
|
"google": "Google",
|
||||||
|
"github": "GitHub",
|
||||||
|
}
|
||||||
|
|
||||||
|
const SessionCookieName = "tinyauth-session"
|
||||||
|
const CSRFCookieName = "tinyauth-csrf"
|
||||||
|
const RedirectCookieName = "tinyauth-redirect"
|
||||||
|
const OAuthSessionCookieName = "tinyauth-oauth"
|
||||||
@@ -0,0 +1,251 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProviderType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProviderLocal ProviderType = iota
|
||||||
|
ProviderBasicAuth
|
||||||
|
ProviderOAuth
|
||||||
|
ProviderLDAP
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserContext struct {
|
||||||
|
Authenticated bool
|
||||||
|
Provider ProviderType
|
||||||
|
Local *LocalContext
|
||||||
|
OAuth *OAuthContext
|
||||||
|
LDAP *LDAPContext
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseContext struct {
|
||||||
|
Username string
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalContext struct {
|
||||||
|
BaseContext
|
||||||
|
TOTPPending bool
|
||||||
|
TOTPEnabled bool
|
||||||
|
Attributes UserAttributes
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthContext struct {
|
||||||
|
BaseContext
|
||||||
|
Groups []string
|
||||||
|
Sub string
|
||||||
|
DisplayName string
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LDAPContext struct {
|
||||||
|
BaseContext
|
||||||
|
Groups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsAuthenticated() bool {
|
||||||
|
return c.Authenticated
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsLocal() bool {
|
||||||
|
return c.Provider == ProviderLocal && c.Local != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsOAuth() bool {
|
||||||
|
return c.Provider == ProviderOAuth && c.OAuth != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsLDAP() bool {
|
||||||
|
return c.Provider == ProviderLDAP && c.LDAP != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsBasicAuth() bool {
|
||||||
|
return c.Provider == ProviderBasicAuth && c.Local != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||||
|
userContextValue, exists := ginctx.Get("context")
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.New("failed to get user context")
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, ok := userContextValue.(*UserContext)
|
||||||
|
|
||||||
|
if !ok || userContext == nil {
|
||||||
|
return nil, errors.New("invalid user context type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
|
||||||
|
return nil, errors.New("incomplete user context")
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = *userContext
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatability layer until we get an excuse to drop in database migrations
|
||||||
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
|
*c = UserContext{
|
||||||
|
Authenticated: !session.TotpPending,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch session.Provider {
|
||||||
|
case "local":
|
||||||
|
c.Provider = ProviderLocal
|
||||||
|
c.Local = &LocalContext{
|
||||||
|
BaseContext: BaseContext{
|
||||||
|
Username: session.Username,
|
||||||
|
Name: session.Name,
|
||||||
|
Email: session.Email,
|
||||||
|
},
|
||||||
|
TOTPPending: session.TotpPending,
|
||||||
|
}
|
||||||
|
case "ldap":
|
||||||
|
c.Provider = ProviderLDAP
|
||||||
|
c.LDAP = &LDAPContext{
|
||||||
|
BaseContext: BaseContext{
|
||||||
|
Username: session.Username,
|
||||||
|
Name: session.Name,
|
||||||
|
Email: session.Email,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// By default we assume an unkown name which is oauth
|
||||||
|
default:
|
||||||
|
c.Provider = ProviderOAuth
|
||||||
|
c.OAuth = &OAuthContext{
|
||||||
|
BaseContext: BaseContext{
|
||||||
|
Username: session.Username,
|
||||||
|
Name: session.Name,
|
||||||
|
Email: session.Email,
|
||||||
|
},
|
||||||
|
Groups: func() []string {
|
||||||
|
if session.OAuthGroups == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return strings.Split(session.OAuthGroups, ",")
|
||||||
|
}(),
|
||||||
|
Sub: session.OAuthSub,
|
||||||
|
DisplayName: session.OAuthName,
|
||||||
|
ID: session.Provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) GetUsername() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderLocal:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.Local.Username
|
||||||
|
case ProviderLDAP:
|
||||||
|
if c.LDAP == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.LDAP.Username
|
||||||
|
case ProviderBasicAuth:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.Local.Username
|
||||||
|
case ProviderOAuth:
|
||||||
|
if c.OAuth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.OAuth.Username
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) GetEmail() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderLocal:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.Local.Email
|
||||||
|
case ProviderLDAP:
|
||||||
|
if c.LDAP == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.LDAP.Email
|
||||||
|
case ProviderBasicAuth:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.Local.Email
|
||||||
|
case ProviderOAuth:
|
||||||
|
if c.OAuth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.OAuth.Email
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) GetName() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderLocal:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.Local.Name
|
||||||
|
case ProviderLDAP:
|
||||||
|
if c.LDAP == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.LDAP.Name
|
||||||
|
case ProviderBasicAuth:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.Local.Name
|
||||||
|
case ProviderOAuth:
|
||||||
|
if c.OAuth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.OAuth.Name
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) ProviderName() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderBasicAuth, ProviderLocal:
|
||||||
|
return "local"
|
||||||
|
case ProviderLDAP:
|
||||||
|
return "ldap"
|
||||||
|
case ProviderOAuth:
|
||||||
|
return c.OAuth.DisplayName // compatability
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) TOTPPending() bool {
|
||||||
|
if c.Provider == ProviderLocal && c.Local != nil {
|
||||||
|
return c.Local.TOTPPending
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) OAuthName() string {
|
||||||
|
if c.Provider == ProviderOAuth && c.OAuth != nil {
|
||||||
|
return c.OAuth.DisplayName
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,276 @@
|
|||||||
|
package model_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContext(t *testing.T) {
|
||||||
|
newGinCtx := func(value any, set bool) *gin.Context {
|
||||||
|
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
if set {
|
||||||
|
c.Set("context", value)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
description string
|
||||||
|
context *model.UserContext
|
||||||
|
run func(*testing.T, *model.UserContext) any
|
||||||
|
expected any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
|
context: &model.UserContext{Authenticated: true},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
|
Provider: "local",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
|
},
|
||||||
|
expected: [2]any{model.ProviderLocal, true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return got.Authenticated
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
|
Username: "carol", Provider: "ldap",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return got.Provider
|
||||||
|
},
|
||||||
|
expected: model.ProviderLDAP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
|
Username: "dave", Provider: "github",
|
||||||
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
|
},
|
||||||
|
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Local getters return BaseContext fields",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
|
},
|
||||||
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "BasicAuth getters fall back to local fields",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderBasicAuth,
|
||||||
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
|
},
|
||||||
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP getters return LDAP fields",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
|
},
|
||||||
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth getters return OAuth fields",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
|
},
|
||||||
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
|
expected: "local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
|
expected: "local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
|
expected: "ldap",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "ProviderName returns OAuth DisplayName for ProviderOAuth",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
|
expected: "GitHub",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTPPending returns true when local context is pending",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{TOTPPending: true},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTPPending returns false when local context is not pending",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{TOTPPending: false},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTPPending returns false for non-local providers",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||||
|
context: &model.UserContext{
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||||
|
expected: "Google",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromGin populates context from gin value",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
stored := &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||||
|
}
|
||||||
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
|
require.NoError(t, err)
|
||||||
|
return [2]any{got.Authenticated, got.GetUsername()}
|
||||||
|
},
|
||||||
|
expected: [2]any{true, "alice"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromGin returns error when context value is missing",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
|
return err.Error()
|
||||||
|
},
|
||||||
|
expected: "failed to get user context",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
|
return err.Error()
|
||||||
|
},
|
||||||
|
expected: "invalid user context type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||||
|
return err.Error()
|
||||||
|
},
|
||||||
|
expected: "incomplete user context",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Getters should not panic if provider context is empty",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
|
},
|
||||||
|
expected: [3]string{"", "", ""},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
assert.Equal(t, test.expected, test.run(t, test.context))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type UserSearchType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserLocal UserSearchType = iota
|
||||||
|
UserLDAP
|
||||||
|
)
|
||||||
|
|
||||||
|
type LDAPUser struct {
|
||||||
|
DN string
|
||||||
|
Groups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalUser struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
TOTPSecret string
|
||||||
|
Attributes UserAttributes
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserSearch struct {
|
||||||
|
Username string
|
||||||
|
Type UserSearchType
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
var Version = "development"
|
||||||
|
var CommitHash = "development"
|
||||||
|
var BuildTimestamp = "0000-00-00T00:00:00Z"
|
||||||
@@ -1,23 +1,22 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider interface {
|
type LabelProvider interface {
|
||||||
GetLabels(appDomain string) (config.App, error)
|
GetLabels(appDomain string) (*model.App, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
labelProvider LabelProvider
|
labelProvider LabelProvider
|
||||||
static map[string]config.App
|
static map[string]model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService {
|
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
labelProvider: labelProvider,
|
labelProvider: labelProvider,
|
||||||
static: static,
|
static: static,
|
||||||
@@ -28,26 +27,29 @@ func (acls *AccessControlsService) Init() error {
|
|||||||
return nil // No initialization needed
|
return nil // No initialization needed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
|
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||||
|
var appAcls *model.App
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
return config, nil
|
appAcls = &config
|
||||||
|
break // If we find a match by domain, we can stop searching
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
return config, nil
|
appAcls = &config
|
||||||
|
break // If we find a match by app name, we can stop searching
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return config.App{}, errors.New("no results")
|
return appAcls
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
|
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||||
// First check in the static config
|
// First check in the static config
|
||||||
app, err := acls.lookupStaticACLs(domain)
|
app := acls.lookupStaticACLs(domain)
|
||||||
|
|
||||||
if err == nil {
|
if app != nil {
|
||||||
tlog.App.Debug().Msg("Using ACls from static configuration")
|
tlog.App.Debug().Msg("Using ACls from static configuration")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+161
-156
@@ -5,12 +5,13 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -29,6 +30,10 @@ const MaxOAuthPendingSessions = 256
|
|||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
const MaxLoginAttemptRecords = 256
|
const MaxLoginAttemptRecords = 256
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUserNotFound = errors.New("user not found")
|
||||||
|
)
|
||||||
|
|
||||||
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
|
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
|
||||||
// parameters and pass them to the authorize page if needed
|
// parameters and pass them to the authorize page if needed
|
||||||
type OAuthURLParams struct {
|
type OAuthURLParams struct {
|
||||||
@@ -68,7 +73,7 @@ type Lockdown struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceConfig struct {
|
type AuthServiceConfig struct {
|
||||||
Users []config.User
|
LocalUsers *[]model.LocalUser
|
||||||
OauthWhitelist []string
|
OauthWhitelist []string
|
||||||
SessionExpiry int
|
SessionExpiry int
|
||||||
SessionMaxLifetime int
|
SessionMaxLifetime int
|
||||||
@@ -77,7 +82,7 @@ type AuthServiceConfig struct {
|
|||||||
LoginTimeout int
|
LoginTimeout int
|
||||||
LoginMaxRetries int
|
LoginMaxRetries int
|
||||||
SessionCookieName string
|
SessionCookieName string
|
||||||
IP config.IPConfig
|
IP model.IPConfig
|
||||||
LDAPGroupsCacheTTL int
|
LDAPGroupsCacheTTL int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +111,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
|
|||||||
ldap: ldap,
|
ldap: ldap,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) Init() error {
|
func (auth *AuthService) Init() error {
|
||||||
@@ -114,79 +119,73 @@ func (auth *AuthService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) SearchUser(username string) config.UserSearch {
|
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||||
if auth.GetLocalUser(username).Username != "" {
|
if auth.GetLocalUser(username) != nil {
|
||||||
return config.UserSearch{
|
return &model.UserSearch{
|
||||||
Username: username,
|
Username: username,
|
||||||
Type: "local",
|
Type: model.UserLocal,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.ldap.IsConfigured() {
|
if auth.ldap.IsConfigured() {
|
||||||
userDN, err := auth.ldap.GetUserDN(username)
|
userDN, err := auth.ldap.GetUserDN(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
|
return nil, fmt.Errorf("failed to get ldap user: %w", err)
|
||||||
return config.UserSearch{
|
|
||||||
Type: "unknown",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return config.UserSearch{
|
return &model.UserSearch{
|
||||||
Username: userDN,
|
Username: userDN,
|
||||||
Type: "ldap",
|
Type: model.UserLDAP,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return config.UserSearch{
|
return nil, ErrUserNotFound
|
||||||
Type: "unknown",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
|
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
|
||||||
switch search.Type {
|
switch search.Type {
|
||||||
case "local":
|
case model.UserLocal:
|
||||||
user := auth.GetLocalUser(search.Username)
|
user := auth.GetLocalUser(search.Username)
|
||||||
return auth.CheckPassword(user, password)
|
if user == nil {
|
||||||
case "ldap":
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
|
case model.UserLDAP:
|
||||||
if auth.ldap.IsConfigured() {
|
if auth.ldap.IsConfigured() {
|
||||||
err := auth.ldap.Bind(search.Username, password)
|
err := auth.ldap.Bind(search.Username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = auth.ldap.BindService(true)
|
err = auth.ldap.BindService(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
|
return fmt.Errorf("failed to bind to ldap service account: %w", err)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
|
return errors.New("unknown user search type")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
return errors.New("user authentication failed")
|
||||||
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) config.User {
|
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||||
for _, user := range auth.config.Users {
|
if auth.config.LocalUsers == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, user := range *auth.config.LocalUsers {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return user
|
return &user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
tlog.App.Warn().Str("username", username).Msg("Local user not found")
|
|
||||||
return config.User{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
||||||
if !auth.ldap.IsConfigured() {
|
if !auth.ldap.IsConfigured() {
|
||||||
return config.LdapUser{}, errors.New("LDAP service not initialized")
|
return nil, errors.New("ldap service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.ldapGroupsMutex.RLock()
|
auth.ldapGroupsMutex.RLock()
|
||||||
@@ -194,7 +193,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|||||||
auth.ldapGroupsMutex.RUnlock()
|
auth.ldapGroupsMutex.RUnlock()
|
||||||
|
|
||||||
if exists && time.Now().Before(entry.Expires) {
|
if exists && time.Now().Before(entry.Expires) {
|
||||||
return config.LdapUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
Groups: entry.Groups,
|
Groups: entry.Groups,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -203,7 +202,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|||||||
groups, err := auth.ldap.GetUserGroups(userDN)
|
groups, err := auth.ldap.GetUserGroups(userDN)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.LdapUser{}, err
|
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.ldapGroupsMutex.Lock()
|
auth.ldapGroupsMutex.Lock()
|
||||||
@@ -213,16 +212,12 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|||||||
}
|
}
|
||||||
auth.ldapGroupsMutex.Unlock()
|
auth.ldapGroupsMutex.Unlock()
|
||||||
|
|
||||||
return config.LdapUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||||
auth.loginMutex.RLock()
|
auth.loginMutex.RLock()
|
||||||
defer auth.loginMutex.RUnlock()
|
defer auth.loginMutex.RUnlock()
|
||||||
@@ -291,11 +286,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
|||||||
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
uuid, err := uuid.NewRandom()
|
uuid, err := uuid.NewRandom()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to generate session uuid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var expiry int
|
var expiry int
|
||||||
@@ -306,6 +301,8 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
|||||||
expiry = auth.config.SessionExpiry
|
expiry = auth.config.SessionExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||||
|
|
||||||
session := repository.CreateSessionParams{
|
session := repository.CreateSessionParams{
|
||||||
UUID: uuid.String(),
|
UUID: uuid.String(),
|
||||||
Username: data.Username,
|
Username: data.Username,
|
||||||
@@ -314,34 +311,36 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
|||||||
Provider: data.Provider,
|
Provider: data.Provider,
|
||||||
TotpPending: data.TotpPending,
|
TotpPending: data.TotpPending,
|
||||||
OAuthGroups: data.OAuthGroups,
|
OAuthGroups: data.OAuthGroups,
|
||||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
Expiry: expiresAt.Unix(),
|
||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
OAuthName: data.OAuthName,
|
OAuthName: data.OAuthName,
|
||||||
OAuthSub: data.OAuthSub,
|
OAuthSub: data.OAuthSub,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = auth.queries.CreateSession(c, session)
|
_, err = auth.queries.CreateSession(ctx, session)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
return &http.Cookie{
|
||||||
|
Name: auth.config.SessionCookieName,
|
||||||
return nil
|
Value: session.UUID,
|
||||||
|
Path: "/",
|
||||||
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
|
Expires: expiresAt,
|
||||||
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
|
Secure: auth.config.SecureCookie,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
session, err := auth.queries.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
session, err := auth.queries.GetSession(c, cookie)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
@@ -355,12 +354,12 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if session.Expiry-currentTime > refreshThreshold {
|
if session.Expiry-currentTime > refreshThreshold {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
newExpiry := session.Expiry + refreshThreshold
|
newExpiry := session.Expiry + refreshThreshold
|
||||||
|
|
||||||
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
|
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||||
Username: session.Username,
|
Username: session.Username,
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
Name: session.Name,
|
Name: session.Name,
|
||||||
@@ -374,122 +373,123 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
return &http.Cookie{
|
||||||
tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
|
Name: auth.config.SessionCookieName,
|
||||||
|
Value: session.UUID,
|
||||||
|
Path: "/",
|
||||||
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
|
MaxAge: int(newExpiry - currentTime),
|
||||||
|
Secure: auth.config.SecureCookie,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
err := auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = auth.queries.DeleteSession(c, cookie)
|
return &http.Cookie{
|
||||||
|
Name: auth.config.SessionCookieName,
|
||||||
if err != nil {
|
Value: "",
|
||||||
return err
|
Path: "/",
|
||||||
}
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
|
Expires: time.Now(),
|
||||||
c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
MaxAge: -1,
|
||||||
|
Secure: auth.config.SecureCookie,
|
||||||
return nil
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
|
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
session, err := auth.queries.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return repository.Session{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := auth.queries.GetSession(c, cookie)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return repository.Session{}, fmt.Errorf("session not found")
|
return nil, errors.New("session not found")
|
||||||
}
|
}
|
||||||
return repository.Session{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||||
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
||||||
err = auth.queries.DeleteSession(c, cookie)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
}
|
}
|
||||||
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
|
return nil, fmt.Errorf("session max lifetime exceeded")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentTime > session.Expiry {
|
if currentTime > session.Expiry {
|
||||||
err = auth.queries.DeleteSession(c, cookie)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
}
|
}
|
||||||
return repository.Session{}, fmt.Errorf("session expired")
|
return nil, fmt.Errorf("session expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
return repository.Session{
|
return &session, nil
|
||||||
UUID: session.UUID,
|
|
||||||
Username: session.Username,
|
|
||||||
Email: session.Email,
|
|
||||||
Name: session.Name,
|
|
||||||
Provider: session.Provider,
|
|
||||||
TotpPending: session.TotpPending,
|
|
||||||
OAuthGroups: session.OAuthGroups,
|
|
||||||
OAuthName: session.OAuthName,
|
|
||||||
OAuthSub: session.OAuthSub,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||||
return len(auth.config.Users) > 0
|
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LdapAuthConfigured() bool {
|
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||||
return auth.ldap.IsConfigured()
|
return auth.ldap.IsConfigured()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
|
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
if context.OAuth {
|
if acls == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if context.Provider == model.ProviderOAuth {
|
||||||
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
||||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
|
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if acls.Users.Block != "" {
|
if acls.Users.Block != "" {
|
||||||
tlog.App.Debug().Msg("Checking blocked users")
|
tlog.App.Debug().Msg("Checking blocked users")
|
||||||
if utils.CheckFilter(acls.Users.Block, context.Username) {
|
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("Checking users")
|
tlog.App.Debug().Msg("Checking users")
|
||||||
return utils.CheckFilter(acls.Users.Allow, context.Username)
|
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
|
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
if requiredGroups == "" {
|
if acls == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for id := range config.OverrideProviders {
|
if !context.IsOAuth() {
|
||||||
if context.Provider == id {
|
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||||
tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
|
return false
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
|
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userGroup := range context.OAuth.Groups {
|
||||||
|
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||||
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -498,14 +498,19 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
|
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
if requiredGroups == "" {
|
if acls == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
|
if !context.IsLDAP() {
|
||||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userGroup := range context.LDAP.Groups {
|
||||||
|
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||||
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -514,10 +519,14 @@ func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContex
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
|
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
||||||
|
if acls == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Check for block list
|
// Check for block list
|
||||||
if path.Block != "" {
|
if acls.Path.Block != "" {
|
||||||
regex, err := regexp.Compile(path.Block)
|
regex, err := regexp.Compile(acls.Path.Block)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
@@ -529,8 +538,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for allow list
|
// Check for allow list
|
||||||
if path.Allow != "" {
|
if acls.Path.Allow != "" {
|
||||||
regex, err := regexp.Compile(path.Allow)
|
regex, err := regexp.Compile(acls.Path.Allow)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
@@ -544,22 +553,14 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
|
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||||
username, password, ok := c.Request.BasicAuth()
|
if acls == nil {
|
||||||
if !ok {
|
return true
|
||||||
tlog.App.Debug().Msg("No basic auth provided")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return &config.User{
|
|
||||||
Username: username,
|
|
||||||
Password: password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
|
|
||||||
// Merge the global and app IP filter
|
// Merge the global and app IP filter
|
||||||
blockedIps := append(auth.config.IP.Block, acls.Block...)
|
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
||||||
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
|
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
||||||
|
|
||||||
for _, blocked := range blockedIps {
|
for _, blocked := range blockedIps {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
@@ -594,8 +595,12 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||||
for _, bypassed := range acls.Bypass {
|
if acls == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bypassed := range acls.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
@@ -674,21 +679,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
|
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
|
||||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.Token == nil {
|
if session.Token == nil {
|
||||||
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
|
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return userinfo, nil
|
return userinfo, nil
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
@@ -51,56 +51,48 @@ func (docker *DockerService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
||||||
containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
|
return docker.client.ContainerList(docker.context, container.ListOptions{})
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return containers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
|
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
|
||||||
inspect, err := docker.client.ContainerInspect(docker.context, containerId)
|
return docker.client.ContainerInspect(docker.context, containerId)
|
||||||
if err != nil {
|
|
||||||
return container.InspectResponse{}, err
|
|
||||||
}
|
|
||||||
return inspect, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
|
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !docker.isConnected {
|
if !docker.isConnected {
|
||||||
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
containers, err := docker.getContainers()
|
containers, err := docker.getContainers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.App{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ctr := range containers {
|
for _, ctr := range containers {
|
||||||
inspect, err := docker.inspectContainer(ctr.ID)
|
inspect, err := docker.inspectContainer(ctr.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.App{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.App{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for appName, appLabels := range labels.Apps {
|
for appName, appLabels := range labels.Apps {
|
||||||
if appLabels.Config.Domain == appDomain {
|
if appLabels.Config.Domain == appDomain {
|
||||||
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||||
return appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||||
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||||
return appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ type ingressAppKey struct {
|
|||||||
type ingressApp struct {
|
type ingressApp struct {
|
||||||
domain string
|
domain string
|
||||||
appName string
|
appName string
|
||||||
app config.App
|
app model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
@@ -89,36 +89,38 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) getByDomain(domain string) (config.App, bool) {
|
func (k *KubernetesService) getByDomain(domain string) *model.App {
|
||||||
k.mu.RLock()
|
k.mu.RLock()
|
||||||
defer k.mu.RUnlock()
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
if appKey, ok := k.domainIndex[domain]; ok {
|
if appKey, ok := k.domainIndex[domain]; ok {
|
||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for i := range apps {
|
||||||
|
app := &apps[i]
|
||||||
if app.domain == domain && app.appName == appKey.appName {
|
if app.domain == domain && app.appName == appKey.appName {
|
||||||
return app.app, true
|
return &app.app
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return config.App{}, false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) getByAppName(appName string) (config.App, bool) {
|
func (k *KubernetesService) getByAppName(appName string) *model.App {
|
||||||
k.mu.RLock()
|
k.mu.RLock()
|
||||||
defer k.mu.RUnlock()
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
if appKey, ok := k.appNameIndex[appName]; ok {
|
if appKey, ok := k.appNameIndex[appName]; ok {
|
||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for i := range apps {
|
||||||
|
app := &apps[i]
|
||||||
if app.appName == appName {
|
if app.appName == appName {
|
||||||
return app.app, true
|
return &app.app
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return config.App{}, false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
||||||
@@ -129,7 +131,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
|||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
@@ -280,24 +282,25 @@ func (k *KubernetesService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) {
|
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !k.started {
|
if !k.started {
|
||||||
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// First check cache
|
// First check cache
|
||||||
if app, found := k.getByDomain(appDomain); found {
|
app := k.getByDomain(appDomain)
|
||||||
|
if app != nil {
|
||||||
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
appName := strings.SplitN(appDomain, ".", 2)[0]
|
appName := strings.SplitN(appDomain, ".", 2)[0]
|
||||||
if app, found := k.getByAppName(appName); found {
|
app = k.getByAppName(appName)
|
||||||
|
if app != nil {
|
||||||
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,11 @@ package service
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestKubernetesService(t *testing.T) {
|
func TestKubernetesService(t *testing.T) {
|
||||||
@@ -20,69 +20,69 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "Cache by domain returns app and misses unknown domain",
|
description: "Cache by domain returns app and misses unknown domain",
|
||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
app := config.App{Config: config.AppConfig{Domain: "foo.example.com"}}
|
app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}}
|
||||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||||
{domain: "foo.example.com", appName: "foo", app: app},
|
{domain: "foo.example.com", appName: "foo", app: app},
|
||||||
})
|
})
|
||||||
|
|
||||||
got, ok := svc.getByDomain("foo.example.com")
|
got := svc.getByDomain("foo.example.com")
|
||||||
require.True(t, ok)
|
require.NotNil(t, got)
|
||||||
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
||||||
|
|
||||||
_, ok = svc.getByDomain("notfound.example.com")
|
got = svc.getByDomain("notfound.example.com")
|
||||||
assert.False(t, ok)
|
assert.Nil(t, got)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Cache by app name returns app and misses unknown name",
|
description: "Cache by app name returns app and misses unknown name",
|
||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
app := config.App{Config: config.AppConfig{Domain: "bar.example.com"}}
|
app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}}
|
||||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||||
{domain: "bar.example.com", appName: "bar", app: app},
|
{domain: "bar.example.com", appName: "bar", app: app},
|
||||||
})
|
})
|
||||||
|
|
||||||
got, ok := svc.getByAppName("bar")
|
got := svc.getByAppName("bar")
|
||||||
require.True(t, ok)
|
require.NotNil(t, got)
|
||||||
assert.Equal(t, "bar.example.com", got.Config.Domain)
|
assert.Equal(t, "bar.example.com", got.Config.Domain)
|
||||||
|
|
||||||
_, ok = svc.getByAppName("notfound")
|
got = svc.getByAppName("notfound")
|
||||||
assert.False(t, ok)
|
assert.Nil(t, got)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "RemoveIngress clears domain and app name entries",
|
description: "RemoveIngress clears domain and app name entries",
|
||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
app := config.App{Config: config.AppConfig{Domain: "baz.example.com"}}
|
app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}}
|
||||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||||
{domain: "baz.example.com", appName: "baz", app: app},
|
{domain: "baz.example.com", appName: "baz", app: app},
|
||||||
})
|
})
|
||||||
|
|
||||||
svc.removeIngress("default", "my-ingress")
|
svc.removeIngress("default", "my-ingress")
|
||||||
|
|
||||||
_, ok := svc.getByDomain("baz.example.com")
|
got := svc.getByDomain("baz.example.com")
|
||||||
assert.False(t, ok)
|
assert.Nil(t, got)
|
||||||
_, ok = svc.getByAppName("baz")
|
got = svc.getByAppName("baz")
|
||||||
assert.False(t, ok)
|
assert.Nil(t, got)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "AddIngressApps replaces stale entries for the same ingress",
|
description: "AddIngressApps replaces stale entries for the same ingress",
|
||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
old := config.App{Config: config.AppConfig{Domain: "old.example.com"}}
|
old := model.App{Config: model.AppConfig{Domain: "old.example.com"}}
|
||||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||||
{domain: "old.example.com", appName: "old", app: old},
|
{domain: "old.example.com", appName: "old", app: old},
|
||||||
})
|
})
|
||||||
|
|
||||||
updated := config.App{Config: config.AppConfig{Domain: "new.example.com"}}
|
updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}}
|
||||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||||
{domain: "new.example.com", appName: "new", app: updated},
|
{domain: "new.example.com", appName: "new", app: updated},
|
||||||
})
|
})
|
||||||
|
|
||||||
_, ok := svc.getByDomain("old.example.com")
|
got := svc.getByDomain("old.example.com")
|
||||||
assert.False(t, ok)
|
assert.Nil(t, got)
|
||||||
|
|
||||||
got, ok := svc.getByDomain("new.example.com")
|
got = svc.getByDomain("new.example.com")
|
||||||
require.True(t, ok)
|
require.NotNil(t, got)
|
||||||
assert.Equal(t, "new.example.com", got.Config.Domain)
|
assert.Equal(t, "new.example.com", got.Config.Domain)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -91,7 +91,7 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
svc.started = true
|
svc.started = true
|
||||||
|
|
||||||
app := config.App{Config: config.AppConfig{Domain: "hit.example.com"}}
|
app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}}
|
||||||
svc.addIngressApps("default", "ing", []ingressApp{
|
svc.addIngressApps("default", "ing", []ingressApp{
|
||||||
{domain: "hit.example.com", appName: "hit", app: app},
|
{domain: "hit.example.com", appName: "hit", app: app},
|
||||||
})
|
})
|
||||||
@@ -108,7 +108,7 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
|
|
||||||
got, err := svc.GetLabels("notfound.example.com")
|
got, err := svc.GetLabels("notfound.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, config.App{}, got)
|
assert.Nil(t, got)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -116,7 +116,7 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
svc.started = true
|
svc.started = true
|
||||||
|
|
||||||
app := config.App{Config: config.AppConfig{Domain: "myapp.internal.example.com"}}
|
app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}}
|
||||||
svc.addIngressApps("default", "ing", []ingressApp{
|
svc.addIngressApps("default", "ing", []ingressApp{
|
||||||
{domain: "myapp.internal.example.com", appName: "myapp", app: app},
|
{domain: "myapp.internal.example.com", appName: "myapp", app: app},
|
||||||
})
|
})
|
||||||
@@ -131,7 +131,7 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
got, err := svc.GetLabels("anything.example.com")
|
got, err := svc.GetLabels("anything.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, config.App{}, got)
|
assert.Nil(t, got)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -147,8 +147,8 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
|
|
||||||
svc.updateFromItem(&item)
|
svc.updateFromItem(&item)
|
||||||
|
|
||||||
got, ok := svc.getByDomain("myapp.example.com")
|
got := svc.getByDomain("myapp.example.com")
|
||||||
require.True(t, ok)
|
require.NotNil(t, got)
|
||||||
assert.Equal(t, "myapp.example.com", got.Config.Domain)
|
assert.Equal(t, "myapp.example.com", got.Config.Domain)
|
||||||
assert.Equal(t, "alice", got.Users.Allow)
|
assert.Equal(t, "alice", got.Users.Allow)
|
||||||
},
|
},
|
||||||
@@ -156,7 +156,7 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "UpdateFromItem with no annotations removes existing cache entries",
|
description: "UpdateFromItem with no annotations removes existing cache entries",
|
||||||
run: func(t *testing.T, svc *KubernetesService) {
|
run: func(t *testing.T, svc *KubernetesService) {
|
||||||
app := config.App{Config: config.AppConfig{Domain: "todelete.example.com"}}
|
app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}}
|
||||||
svc.addIngressApps("default", "test-ingress", []ingressApp{
|
svc.addIngressApps("default", "test-ingress", []ingressApp{
|
||||||
{domain: "todelete.example.com", appName: "todelete", app: app},
|
{domain: "todelete.example.com", appName: "todelete", app: app},
|
||||||
})
|
})
|
||||||
@@ -167,8 +167,8 @@ func TestKubernetesService(t *testing.T) {
|
|||||||
|
|
||||||
svc.updateFromItem(&item)
|
svc.updateFromItem(&item)
|
||||||
|
|
||||||
_, ok := svc.getByDomain("todelete.example.com")
|
got := svc.getByDomain("todelete.example.com")
|
||||||
assert.False(t, ok)
|
assert.Nil(t, got)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
@@ -15,20 +15,20 @@ type OAuthServiceImpl interface {
|
|||||||
NewRandom() string
|
NewRandom() string
|
||||||
GetAuthURL(state string, verifier string) string
|
GetAuthURL(state string, verifier string) string
|
||||||
GetToken(code string, verifier string) (*oauth2.Token, error)
|
GetToken(code string, verifier string) (*oauth2.Token, error)
|
||||||
GetUserinfo(token *oauth2.Token) (config.Claims, error)
|
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerService struct {
|
type OAuthBrokerService struct {
|
||||||
services map[string]OAuthServiceImpl
|
services map[string]OAuthServiceImpl
|
||||||
configs map[string]config.OAuthServiceConfig
|
configs map[string]model.OAuthServiceConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
|
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
|
||||||
"github": newGitHubOAuthService,
|
"github": newGitHubOAuthService,
|
||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
|
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
|
||||||
return &OAuthBrokerService{
|
return &OAuthBrokerService{
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]OAuthServiceImpl),
|
||||||
configs: configs,
|
configs: configs,
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GithubEmailResponse []struct {
|
type GithubEmailResponse []struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
|
Verified bool `json:"verified"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GithubUserInfoResponse struct {
|
type GithubUserInfoResponse struct {
|
||||||
@@ -22,32 +23,32 @@ type GithubUserInfoResponse struct {
|
|||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
|
func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||||
return simpleReq[config.Claims](client, url, nil)
|
return simpleReq[model.Claims](client, url, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
func githubExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||||
var user config.Claims
|
var user model.Claims
|
||||||
|
|
||||||
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
|
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(userEmails) == 0 {
|
if len(*userEmails) == 0 {
|
||||||
return user, errors.New("no emails found")
|
return nil, errors.New("no emails found")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, email := range userEmails {
|
for _, email := range *userEmails {
|
||||||
if email.Primary {
|
if email.Primary {
|
||||||
user.Email = email.Email
|
user.Email = email.Email
|
||||||
break
|
break
|
||||||
@@ -56,22 +57,31 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
|||||||
|
|
||||||
// Use first available email if no primary email was found
|
// Use first available email if no primary email was found
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
user.Email = userEmails[0].Email
|
for _, email := range *userEmails {
|
||||||
|
if email.Verified {
|
||||||
|
user.Email = email.Email
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Email == "" {
|
||||||
|
return nil, errors.New("no verified email found")
|
||||||
}
|
}
|
||||||
|
|
||||||
user.PreferredUsername = userInfo.Login
|
user.PreferredUsername = userInfo.Login
|
||||||
user.Name = userInfo.Name
|
user.Name = userInfo.Name
|
||||||
user.Sub = strconv.Itoa(userInfo.ID)
|
user.Sub = strconv.Itoa(userInfo.ID)
|
||||||
|
|
||||||
return user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
|
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) {
|
||||||
var decodedRes T
|
var decodedRes T
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
@@ -80,23 +90,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
|
|||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
|
return nil, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(body, &decodedRes)
|
err = json.Unmarshal(body, &decodedRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return decodedRes, nil
|
return &decodedRes, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"golang.org/x/oauth2/endpoints"
|
"golang.org/x/oauth2/endpoints"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
|
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"openid", "email", "profile"}
|
scopes := []string{"openid", "email", "profile"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.Google.AuthURL
|
config.AuthURL = endpoints.Google.AuthURL
|
||||||
@@ -14,7 +14,7 @@ func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
|
|||||||
return NewOAuthService(config, "google")
|
return NewOAuthService(config, "google")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
|
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"read:user", "user:email"}
|
scopes := []string{"read:user", "user:email"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.GitHub.AuthURL
|
config.AuthURL = endpoints.GitHub.AuthURL
|
||||||
|
|||||||
@@ -6,21 +6,21 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
|
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
||||||
|
|
||||||
type OAuthService struct {
|
type OAuthService struct {
|
||||||
serviceCfg config.OAuthServiceConfig
|
serviceCfg model.OAuthServiceConfig
|
||||||
config *oauth2.Config
|
config *oauth2.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
userinfoExtractor UserinfoExtractor
|
userinfoExtractor UserinfoExtractor
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
|
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
@@ -78,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
|
|||||||
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
|
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
|
||||||
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
||||||
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -87,7 +87,7 @@ type UserinfoResponse struct {
|
|||||||
EmailVerified bool `json:"email_verified,omitempty"`
|
EmailVerified bool `json:"email_verified,omitempty"`
|
||||||
PhoneNumber string `json:"phone_number,omitempty"`
|
PhoneNumber string `json:"phone_number,omitempty"`
|
||||||
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
|
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
|
||||||
Address *config.AddressClaim `json:"address,omitempty"`
|
Address *model.AddressClaim `json:"address,omitempty"`
|
||||||
UpdatedAt int64 `json:"updated_at"`
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ type AuthorizeRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OIDCServiceConfig struct {
|
type OIDCServiceConfig struct {
|
||||||
Clients map[string]config.OIDCClientConfig
|
Clients map[string]model.OIDCClientConfig
|
||||||
PrivateKeyPath string
|
PrivateKeyPath string
|
||||||
PublicKeyPath string
|
PublicKeyPath string
|
||||||
Issuer string
|
Issuer string
|
||||||
@@ -122,7 +122,7 @@ type OIDCServiceConfig struct {
|
|||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
config OIDCServiceConfig
|
config OIDCServiceConfig
|
||||||
queries *repository.Queries
|
queries *repository.Queries
|
||||||
clients map[string]config.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
publicKey crypto.PublicKey
|
publicKey crypto.PublicKey
|
||||||
issuer string
|
issuer string
|
||||||
@@ -255,7 +255,7 @@ func (service *OIDCService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
service.clients = make(map[string]config.OIDCClientConfig)
|
service.clients = make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range service.config.Clients {
|
for id, client := range service.config.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
@@ -283,7 +283,7 @@ func (service *OIDCService) GetIssuer() string {
|
|||||||
return service.issuer
|
return service.issuer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) {
|
||||||
client, ok := service.clients[id]
|
client, ok := service.clients[id]
|
||||||
return client, ok
|
return client, ok
|
||||||
}
|
}
|
||||||
@@ -367,43 +367,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
|
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
|
||||||
addressJSON, err := json.Marshal(userContext.Attributes.Address)
|
userInfoParams := repository.CreateOidcUserInfoParams{
|
||||||
|
Sub: sub,
|
||||||
|
Name: userContext.GetName(),
|
||||||
|
Email: userContext.GetEmail(),
|
||||||
|
PreferredUsername: userContext.GetUsername(),
|
||||||
|
UpdatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.IsLocal() {
|
||||||
|
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
|
||||||
userInfoParams := repository.CreateOidcUserInfoParams{
|
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
|
||||||
Sub: sub,
|
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
|
||||||
Name: userContext.Name,
|
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
|
||||||
Email: userContext.Email,
|
userInfoParams.Profile = userContext.Local.Attributes.Profile
|
||||||
PreferredUsername: userContext.Username,
|
userInfoParams.Picture = userContext.Local.Attributes.Picture
|
||||||
UpdatedAt: time.Now().Unix(),
|
userInfoParams.Website = userContext.Local.Attributes.Website
|
||||||
GivenName: userContext.Attributes.GivenName,
|
userInfoParams.Gender = userContext.Local.Attributes.Gender
|
||||||
FamilyName: userContext.Attributes.FamilyName,
|
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
|
||||||
MiddleName: userContext.Attributes.MiddleName,
|
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
|
||||||
Nickname: userContext.Attributes.Nickname,
|
userInfoParams.Locale = userContext.Local.Attributes.Locale
|
||||||
Profile: userContext.Attributes.Profile,
|
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
|
||||||
Picture: userContext.Attributes.Picture,
|
userInfoParams.Address = string(addressJSON)
|
||||||
Website: userContext.Attributes.Website,
|
|
||||||
Gender: userContext.Attributes.Gender,
|
|
||||||
Birthdate: userContext.Attributes.Birthdate,
|
|
||||||
Zoneinfo: userContext.Attributes.Zoneinfo,
|
|
||||||
Locale: userContext.Attributes.Locale,
|
|
||||||
PhoneNumber: userContext.Attributes.PhoneNumber,
|
|
||||||
Address: string(addressJSON),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
||||||
if userContext.Provider == "ldap" {
|
if userContext.IsLDAP() {
|
||||||
userInfoParams.Groups = userContext.LdapGroups
|
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
|
if userContext.IsOAuth() {
|
||||||
userInfoParams.Groups = userContext.OAuthGroups
|
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
|
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -445,7 +447,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
|||||||
return oidcCode, nil
|
return oidcCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -511,7 +513,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
|
||||||
user, err := service.GetUserinfo(c, codeEntry.Sub)
|
user, err := service.GetUserinfo(c, codeEntry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -585,7 +587,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
|||||||
return TokenResponse{}, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := service.generateIDToken(config.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, user, entry.Scope, entry.Nonce)
|
}, user, entry.Scope, entry.Nonce)
|
||||||
|
|
||||||
@@ -714,7 +716,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "address") {
|
if slices.Contains(scopes, "address") {
|
||||||
var addr config.AddressClaim
|
var addr model.AddressClaim
|
||||||
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
|
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
|
||||||
userInfo.Address = &addr
|
userInfo.Address = &addr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() repository.OidcUserinfo {
|
func newTestUser() repository.OidcUserinfo {
|
||||||
addr := config.AddressClaim{
|
addr := model.AddressClaim{
|
||||||
Formatted: "123 Main St",
|
Formatted: "123 Main St",
|
||||||
StreetAddress: "123 Main St",
|
StreetAddress: "123 Main St",
|
||||||
Locality: "Springfield",
|
Locality: "Springfield",
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,22 +71,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetContext(c *gin.Context) (config.UserContext, error) {
|
|
||||||
userContextValue, exists := c.Get("context")
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return config.UserContext{}, errors.New("no user context in request")
|
|
||||||
}
|
|
||||||
|
|
||||||
userContext, ok := userContextValue.(*config.UserContext)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return config.UserContext{}, errors.New("invalid user context in request")
|
|
||||||
}
|
|
||||||
|
|
||||||
return *userContext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||||
if redirectURL == "" {
|
if redirectURL == "" {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,11 +3,8 @@ package utils_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetRootDomain(t *testing.T) {
|
func TestGetRootDomain(t *testing.T) {
|
||||||
@@ -15,14 +12,14 @@ func TestGetRootDomain(t *testing.T) {
|
|||||||
domain := "http://sub.tinyauth.app"
|
domain := "http://sub.tinyauth.app"
|
||||||
expected := "tinyauth.app"
|
expected := "tinyauth.app"
|
||||||
result, err := utils.GetCookieDomain(domain)
|
result, err := utils.GetCookieDomain(domain)
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Domain with multiple subdomains
|
// Domain with multiple subdomains
|
||||||
domain = "http://b.c.tinyauth.app"
|
domain = "http://b.c.tinyauth.app"
|
||||||
expected = "c.tinyauth.app"
|
expected = "c.tinyauth.app"
|
||||||
result, err = utils.GetCookieDomain(domain)
|
result, err = utils.GetCookieDomain(domain)
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Invalid domain (only TLD)
|
// Invalid domain (only TLD)
|
||||||
@@ -44,14 +41,14 @@ func TestGetRootDomain(t *testing.T) {
|
|||||||
domain = "https://sub.tinyauth.app/path"
|
domain = "https://sub.tinyauth.app/path"
|
||||||
expected = "tinyauth.app"
|
expected = "tinyauth.app"
|
||||||
result, err = utils.GetCookieDomain(domain)
|
result, err = utils.GetCookieDomain(domain)
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// URL with port
|
// URL with port
|
||||||
domain = "http://sub.tinyauth.app:8080"
|
domain = "http://sub.tinyauth.app:8080"
|
||||||
expected = "tinyauth.app"
|
expected = "tinyauth.app"
|
||||||
result, err = utils.GetCookieDomain(domain)
|
result, err = utils.GetCookieDomain(domain)
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Domain managed by ICANN
|
// Domain managed by ICANN
|
||||||
@@ -98,57 +95,35 @@ func TestFilter(t *testing.T) {
|
|||||||
testFunc := func(n int) bool { return n%2 == 0 }
|
testFunc := func(n int) bool { return n%2 == 0 }
|
||||||
expected := []int{2, 4}
|
expected := []int{2, 4}
|
||||||
result := utils.Filter(slice, testFunc)
|
result := utils.Filter(slice, testFunc)
|
||||||
assert.DeepEqual(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Case with no matches
|
// Case with no matches
|
||||||
slice = []int{1, 3, 5}
|
slice = []int{1, 3, 5}
|
||||||
testFunc = func(n int) bool { return n%2 == 0 }
|
testFunc = func(n int) bool { return n%2 == 0 }
|
||||||
expected = []int{}
|
expected = []int{}
|
||||||
result = utils.Filter(slice, testFunc)
|
result = utils.Filter(slice, testFunc)
|
||||||
assert.DeepEqual(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Case with all matches
|
// Case with all matches
|
||||||
slice = []int{2, 4, 6}
|
slice = []int{2, 4, 6}
|
||||||
testFunc = func(n int) bool { return n%2 == 0 }
|
testFunc = func(n int) bool { return n%2 == 0 }
|
||||||
expected = []int{2, 4, 6}
|
expected = []int{2, 4, 6}
|
||||||
result = utils.Filter(slice, testFunc)
|
result = utils.Filter(slice, testFunc)
|
||||||
assert.DeepEqual(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Case with empty slice
|
// Case with empty slice
|
||||||
slice = []int{}
|
slice = []int{}
|
||||||
testFunc = func(n int) bool { return n%2 == 0 }
|
testFunc = func(n int) bool { return n%2 == 0 }
|
||||||
expected = []int{}
|
expected = []int{}
|
||||||
result = utils.Filter(slice, testFunc)
|
result = utils.Filter(slice, testFunc)
|
||||||
assert.DeepEqual(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
// Case with different type (string)
|
// Case with different type (string)
|
||||||
sliceStr := []string{"apple", "banana", "cherry"}
|
sliceStr := []string{"apple", "banana", "cherry"}
|
||||||
testFuncStr := func(s string) bool { return len(s) > 5 }
|
testFuncStr := func(s string) bool { return len(s) > 5 }
|
||||||
expectedStr := []string{"banana", "cherry"}
|
expectedStr := []string{"banana", "cherry"}
|
||||||
resultStr := utils.Filter(sliceStr, testFuncStr)
|
resultStr := utils.Filter(sliceStr, testFuncStr)
|
||||||
assert.DeepEqual(t, expectedStr, resultStr)
|
assert.Equal(t, expectedStr, resultStr)
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetContext(t *testing.T) {
|
|
||||||
// Setup
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
c, _ := gin.CreateTestContext(nil)
|
|
||||||
|
|
||||||
// Normal case
|
|
||||||
c.Set("context", &config.UserContext{Username: "testuser"})
|
|
||||||
result, err := utils.GetContext(c)
|
|
||||||
assert.NilError(t, err)
|
|
||||||
assert.Equal(t, "testuser", result.Username)
|
|
||||||
|
|
||||||
// Case with no context
|
|
||||||
c.Set("context", nil)
|
|
||||||
_, err = utils.GetContext(c)
|
|
||||||
assert.Error(t, err, "invalid user context in request")
|
|
||||||
|
|
||||||
// Case with invalid context type
|
|
||||||
c.Set("context", "invalid type")
|
|
||||||
_, err = utils.GetContext(c)
|
|
||||||
assert.Error(t, err, "invalid user context in request")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectSafe(t *testing.T) {
|
func TestIsRedirectSafe(t *testing.T) {
|
||||||
@@ -158,50 +133,50 @@ func TestIsRedirectSafe(t *testing.T) {
|
|||||||
// Case with no subdomain
|
// Case with no subdomain
|
||||||
redirectURL := "http://example.com/welcome"
|
redirectURL := "http://example.com/welcome"
|
||||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
result := utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, true, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
// Case with different domain
|
// Case with different domain
|
||||||
redirectURL = "http://malicious.com/phishing"
|
redirectURL = "http://malicious.com/phishing"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, false, result)
|
assert.False(t, result)
|
||||||
|
|
||||||
// Case with subdomain
|
// Case with subdomain
|
||||||
redirectURL = "http://sub.example.com/page"
|
redirectURL = "http://sub.example.com/page"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, true, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
// Case with sub-subdomain
|
// Case with sub-subdomain
|
||||||
redirectURL = "http://a.b.example.com/home"
|
redirectURL = "http://a.b.example.com/home"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, true, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
// Case with empty redirect URL
|
// Case with empty redirect URL
|
||||||
redirectURL = ""
|
redirectURL = ""
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, false, result)
|
assert.False(t, result)
|
||||||
|
|
||||||
// Case with invalid URL
|
// Case with invalid URL
|
||||||
redirectURL = "http://[::1]:namedport"
|
redirectURL = "http://[::1]:namedport"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, false, result)
|
assert.False(t, result)
|
||||||
|
|
||||||
// Case with URL having port
|
// Case with URL having port
|
||||||
redirectURL = "http://sub.example.com:8080/page"
|
redirectURL = "http://sub.example.com:8080/page"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, true, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
// Case with URL having different subdomain
|
// Case with URL having different subdomain
|
||||||
redirectURL = "http://another.example.com/page"
|
redirectURL = "http://another.example.com/page"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, true, result)
|
assert.True(t, result)
|
||||||
|
|
||||||
// Case with URL having different TLD
|
// Case with URL having different TLD
|
||||||
redirectURL = "http://example.org/page"
|
redirectURL = "http://example.org/page"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, false, result)
|
assert.False(t, result)
|
||||||
|
|
||||||
// Case with malicious domain
|
// Case with malicious domain
|
||||||
redirectURL = "https://malicious-example.com/yoyo"
|
redirectURL = "https://malicious-example.com/yoyo"
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||||
assert.Equal(t, false, result)
|
assert.False(t, result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,42 +3,41 @@ package decoders_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
|
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDecodeLabels(t *testing.T) {
|
func TestDecodeLabels(t *testing.T) {
|
||||||
// Variables
|
// Variables
|
||||||
expected := config.Apps{
|
expected := model.Apps{
|
||||||
Apps: map[string]config.App{
|
Apps: map[string]model.App{
|
||||||
"foo": {
|
"foo": {
|
||||||
Config: config.AppConfig{
|
Config: model.AppConfig{
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
},
|
},
|
||||||
Users: config.AppUsers{
|
Users: model.AppUsers{
|
||||||
Allow: "user1,user2",
|
Allow: "user1,user2",
|
||||||
Block: "user3",
|
Block: "user3",
|
||||||
},
|
},
|
||||||
OAuth: config.AppOAuth{
|
OAuth: model.AppOAuth{
|
||||||
Whitelist: "somebody@example.com",
|
Whitelist: "somebody@example.com",
|
||||||
Groups: "group3",
|
Groups: "group3",
|
||||||
},
|
},
|
||||||
IP: config.AppIP{
|
IP: model.AppIP{
|
||||||
Allow: []string{"10.71.0.1/24", "10.71.0.2"},
|
Allow: []string{"10.71.0.1/24", "10.71.0.2"},
|
||||||
Block: []string{"10.10.10.10", "10.0.0.0/24"},
|
Block: []string{"10.10.10.10", "10.0.0.0/24"},
|
||||||
Bypass: []string{"192.168.1.1"},
|
Bypass: []string{"192.168.1.1"},
|
||||||
},
|
},
|
||||||
Response: config.AppResponse{
|
Response: model.AppResponse{
|
||||||
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
|
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
|
||||||
BasicAuth: config.AppBasicAuth{
|
BasicAuth: model.AppBasicAuth{
|
||||||
Username: "admin",
|
Username: "admin",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
PasswordFile: "/path/to/passwordfile",
|
PasswordFile: "/path/to/passwordfile",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Path: config.AppPath{
|
Path: model.AppPath{
|
||||||
Allow: "/public",
|
Allow: "/public",
|
||||||
Block: "/private",
|
Block: "/private",
|
||||||
},
|
},
|
||||||
@@ -63,7 +62,7 @@ func TestDecodeLabels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test
|
// Test
|
||||||
result, err := decoders.DecodeLabels[config.Apps](test, "apps")
|
result, err := decoders.DecodeLabels[model.Apps](test, "apps")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.DeepEqual(t, expected, result)
|
assert.Equal(t, expected, result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,24 +4,25 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gotest.tools/v3/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReadFile(t *testing.T) {
|
func TestReadFile(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = file.WriteString("file content\n")
|
_, err = file.WriteString("file content\n")
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove("/tmp/tinyauth_test_file")
|
defer os.Remove("/tmp/tinyauth_test_file")
|
||||||
|
|
||||||
// Normal case
|
// Normal case
|
||||||
content, err := ReadFile("/tmp/tinyauth_test_file")
|
content, err := ReadFile("/tmp/tinyauth_test_file")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "file content\n", content)
|
assert.Equal(t, "file content\n", content)
|
||||||
|
|
||||||
// Non-existing file
|
// Non-existing file
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ package utils_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseHeaders(t *testing.T) {
|
func TestParseHeaders(t *testing.T) {
|
||||||
@@ -18,7 +17,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
"X-Custom-Header": "Value",
|
"X-Custom-Header": "Value",
|
||||||
"Another-Header": "AnotherValue",
|
"Another-Header": "AnotherValue",
|
||||||
}
|
}
|
||||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||||
|
|
||||||
// Case insensitivity and trimming
|
// Case insensitivity and trimming
|
||||||
headers = []string{
|
headers = []string{
|
||||||
@@ -29,7 +28,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
"X-Custom-Header": "Value",
|
"X-Custom-Header": "Value",
|
||||||
"Another-Header": "AnotherValue",
|
"Another-Header": "AnotherValue",
|
||||||
}
|
}
|
||||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||||
|
|
||||||
// Invalid headers (missing '=', empty key/value)
|
// Invalid headers (missing '=', empty key/value)
|
||||||
headers = []string{
|
headers = []string{
|
||||||
@@ -39,7 +38,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
" = ",
|
" = ",
|
||||||
}
|
}
|
||||||
expected = map[string]string{}
|
expected = map[string]string{}
|
||||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||||
|
|
||||||
// Headers with unsafe characters
|
// Headers with unsafe characters
|
||||||
headers = []string{
|
headers = []string{
|
||||||
@@ -52,7 +51,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
"Another-Header": "AnotherValue",
|
"Another-Header": "AnotherValue",
|
||||||
"Good-Header": "GoodValue",
|
"Good-Header": "GoodValue",
|
||||||
}
|
}
|
||||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||||
|
|
||||||
// Header with spaces in key (should be ignored)
|
// Header with spaces in key (should be ignored)
|
||||||
headers = []string{
|
headers = []string{
|
||||||
@@ -62,7 +61,7 @@ func TestParseHeaders(t *testing.T) {
|
|||||||
expected = map[string]string{
|
expected = map[string]string{
|
||||||
"Valid-Header": "ValidValue",
|
"Valid-Header": "ValidValue",
|
||||||
}
|
}
|
||||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSanitizeHeader(t *testing.T) {
|
func TestSanitizeHeader(t *testing.T) {
|
||||||
|
|||||||
@@ -4,21 +4,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
|
|
||||||
"github.com/tinyauthapp/paerser/cli"
|
"github.com/tinyauthapp/paerser/cli"
|
||||||
"github.com/tinyauthapp/paerser/env"
|
"github.com/tinyauthapp/paerser/env"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EnvLoader struct{}
|
type EnvLoader struct{}
|
||||||
|
|
||||||
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
|
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
|
||||||
vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration)
|
vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration)
|
||||||
if len(vars) == 0 {
|
if len(vars) == 0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil {
|
if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil {
|
||||||
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
|
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetBasicAuth(username string, password string) string {
|
func EncodeBasicAuth(username string, password string) string {
|
||||||
auth := username + ":" + password
|
auth := username + ":" + password
|
||||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,21 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetSecret(t *testing.T) {
|
func TestGetSecret(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
file, err := os.Create("/tmp/tinyauth_test_secret")
|
file, err := os.Create("/tmp/tinyauth_test_secret")
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = file.WriteString(" secret \n")
|
_, err = file.WriteString(" secret \n")
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||||
|
|
||||||
// Get from config
|
// Get from config
|
||||||
@@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) {
|
|||||||
assert.Equal(t, "", utils.ParseSecretFile(content))
|
assert.Equal(t, "", utils.ParseSecretFile(content))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetBasicAuth(t *testing.T) {
|
func TestEncodeBasicAuth(t *testing.T) {
|
||||||
// Normal case
|
// Normal case
|
||||||
username := "user"
|
username := "user"
|
||||||
password := "pass"
|
password := "pass"
|
||||||
expected := "dXNlcjpwYXNz" // base64 of "user:pass"
|
expected := "dXNlcjpwYXNz" // base64 of "user:pass"
|
||||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||||
|
|
||||||
// Empty username
|
// Empty username
|
||||||
username = ""
|
username = ""
|
||||||
password = "pass"
|
password = "pass"
|
||||||
expected = "OnBhc3M=" // base64 of ":pass"
|
expected = "OnBhc3M=" // base64 of ":pass"
|
||||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||||
|
|
||||||
// Empty password
|
// Empty password
|
||||||
username = "user"
|
username = "user"
|
||||||
password = ""
|
password = ""
|
||||||
expected = "dXNlcjo=" // base64 of "user:"
|
expected = "dXNlcjo=" // base64 of "user:"
|
||||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFilterIP(t *testing.T) {
|
func TestFilterIP(t *testing.T) {
|
||||||
// Exact match IPv4
|
// Exact match IPv4
|
||||||
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// Non-match IPv4
|
// Non-match IPv4
|
||||||
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// CIDR match IPv4
|
// CIDR match IPv4
|
||||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// CIDR match IPv4 with '-' instead of '/'
|
// CIDR match IPv4 with '-' instead of '/'
|
||||||
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, true, ok)
|
assert.Equal(t, true, ok)
|
||||||
|
|
||||||
// CIDR non-match IPv4
|
// CIDR non-match IPv4
|
||||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
|
|
||||||
// Invalid CIDR
|
// Invalid CIDR
|
||||||
@@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) {
|
|||||||
|
|
||||||
// Different output for different input
|
// Different output for different input
|
||||||
id3 := utils.GenerateUUID("differentstring")
|
id3 := utils.GenerateUUID("differentstring")
|
||||||
assert.Assert(t, id1 != id3)
|
assert.NotEqual(t, id2, id3)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ package utils_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCapitalize(t *testing.T) {
|
func TestCapitalize(t *testing.T) {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
@@ -22,7 +22,7 @@ var (
|
|||||||
App zerolog.Logger
|
App zerolog.Logger
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewLogger(cfg config.LogConfig) *Logger {
|
func NewLogger(cfg model.LogConfig) *Logger {
|
||||||
baseLogger := log.With().
|
baseLogger := log.With().
|
||||||
Timestamp().
|
Timestamp().
|
||||||
Caller().
|
Caller().
|
||||||
@@ -44,24 +44,24 @@ func NewLogger(cfg config.LogConfig) *Logger {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewSimpleLogger() *Logger {
|
func NewSimpleLogger() *Logger {
|
||||||
return NewLogger(config.LogConfig{
|
return NewLogger(model.LogConfig{
|
||||||
Level: "info",
|
Level: "info",
|
||||||
Json: false,
|
Json: false,
|
||||||
Streams: config.LogStreams{
|
Streams: model.LogStreams{
|
||||||
HTTP: config.LogStreamConfig{Enabled: true},
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
App: config.LogStreamConfig{Enabled: true},
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
Audit: config.LogStreamConfig{Enabled: false},
|
Audit: model.LogStreamConfig{Enabled: false},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestLogger() *Logger {
|
func NewTestLogger() *Logger {
|
||||||
return NewLogger(config.LogConfig{
|
return NewLogger(model.LogConfig{
|
||||||
Level: "trace",
|
Level: "trace",
|
||||||
Streams: config.LogStreams{
|
Streams: model.LogStreams{
|
||||||
HTTP: config.LogStreamConfig{Enabled: true},
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
App: config.LogStreamConfig{Enabled: true},
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
Audit: config.LogStreamConfig{Enabled: true},
|
Audit: model.LogStreamConfig{Enabled: true},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -72,7 +72,7 @@ func (l *Logger) Init() {
|
|||||||
App = l.App
|
App = l.App
|
||||||
}
|
}
|
||||||
|
|
||||||
func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
|
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
|
||||||
if !streamCfg.Enabled {
|
if !streamCfg.Enabled {
|
||||||
return zerolog.Nop()
|
return zerolog.Nop()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,75 +5,75 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewLogger(t *testing.T) {
|
func TestNewLogger(t *testing.T) {
|
||||||
cfg := config.LogConfig{
|
cfg := model.LogConfig{
|
||||||
Level: "debug",
|
Level: "debug",
|
||||||
Json: true,
|
Json: true,
|
||||||
Streams: config.LogStreams{
|
Streams: model.LogStreams{
|
||||||
HTTP: config.LogStreamConfig{Enabled: true, Level: "info"},
|
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
|
||||||
App: config.LogStreamConfig{Enabled: true, Level: ""},
|
App: model.LogStreamConfig{Enabled: true, Level: ""},
|
||||||
Audit: config.LogStreamConfig{Enabled: false, Level: ""},
|
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := tlog.NewLogger(cfg)
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
assert.Assert(t, logger != nil)
|
assert.NotNil(t, logger)
|
||||||
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
|
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||||
assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel)
|
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
|
||||||
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewSimpleLogger(t *testing.T) {
|
func TestNewSimpleLogger(t *testing.T) {
|
||||||
logger := tlog.NewSimpleLogger()
|
logger := tlog.NewSimpleLogger()
|
||||||
assert.Assert(t, logger != nil)
|
assert.NotNil(t, logger)
|
||||||
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
|
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||||
assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel)
|
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
|
||||||
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoggerInit(t *testing.T) {
|
func TestLoggerInit(t *testing.T) {
|
||||||
logger := tlog.NewSimpleLogger()
|
logger := tlog.NewSimpleLogger()
|
||||||
logger.Init()
|
logger.Init()
|
||||||
|
|
||||||
assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled)
|
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoggerWithDisabledStreams(t *testing.T) {
|
func TestLoggerWithDisabledStreams(t *testing.T) {
|
||||||
cfg := config.LogConfig{
|
cfg := model.LogConfig{
|
||||||
Level: "info",
|
Level: "info",
|
||||||
Json: false,
|
Json: false,
|
||||||
Streams: config.LogStreams{
|
Streams: model.LogStreams{
|
||||||
HTTP: config.LogStreamConfig{Enabled: false},
|
HTTP: model.LogStreamConfig{Enabled: false},
|
||||||
App: config.LogStreamConfig{Enabled: false},
|
App: model.LogStreamConfig{Enabled: false},
|
||||||
Audit: config.LogStreamConfig{Enabled: false},
|
Audit: model.LogStreamConfig{Enabled: false},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := tlog.NewLogger(cfg)
|
logger := tlog.NewLogger(cfg)
|
||||||
|
|
||||||
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled)
|
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
|
||||||
assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled)
|
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
|
||||||
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
|
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogStreamField(t *testing.T) {
|
func TestLogStreamField(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
|
||||||
cfg := config.LogConfig{
|
cfg := model.LogConfig{
|
||||||
Level: "info",
|
Level: "info",
|
||||||
Json: true,
|
Json: true,
|
||||||
Streams: config.LogStreams{
|
Streams: model.LogStreams{
|
||||||
HTTP: config.LogStreamConfig{Enabled: true},
|
HTTP: model.LogStreamConfig{Enabled: true},
|
||||||
App: config.LogStreamConfig{Enabled: true},
|
App: model.LogStreamConfig{Enabled: true},
|
||||||
Audit: config.LogStreamConfig{Enabled: true},
|
Audit: model.LogStreamConfig{Enabled: true},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ func TestLogStreamField(t *testing.T) {
|
|||||||
|
|
||||||
var logEntry map[string]interface{}
|
var logEntry map[string]interface{}
|
||||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "http", logEntry["log_stream"])
|
assert.Equal(t, "http", logEntry["log_stream"])
|
||||||
assert.Equal(t, "test message", logEntry["message"])
|
assert.Equal(t, "test message", logEntry["message"])
|
||||||
|
|||||||
@@ -6,14 +6,14 @@ import (
|
|||||||
"net/mail"
|
"net/mail"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
|
func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
|
||||||
var users []config.User
|
var users []model.LocalUser
|
||||||
|
|
||||||
if len(usersStr) == 0 {
|
if len(usersStr) == 0 {
|
||||||
return []config.User{}, nil
|
return &users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range usersStr {
|
for _, user := range usersStr {
|
||||||
@@ -22,22 +22,22 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut
|
|||||||
}
|
}
|
||||||
parsed, err := ParseUser(strings.TrimSpace(user))
|
parsed, err := ParseUser(strings.TrimSpace(user))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []config.User{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if attrs, ok := userAttributes[parsed.Username]; ok {
|
if attrs, ok := userAttributes[parsed.Username]; ok {
|
||||||
parsed.Attributes = attrs
|
parsed.Attributes = attrs
|
||||||
}
|
}
|
||||||
users = append(users, parsed)
|
users = append(users, *parsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return &users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
|
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) {
|
||||||
var usersStr []string
|
var usersStr []string
|
||||||
|
|
||||||
if len(usersCfg) == 0 && usersPath == "" {
|
if len(usersCfg) == 0 && usersPath == "" {
|
||||||
return []config.User{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(usersCfg) > 0 {
|
if len(usersCfg) > 0 {
|
||||||
@@ -48,7 +48,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]con
|
|||||||
contents, err := ReadFile(usersPath)
|
contents, err := ReadFile(usersPath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []config.User{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.SplitSeq(contents, "\n")
|
lines := strings.SplitSeq(contents, "\n")
|
||||||
@@ -65,7 +65,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]con
|
|||||||
return ParseUsers(usersStr, userAttributes)
|
return ParseUsers(usersStr, userAttributes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseUser(userStr string) (config.User, error) {
|
func ParseUser(userStr string) (*model.LocalUser, error) {
|
||||||
if strings.Contains(userStr, "$$") {
|
if strings.Contains(userStr, "$$") {
|
||||||
userStr = strings.ReplaceAll(userStr, "$$", "$")
|
userStr = strings.ReplaceAll(userStr, "$$", "$")
|
||||||
}
|
}
|
||||||
@@ -73,27 +73,27 @@ func ParseUser(userStr string) (config.User, error) {
|
|||||||
parts := strings.SplitN(userStr, ":", 4)
|
parts := strings.SplitN(userStr, ":", 4)
|
||||||
|
|
||||||
if len(parts) < 2 || len(parts) > 3 {
|
if len(parts) < 2 || len(parts) > 3 {
|
||||||
return config.User{}, errors.New("invalid user format")
|
return nil, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
trimmed := strings.TrimSpace(part)
|
trimmed := strings.TrimSpace(part)
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
return config.User{}, errors.New("invalid user format")
|
return nil, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
parts[i] = trimmed
|
parts[i] = trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
user := config.User{
|
user := model.LocalUser{
|
||||||
Username: parts[0],
|
Username: parts[0],
|
||||||
Password: parts[1],
|
Password: parts[1],
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) == 3 {
|
if len(parts) == 3 {
|
||||||
user.TotpSecret = parts[2]
|
user.TOTPSecret = parts[2]
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CompileUserEmail(username string, domain string) string {
|
func CompileUserEmail(username string, domain string) string {
|
||||||
|
|||||||
@@ -4,74 +4,76 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
|
||||||
"gotest.tools/v3/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetUsers(t *testing.T) {
|
func TestGetUsers(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
||||||
|
|
||||||
// Setup
|
// Setup
|
||||||
file, err := os.Create("/tmp/tinyauth_users_test.txt")
|
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
assert.NilError(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove("/tmp/tinyauth_users_test.txt")
|
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
|
||||||
|
|
||||||
noAttrs := map[string]config.UserAttributes{}
|
noAttrs := map[string]model.UserAttributes{}
|
||||||
|
|
||||||
// Test file only
|
// Test file only
|
||||||
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, users)
|
||||||
|
assert.Len(t, *users, 2)
|
||||||
|
|
||||||
assert.Equal(t, 2, len(users))
|
assert.Equal(t, "user1", (*users)[0].Username)
|
||||||
|
assert.Equal(t, hash, (*users)[0].Password)
|
||||||
assert.Equal(t, "user1", users[0].Username)
|
assert.Equal(t, "user2", (*users)[1].Username)
|
||||||
assert.Equal(t, hash, users[0].Password)
|
assert.Equal(t, hash, (*users)[1].Password)
|
||||||
assert.Equal(t, "user2", users[1].Username)
|
|
||||||
assert.Equal(t, hash, users[1].Password)
|
|
||||||
|
|
||||||
// Test inline config only
|
// Test inline config only
|
||||||
users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs)
|
users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, 2, len(users))
|
assert.Len(t, *users, 2)
|
||||||
assert.Equal(t, "user3", users[0].Username)
|
assert.Equal(t, "user3", (*users)[0].Username)
|
||||||
assert.Equal(t, "user4", users[1].Username)
|
assert.Equal(t, "user4", (*users)[1].Username)
|
||||||
|
|
||||||
// Test both
|
// Test both
|
||||||
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, 3, len(users))
|
assert.Len(t, *users, 3)
|
||||||
|
|
||||||
usernames := map[string]bool{}
|
usernames := map[string]bool{}
|
||||||
for _, u := range users {
|
for _, u := range *users {
|
||||||
usernames[u.Username] = true
|
usernames[u.Username] = true
|
||||||
}
|
}
|
||||||
assert.Assert(t, usernames["user1"])
|
assert.True(t, usernames["user1"])
|
||||||
assert.Assert(t, usernames["user2"])
|
assert.True(t, usernames["user2"])
|
||||||
assert.Assert(t, usernames["user5"])
|
assert.True(t, usernames["user5"])
|
||||||
|
|
||||||
// Test attributes applied from userAttributes map
|
// Test attributes applied from userAttributes map
|
||||||
attrs := map[string]config.UserAttributes{
|
attrs := map[string]model.UserAttributes{
|
||||||
"user1": {Name: "User One", Email: "user1@example.com"},
|
"user1": {Name: "User One", Email: "user1@example.com"},
|
||||||
}
|
}
|
||||||
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
|
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 2, len(users))
|
assert.Len(t, *users, 2)
|
||||||
|
|
||||||
for _, u := range users {
|
for _, u := range *users {
|
||||||
if u.Username == "user1" {
|
if u.Username == "user1" {
|
||||||
assert.Equal(t, "User One", u.Attributes.Name)
|
assert.Equal(t, "User One", u.Attributes.Name)
|
||||||
assert.Equal(t, "user1@example.com", u.Attributes.Email)
|
assert.Equal(t, "user1@example.com", u.Attributes.Email)
|
||||||
@@ -84,16 +86,14 @@ func TestGetUsers(t *testing.T) {
|
|||||||
// Test empty
|
// Test empty
|
||||||
users, err = utils.GetUsers([]string{}, "", noAttrs)
|
users, err = utils.GetUsers([]string{}, "", noAttrs)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, users)
|
||||||
assert.Equal(t, 0, len(users))
|
|
||||||
|
|
||||||
// Test non-existent file
|
// Test non-existent file
|
||||||
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
|
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
|
||||||
|
|
||||||
assert.ErrorContains(t, err, "no such file or directory")
|
assert.ErrorContains(t, err, "no such file or directory")
|
||||||
|
assert.Nil(t, users)
|
||||||
assert.Equal(t, 0, len(users))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseUser(t *testing.T) {
|
func TestParseUser(t *testing.T) {
|
||||||
@@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) {
|
|||||||
// Valid user without TOTP
|
// Valid user without TOTP
|
||||||
user, err := utils.ParseUser("user1:" + hash)
|
user, err := utils.ParseUser("user1:" + hash)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "user1", user.Username)
|
assert.Equal(t, "user1", user.Username)
|
||||||
assert.Equal(t, hash, user.Password)
|
assert.Equal(t, hash, user.Password)
|
||||||
assert.Equal(t, "", user.TotpSecret)
|
assert.Equal(t, "", user.TOTPSecret)
|
||||||
|
|
||||||
// Valid user with TOTP
|
// Valid user with TOTP
|
||||||
user, err = utils.ParseUser("user2:" + hash + ":ABCDEF")
|
user, err = utils.ParseUser("user2:" + hash + ":ABCDEF")
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "user2", user.Username)
|
assert.Equal(t, "user2", user.Username)
|
||||||
assert.Equal(t, hash, user.Password)
|
assert.Equal(t, hash, user.Password)
|
||||||
assert.Equal(t, "ABCDEF", user.TotpSecret)
|
assert.Equal(t, "ABCDEF", user.TOTPSecret)
|
||||||
|
|
||||||
// Valid user with $$ in password
|
// Valid user with $$ in password
|
||||||
user, err = utils.ParseUser("user3:pa$$word123")
|
user, err = utils.ParseUser("user3:pa$$word123")
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "user3", user.Username)
|
assert.Equal(t, "user3", user.Username)
|
||||||
assert.Equal(t, "pa$word123", user.Password)
|
assert.Equal(t, "pa$word123", user.Password)
|
||||||
assert.Equal(t, "", user.TotpSecret)
|
assert.Equal(t, "", user.TOTPSecret)
|
||||||
|
|
||||||
// User with spaces
|
// User with spaces
|
||||||
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
|
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, "user4", user.Username)
|
assert.Equal(t, "user4", user.Username)
|
||||||
assert.Equal(t, "password123", user.Password)
|
assert.Equal(t, "password123", user.Password)
|
||||||
assert.Equal(t, "TOTPSECRET", user.TotpSecret)
|
assert.Equal(t, "TOTPSECRET", user.TOTPSecret)
|
||||||
|
|
||||||
// Invalid users
|
// Invalid users
|
||||||
_, err = utils.ParseUser("user1") // Missing password
|
_, err = utils.ParseUser("user1") // Missing password
|
||||||
|
|||||||
Reference in New Issue
Block a user