mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-07-01 15:50:13 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 62ffd2fd11 | |||
| a3ec07230c | |||
| b4eb7090bd | |||
| 2f24f823eb | |||
| 9a219046ac | |||
| 97d58b376d | |||
| b426a1529e | |||
| c7efb71a5a | |||
| eec75a6f49 |
@@ -84,7 +84,7 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
cp -r frontend/dist internal/assets/dist
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
|
||||
@@ -130,7 +130,7 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
cp -r frontend/dist internal/assets/dist
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
cp -r frontend/dist internal/assets/dist
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
|
||||
@@ -103,7 +103,7 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
cp -r frontend/dist internal/assets/dist
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
||||
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
|
||||
|
||||
@@ -38,6 +38,6 @@ jobs:
|
||||
retention-days: 5
|
||||
|
||||
- name: Upload to code-scanning
|
||||
uses: github/codeql-action/upload-sarif@e46ed2cbd01164d986452f91f178727624ae40d7 # v4
|
||||
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
+3
-3
@@ -38,9 +38,9 @@ COPY ./internal ./internal
|
||||
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
|
||||
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||
|
||||
# Runner
|
||||
FROM alpine:3.23 AS runner
|
||||
|
||||
@@ -40,9 +40,9 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
|
||||
RUN mkdir -p data
|
||||
|
||||
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${VERSION} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
|
||||
|
||||
# Runner
|
||||
FROM gcr.io/distroless/static-debian12:latest AS runner
|
||||
|
||||
@@ -37,9 +37,9 @@ webui: clean-webui
|
||||
# Build the binary
|
||||
binary: webui
|
||||
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.Version=${TAG_NAME} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
|
||||
-X github.com/tinyauthapp/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
|
||||
-o ${BIN_NAME} ./cmd/tinyauth
|
||||
|
||||
# Build for amd64
|
||||
|
||||
+2
-2
@@ -10,7 +10,7 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
)
|
||||
|
||||
type EnvEntry struct {
|
||||
@@ -20,7 +20,7 @@ type EnvEntry struct {
|
||||
}
|
||||
|
||||
func generateExampleEnv() {
|
||||
cfg := model.NewDefaultConfiguration()
|
||||
cfg := config.NewDefaultConfiguration()
|
||||
entries := make([]EnvEntry, 0)
|
||||
|
||||
root := reflect.TypeOf(cfg).Elem()
|
||||
|
||||
+2
-2
@@ -10,7 +10,7 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
)
|
||||
|
||||
type MarkdownEntry struct {
|
||||
@@ -21,7 +21,7 @@ type MarkdownEntry struct {
|
||||
}
|
||||
|
||||
func generateMarkdown() {
|
||||
cfg := model.NewDefaultConfiguration()
|
||||
cfg := config.NewDefaultConfiguration()
|
||||
entries := make([]MarkdownEntry, 0)
|
||||
|
||||
root := reflect.TypeOf(cfg).Elem()
|
||||
|
||||
@@ -20,6 +20,7 @@ require (
|
||||
github.com/weppos/publicsuffix-go v0.50.3
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
gotest.tools/v3 v3.5.2
|
||||
k8s.io/apimachinery v0.32.2
|
||||
k8s.io/client-go v0.32.2
|
||||
modernc.org/sqlite v1.49.1
|
||||
@@ -132,7 +133,6 @@ require (
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gotest.tools/v3 v3.5.2 // indirect
|
||||
k8s.io/klog/v2 v2.130.1 // indirect
|
||||
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
|
||||
modernc.org/libc v1.72.0 // indirect
|
||||
|
||||
@@ -29,7 +29,7 @@ type BootstrapApp struct {
|
||||
csrfCookieName string
|
||||
redirectCookieName string
|
||||
oauthSessionCookieName string
|
||||
localUsers *[]model.LocalUser
|
||||
localUsers []model.LocalUser
|
||||
oauthProviders map[string]model.OAuthServiceConfig
|
||||
configuredProviders []controller.Provider
|
||||
oidcClients []model.OIDCClientConfig
|
||||
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
return err
|
||||
}
|
||||
|
||||
app.context.localUsers = users
|
||||
app.context.localUsers = *users
|
||||
|
||||
// Setup OAuth providers
|
||||
app.context.oauthProviders = app.config.OAuth.Providers
|
||||
@@ -104,13 +104,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
}
|
||||
|
||||
// Get cookie domain
|
||||
cookieDomainResolver := utils.GetCookieDomain
|
||||
if !app.config.Auth.SubdomainsEnabled {
|
||||
tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work")
|
||||
cookieDomainResolver = utils.GetStandaloneCookieDomain
|
||||
}
|
||||
|
||||
cookieDomain, err := cookieDomainResolver(app.context.appUrl)
|
||||
cookieDomain, err := utils.GetCookieDomain(app.context.appUrl)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -84,7 +84,6 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||
RedirectCookieName: app.context.redirectCookieName,
|
||||
CookieDomain: app.context.cookieDomain,
|
||||
OAuthSessionCookieName: app.context.oauthSessionCookieName,
|
||||
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
||||
}, apiRouter, app.services.authService)
|
||||
|
||||
oauthController.SetupRoutes()
|
||||
|
||||
@@ -100,7 +100,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
||||
SessionCookieName: app.context.sessionCookieName,
|
||||
IP: app.config.Auth.IP,
|
||||
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
|
||||
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
|
||||
}, services.ldapService, queries, services.oauthBrokerService)
|
||||
|
||||
err = authService.Init()
|
||||
|
||||
@@ -95,7 +95,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
|
||||
Username: context.GetUsername(),
|
||||
Name: context.GetName(),
|
||||
Email: context.GetEmail(),
|
||||
Provider: context.GetProviderID(),
|
||||
Provider: context.ProviderName(),
|
||||
OAuth: context.IsOAuth(),
|
||||
TOTPPending: context.TOTPPending(),
|
||||
OAuthName: context.OAuthName(),
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContextController(t *testing.T) {
|
||||
@@ -79,16 +79,12 @@ func TestContextController(t *testing.T) {
|
||||
description: "Ensure user context returns when authorized",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "johndoe",
|
||||
Name: "John Doe",
|
||||
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||
},
|
||||
},
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "johndoe",
|
||||
Name: "John Doe",
|
||||
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
||||
Provider: "local",
|
||||
IsLoggedIn: true,
|
||||
})
|
||||
},
|
||||
},
|
||||
|
||||
@@ -26,7 +26,6 @@ type OAuthControllerConfig struct {
|
||||
SecureCookie bool
|
||||
AppURL string
|
||||
CookieDomain string
|
||||
SubdomainsEnabled bool
|
||||
}
|
||||
|
||||
type OAuthController struct {
|
||||
@@ -106,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
|
||||
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
@@ -136,7 +135,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true)
|
||||
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||
|
||||
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
|
||||
|
||||
@@ -284,10 +283,3 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams)
|
||||
params.ClientID != "" &&
|
||||
params.RedirectURI != ""
|
||||
}
|
||||
|
||||
func (controller *OAuthController) getCookieDomain() string {
|
||||
if controller.config.SubdomainsEnabled {
|
||||
return "." + controller.config.CookieDomain
|
||||
}
|
||||
return controller.config.CookieDomain
|
||||
}
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-querystring/query"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOIDCController(t *testing.T) {
|
||||
@@ -27,7 +27,7 @@ func TestOIDCController(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
oidcServiceCfg := service.OIDCServiceConfig{
|
||||
Clients: map[string]model.OIDCClientConfig{
|
||||
Clients: map[string]config.OIDCClientConfig{
|
||||
"test": {
|
||||
ClientID: "some-client-id",
|
||||
ClientSecret: "some-client-secret",
|
||||
@@ -44,16 +44,12 @@ func TestOIDCController(t *testing.T) {
|
||||
controllerCfg := controller.OIDCControllerConfig{}
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "test",
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
},
|
||||
},
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "test",
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
IsLoggedIn: true,
|
||||
Provider: "local",
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
@@ -852,7 +848,7 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -99,12 +99,16 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if acls == nil {
|
||||
acls = &model.App{}
|
||||
}
|
||||
|
||||
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||
controller.setHeaders(c, acls)
|
||||
if controller.auth.IsBypassedIP(&acls.IP, clientIP) {
|
||||
controller.setHeaders(c, *acls)
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Authenticated",
|
||||
@@ -112,7 +116,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, &acls.Path)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
||||
@@ -122,7 +126,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
|
||||
if !authEnabled {
|
||||
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
|
||||
controller.setHeaders(c, acls)
|
||||
controller.setHeaders(c, *acls)
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Authenticated",
|
||||
@@ -130,7 +134,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !controller.auth.CheckIP(clientIP, acls) {
|
||||
if !controller.auth.CheckIP(&acls.IP, clientIP) {
|
||||
queries, err := query.Values(UnauthorizedQuery{
|
||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||
IP: clientIP,
|
||||
@@ -209,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
var groupOK bool
|
||||
|
||||
if userContext.IsOAuth() {
|
||||
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
|
||||
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls.OAuth.Groups)
|
||||
} else {
|
||||
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
|
||||
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls.LDAP.Groups)
|
||||
}
|
||||
|
||||
if !groupOK {
|
||||
@@ -263,7 +267,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
|
||||
}
|
||||
|
||||
controller.setHeaders(c, acls)
|
||||
controller.setHeaders(c, *acls)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
@@ -296,13 +300,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls model.App) {
|
||||
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
||||
|
||||
if acls == nil {
|
||||
return
|
||||
}
|
||||
|
||||
headers := utils.ParseHeaders(acls.Response.Headers)
|
||||
|
||||
for key, value := range headers {
|
||||
@@ -314,7 +314,7 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||
|
||||
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
|
||||
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
|
||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProxyController(t *testing.T) {
|
||||
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
Users: []config.User{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
@@ -29,7 +29,7 @@ func TestProxyController(t *testing.T) {
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
},
|
||||
},
|
||||
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||
@@ -43,28 +43,28 @@ func TestProxyController(t *testing.T) {
|
||||
AppURL: "https://tinyauth.example.com",
|
||||
}
|
||||
|
||||
acls := map[string]model.App{
|
||||
acls := map[string]config.App{
|
||||
"app_path_allow": {
|
||||
Config: model.AppConfig{
|
||||
Config: config.AppConfig{
|
||||
Domain: "path-allow.example.com",
|
||||
},
|
||||
Path: model.AppPath{
|
||||
Path: config.AppPath{
|
||||
Allow: "/allowed",
|
||||
},
|
||||
},
|
||||
"app_user_allow": {
|
||||
Config: model.AppConfig{
|
||||
Config: config.AppConfig{
|
||||
Domain: "user-allow.example.com",
|
||||
},
|
||||
Users: model.AppUsers{
|
||||
Users: config.AppUsers{
|
||||
Allow: "testuser",
|
||||
},
|
||||
},
|
||||
"ip_bypass": {
|
||||
Config: model.AppConfig{
|
||||
Config: config.AppConfig{
|
||||
Domain: "ip-bypass.example.com",
|
||||
},
|
||||
IP: model.AppIP{
|
||||
IP: config.AppIP{
|
||||
Bypass: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
@@ -74,31 +74,24 @@ func TestProxyController(t *testing.T) {
|
||||
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
},
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "testuser",
|
||||
Name: "Testuser",
|
||||
Email: "testuser@example.com",
|
||||
IsLoggedIn: true,
|
||||
Provider: "local",
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
simpleCtxTotp := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "totpuser",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
},
|
||||
},
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: true,
|
||||
Provider: "local",
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
}
|
||||
@@ -398,9 +391,9 @@ func TestProxyController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -102,7 +102,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
|
||||
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
|
||||
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
|
||||
controller.auth.RecordLoginAttempt(req.Username, false)
|
||||
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
|
||||
c.JSON(401, gin.H{
|
||||
@@ -112,20 +112,16 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
|
||||
tlog.AuditLoginSuccess(c, req.Username, "username")
|
||||
|
||||
controller.auth.RecordLoginAttempt(req.Username, true)
|
||||
|
||||
var localUser *model.LocalUser
|
||||
|
||||
if search.Type == model.UserLocal {
|
||||
localUser = controller.auth.GetLocalUser(req.Username)
|
||||
|
||||
if localUser == nil {
|
||||
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if localUser.TOTPSecret != "" {
|
||||
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||
|
||||
@@ -202,11 +198,6 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
|
||||
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
|
||||
tlog.AuditLoginSuccess(c, req.Username, "username")
|
||||
|
||||
controller.auth.RecordLoginAttempt(req.Username, true)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Login successful",
|
||||
@@ -235,6 +226,17 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||
|
||||
if err != nil {
|
||||
@@ -246,14 +248,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err == nil {
|
||||
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
|
||||
} else {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||
tlog.AuditLogout(c, "unknown", "unknown")
|
||||
}
|
||||
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
|
||||
@@ -313,15 +308,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
|
||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||
|
||||
if user == nil {
|
||||
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||
|
||||
if !ok {
|
||||
@@ -335,16 +321,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
uuid, err := c.Cookie(controller.config.SessionCookieName)
|
||||
|
||||
if err == nil {
|
||||
_, err = controller.auth.DeleteSession(c, uuid)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
|
||||
}
|
||||
} else {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
|
||||
}
|
||||
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
|
||||
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
|
||||
|
||||
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
|
||||
|
||||
@@ -377,9 +355,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
|
||||
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
|
||||
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Login successful",
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"strings"
|
||||
@@ -12,14 +10,14 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUserController(t *testing.T) {
|
||||
@@ -27,7 +25,7 @@ func TestUserController(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
Users: []config.User{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
@@ -35,12 +33,12 @@ func TestUserController(t *testing.T) {
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
},
|
||||
{
|
||||
Username: "attruser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
Attributes: model.UserAttributes{
|
||||
Attributes: config.UserAttributes{
|
||||
Name: "Alice Smith",
|
||||
Email: "alice@example.com",
|
||||
},
|
||||
@@ -48,8 +46,8 @@ func TestUserController(t *testing.T) {
|
||||
{
|
||||
Username: "attrtotpuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
Attributes: model.UserAttributes{
|
||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
Attributes: config.UserAttributes{
|
||||
Name: "Bob Jones",
|
||||
Email: "bob@example.com",
|
||||
},
|
||||
@@ -63,63 +61,9 @@ func TestUserController(t *testing.T) {
|
||||
}
|
||||
|
||||
userControllerCfg := controller.UserControllerConfig{
|
||||
CookieDomain: "example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
CookieDomain: "example.com",
|
||||
}
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "totpuser",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
},
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
totpAttrCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "attrtotpuser",
|
||||
Name: "Bob Jones",
|
||||
Email: "bob@example.com",
|
||||
},
|
||||
TOTPPending: true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: "testuser",
|
||||
Name: "Test User",
|
||||
Email: "testuser@example.com",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
middlewares []gin.HandlerFunc
|
||||
@@ -150,9 +94,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
// 3 seconds should be more than enough for even slow test environments
|
||||
assert.GreaterOrEqual(t, cookie.MaxAge, 7)
|
||||
assert.LessOrEqual(t, cookie.MaxAge, 10)
|
||||
assert.Equal(t, 10, cookie.MaxAge)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -241,15 +183,12 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.GreaterOrEqual(t, cookie.MaxAge, 3597)
|
||||
assert.LessOrEqual(t, cookie.MaxAge, 3600)
|
||||
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to logout",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
simpleCtx,
|
||||
},
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
// First login to get a session cookie
|
||||
loginReq := controller.LoginRequest{
|
||||
@@ -265,10 +204,9 @@ func TestUserController(t *testing.T) {
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
cookies := recorder.Result().Cookies()
|
||||
assert.Len(t, cookies, 1)
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
|
||||
cookie := cookies[0]
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
|
||||
// Now logout using the session cookie
|
||||
@@ -279,33 +217,18 @@ func TestUserController(t *testing.T) {
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
cookies = recorder.Result().Cookies()
|
||||
assert.Len(t, cookies, 1)
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
|
||||
cookie = cookies[0]
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Equal(t, "", cookie.Value)
|
||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||
logoutCookie := recorder.Result().Cookies()[0]
|
||||
assert.Equal(t, "tinyauth-session", logoutCookie.Name)
|
||||
assert.Equal(t, "", logoutCookie.Value)
|
||||
assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to login with totp",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
totpCtx,
|
||||
},
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
UUID: "test-totp-login-uuid",
|
||||
Username: "test",
|
||||
Email: "test@example.com",
|
||||
Name: "Test",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
Expiry: time.Now().Add(1 * time.Hour).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -319,13 +242,7 @@ func TestUserController(t *testing.T) {
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "tinyauth-session",
|
||||
Value: "test-totp-login-uuid",
|
||||
HttpOnly: true,
|
||||
MaxAge: 3600,
|
||||
Expires: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
@@ -336,15 +253,12 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
||||
assert.True(t, totpCookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", totpCookie.Domain)
|
||||
assert.GreaterOrEqual(t, totpCookie.MaxAge, 7)
|
||||
assert.LessOrEqual(t, totpCookie.MaxAge, 10)
|
||||
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Totp should rate limit on multiple invalid attempts",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
totpCtx,
|
||||
},
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
for range 3 {
|
||||
totpReq := controller.TotpRequest{
|
||||
@@ -414,22 +328,8 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
{
|
||||
description: "TOTP completion uses name and email from user attributes",
|
||||
middlewares: []gin.HandlerFunc{
|
||||
totpAttrCtx,
|
||||
},
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
|
||||
UUID: "test-totp-login-attributes-uuid",
|
||||
Username: "test",
|
||||
Email: "test@example.com",
|
||||
Name: "Test",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
Expiry: time.Now().Add(1 * time.Hour).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -439,13 +339,6 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "tinyauth-session",
|
||||
Value: "test-totp-login-attributes-uuid",
|
||||
HttpOnly: true,
|
||||
MaxAge: 3600,
|
||||
Expires: time.Now().Add(1 * time.Hour),
|
||||
})
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, 200, recorder.Code)
|
||||
@@ -456,6 +349,15 @@ func TestUserController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
docker := service.NewDockerService()
|
||||
err = docker.Init()
|
||||
require.NoError(t, err)
|
||||
@@ -477,6 +379,33 @@ func TestUserController(t *testing.T) {
|
||||
authService.ClearRateLimitsTestingOnly()
|
||||
}
|
||||
|
||||
setTotpMiddlewareOverrides := map[string]config.UserContext{
|
||||
"Should be able to login with totp": {
|
||||
Username: "totpuser",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
TotpEnabled: true,
|
||||
},
|
||||
"Totp should rate limit on multiple invalid attempts": {
|
||||
Username: "totpuser",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
TotpEnabled: true,
|
||||
},
|
||||
"TOTP completion uses name and email from user attributes": {
|
||||
Username: "attrtotpuser",
|
||||
Name: "Bob Jones",
|
||||
Email: "bob@example.com",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
TotpEnabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
beforeEach()
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
@@ -486,6 +415,15 @@ func TestUserController(t *testing.T) {
|
||||
router.Use(middleware)
|
||||
}
|
||||
|
||||
// Gin is stupid and doesn't allow setting a middleware after the groups
|
||||
// so we need to do some stupid overrides here
|
||||
if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok {
|
||||
ctx := ctx
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("context", &ctx)
|
||||
})
|
||||
}
|
||||
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWellKnownController(t *testing.T) {
|
||||
@@ -23,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
oidcServiceCfg := service.OIDCServiceConfig{
|
||||
Clients: map[string]model.OIDCClientConfig{
|
||||
Clients: map[string]config.OIDCClientConfig{
|
||||
"test": {
|
||||
ClientID: "some-client-id",
|
||||
ClientSecret: "some-client-secret",
|
||||
@@ -101,7 +101,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -70,24 +70,26 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
if err == nil {
|
||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||
|
||||
if err == nil {
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
if err != nil {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
c.Next()
|
||||
return
|
||||
} else {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
}
|
||||
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
basic, err := m.auth.GetBasicAuth(c.Request)
|
||||
|
||||
if ok {
|
||||
userContext, headers, err := m.basicAuth(username, password)
|
||||
if err == nil {
|
||||
userContext, headers, err := m.basicAuth(c.Request.Context(), basic)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||
@@ -123,6 +125,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
||||
|
||||
if userContext.Provider == model.ProviderLocal &&
|
||||
userContext.Local.TOTPPending {
|
||||
userContext.Local.TOTPEnabled = true
|
||||
return userContext, nil, nil
|
||||
}
|
||||
|
||||
@@ -185,39 +188,39 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
||||
return userContext, cookie, nil
|
||||
}
|
||||
|
||||
func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) {
|
||||
func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) {
|
||||
headers := make(map[string]string)
|
||||
userContext := new(model.UserContext)
|
||||
locked, remaining := m.auth.IsAccountLocked(username)
|
||||
locked, remaining := m.auth.IsAccountLocked(basic.Username)
|
||||
|
||||
if locked {
|
||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
|
||||
headers["x-tinyauth-lock-locked"] = "true"
|
||||
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||
return nil, headers, nil
|
||||
}
|
||||
|
||||
search, err := m.auth.SearchUser(username)
|
||||
search, err := m.auth.SearchUser(basic.Username)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error searching for user: %w", err)
|
||||
}
|
||||
|
||||
err = m.auth.CheckUserPassword(*search, password)
|
||||
err = m.auth.CheckUserPassword(*search, basic.Password)
|
||||
|
||||
if err != nil {
|
||||
m.auth.RecordLoginAttempt(username, false)
|
||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
||||
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
|
||||
}
|
||||
|
||||
m.auth.RecordLoginAttempt(username, true)
|
||||
m.auth.RecordLoginAttempt(basic.Username, true)
|
||||
|
||||
switch search.Type {
|
||||
case model.UserLocal:
|
||||
user := m.auth.GetLocalUser(username)
|
||||
user := m.auth.GetLocalUser(basic.Username)
|
||||
|
||||
if user.TOTPSecret != "" {
|
||||
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
|
||||
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username)
|
||||
}
|
||||
|
||||
userContext.Local = &model.LocalContext{
|
||||
@@ -230,7 +233,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
||||
}
|
||||
userContext.Provider = model.ProviderLocal
|
||||
case model.UserLDAP:
|
||||
user, err := m.auth.GetLDAPUser(username)
|
||||
user, err := m.auth.GetLDAPUser(basic.Username)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||
@@ -238,9 +241,9 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
|
||||
|
||||
userContext.LDAP = &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: username,
|
||||
Name: utils.Capitalize(username),
|
||||
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
|
||||
Username: basic.Username,
|
||||
Name: utils.Capitalize(basic.Username),
|
||||
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
|
||||
},
|
||||
Groups: user.Groups,
|
||||
}
|
||||
|
||||
@@ -1,328 +0,0 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
func TestContextMiddleware(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
},
|
||||
},
|
||||
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||
CookieDomain: "example.com",
|
||||
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||
LoginMaxRetries: 3,
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}
|
||||
|
||||
middlewareCfg := middleware.ContextMiddlewareConfig{
|
||||
CookieDomain: "example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}
|
||||
|
||||
basicAuthHeader := func(username, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
||||
t.Helper()
|
||||
_, err := queries.CreateSession(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runArgs struct {
|
||||
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||
queries *repository.Queries
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(t *testing.T, args runArgs)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Skip path bypasses auth processing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/healthz", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "No credentials yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Valid session cookie sets authenticated local context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-valid-local"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
require.NotNil(t, userCtx.Local)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-totp-pending"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "totpuser",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
Expiry: time.Now().Add(60 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "totpuser", userCtx.GetUsername())
|
||||
assert.False(t, userCtx.Authenticated)
|
||||
require.NotNil(t, userCtx.Local)
|
||||
assert.True(t, userCtx.Local.TOTPPending)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Unknown session cookie yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Session for missing local user yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-deleted-user"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "ghostuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Expired session cookie yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-expired"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(-1 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Valid basic auth sets authenticated local context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Invalid basic auth password yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Basic auth is rejected for users with totp",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Locked account on basic auth sets lock headers",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
for range 3 {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||
args.do(req)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, recorder := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
|
||||
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Cookie auth takes precedence over basic auth",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-precedence"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure fallback to basic auth when cookie is missing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
err = ldap.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||
err = broker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||
err = authService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
||||
err = contextMiddleware.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
authService.ClearRateLimitsTestingOnly()
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
|
||||
var captured *model.UserContext
|
||||
router := gin.New()
|
||||
router.Use(contextMiddleware.Middleware())
|
||||
handler := func(c *gin.Context) {
|
||||
if val, exists := c.Get("context"); exists {
|
||||
captured, _ = val.(*model.UserContext)
|
||||
}
|
||||
}
|
||||
router.GET("/api/test", handler)
|
||||
router.GET("/api/healthz", handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
return captured, recorder
|
||||
}
|
||||
|
||||
test.run(t, runArgs{do: do, queries: queries})
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -18,7 +18,6 @@ func NewDefaultConfiguration() *Config {
|
||||
Address: "0.0.0.0",
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
SubdomainsEnabled: true,
|
||||
SessionExpiry: 86400, // 1 day
|
||||
SessionMaxLifetime: 0, // disabled
|
||||
LoginTimeout: 300, // 5 minutes
|
||||
@@ -103,7 +102,6 @@ type ServerConfig struct {
|
||||
type AuthConfig struct {
|
||||
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"`
|
||||
Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
|
||||
SubdomainsEnabled bool `description:"Enable subdomains support." yaml:"subdomainsEnabled"`
|
||||
UserAttributes map[string]UserAttributes `description:"Map of per-user OIDC attributes (username -> attributes)." yaml:"userAttributes"`
|
||||
UsersFile string `description:"Path to the users file." yaml:"usersFile"`
|
||||
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
|
||||
|
||||
+15
-59
@@ -34,6 +34,7 @@ type BaseContext struct {
|
||||
type LocalContext struct {
|
||||
BaseContext
|
||||
TOTPPending bool
|
||||
TOTPEnabled bool
|
||||
Attributes UserAttributes
|
||||
}
|
||||
|
||||
@@ -55,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool {
|
||||
}
|
||||
|
||||
func (c *UserContext) IsLocal() bool {
|
||||
return c.Provider == ProviderLocal && c.Local != nil
|
||||
return c.Provider == ProviderLocal
|
||||
}
|
||||
|
||||
func (c *UserContext) IsOAuth() bool {
|
||||
return c.Provider == ProviderOAuth && c.OAuth != nil
|
||||
return c.Provider == ProviderOAuth
|
||||
}
|
||||
|
||||
func (c *UserContext) IsLDAP() bool {
|
||||
return c.Provider == ProviderLDAP && c.LDAP != nil
|
||||
return c.Provider == ProviderLDAP
|
||||
}
|
||||
|
||||
func (c *UserContext) IsBasicAuth() bool {
|
||||
return c.Provider == ProviderBasicAuth && c.Local != nil
|
||||
return c.Provider == ProviderBasicAuth
|
||||
}
|
||||
|
||||
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
@@ -79,24 +80,16 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
|
||||
userContext, ok := userContextValue.(*UserContext)
|
||||
|
||||
if !ok || userContext == nil {
|
||||
if !ok {
|
||||
return nil, errors.New("invalid user context type")
|
||||
}
|
||||
|
||||
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
|
||||
return nil, errors.New("incomplete user context")
|
||||
}
|
||||
|
||||
*c = *userContext
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Compatability layer until we get an excuse to drop in database migrations
|
||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||
*c = UserContext{
|
||||
Authenticated: !session.TotpPending,
|
||||
}
|
||||
|
||||
switch session.Provider {
|
||||
case "local":
|
||||
c.Provider = ProviderLocal
|
||||
@@ -126,42 +119,29 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
||||
Name: session.Name,
|
||||
Email: session.Email,
|
||||
},
|
||||
Groups: func() []string {
|
||||
if session.OAuthGroups == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(session.OAuthGroups, ",")
|
||||
}(),
|
||||
Groups: strings.Split(session.OAuthGroups, ","),
|
||||
Sub: session.OAuthSub,
|
||||
DisplayName: session.OAuthName,
|
||||
ID: session.Provider,
|
||||
}
|
||||
}
|
||||
|
||||
if !session.TotpPending {
|
||||
c.Authenticated = true
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *UserContext) GetUsername() string {
|
||||
switch c.Provider {
|
||||
case ProviderLocal:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Username
|
||||
case ProviderLDAP:
|
||||
if c.LDAP == nil {
|
||||
return ""
|
||||
}
|
||||
return c.LDAP.Username
|
||||
case ProviderBasicAuth:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Username
|
||||
case ProviderOAuth:
|
||||
if c.OAuth == nil {
|
||||
return ""
|
||||
}
|
||||
return c.OAuth.Username
|
||||
default:
|
||||
return ""
|
||||
@@ -171,24 +151,12 @@ func (c *UserContext) GetUsername() string {
|
||||
func (c *UserContext) GetEmail() string {
|
||||
switch c.Provider {
|
||||
case ProviderLocal:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Email
|
||||
case ProviderLDAP:
|
||||
if c.LDAP == nil {
|
||||
return ""
|
||||
}
|
||||
return c.LDAP.Email
|
||||
case ProviderBasicAuth:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Email
|
||||
case ProviderOAuth:
|
||||
if c.OAuth == nil {
|
||||
return ""
|
||||
}
|
||||
return c.OAuth.Email
|
||||
default:
|
||||
return ""
|
||||
@@ -198,52 +166,40 @@ func (c *UserContext) GetEmail() string {
|
||||
func (c *UserContext) GetName() string {
|
||||
switch c.Provider {
|
||||
case ProviderLocal:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Name
|
||||
case ProviderLDAP:
|
||||
if c.LDAP == nil {
|
||||
return ""
|
||||
}
|
||||
return c.LDAP.Name
|
||||
case ProviderBasicAuth:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Name
|
||||
case ProviderOAuth:
|
||||
if c.OAuth == nil {
|
||||
return ""
|
||||
}
|
||||
return c.OAuth.Name
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UserContext) GetProviderID() string {
|
||||
func (c *UserContext) ProviderName() string {
|
||||
switch c.Provider {
|
||||
case ProviderBasicAuth, ProviderLocal:
|
||||
return "local"
|
||||
case ProviderLDAP:
|
||||
return "ldap"
|
||||
case ProviderOAuth:
|
||||
return c.OAuth.ID
|
||||
return c.OAuth.DisplayName // compatability
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UserContext) TOTPPending() bool {
|
||||
if c.Provider == ProviderLocal && c.Local != nil {
|
||||
if c.Provider == ProviderLocal {
|
||||
return c.Local.TOTPPending
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *UserContext) OAuthName() string {
|
||||
if c.Provider == ProviderOAuth && c.OAuth != nil {
|
||||
if c.Provider == ProviderOAuth {
|
||||
return c.OAuth.DisplayName
|
||||
}
|
||||
return ""
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
newGinCtx := func(value any, set bool) *gin.Context {
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
if set {
|
||||
c.Set("context", value)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
context *model.UserContext
|
||||
run func(*testing.T, *model.UserContext) any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
description: "IsAuthenticated reflects Authenticated field",
|
||||
context: &model.UserContext{Authenticated: true},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLocal returns true for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsOAuth returns true for ProviderOAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLDAP returns true for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||
Provider: "local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Provider, got.Authenticated}
|
||||
},
|
||||
expected: [2]any{model.ProviderLocal, true},
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "bob", Provider: "local", TotpPending: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Authenticated
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession ldap session is ProviderLDAP",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "carol", Provider: "ldap",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Provider
|
||||
},
|
||||
expected: model.ProviderLDAP,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "dave", Provider: "github",
|
||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||
},
|
||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||
},
|
||||
{
|
||||
description: "Local getters return BaseContext fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||
},
|
||||
{
|
||||
description: "BasicAuth getters fall back to local fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||
},
|
||||
{
|
||||
description: "LDAP getters return LDAP fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||
},
|
||||
{
|
||||
description: "OAuth getters return OAuth fields",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "ldap",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{ID: "github"},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
||||
expected: "github",
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns true when local context is pending",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: true},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false when local context is not pending",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: false},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for non-local providers",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||
context: &model.UserContext{
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||
},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "Google",
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns empty string for non-oauth providers",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin populates context from gin value",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
stored := &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||
}
|
||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Authenticated, got.GetUsername()}
|
||||
},
|
||||
expected: [2]any{true, "alice"},
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns error when context value is missing",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "failed to get user context",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns error when context value has wrong type",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "invalid user context type",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns an error when context doesn't include user information",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "incomplete user context",
|
||||
},
|
||||
{
|
||||
description: "Getters should not panic if provider context is empty",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"", "", ""},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
assert.Equal(t, test.expected, test.run(t, test.context))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
@@ -27,29 +28,26 @@ func (acls *AccessControlsService) Init() error {
|
||||
return nil // No initialization needed
|
||||
}
|
||||
|
||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||
var appAcls *model.App
|
||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) {
|
||||
for app, config := range acls.static {
|
||||
if config.Config.Domain == domain {
|
||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||
appAcls = &config
|
||||
break // If we find a match by domain, we can stop searching
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||
appAcls = &config
|
||||
break // If we find a match by app name, we can stop searching
|
||||
return &config, nil
|
||||
}
|
||||
}
|
||||
return appAcls
|
||||
return nil, errors.New("no results")
|
||||
}
|
||||
|
||||
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||
// First check in the static config
|
||||
app := acls.lookupStaticACLs(domain)
|
||||
app, err := acls.lookupStaticACLs(domain)
|
||||
|
||||
if app != nil {
|
||||
if err == nil {
|
||||
tlog.App.Debug().Msg("Using ACls from static configuration")
|
||||
return app, nil
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ type Lockdown struct {
|
||||
}
|
||||
|
||||
type AuthServiceConfig struct {
|
||||
LocalUsers *[]model.LocalUser
|
||||
LocalUsers []model.LocalUser
|
||||
OauthWhitelist []string
|
||||
SessionExpiry int
|
||||
SessionMaxLifetime int
|
||||
@@ -84,7 +84,6 @@ type AuthServiceConfig struct {
|
||||
SessionCookieName string
|
||||
IP model.IPConfig
|
||||
LDAPGroupsCacheTTL int
|
||||
SubdomainsEnabled bool
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
@@ -121,7 +120,7 @@ func (auth *AuthService) Init() error {
|
||||
}
|
||||
|
||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||
if auth.GetLocalUser(username) != nil {
|
||||
if auth.GetLocalUser(username).Username != "" {
|
||||
return &model.UserSearch{
|
||||
Username: username,
|
||||
Type: model.UserLocal,
|
||||
@@ -148,9 +147,6 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
||||
switch search.Type {
|
||||
case model.UserLocal:
|
||||
user := auth.GetLocalUser(search.Username)
|
||||
if user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||
case model.UserLDAP:
|
||||
if auth.ldap.IsConfigured() {
|
||||
@@ -173,10 +169,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||
if auth.config.LocalUsers == nil {
|
||||
return nil
|
||||
}
|
||||
for _, user := range *auth.config.LocalUsers {
|
||||
for _, user := range auth.config.LocalUsers {
|
||||
if user.Username == username {
|
||||
return &user
|
||||
}
|
||||
@@ -302,8 +295,6 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
expiry = auth.config.SessionExpiry
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||
|
||||
session := repository.CreateSessionParams{
|
||||
UUID: uuid.String(),
|
||||
Username: data.Username,
|
||||
@@ -312,7 +303,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
Provider: data.Provider,
|
||||
TotpPending: data.TotpPending,
|
||||
OAuthGroups: data.OAuthGroups,
|
||||
Expiry: expiresAt.Unix(),
|
||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
OAuthName: data.OAuthName,
|
||||
OAuthSub: data.OAuthSub,
|
||||
@@ -329,8 +320,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
|
||||
MaxAge: expiry,
|
||||
Secure: auth.config.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
@@ -383,7 +374,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
MaxAge: auth.config.SessionExpiry,
|
||||
Secure: auth.config.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
@@ -398,12 +389,6 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
||||
}
|
||||
|
||||
err = auth.queries.DeleteSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.config.SessionCookieName,
|
||||
Value: "",
|
||||
@@ -451,7 +436,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
}
|
||||
|
||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||
return len(auth.config.LocalUsers) > 0
|
||||
}
|
||||
|
||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
@@ -479,8 +464,8 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
|
||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
|
||||
if requiredGroups == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -495,8 +480,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
||||
}
|
||||
|
||||
for _, userGroup := range context.OAuth.Groups {
|
||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -505,8 +490,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
|
||||
if requiredGroups == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -516,8 +501,8 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
||||
}
|
||||
|
||||
for _, userGroup := range context.LDAP.Groups {
|
||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -526,14 +511,14 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
||||
if acls == nil {
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, error) {
|
||||
if path == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check for block list
|
||||
if acls.Path.Block != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Block)
|
||||
if path.Block != "" {
|
||||
regex, err := regexp.Compile(path.Block)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
@@ -545,8 +530,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error
|
||||
}
|
||||
|
||||
// Check for allow list
|
||||
if acls.Path.Allow != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Allow)
|
||||
if path.Allow != "" {
|
||||
regex, err := regexp.Compile(path.Allow)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
@@ -560,14 +545,29 @@ func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
// local user is used only as a medium to pass the basic auth credentials, user can be ldap too
|
||||
func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
username, password, ok := req.BasicAuth()
|
||||
if !ok {
|
||||
return nil, errors.New("no basic auth credentials provided")
|
||||
}
|
||||
return &model.LocalUser{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
acls = &model.AppIP{}
|
||||
}
|
||||
|
||||
// Merge the global and app IP filter
|
||||
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
||||
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
||||
blockedIps := append(auth.config.IP.Block, acls.Block...)
|
||||
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
|
||||
|
||||
for _, blocked := range blockedIps {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
@@ -602,12 +602,12 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||
func (auth *AuthService) IsBypassedIP(acls *model.AppIP, ip string) bool {
|
||||
if acls == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, bypassed := range acls.IP.Bypass {
|
||||
for _, bypassed := range acls.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
@@ -845,10 +845,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
||||
}
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if auth.config.SubdomainsEnabled {
|
||||
return "." + auth.config.CookieDomain
|
||||
}
|
||||
return auth.config.CookieDomain
|
||||
}
|
||||
|
||||
@@ -51,11 +51,19 @@ func (docker *DockerService) Init() error {
|
||||
}
|
||||
|
||||
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
||||
return docker.client.ContainerList(docker.context, container.ListOptions{})
|
||||
containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return containers, nil
|
||||
}
|
||||
|
||||
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
|
||||
return docker.client.ContainerInspect(docker.context, containerId)
|
||||
inspect, err := docker.client.ContainerInspect(docker.context, containerId)
|
||||
if err != nil {
|
||||
return container.InspectResponse{}, err
|
||||
}
|
||||
return inspect, nil
|
||||
}
|
||||
|
||||
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||
|
||||
@@ -89,38 +89,36 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (k *KubernetesService) getByDomain(domain string) *model.App {
|
||||
func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
if appKey, ok := k.domainIndex[domain]; ok {
|
||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||
for i := range apps {
|
||||
app := &apps[i]
|
||||
for _, app := range apps {
|
||||
if app.domain == domain && app.appName == appKey.appName {
|
||||
return &app.app
|
||||
return &app.app, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (k *KubernetesService) getByAppName(appName string) *model.App {
|
||||
func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
if appKey, ok := k.appNameIndex[appName]; ok {
|
||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||
for i := range apps {
|
||||
app := &apps[i]
|
||||
for _, app := range apps {
|
||||
if app.appName == appName {
|
||||
return &app.app
|
||||
return &app.app, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
||||
@@ -289,14 +287,12 @@ func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
||||
}
|
||||
|
||||
// First check cache
|
||||
app := k.getByDomain(appDomain)
|
||||
if app != nil {
|
||||
if app, found := k.getByDomain(appDomain); found {
|
||||
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
||||
return app, nil
|
||||
}
|
||||
appName := strings.SplitN(appDomain, ".", 2)[0]
|
||||
app = k.getByAppName(appName)
|
||||
if app != nil {
|
||||
if app, found := k.getByAppName(appName); found {
|
||||
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
||||
return app, nil
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
)
|
||||
|
||||
func TestKubernetesService(t *testing.T) {
|
||||
@@ -20,69 +20,69 @@ func TestKubernetesService(t *testing.T) {
|
||||
{
|
||||
description: "Cache by domain returns app and misses unknown domain",
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}}
|
||||
app := config.App{Config: config.AppConfig{Domain: "foo.example.com"}}
|
||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||
{domain: "foo.example.com", appName: "foo", app: app},
|
||||
})
|
||||
|
||||
got := svc.getByDomain("foo.example.com")
|
||||
require.NotNil(t, got)
|
||||
got, ok := svc.getByDomain("foo.example.com")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "foo.example.com", got.Config.Domain)
|
||||
|
||||
got = svc.getByDomain("notfound.example.com")
|
||||
assert.Nil(t, got)
|
||||
_, ok = svc.getByDomain("notfound.example.com")
|
||||
assert.False(t, ok)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Cache by app name returns app and misses unknown name",
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}}
|
||||
app := config.App{Config: config.AppConfig{Domain: "bar.example.com"}}
|
||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||
{domain: "bar.example.com", appName: "bar", app: app},
|
||||
})
|
||||
|
||||
got := svc.getByAppName("bar")
|
||||
require.NotNil(t, got)
|
||||
got, ok := svc.getByAppName("bar")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "bar.example.com", got.Config.Domain)
|
||||
|
||||
got = svc.getByAppName("notfound")
|
||||
assert.Nil(t, got)
|
||||
_, ok = svc.getByAppName("notfound")
|
||||
assert.False(t, ok)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "RemoveIngress clears domain and app name entries",
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}}
|
||||
app := config.App{Config: config.AppConfig{Domain: "baz.example.com"}}
|
||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||
{domain: "baz.example.com", appName: "baz", app: app},
|
||||
})
|
||||
|
||||
svc.removeIngress("default", "my-ingress")
|
||||
|
||||
got := svc.getByDomain("baz.example.com")
|
||||
assert.Nil(t, got)
|
||||
got = svc.getByAppName("baz")
|
||||
assert.Nil(t, got)
|
||||
_, ok := svc.getByDomain("baz.example.com")
|
||||
assert.False(t, ok)
|
||||
_, ok = svc.getByAppName("baz")
|
||||
assert.False(t, ok)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "AddIngressApps replaces stale entries for the same ingress",
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
old := model.App{Config: model.AppConfig{Domain: "old.example.com"}}
|
||||
old := config.App{Config: config.AppConfig{Domain: "old.example.com"}}
|
||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||
{domain: "old.example.com", appName: "old", app: old},
|
||||
})
|
||||
|
||||
updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}}
|
||||
updated := config.App{Config: config.AppConfig{Domain: "new.example.com"}}
|
||||
svc.addIngressApps("default", "my-ingress", []ingressApp{
|
||||
{domain: "new.example.com", appName: "new", app: updated},
|
||||
})
|
||||
|
||||
got := svc.getByDomain("old.example.com")
|
||||
assert.Nil(t, got)
|
||||
_, ok := svc.getByDomain("old.example.com")
|
||||
assert.False(t, ok)
|
||||
|
||||
got = svc.getByDomain("new.example.com")
|
||||
require.NotNil(t, got)
|
||||
got, ok := svc.getByDomain("new.example.com")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "new.example.com", got.Config.Domain)
|
||||
},
|
||||
},
|
||||
@@ -91,7 +91,7 @@ func TestKubernetesService(t *testing.T) {
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
svc.started = true
|
||||
|
||||
app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}}
|
||||
app := config.App{Config: config.AppConfig{Domain: "hit.example.com"}}
|
||||
svc.addIngressApps("default", "ing", []ingressApp{
|
||||
{domain: "hit.example.com", appName: "hit", app: app},
|
||||
})
|
||||
@@ -108,7 +108,7 @@ func TestKubernetesService(t *testing.T) {
|
||||
|
||||
got, err := svc.GetLabels("notfound.example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, got)
|
||||
assert.Equal(t, config.App{}, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -116,7 +116,7 @@ func TestKubernetesService(t *testing.T) {
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
svc.started = true
|
||||
|
||||
app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}}
|
||||
app := config.App{Config: config.AppConfig{Domain: "myapp.internal.example.com"}}
|
||||
svc.addIngressApps("default", "ing", []ingressApp{
|
||||
{domain: "myapp.internal.example.com", appName: "myapp", app: app},
|
||||
})
|
||||
@@ -131,7 +131,7 @@ func TestKubernetesService(t *testing.T) {
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
got, err := svc.GetLabels("anything.example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, got)
|
||||
assert.Equal(t, config.App{}, got)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -147,8 +147,8 @@ func TestKubernetesService(t *testing.T) {
|
||||
|
||||
svc.updateFromItem(&item)
|
||||
|
||||
got := svc.getByDomain("myapp.example.com")
|
||||
require.NotNil(t, got)
|
||||
got, ok := svc.getByDomain("myapp.example.com")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "myapp.example.com", got.Config.Domain)
|
||||
assert.Equal(t, "alice", got.Users.Allow)
|
||||
},
|
||||
@@ -156,7 +156,7 @@ func TestKubernetesService(t *testing.T) {
|
||||
{
|
||||
description: "UpdateFromItem with no annotations removes existing cache entries",
|
||||
run: func(t *testing.T, svc *KubernetesService) {
|
||||
app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}}
|
||||
app := config.App{Config: config.AppConfig{Domain: "todelete.example.com"}}
|
||||
svc.addIngressApps("default", "test-ingress", []ingressApp{
|
||||
{domain: "todelete.example.com", appName: "todelete", app: app},
|
||||
})
|
||||
@@ -167,8 +167,8 @@ func TestKubernetesService(t *testing.T) {
|
||||
|
||||
svc.updateFromItem(&item)
|
||||
|
||||
got := svc.getByDomain("todelete.example.com")
|
||||
assert.Nil(t, got)
|
||||
_, ok := svc.getByDomain("todelete.example.com")
|
||||
assert.False(t, ok)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -12,9 +12,8 @@ import (
|
||||
)
|
||||
|
||||
type GithubEmailResponse []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
}
|
||||
|
||||
type GithubUserInfoResponse struct {
|
||||
@@ -27,7 +26,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||
return simpleReq[model.Claims](client, url, nil)
|
||||
}
|
||||
|
||||
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
||||
func githubExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||
var user model.Claims
|
||||
|
||||
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||
@@ -49,7 +48,7 @@ func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
||||
}
|
||||
|
||||
for _, email := range *userEmails {
|
||||
if email.Primary && email.Verified {
|
||||
if email.Primary {
|
||||
user.Email = email.Email
|
||||
break
|
||||
}
|
||||
@@ -57,16 +56,7 @@ func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
|
||||
|
||||
// Use first available email if no primary email was found
|
||||
if user.Email == "" {
|
||||
for _, email := range *userEmails {
|
||||
if email.Verified {
|
||||
user.Email = email.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if user.Email == "" {
|
||||
return nil, errors.New("no verified email found")
|
||||
user.Email = (*userEmails)[0].Email
|
||||
}
|
||||
|
||||
user.PreferredUsername = userInfo.Login
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
)
|
||||
|
||||
func newTestUser() repository.OidcUserinfo {
|
||||
addr := model.AddressClaim{
|
||||
addr := config.AddressClaim{
|
||||
Formatted: "123 Main St",
|
||||
StreetAddress: "123 Main St",
|
||||
Locality: "Springfield",
|
||||
|
||||
@@ -47,15 +47,6 @@ func GetCookieDomain(u string) (string, error) {
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
func GetStandaloneCookieDomain(u string) (string, error) {
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return parsed.Hostname(), nil
|
||||
}
|
||||
|
||||
func ParseFileToLine(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
users := make([]string, 0)
|
||||
|
||||
@@ -3,8 +3,9 @@ package utils_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetRootDomain(t *testing.T) {
|
||||
@@ -12,14 +13,14 @@ func TestGetRootDomain(t *testing.T) {
|
||||
domain := "http://sub.tinyauth.app"
|
||||
expected := "tinyauth.app"
|
||||
result, err := utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain with multiple subdomains
|
||||
domain = "http://b.c.tinyauth.app"
|
||||
expected = "c.tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Invalid domain (only TLD)
|
||||
@@ -41,14 +42,14 @@ func TestGetRootDomain(t *testing.T) {
|
||||
domain = "https://sub.tinyauth.app/path"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with port
|
||||
domain = "http://sub.tinyauth.app:8080"
|
||||
expected = "tinyauth.app"
|
||||
result, err = utils.GetCookieDomain(domain)
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain managed by ICANN
|
||||
@@ -95,35 +96,35 @@ func TestFilter(t *testing.T) {
|
||||
testFunc := func(n int) bool { return n%2 == 0 }
|
||||
expected := []int{2, 4}
|
||||
result := utils.Filter(slice, testFunc)
|
||||
assert.Equal(t, expected, result)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with no matches
|
||||
slice = []int{1, 3, 5}
|
||||
testFunc = func(n int) bool { return n%2 == 0 }
|
||||
expected = []int{}
|
||||
result = utils.Filter(slice, testFunc)
|
||||
assert.Equal(t, expected, result)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with all matches
|
||||
slice = []int{2, 4, 6}
|
||||
testFunc = func(n int) bool { return n%2 == 0 }
|
||||
expected = []int{2, 4, 6}
|
||||
result = utils.Filter(slice, testFunc)
|
||||
assert.Equal(t, expected, result)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with empty slice
|
||||
slice = []int{}
|
||||
testFunc = func(n int) bool { return n%2 == 0 }
|
||||
expected = []int{}
|
||||
result = utils.Filter(slice, testFunc)
|
||||
assert.Equal(t, expected, result)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with different type (string)
|
||||
sliceStr := []string{"apple", "banana", "cherry"}
|
||||
testFuncStr := func(s string) bool { return len(s) > 5 }
|
||||
expectedStr := []string{"banana", "cherry"}
|
||||
resultStr := utils.Filter(sliceStr, testFuncStr)
|
||||
assert.Equal(t, expectedStr, resultStr)
|
||||
assert.DeepEqual(t, expectedStr, resultStr)
|
||||
}
|
||||
|
||||
func TestIsRedirectSafe(t *testing.T) {
|
||||
@@ -133,50 +134,50 @@ func TestIsRedirectSafe(t *testing.T) {
|
||||
// Case with no subdomain
|
||||
redirectURL := "http://example.com/welcome"
|
||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with different domain
|
||||
redirectURL = "http://malicious.com/phishing"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with subdomain
|
||||
redirectURL = "http://sub.example.com/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with sub-subdomain
|
||||
redirectURL = "http://a.b.example.com/home"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with empty redirect URL
|
||||
redirectURL = ""
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with invalid URL
|
||||
redirectURL = "http://[::1]:namedport"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with URL having port
|
||||
redirectURL = "http://sub.example.com:8080/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with URL having different subdomain
|
||||
redirectURL = "http://another.example.com/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.True(t, result)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with URL having different TLD
|
||||
redirectURL = "http://example.org/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with malicious domain
|
||||
redirectURL = "https://malicious-example.com/yoyo"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.False(t, result)
|
||||
assert.Equal(t, false, result)
|
||||
}
|
||||
|
||||
@@ -3,41 +3,42 @@ package decoders_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestDecodeLabels(t *testing.T) {
|
||||
// Variables
|
||||
expected := model.Apps{
|
||||
Apps: map[string]model.App{
|
||||
expected := config.Apps{
|
||||
Apps: map[string]config.App{
|
||||
"foo": {
|
||||
Config: model.AppConfig{
|
||||
Config: config.AppConfig{
|
||||
Domain: "example.com",
|
||||
},
|
||||
Users: model.AppUsers{
|
||||
Users: config.AppUsers{
|
||||
Allow: "user1,user2",
|
||||
Block: "user3",
|
||||
},
|
||||
OAuth: model.AppOAuth{
|
||||
OAuth: config.AppOAuth{
|
||||
Whitelist: "somebody@example.com",
|
||||
Groups: "group3",
|
||||
},
|
||||
IP: model.AppIP{
|
||||
IP: config.AppIP{
|
||||
Allow: []string{"10.71.0.1/24", "10.71.0.2"},
|
||||
Block: []string{"10.10.10.10", "10.0.0.0/24"},
|
||||
Bypass: []string{"192.168.1.1"},
|
||||
},
|
||||
Response: model.AppResponse{
|
||||
Response: config.AppResponse{
|
||||
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
|
||||
BasicAuth: model.AppBasicAuth{
|
||||
BasicAuth: config.AppBasicAuth{
|
||||
Username: "admin",
|
||||
Password: "password",
|
||||
PasswordFile: "/path/to/passwordfile",
|
||||
},
|
||||
},
|
||||
Path: model.AppPath{
|
||||
Path: config.AppPath{
|
||||
Allow: "/public",
|
||||
Block: "/private",
|
||||
},
|
||||
@@ -62,7 +63,7 @@ func TestDecodeLabels(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test
|
||||
result, err := decoders.DecodeLabels[model.Apps](test, "apps")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
result, err := decoders.DecodeLabels[config.Apps](test, "apps")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
}
|
||||
|
||||
@@ -4,25 +4,24 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestReadFile(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString("file content\n")
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_file")
|
||||
|
||||
// Normal case
|
||||
content, err := ReadFile("/tmp/tinyauth_test_file")
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, "file content\n", content)
|
||||
|
||||
// Non-existing file
|
||||
|
||||
@@ -3,8 +3,9 @@ package utils_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestParseHeaders(t *testing.T) {
|
||||
@@ -17,7 +18,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
"X-Custom-Header": "Value",
|
||||
"Another-Header": "AnotherValue",
|
||||
}
|
||||
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Case insensitivity and trimming
|
||||
headers = []string{
|
||||
@@ -28,7 +29,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
"X-Custom-Header": "Value",
|
||||
"Another-Header": "AnotherValue",
|
||||
}
|
||||
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Invalid headers (missing '=', empty key/value)
|
||||
headers = []string{
|
||||
@@ -38,7 +39,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
" = ",
|
||||
}
|
||||
expected = map[string]string{}
|
||||
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Headers with unsafe characters
|
||||
headers = []string{
|
||||
@@ -51,7 +52,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
"Another-Header": "AnotherValue",
|
||||
"Good-Header": "GoodValue",
|
||||
}
|
||||
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Header with spaces in key (should be ignored)
|
||||
headers = []string{
|
||||
@@ -61,7 +62,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
expected = map[string]string{
|
||||
"Valid-Header": "ValidValue",
|
||||
}
|
||||
assert.Equal(t, expected, utils.ParseHeaders(headers))
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
}
|
||||
|
||||
func TestSanitizeHeader(t *testing.T) {
|
||||
|
||||
@@ -41,7 +41,7 @@ func ParseSecretFile(contents string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func EncodeBasicAuth(username string, password string) string {
|
||||
func GetBasicAuth(username string, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
@@ -4,21 +4,21 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetSecret(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_secret")
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString(" secret \n")
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||
|
||||
// Get from config
|
||||
@@ -55,50 +55,50 @@ func TestParseSecretFile(t *testing.T) {
|
||||
assert.Equal(t, "", utils.ParseSecretFile(content))
|
||||
}
|
||||
|
||||
func TestEncodeBasicAuth(t *testing.T) {
|
||||
func TestGetBasicAuth(t *testing.T) {
|
||||
// Normal case
|
||||
username := "user"
|
||||
password := "pass"
|
||||
expected := "dXNlcjpwYXNz" // base64 of "user:pass"
|
||||
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
||||
|
||||
// Empty username
|
||||
username = ""
|
||||
password = "pass"
|
||||
expected = "OnBhc3M=" // base64 of ":pass"
|
||||
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
||||
|
||||
// Empty password
|
||||
username = "user"
|
||||
password = ""
|
||||
expected = "dXNlcjo=" // base64 of "user:"
|
||||
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
||||
}
|
||||
|
||||
func TestFilterIP(t *testing.T) {
|
||||
// Exact match IPv4
|
||||
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// Non-match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// CIDR match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR match IPv4 with '-' instead of '/'
|
||||
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR non-match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid CIDR
|
||||
@@ -145,5 +145,5 @@ func TestGenerateUUID(t *testing.T) {
|
||||
|
||||
// Different output for different input
|
||||
id3 := utils.GenerateUUID("differentstring")
|
||||
assert.NotEqual(t, id2, id3)
|
||||
assert.Assert(t, id1 != id3)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,9 @@ package utils_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestCapitalize(t *testing.T) {
|
||||
|
||||
@@ -5,75 +5,75 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestNewLogger(t *testing.T) {
|
||||
cfg := model.LogConfig{
|
||||
cfg := config.LogConfig{
|
||||
Level: "debug",
|
||||
Json: true,
|
||||
Streams: model.LogStreams{
|
||||
HTTP: model.LogStreamConfig{Enabled: true, Level: "info"},
|
||||
App: model.LogStreamConfig{Enabled: true, Level: ""},
|
||||
Audit: model.LogStreamConfig{Enabled: false, Level: ""},
|
||||
Streams: config.LogStreams{
|
||||
HTTP: config.LogStreamConfig{Enabled: true, Level: "info"},
|
||||
App: config.LogStreamConfig{Enabled: true, Level: ""},
|
||||
Audit: config.LogStreamConfig{Enabled: false, Level: ""},
|
||||
},
|
||||
}
|
||||
|
||||
logger := tlog.NewLogger(cfg)
|
||||
|
||||
assert.NotNil(t, logger)
|
||||
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||
assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
|
||||
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||
assert.Assert(t, logger != nil)
|
||||
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
|
||||
assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel)
|
||||
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
|
||||
}
|
||||
|
||||
func TestNewSimpleLogger(t *testing.T) {
|
||||
logger := tlog.NewSimpleLogger()
|
||||
assert.NotNil(t, logger)
|
||||
assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
|
||||
assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel())
|
||||
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||
assert.Assert(t, logger != nil)
|
||||
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel)
|
||||
assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel)
|
||||
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
|
||||
}
|
||||
|
||||
func TestLoggerInit(t *testing.T) {
|
||||
logger := tlog.NewSimpleLogger()
|
||||
logger.Init()
|
||||
|
||||
assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel())
|
||||
assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled)
|
||||
}
|
||||
|
||||
func TestLoggerWithDisabledStreams(t *testing.T) {
|
||||
cfg := model.LogConfig{
|
||||
cfg := config.LogConfig{
|
||||
Level: "info",
|
||||
Json: false,
|
||||
Streams: model.LogStreams{
|
||||
HTTP: model.LogStreamConfig{Enabled: false},
|
||||
App: model.LogStreamConfig{Enabled: false},
|
||||
Audit: model.LogStreamConfig{Enabled: false},
|
||||
Streams: config.LogStreams{
|
||||
HTTP: config.LogStreamConfig{Enabled: false},
|
||||
App: config.LogStreamConfig{Enabled: false},
|
||||
Audit: config.LogStreamConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
|
||||
logger := tlog.NewLogger(cfg)
|
||||
|
||||
assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
|
||||
assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
|
||||
assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
|
||||
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled)
|
||||
assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled)
|
||||
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled)
|
||||
}
|
||||
|
||||
func TestLogStreamField(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
cfg := model.LogConfig{
|
||||
cfg := config.LogConfig{
|
||||
Level: "info",
|
||||
Json: true,
|
||||
Streams: model.LogStreams{
|
||||
HTTP: model.LogStreamConfig{Enabled: true},
|
||||
App: model.LogStreamConfig{Enabled: true},
|
||||
Audit: model.LogStreamConfig{Enabled: true},
|
||||
Streams: config.LogStreams{
|
||||
HTTP: config.LogStreamConfig{Enabled: true},
|
||||
App: config.LogStreamConfig{Enabled: true},
|
||||
Audit: config.LogStreamConfig{Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestLogStreamField(t *testing.T) {
|
||||
|
||||
var logEntry map[string]interface{}
|
||||
err := json.Unmarshal(buf.Bytes(), &logEntry)
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "http", logEntry["log_stream"])
|
||||
assert.Equal(t, "test message", logEntry["message"])
|
||||
|
||||
@@ -37,7 +37,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]mod
|
||||
var usersStr []string
|
||||
|
||||
if len(usersCfg) == 0 && usersPath == "" {
|
||||
return nil, nil
|
||||
return &[]model.LocalUser{}, nil
|
||||
}
|
||||
|
||||
if len(usersCfg) > 0 {
|
||||
|
||||
@@ -4,76 +4,74 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
||||
|
||||
// Setup
|
||||
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
|
||||
require.NoError(t, err)
|
||||
file, err := os.Create("/tmp/tinyauth_users_test.txt")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
||||
require.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_users_test.txt")
|
||||
|
||||
noAttrs := map[string]model.UserAttributes{}
|
||||
noAttrs := map[string]config.UserAttributes{}
|
||||
|
||||
// Test file only
|
||||
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, users)
|
||||
assert.Len(t, *users, 2)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user1", (*users)[0].Username)
|
||||
assert.Equal(t, hash, (*users)[0].Password)
|
||||
assert.Equal(t, "user2", (*users)[1].Username)
|
||||
assert.Equal(t, hash, (*users)[1].Password)
|
||||
assert.Equal(t, 2, len(users))
|
||||
|
||||
assert.Equal(t, "user1", users[0].Username)
|
||||
assert.Equal(t, hash, users[0].Password)
|
||||
assert.Equal(t, "user2", users[1].Username)
|
||||
assert.Equal(t, hash, users[1].Password)
|
||||
|
||||
// Test inline config only
|
||||
users, err = utils.GetUsers([]string{"user3:" + hash, "user4:" + hash}, "", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Len(t, *users, 2)
|
||||
assert.Equal(t, "user3", (*users)[0].Username)
|
||||
assert.Equal(t, "user4", (*users)[1].Username)
|
||||
assert.Equal(t, 2, len(users))
|
||||
assert.Equal(t, "user3", users[0].Username)
|
||||
assert.Equal(t, "user4", users[1].Username)
|
||||
|
||||
// Test both
|
||||
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Len(t, *users, 3)
|
||||
assert.Equal(t, 3, len(users))
|
||||
|
||||
usernames := map[string]bool{}
|
||||
for _, u := range *users {
|
||||
for _, u := range users {
|
||||
usernames[u.Username] = true
|
||||
}
|
||||
assert.True(t, usernames["user1"])
|
||||
assert.True(t, usernames["user2"])
|
||||
assert.True(t, usernames["user5"])
|
||||
assert.Assert(t, usernames["user1"])
|
||||
assert.Assert(t, usernames["user2"])
|
||||
assert.Assert(t, usernames["user5"])
|
||||
|
||||
// Test attributes applied from userAttributes map
|
||||
attrs := map[string]model.UserAttributes{
|
||||
attrs := map[string]config.UserAttributes{
|
||||
"user1": {Name: "User One", Email: "user1@example.com"},
|
||||
}
|
||||
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
|
||||
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, *users, 2)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, 2, len(users))
|
||||
|
||||
for _, u := range *users {
|
||||
for _, u := range users {
|
||||
if u.Username == "user1" {
|
||||
assert.Equal(t, "User One", u.Attributes.Name)
|
||||
assert.Equal(t, "user1@example.com", u.Attributes.Email)
|
||||
@@ -86,14 +84,16 @@ func TestGetUsers(t *testing.T) {
|
||||
// Test empty
|
||||
users, err = utils.GetUsers([]string{}, "", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, users)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 0, len(users))
|
||||
|
||||
// Test non-existent file
|
||||
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
|
||||
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
|
||||
|
||||
assert.ErrorContains(t, err, "no such file or directory")
|
||||
assert.Nil(t, users)
|
||||
|
||||
assert.Equal(t, 0, len(users))
|
||||
}
|
||||
|
||||
func TestParseUser(t *testing.T) {
|
||||
@@ -102,38 +102,38 @@ func TestParseUser(t *testing.T) {
|
||||
// Valid user without TOTP
|
||||
user, err := utils.ParseUser("user1:" + hash)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user1", user.Username)
|
||||
assert.Equal(t, hash, user.Password)
|
||||
assert.Equal(t, "", user.TOTPSecret)
|
||||
assert.Equal(t, "", user.TotpSecret)
|
||||
|
||||
// Valid user with TOTP
|
||||
user, err = utils.ParseUser("user2:" + hash + ":ABCDEF")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user2", user.Username)
|
||||
assert.Equal(t, hash, user.Password)
|
||||
assert.Equal(t, "ABCDEF", user.TOTPSecret)
|
||||
assert.Equal(t, "ABCDEF", user.TotpSecret)
|
||||
|
||||
// Valid user with $$ in password
|
||||
user, err = utils.ParseUser("user3:pa$$word123")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user3", user.Username)
|
||||
assert.Equal(t, "pa$word123", user.Password)
|
||||
assert.Equal(t, "", user.TOTPSecret)
|
||||
assert.Equal(t, "", user.TotpSecret)
|
||||
|
||||
// User with spaces
|
||||
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user4", user.Username)
|
||||
assert.Equal(t, "password123", user.Password)
|
||||
assert.Equal(t, "TOTPSECRET", user.TOTPSecret)
|
||||
assert.Equal(t, "TOTPSECRET", user.TotpSecret)
|
||||
|
||||
// Invalid users
|
||||
_, err = utils.ParseUser("user1") // Missing password
|
||||
|
||||
Reference in New Issue
Block a user