mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-17 17:00:14 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c380123602 | |||
| 2cec88799e | |||
| 274069c790 | |||
| ce1ed6207d | |||
| 4387ebcf5a | |||
| e48f9d2517 | |||
| e40f6b50a0 | |||
| addc60d59c | |||
| 53af1b99c0 | |||
| 654b5cc436 | |||
| f7d7f1c4f0 | |||
| e7d26f497d | |||
| a9face749d |
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
@@ -33,22 +32,22 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/app",
|
path: "/api/context/app",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedAppContextResponse := controller.AppContextResponse{
|
expectedAppContextResponse := AppContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.ACRAuth{
|
Auth: ACRAuth{
|
||||||
Providers: runtime.ConfiguredProviders,
|
Providers: runtime.ConfiguredProviders,
|
||||||
},
|
},
|
||||||
OAuth: controller.ACROAuth{
|
OAuth: ACROAuth{
|
||||||
AutoRedirect: cfg.OAuth.AutoRedirect,
|
AutoRedirect: cfg.OAuth.AutoRedirect,
|
||||||
},
|
},
|
||||||
UI: controller.ACRUI{
|
UI: ACRUI{
|
||||||
Title: cfg.UI.Title,
|
Title: cfg.UI.Title,
|
||||||
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
|
||||||
BackgroundImage: cfg.UI.BackgroundImage,
|
BackgroundImage: cfg.UI.BackgroundImage,
|
||||||
WarningsEnabled: cfg.UI.WarningsEnabled,
|
WarningsEnabled: cfg.UI.WarningsEnabled,
|
||||||
},
|
},
|
||||||
App: controller.ACRApp{
|
App: ACRApp{
|
||||||
AppURL: runtime.AppURL,
|
AppURL: runtime.AppURL,
|
||||||
CookieDomain: runtime.CookieDomain,
|
CookieDomain: runtime.CookieDomain,
|
||||||
TrustedDomains: runtime.TrustedDomains,
|
TrustedDomains: runtime.TrustedDomains,
|
||||||
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
|
|||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 401,
|
Status: 401,
|
||||||
Message: "Unauthorized",
|
Message: "Unauthorized",
|
||||||
}
|
}
|
||||||
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
path: "/api/context/user",
|
path: "/api/context/user",
|
||||||
expected: func() string {
|
expected: func() string {
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
expectedUserContextResponse := UserContextResponse{
|
||||||
Status: 200,
|
Status: 200,
|
||||||
Message: "Success",
|
Message: "Success",
|
||||||
Auth: controller.UCRAuth{
|
Auth: UCRAuth{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Username: "johndoe",
|
Username: "johndoe",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
@@ -121,7 +120,7 @@ func TestContextController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewContextController(controller.ContextControllerInput{
|
NewContextController(ContextControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
Runtime: &runtime,
|
Runtime: &runtime,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
func TestHealthController(t *testing.T) {
|
||||||
@@ -55,7 +54,7 @@ func TestHealthController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewHealthController(controller.HealthControllerInput{
|
NewHealthController(HealthControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
"go.uber.org/dig"
|
"go.uber.org/dig"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -80,9 +82,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.isOidcRequest(reqParams) {
|
if !controller.isOidcRequest(reqParams) {
|
||||||
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
|
if !controller.isRedirectSafe(reqParams.RedirectURI) {
|
||||||
|
|
||||||
if !isRedirectSafe {
|
|
||||||
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
|
||||||
reqParams.RedirectURI = ""
|
reqParams.RedirectURI = ""
|
||||||
}
|
}
|
||||||
@@ -310,3 +310,56 @@ func (controller *OAuthController) getCookieDomain() string {
|
|||||||
}
|
}
|
||||||
return controller.runtime.CookieDomain
|
return controller.runtime.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
|
||||||
|
u, err := url.Parse(redirectURI)
|
||||||
|
|
||||||
|
if err != nil || u.Host == "" || u.Scheme == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowed := range controller.runtime.TrustedDomains {
|
||||||
|
tu, err := url.Parse(allowed)
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tu.Scheme != u.Scheme {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// exact match
|
||||||
|
if strings.EqualFold(u.Host, tu.Host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if subdomains are disabled, end here
|
||||||
|
if !controller.config.Auth.SubdomainsEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the root domain (e.g. tinyauth.example.com -> example.com or
|
||||||
|
// tinyauth.sub.example.com -> sub.example.com)
|
||||||
|
_, root, ok := strings.Cut(tu.Host, ".")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
root = strings.ToLower(root)
|
||||||
|
|
||||||
|
// check if the root domain is in the psl
|
||||||
|
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// subdomain match
|
||||||
|
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOAuthController(t *testing.T) {
|
||||||
|
log := logger.NewLogger().WithTestConfig()
|
||||||
|
log.Init()
|
||||||
|
|
||||||
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
description string
|
||||||
|
run func(ctrl *OAuthController)
|
||||||
|
trustedDomains []string
|
||||||
|
subdomainsEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Test exact match of redirect URI",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://tinyauth.example.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test subdomain match of redirect URI",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test different trusted domain",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://app.foo.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test invalid redirect URI",
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https:/malicious"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test empty redirect URI",
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := ""
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test redirect URI with different scheme",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "http://tinyauth.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test redirect URI with different port",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://tinyauth.example.com:8080"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// weird case, subdomains enabled and domain without subdomain can't happen
|
||||||
|
description: "Test with trusted domain that's in PSL when split",
|
||||||
|
trustedDomains: []string{"https://example.com"}, // will become .com which we
|
||||||
|
// obviously don't want to allow
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test subdomain redirect URI when subdomains are disabled",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.tinyauth.example.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test domain like the .co.uk",
|
||||||
|
trustedDomains: []string{"https://example.co.uk"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sub.example.co.uk"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test domain like the .co.uk with subdomains disabled",
|
||||||
|
trustedDomains: []string{"https://example.co.uk"},
|
||||||
|
subdomainsEnabled: false,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://example.co.uk"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test caps domain",
|
||||||
|
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://sUb.ExAmPle.com"
|
||||||
|
assert.True(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test edge case with @",
|
||||||
|
trustedDomains: []string{"https://tinyauth.example.com"},
|
||||||
|
subdomainsEnabled: true,
|
||||||
|
run: func(ctrl *OAuthController) {
|
||||||
|
redirectUri := "https://malicious.example.com@evil.com"
|
||||||
|
assert.False(t, ctrl.isRedirectSafe(redirectUri))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: add auth service
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
router := gin.Default()
|
||||||
|
group := router.Group("/api")
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
// overwrite the trusted domains and subdomain setting for each test case
|
||||||
|
runtime.TrustedDomains = tc.trustedDomains
|
||||||
|
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
|
||||||
|
ctrl := NewOAuthController(OAuthControllerInput{
|
||||||
|
Log: log,
|
||||||
|
Config: &cfg,
|
||||||
|
RuntimeConfig: &runtime,
|
||||||
|
RouterGroup: group,
|
||||||
|
})
|
||||||
|
tc.run(ctrl)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -45,7 +44,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Middleware that injects an authenticated local user into the gin context,
|
// Middleware that injects an authenticated local user into the gin context,
|
||||||
// mimicking the context middleware that runs before the OIDC controller.
|
// mimicking the context middleware that runs before the OIDC
|
||||||
authedUser := func(c *gin.Context) {
|
authedUser := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
@@ -210,10 +209,30 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// --- authorize-complete ---
|
// --- authorize-complete ---
|
||||||
|
{
|
||||||
|
description: "Should fail if oidc is disabled",
|
||||||
|
oidcDisabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
var res map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
||||||
|
redirectURI, ok := res["redirect_uri"].(string)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -243,7 +262,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -263,7 +282,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
description: "Authorize complete returns a JSON error when the ticket is invalid",
|
||||||
middlewares: []gin.HandlerFunc{authedUser},
|
middlewares: []gin.HandlerFunc{authedUser},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -291,7 +310,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
State: "state-123",
|
State: "state-123",
|
||||||
})
|
})
|
||||||
|
|
||||||
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
|
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
||||||
@@ -837,7 +856,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
svc = nil
|
svc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.NewOIDCController(controller.OIDCControllerInput{
|
NewOIDCController(OIDCControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
OIDCService: svc,
|
OIDCService: svc,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -10,7 +13,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
@@ -64,6 +66,17 @@ func TestProxyController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Should get bad request on invalid proxy",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad request")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Default forward auth should be detected and used for traefik",
|
description: "Default forward auth should be detected and used for traefik",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -75,7 +88,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -90,7 +103,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -106,7 +119,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -124,7 +137,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -141,7 +154,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -159,7 +172,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 307, recorder.Code)
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -176,7 +189,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -191,7 +204,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -206,7 +219,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -223,7 +236,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -239,7 +252,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -256,7 +269,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -271,7 +284,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -281,7 +294,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -292,7 +305,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Host = "path-allow.example.com"
|
req.Host = "path-allow.example.com"
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -305,7 +318,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -316,7 +329,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -328,7 +341,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -342,7 +355,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -356,12 +369,301 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, 403, recorder.Code)
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with non browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
||||||
|
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Test IP block rule, with browser user agent",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
||||||
|
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "OAuth not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderOAuth,
|
||||||
|
OAuth: &model.OAuthContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "oauth-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP allowed group",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and non browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
|
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
||||||
|
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "LDAP not in required groups and browser",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: true,
|
||||||
|
Provider: model.ProviderLDAP,
|
||||||
|
LDAP: &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "testuser",
|
||||||
|
Name: "Testuser",
|
||||||
|
Email: "testuser@example.com",
|
||||||
|
},
|
||||||
|
Groups: []string{"group3"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||||
|
location := recorder.Header().Get("Location")
|
||||||
|
assert.Contains(t, location, "groupErr=true")
|
||||||
|
assert.Contains(t, location, "ldap-group")
|
||||||
|
assert.Contains(t, location, runtime.AppURL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add basic auth if it's in ACLs",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Authorization header should be preserved when not basic auth acls",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "test.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
req.Header.Set("authorization", "Bearer mytoken")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
authorizationHeader := recorder.Header().Get("Authorization")
|
||||||
|
assert.NotEmpty(t, authorizationHeader)
|
||||||
|
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should add response headers if present",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||||
|
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
||||||
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -432,7 +734,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewProxyController(controller.ProxyControllerInput{
|
NewProxyController(ProxyControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
|
|||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create a "backup" of the original configuration to restore after each test
|
||||||
|
originalCfg := cfg.Resources
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
customCfg *model.ResourcesConfig
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
|
|||||||
assert.Equal(t, 404, recorder.Code)
|
assert.Equal(t, 404, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 404 when resources path is empty",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: "",
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 404, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure resources controller returns 403 when resources are disabled",
|
||||||
|
customCfg: &model.ResourcesConfig{
|
||||||
|
Path: cfg.Resources.Path,
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 403, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||||
@@ -69,7 +99,15 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewResourcesController(controller.ResourcesControllerInput{
|
// if custom configuration is provided, override the default config
|
||||||
|
if test.customCfg != nil {
|
||||||
|
cfg.Resources = *test.customCfg
|
||||||
|
} else {
|
||||||
|
// Reset to default configuration for each test
|
||||||
|
cfg.Resources = originalCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
NewResourcesController(ResourcesControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -42,6 +41,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
@@ -57,6 +57,7 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
@@ -71,6 +72,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -82,11 +84,45 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
description: "Login should fail gracefully on invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Should fail on missing user",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
loginReq := LoginRequest{
|
||||||
|
Username: "nonexistentuser",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
loginReqBody, err := json.Marshal(loginReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with valid credentials",
|
description: "Should be able to login with valid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -114,7 +150,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should reject login with invalid credentials",
|
description: "Should reject login with invalid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -135,7 +171,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should rate limit on 3 invalid attempts",
|
description: "Should rate limit on 3 invalid attempts",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "wrongpassword",
|
Password: "wrongpassword",
|
||||||
}
|
}
|
||||||
@@ -170,7 +206,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Should not allow full login with totp",
|
description: "Should not allow full login with totp",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -207,7 +243,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
// First login to get a session cookie
|
// First login to get a session cookie
|
||||||
loginReq := controller.LoginRequest{
|
loginReq := LoginRequest{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "password",
|
Password: "password",
|
||||||
}
|
}
|
||||||
@@ -243,6 +279,87 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Logout should be treated as valid without a session cookie",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should gracefully reject invalid json",
|
||||||
|
middlewares: []gin.HandlerFunc{},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail on non-totp context",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
simpleCtx,
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "TOTP should fail when user in context doesn't exist",
|
||||||
|
middlewares: []gin.HandlerFunc{
|
||||||
|
func(ctx *gin.Context) {
|
||||||
|
ctx.Set("context", &model.UserContext{
|
||||||
|
Authenticated: false,
|
||||||
|
Provider: model.ProviderLocal,
|
||||||
|
Local: &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: "idontexist",
|
||||||
|
Name: "Totpuser",
|
||||||
|
Email: "totpuser@example.com",
|
||||||
|
},
|
||||||
|
TOTPPending: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx.Next()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
totpReq := TotpRequest{
|
||||||
|
Code: "123456",
|
||||||
|
}
|
||||||
|
|
||||||
|
totpReqBody, err := json.Marshal(totpReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 401, recorder.Code)
|
||||||
|
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
@@ -264,7 +381,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +419,7 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
for range 3 {
|
for range 3 {
|
||||||
totpReq := controller.TotpRequest{
|
totpReq := TotpRequest{
|
||||||
Code: "000000", // invalid code
|
Code: "000000", // invalid code
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -334,7 +451,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login uses name and email from user attributes",
|
description: "Login uses name and email from user attributes",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
|
loginReq := LoginRequest{Username: "attruser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -352,7 +469,7 @@ func TestUserController(t *testing.T) {
|
|||||||
description: "Login with TOTP uses name and email from user attributes in pending session",
|
description: "Login with TOTP uses name and email from user attributes in pending session",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
|
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
|
||||||
body, err := json.Marshal(loginReq)
|
body, err := json.Marshal(loginReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -388,7 +505,7 @@ func TestUserController(t *testing.T) {
|
|||||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
totpReq := controller.TotpRequest{Code: code}
|
totpReq := TotpRequest{Code: code}
|
||||||
body, err := json.Marshal(totpReq)
|
body, err := json.Marshal(totpReq)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -455,7 +572,7 @@ func TestUserController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewUserController(controller.UserControllerInput{
|
NewUserController(UserControllerInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
package controller_test
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
|
oidcEnabled bool
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
res := controller.OpenIDConnectConfiguration{}
|
res := OpenIDConnectConfiguration{}
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected := controller.OpenIDConnectConfiguration{
|
expected := OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: runtime.AppURL,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
|
||||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
|
||||||
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||||
RequestParameterSupported: true,
|
|
||||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||||
|
RequestParameterSupported: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, res)
|
assert.Equal(t, expected, res)
|
||||||
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct JWKS",
|
description: "Ensure well-known endpoint returns correct JWKS",
|
||||||
|
oidcEnabled: true,
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys, ok := decodedBody["keys"].([]any)
|
keys, ok := decodedBody["keys"].([]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Len(t, keys, 1)
|
assert.Len(t, keys, 1)
|
||||||
|
|
||||||
keyData, ok := keys[0].(map[string]any)
|
keyData, ok := keys[0].(map[string]any)
|
||||||
assert.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, "RSA", keyData["kty"])
|
assert.Equal(t, "RSA", keyData["kty"])
|
||||||
assert.Equal(t, "sig", keyData["use"])
|
assert.Equal(t, "sig", keyData["use"])
|
||||||
assert.Equal(t, "RS256", keyData["alg"])
|
assert.Equal(t, "RS256", keyData["alg"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure openid configuration returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 500, recorder.Code)
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger returns 400 on invalid resource",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 400, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "invalid resource", decodedBody["message"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows acct",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows https",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "https://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure webfinger resource validator allows http",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "http://example.com/testuser"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is nil",
|
||||||
|
oidcEnabled: false,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return links when oidc is configured and rel is provided",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
||||||
|
rel := "http://openid.net/specs/connect/1.0/issuer"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 1)
|
||||||
|
|
||||||
|
linkData, ok := links[0].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, rel, linkData["rel"])
|
||||||
|
assert.Equal(t, runtime.AppURL, linkData["href"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
||||||
|
oidcEnabled: true,
|
||||||
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
resource := "acct:testuser@example.com"
|
||||||
|
rel := "http://example.com/does-not-exist"
|
||||||
|
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
||||||
|
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
||||||
|
|
||||||
|
decodedBody := make(map[string]any)
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
links, ok := decodedBody["links"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, links, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@@ -109,10 +297,15 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
controller.NewWellKnownController(controller.WellKnownControllerInput{
|
wellKnownControllerInput := WellKnownControllerInput{
|
||||||
OIDCService: oidcService,
|
|
||||||
RouterGroup: &router.RouterGroup,
|
RouterGroup: &router.RouterGroup,
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if test.oidcEnabled {
|
||||||
|
wellKnownControllerInput.OIDCService = oidcService
|
||||||
|
}
|
||||||
|
|
||||||
|
NewWellKnownController(wellKnownControllerInput)
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package middleware_test
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/steveiliop56/ding"
|
"github.com/steveiliop56/ding"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
@@ -278,7 +277,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
PolicyEngine: policyEngine,
|
PolicyEngine: policyEngine,
|
||||||
})
|
})
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{
|
contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
RuntimeConfig: &runtime,
|
RuntimeConfig: &runtime,
|
||||||
AuthService: authService,
|
AuthService: authService,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func NewDefaultConfiguration() *Config {
|
|||||||
ACLs: ACLsConfig{
|
ACLs: ACLsConfig{
|
||||||
Policy: "allow",
|
Policy: "allow",
|
||||||
},
|
},
|
||||||
|
LockdownEnabled: true,
|
||||||
},
|
},
|
||||||
UI: UIConfig{
|
UI: UIConfig{
|
||||||
Title: "Tinyauth",
|
Title: "Tinyauth",
|
||||||
@@ -120,6 +121,7 @@ type AuthConfig struct {
|
|||||||
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
SessionMaxLifetime int `description:"Maximum session lifetime in seconds." yaml:"sessionMaxLifetime"`
|
||||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||||
|
LockdownEnabled bool `description:"Enable lockdown mode after maximum login retries. Lockdown mode limit is calculated automatically." yaml:"lockdownEnabled"`
|
||||||
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||||
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"`
|
||||||
}
|
}
|
||||||
@@ -178,16 +180,16 @@ type UIConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type LDAPConfig struct {
|
type LDAPConfig struct {
|
||||||
Address string `description:"LDAP server address." yaml:"address"`
|
Address string `description:"LDAP server address." yaml:"address"`
|
||||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||||
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
BindPasswordFile string `description:"Path to the Bind password." yaml:"bindPasswordFile"`
|
||||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||||
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
|
||||||
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
|
||||||
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package model_test
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,44 +21,44 @@ func TestContext(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
context *model.UserContext
|
context *UserContext
|
||||||
run func(*testing.T, *model.UserContext) any
|
run func(*testing.T, *UserContext) any
|
||||||
expected any
|
expected any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "IsAuthenticated reflects Authenticated field",
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
context: &model.UserContext{Authenticated: true},
|
context: &UserContext{Authenticated: true},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsAuthenticated() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
context: &UserContext{Provider: ProviderLDAP, LDAP: &LDAPContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderBasicAuth, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
@@ -67,12 +66,12 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [2]any{got.Provider, got.Authenticated}
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
},
|
},
|
||||||
expected: [2]any{model.ProviderLocal, true},
|
expected: [2]any{ProviderLocal, true},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "bob", Provider: "local", TotpPending: true,
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
})
|
})
|
||||||
@@ -83,20 +82,20 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession ldap session is ProviderLDAP",
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "carol", Provider: "ldap",
|
Username: "carol", Provider: "ldap",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return got.Provider
|
return got.Provider
|
||||||
},
|
},
|
||||||
expected: model.ProviderLDAP,
|
expected: ProviderLDAP,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
got, err := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "dave", Provider: "github",
|
Username: "dave", Provider: "github",
|
||||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
@@ -104,126 +103,126 @@ func TestContext(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
},
|
},
|
||||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
expected: [5]any{ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Local getters return BaseContext fields",
|
description: "Local getters return BaseContext fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "BasicAuth getters fall back to local fields",
|
description: "BasicAuth getters fall back to local fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderBasicAuth,
|
Provider: ProviderBasicAuth,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "LDAP getters return LDAP fields",
|
description: "LDAP getters return LDAP fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLDAP,
|
Provider: ProviderLDAP,
|
||||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
LDAP: &LDAPContext{BaseContext: BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuth getters return OAuth fields",
|
description: "OAuth getters return OAuth fields",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
OAuth: &OAuthContext{BaseContext: BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderLocal",
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &UserContext{Provider: ProviderBasicAuth},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &UserContext{Provider: ProviderLDAP},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "ldap",
|
expected: "ldap",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
description: "ProviderName returns OAuth provider ID for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{ID: "github"},
|
OAuth: &OAuthContext{ID: "github"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.GetProviderID() },
|
run: func(t *testing.T, c *UserContext) any { return c.GetProviderID() },
|
||||||
expected: "github",
|
expected: "github",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns true when local context is pending",
|
description: "TOTPPending returns true when local context is pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: true},
|
Local: &LocalContext{TOTPPending: true},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false when local context is not pending",
|
description: "TOTPPending returns false when local context is not pending",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: false},
|
Local: &LocalContext{TOTPPending: false},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false for non-local providers",
|
description: "TOTPPending returns false for non-local providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &UserContext{Provider: ProviderOAuth, OAuth: &OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns DisplayName for ProviderOAuth",
|
description: "OAuthName returns DisplayName for ProviderOAuth",
|
||||||
context: &model.UserContext{
|
context: &UserContext{
|
||||||
Provider: model.ProviderOAuth,
|
Provider: ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
OAuth: &OAuthContext{DisplayName: "Google"},
|
||||||
},
|
},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "Google",
|
expected: "Google",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns empty string for non-oauth providers",
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &UserContext{Provider: ProviderLocal, Local: &LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *UserContext) any { return c.OAuthName() },
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin populates context from gin value",
|
description: "NewFromGin populates context from gin value",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
stored := &model.UserContext{
|
stored := &UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Provider: model.ProviderLocal,
|
Provider: ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
Local: &LocalContext{BaseContext: BaseContext{Username: "alice"}},
|
||||||
}
|
}
|
||||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -233,17 +232,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value is missing",
|
description: "NewFromGin returns error when context value is missing",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: model.ErrUserContextNotFound.Error(),
|
expected: ErrUserContextNotFound.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
@@ -251,17 +250,17 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin returns an error when context doesn't include user information",
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
context: &model.UserContext{},
|
context: &UserContext{},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
_, err := c.NewFromGin(newGinCtx(&UserContext{Provider: ProviderLocal}, true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: "incomplete user context",
|
expected: "incomplete user context",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Getters should not panic if provider context is empty",
|
description: "Getters should not panic if provider context is empty",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &UserContext{Provider: ProviderLocal},
|
||||||
run: func(t *testing.T, c *model.UserContext) any {
|
run: func(t *testing.T, c *UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"", "", ""},
|
expected: [3]string{"", "", ""},
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -25,7 +27,6 @@ import (
|
|||||||
// but for now these are just safety limits to prevent unbounded memory usage
|
// but for now these are just safety limits to prevent unbounded memory usage
|
||||||
const MaxOAuthPendingSessions = 256
|
const MaxOAuthPendingSessions = 256
|
||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
const MaxLoginAttemptRecords = 256
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
@@ -81,6 +82,8 @@ type AuthService struct {
|
|||||||
oauth *CacheStore[OAuthPendingSession]
|
oauth *CacheStore[OAuthPendingSession]
|
||||||
ldap *CacheStore[[]string]
|
ldap *CacheStore[[]string]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
maxLoginLimits int
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceInput struct {
|
type AuthServiceInput struct {
|
||||||
@@ -111,9 +114,18 @@ func NewAuthService(i AuthServiceInput) *AuthService {
|
|||||||
policyEngine: i.PolicyEngine,
|
policyEngine: i.PolicyEngine,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get the max login limits based on the number of users and the configured max retries
|
||||||
|
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||||
|
|
||||||
|
loginCacheSize := 0
|
||||||
|
|
||||||
|
if !service.config.Auth.LockdownEnabled {
|
||||||
|
loginCacheSize = service.maxLoginLimits
|
||||||
|
}
|
||||||
|
|
||||||
// caches setup
|
// caches setup
|
||||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||||
ldapCache := NewCacheStore[[]string](1024)
|
ldapCache := NewCacheStore[[]string](1024)
|
||||||
|
|
||||||
service.caches.oauth = oauthCache
|
service.caches.oauth = oauthCache
|
||||||
@@ -259,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||||
if locked, _ := auth.IsInLockdown(); locked {
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -634,16 +646,17 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(auth.ctx)
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown.active = true
|
auth.lockdown.active = true
|
||||||
auth.lockdown.ctx = ctx
|
auth.lockdown.ctx = ctx
|
||||||
auth.lockdown.cancelFunc = cancel
|
auth.lockdown.cancelFunc = cancel
|
||||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||||
|
auth.lockdown.until = time.Now().Add(d)
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
|
||||||
auth.lockdown.mu.Unlock()
|
auth.lockdown.mu.Unlock()
|
||||||
|
|
||||||
@@ -655,14 +668,13 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.ctx.Done():
|
|
||||||
// Service is shutting down, end lockdown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdown.mu.Lock()
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
|
auth.caches.login.Clear()
|
||||||
auth.lockdown.active = false
|
auth.lockdown.active = false
|
||||||
auth.lockdown.until = time.Time{}
|
auth.lockdown.until = time.Time{}
|
||||||
auth.lockdown.ctx = nil
|
auth.lockdown.ctx = nil
|
||||||
@@ -685,3 +697,32 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
|||||||
func (auth *AuthService) ClearLoginAttempts() {
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
auth.caches.login.Clear()
|
auth.caches.login.Clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) calculateLockdownLimit() int {
|
||||||
|
userCount := len(auth.runtime.LocalUsers)
|
||||||
|
|
||||||
|
if auth.ldap != nil {
|
||||||
|
ldapUsers, err := auth.ldap.GetUserCount()
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||||
|
} else {
|
||||||
|
userCount += ldapUsers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||||
|
|
||||||
|
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||||
|
} else {
|
||||||
|
limit += int(jitter.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit < 256 {
|
||||||
|
limit = 256
|
||||||
|
}
|
||||||
|
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
|||||||
@@ -169,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
|||||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||||
|
searchRequest := ldapgo.NewSearchRequest(
|
||||||
|
ldap.config.LDAP.BaseDN,
|
||||||
|
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||||
|
"(objectClass=person)",
|
||||||
|
[]string{"dn"},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
ldap.mutex.Lock()
|
||||||
|
defer ldap.mutex.Unlock()
|
||||||
|
|
||||||
|
searchResult, err := ldap.conn.Search(searchRequest)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(searchResult.Entries), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -10,12 +10,11 @@ import (
|
|||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestUser() service.UserinfoResponse {
|
func newTestUser() UserinfoResponse {
|
||||||
return service.UserinfoResponse{
|
return UserinfoResponse{
|
||||||
Sub: "test-sub",
|
Sub: "test-sub",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
PreferredUsername: "testuser",
|
PreferredUsername: "testuser",
|
||||||
@@ -70,7 +69,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
svc, err := service.NewOIDCService(service.OIDCServiceInput{
|
svc, err := NewOIDCService(OIDCServiceInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
Runtime: &runtime,
|
Runtime: &runtime,
|
||||||
@@ -81,16 +80,16 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
mutate func(u *service.UserinfoResponse)
|
mutate func(u *UserinfoResponse)
|
||||||
scope string
|
scope string
|
||||||
run func(t *testing.T, info service.UserinfoResponse)
|
run func(t *testing.T, info UserinfoResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "openid scope only returns sub and updated_at",
|
description: "openid scope only returns sub and updated_at",
|
||||||
scope: "openid",
|
scope: "openid",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test-sub", info.Sub)
|
assert.Equal(t, "test-sub", info.Sub)
|
||||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -103,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "profile scope returns all profile fields",
|
description: "profile scope returns all profile fields",
|
||||||
scope: "openid profile",
|
scope: "openid profile",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||||
assert.Equal(t, "Test", info.GivenName)
|
assert.Equal(t, "Test", info.GivenName)
|
||||||
@@ -123,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email and email_verified true when email present",
|
description: "email scope sets email and email_verified true when email present",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.True(t, info.EmailVerified)
|
assert.True(t, info.EmailVerified)
|
||||||
assert.Empty(t, info.Name)
|
assert.Empty(t, info.Name)
|
||||||
@@ -132,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "email scope sets email_verified false when email absent",
|
description: "email scope sets email_verified false when email absent",
|
||||||
scope: "openid email",
|
scope: "openid email",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Empty(t, info.Email)
|
assert.Empty(t, info.Email)
|
||||||
assert.False(t, info.EmailVerified)
|
assert.False(t, info.EmailVerified)
|
||||||
},
|
},
|
||||||
@@ -141,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified true when phone present",
|
description: "phone scope sets phone_number_verified true when phone present",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.True(t, *info.PhoneNumberVerified)
|
assert.True(t, *info.PhoneNumberVerified)
|
||||||
@@ -150,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "phone scope sets phone_number_verified false when phone absent",
|
description: "phone scope sets phone_number_verified false when phone absent",
|
||||||
scope: "openid phone",
|
scope: "openid phone",
|
||||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.PhoneNumberVerified)
|
require.NotNil(t, info.PhoneNumberVerified)
|
||||||
assert.False(t, *info.PhoneNumberVerified)
|
assert.False(t, *info.PhoneNumberVerified)
|
||||||
},
|
},
|
||||||
@@ -159,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "address scope returns parsed address",
|
description: "address scope returns parsed address",
|
||||||
scope: "openid address",
|
scope: "openid address",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
require.NotNil(t, info.Address)
|
require.NotNil(t, info.Address)
|
||||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||||
@@ -172,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "groups scope returns split groups",
|
description: "groups scope returns split groups",
|
||||||
scope: "openid groups",
|
scope: "openid groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "all scopes return all fields",
|
description: "all scopes return all fields",
|
||||||
scope: "openid profile email phone address groups",
|
scope: "openid profile email phone address groups",
|
||||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
run: func(t *testing.T, info UserinfoResponse) {
|
||||||
assert.Equal(t, "Test User", info.Name)
|
assert.Equal(t, "Test User", info.Name)
|
||||||
assert.Equal(t, "test@example.com", info.Email)
|
assert.Equal(t, "test@example.com", info.Email)
|
||||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package service_test
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// Create test rule
|
// Create test rule
|
||||||
type TestRule struct{}
|
type TestRule struct{}
|
||||||
|
|
||||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||||
switch ctx.Path {
|
switch ctx.Path {
|
||||||
case "/allowed":
|
case "/allowed":
|
||||||
return service.EffectAllow
|
return EffectAllow
|
||||||
case "/denied":
|
case "/denied":
|
||||||
return service.EffectDeny
|
return EffectDeny
|
||||||
default:
|
default:
|
||||||
return service.EffectAbstain
|
return EffectAbstain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,32 +32,32 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
|
|
||||||
// Engine should fail with invalid policy
|
// Engine should fail with invalid policy
|
||||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||||
_, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// Engine should initialize with 'allow' policy
|
// Engine should initialize with 'allow' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err := service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||||
|
|
||||||
// Engine should initialize with 'deny' policy
|
// Engine should initialize with 'deny' policy
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||||
|
|
||||||
// Engine should allow adding rules
|
// Engine should allow adding rules
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
@@ -68,8 +67,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Begin allow policy tests
|
// Begin allow policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
@@ -77,7 +76,7 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
engine.RegisterRule("test-rule", testRule)
|
engine.RegisterRule("test-rule", testRule)
|
||||||
|
|
||||||
// With allow policy, if rule allows, access should be allowed
|
// With allow policy, if rule allows, access should be allowed
|
||||||
ctx := &service.ACLContext{Path: "/allowed"}
|
ctx := &ACLContext{Path: "/allowed"}
|
||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// With allow policy, if rule denies, access should be denied
|
// With allow policy, if rule denies, access should be denied
|
||||||
@@ -89,8 +88,8 @@ func TestPolicyEngine(t *testing.T) {
|
|||||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||||
|
|
||||||
// Begin deny policy tests
|
// Begin deny policy tests
|
||||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||||
engine, err = service.NewPolicyEngine(service.PolicyEngineInput{
|
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||||
Log: log,
|
Log: log,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -138,8 +138,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
|||||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
|
||||||
|
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,50 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"ip_block": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ip-block.example.com",
|
||||||
|
},
|
||||||
|
IP: model.AppIP{
|
||||||
|
Block: []string{"10.10.10.10"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"oauth_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "oauth-group.example.com",
|
||||||
|
},
|
||||||
|
OAuth: model.AppOAuth{
|
||||||
|
Whitelist: "testuser@example.com",
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"ldap_group": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "ldap-group.example.com",
|
||||||
|
},
|
||||||
|
LDAP: model.AppLDAP{
|
||||||
|
Groups: "group1,group2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"basic_auth": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "basic-auth.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
BasicAuth: model.AppBasicAuth{
|
||||||
|
Username: "test",
|
||||||
|
Password: "password",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"response_headers": {
|
||||||
|
Config: model.AppConfig{
|
||||||
|
Domain: "response-headers.example.com",
|
||||||
|
},
|
||||||
|
Response: model.AppResponse{
|
||||||
|
Headers: []string{"x-foo=bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +165,10 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
CookieDomain: "example.com",
|
CookieDomain: "example.com",
|
||||||
AppURL: "https://tinyauth.example.com",
|
AppURL: "https://tinyauth.example.com",
|
||||||
SessionCookieName: "tinyauth-session",
|
SessionCookieName: "tinyauth-session",
|
||||||
|
TrustedDomains: []string{
|
||||||
|
"https://tinyauth.example.com",
|
||||||
|
"https://tinyauth.foo.com",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, runtime
|
return config, runtime
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -88,23 +87,3 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
|
||||||
if redirectURL == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := url.Parse(redirectURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := parsed.Hostname()
|
|
||||||
|
|
||||||
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostname == domain
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -126,61 +126,6 @@ func TestFilter(t *testing.T) {
|
|||||||
assert.Equal(t, expectedStr, resultStr)
|
assert.Equal(t, expectedStr, resultStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsRedirectSafe(t *testing.T) {
|
|
||||||
// Setup
|
|
||||||
domain := "example.com"
|
|
||||||
|
|
||||||
// Case with no subdomain
|
|
||||||
redirectURL := "http://example.com/welcome"
|
|
||||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with different domain
|
|
||||||
redirectURL = "http://malicious.com/phishing"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with subdomain
|
|
||||||
redirectURL = "http://sub.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with sub-subdomain
|
|
||||||
redirectURL = "http://a.b.example.com/home"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with empty redirect URL
|
|
||||||
redirectURL = ""
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with invalid URL
|
|
||||||
redirectURL = "http://[::1]:namedport"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with URL having port
|
|
||||||
redirectURL = "http://sub.example.com:8080/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different subdomain
|
|
||||||
redirectURL = "http://another.example.com/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.True(t, result)
|
|
||||||
|
|
||||||
// Case with URL having different TLD
|
|
||||||
redirectURL = "http://example.org/page"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
|
|
||||||
// Case with malicious domain
|
|
||||||
redirectURL = "https://malicious-example.com/yoyo"
|
|
||||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
|
||||||
assert.False(t, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStandaloneCookieDomain(t *testing.T) {
|
func TestGetStandaloneCookieDomain(t *testing.T) {
|
||||||
// Normal case
|
// Normal case
|
||||||
domain := "http://tinyauth.app"
|
domain := "http://tinyauth.app"
|
||||||
|
|||||||
Reference in New Issue
Block a user