Merge branch 'main' into feat/oidc-preserve-consent

This commit is contained in:
Stavros
2026-06-21 12:53:07 +03:00
54 changed files with 2348 additions and 839 deletions
+21 -16
View File
@@ -3,6 +3,7 @@ package controller
import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
)
@@ -71,29 +72,33 @@ type AppContextResponse struct {
App ACRApp `json:"app"`
}
type ContextController struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
type ContextControllerInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
func NewContextController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
type ContextController struct {
log *logger.Logger
config *model.Config
runtime *model.RuntimeConfig
}
func NewContextController(i ContextControllerInput) *ContextController {
controller := &ContextController{
log: log,
config: config,
runtime: runtimeConfig,
log: i.Log,
config: i.Config,
runtime: i.Runtime,
}
if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
if !i.Config.UI.WarningsEnabled {
i.Log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
}
contextGroup := router.Group("/context")
contextGroup := i.RouterGroup.Group("/context")
contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler)
+15 -11
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -33,22 +32,22 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/app",
expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{
expectedAppContextResponse := AppContextResponse{
Status: 200,
Message: "Success",
Auth: controller.ACRAuth{
Auth: ACRAuth{
Providers: runtime.ConfiguredProviders,
},
OAuth: controller.ACROAuth{
OAuth: ACROAuth{
AutoRedirect: cfg.OAuth.AutoRedirect,
},
UI: controller.ACRUI{
UI: ACRUI{
Title: cfg.UI.Title,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage,
WarningsEnabled: cfg.UI.WarningsEnabled,
},
App: controller.ACRApp{
App: ACRApp{
AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain,
TrustedDomains: runtime.TrustedDomains,
@@ -64,7 +63,7 @@ func TestContextController(t *testing.T) {
middlewares: []gin.HandlerFunc{},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
expectedUserContextResponse := UserContextResponse{
Status: 401,
Message: "Unauthorized",
}
@@ -92,10 +91,10 @@ func TestContextController(t *testing.T) {
},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
expectedUserContextResponse := UserContextResponse{
Status: 200,
Message: "Success",
Auth: controller.UCRAuth{
Auth: UCRAuth{
Authenticated: true,
Username: "johndoe",
Name: "John Doe",
@@ -121,7 +120,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewContextController(log, cfg, runtime, group)
NewContextController(ContextControllerInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
RouterGroup: group,
})
recorder := httptest.NewRecorder()
+13 -4
View File
@@ -1,15 +1,24 @@
package controller
import "github.com/gin-gonic/gin"
import (
"github.com/gin-gonic/gin"
"go.uber.org/dig"
)
type HealthController struct {
}
func NewHealthController(router *gin.RouterGroup) *HealthController {
type HealthControllerInput struct {
dig.In
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
}
func NewHealthController(i HealthControllerInput) *HealthController {
controller := &HealthController{}
router.GET("/healthz", controller.healthHandler)
router.HEAD("/healthz", controller.healthHandler)
i.RouterGroup.GET("/healthz", controller.healthHandler)
i.RouterGroup.HEAD("/healthz", controller.healthHandler)
return controller
}
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"encoding/json"
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
@@ -55,7 +54,9 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewHealthController(group)
NewHealthController(HealthControllerInput{
RouterGroup: group,
})
recorder := httptest.NewRecorder()
+82 -20
View File
@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@@ -11,6 +12,8 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"github.com/weppos/publicsuffix-go/publicsuffix"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -22,29 +25,30 @@ type OAuthRequest struct {
type OAuthController struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
helpers *model.RuntimeHelpers
config *model.Config
runtime *model.RuntimeConfig
auth *service.AuthService
}
func NewOAuthController(
log *logger.Logger,
config model.Config,
runtimeConfig model.RuntimeConfig,
helpers *model.RuntimeHelpers,
router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
type OAuthControllerInput struct {
dig.In
Log *logger.Logger
Config *model.Config
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
func NewOAuthController(i OAuthControllerInput) *OAuthController {
controller := &OAuthController{
log: log,
config: config,
runtime: runtimeConfig,
helpers: helpers,
auth: auth,
log: i.Log,
config: i.Config,
runtime: i.RuntimeConfig,
auth: i.AuthService,
}
oauthGroup := router.Group("/oauth")
oauthGroup := i.RouterGroup.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
@@ -78,9 +82,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
}
if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain)
if !isRedirectSafe {
if !controller.isRedirectSafe(reqParams.RedirectURI) {
controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring")
reqParams.RedirectURI = ""
}
@@ -320,3 +322,63 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
return params.LoginFor == string(FrontendLoginForOIDC)
}
func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
}
return controller.runtime.CookieDomain
}
func (controller *OAuthController) isRedirectSafe(redirectURI string) bool {
u, err := url.Parse(redirectURI)
if err != nil || u.Host == "" || u.Scheme == "" {
return false
}
for _, allowed := range controller.runtime.TrustedDomains {
tu, err := url.Parse(allowed)
if err != nil {
controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain")
continue
}
if tu.Scheme != u.Scheme {
continue
}
// exact match
if strings.EqualFold(u.Host, tu.Host) {
return true
}
// if subdomains are disabled, end here
if !controller.config.Auth.SubdomainsEnabled {
continue
}
// get the root domain (e.g. tinyauth.example.com -> example.com or
// tinyauth.sub.example.com -> sub.example.com)
_, root, ok := strings.Cut(tu.Host, ".")
if !ok {
continue
}
root = strings.ToLower(root)
// check if the root domain is in the psl
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil)
if err != nil {
continue
}
// subdomain match
if strings.HasSuffix(strings.ToLower(u.Host), "."+root) {
return true
}
}
return false
}
@@ -0,0 +1,161 @@
package controller
import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOAuthController(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := test.CreateTestConfigs(t)
type testCase struct {
description string
run func(ctrl *OAuthController)
trustedDomains []string
subdomainsEnabled bool
}
tests := []testCase{
{
description: "Test exact match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain match of redirect URI",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test different trusted domain",
trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://app.foo.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test invalid redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := "https:/malicious"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test empty redirect URI",
run: func(ctrl *OAuthController) {
redirectUri := ""
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different scheme",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "http://tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test redirect URI with different port",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://tinyauth.example.com:8080"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
// weird case, subdomains enabled and domain without subdomain can't happen
description: "Test with trusted domain that's in PSL when split",
trustedDomains: []string{"https://example.com"}, // will become .com which we
// obviously don't want to allow
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test subdomain redirect URI when subdomains are disabled",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.tinyauth.example.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sub.example.co.uk"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test domain like the .co.uk with subdomains disabled",
trustedDomains: []string{"https://example.co.uk"},
subdomainsEnabled: false,
run: func(ctrl *OAuthController) {
redirectUri := "https://example.co.uk"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test caps domain",
trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://sUb.ExAmPle.com"
assert.True(t, ctrl.isRedirectSafe(redirectUri))
},
},
{
description: "Test edge case with @",
trustedDomains: []string{"https://tinyauth.example.com"},
subdomainsEnabled: true,
run: func(ctrl *OAuthController) {
redirectUri := "https://malicious.example.com@evil.com"
assert.False(t, ctrl.isRedirectSafe(redirectUri))
},
},
}
// TODO: add auth service
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
// overwrite the trusted domains and subdomain setting for each test case
runtime.TrustedDomains = tc.trustedDomains
cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled
ctrl := NewOAuthController(OAuthControllerInput{
Log: log,
Config: &cfg,
RuntimeConfig: &runtime,
RouterGroup: group,
})
tc.run(ctrl)
})
}
}
+118 -55
View File
@@ -7,12 +7,14 @@ import (
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query"
"go.uber.org/dig"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -32,9 +34,7 @@ type authorizeErrorParams struct {
type OIDCController struct {
log *logger.Logger
oidc *service.OIDCService
runtime model.RuntimeConfig
helpers *model.RuntimeHelpers
config model.Config
runtime *model.RuntimeConfig
}
type AuthorizeCallback struct {
@@ -72,37 +72,38 @@ type ClientCredentials struct {
}
type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCShowConsent bool `url:"oidc_show_consent"`
LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"`
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
}
type AuthorizeCompleteRequest struct {
Ticket string `json:"ticket" binding:"required"`
}
func NewOIDCController(
log *logger.Logger,
oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig,
helpers *model.RuntimeHelpers,
config model.Config,
router *gin.RouterGroup,
mainRouter *gin.RouterGroup) *OIDCController {
type OIDCControllerInput struct {
dig.In
Log *logger.Logger
OIDCService *service.OIDCService
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
}
func NewOIDCController(i OIDCControllerInput) *OIDCController {
controller := &OIDCController{
log: log,
oidc: oidcService,
runtime: runtimeConfig,
helpers: helpers,
config: config,
log: i.Log,
oidc: i.OIDCService,
runtime: i.RuntimeConfig,
}
mainRouter.POST("/authorize", controller.authorize)
mainRouter.GET("/authorize", controller.authorize)
i.MainRouter.POST("/authorize", controller.authorize)
i.MainRouter.GET("/authorize", controller.authorize)
oidcGroup := router.Group("/oidc")
oidcGroup := i.RouterGroup.Group("/oidc")
oidcGroup.POST("/authorize-complete", controller.authorizeComplete)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
@@ -170,40 +171,106 @@ func (controller *OIDCController) authorize(c *gin.Context) {
return
}
prompts := controller.oidc.GetPrompt(req.Prompt)
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("invalid prompt"),
reason: "Invalid prompt",
reasonPublic: "The prompt parameters are invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
callback: req.RedirectURI,
callbackError: "login_required",
state: req.State,
})
return
}
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
// Check if we have consented before for this client and scope
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
values := AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
}
showConsent := true
if slices.Contains(prompts, service.OIDCPromptLogin) {
values.OIDCPrompt = service.OIDCPromptLogin
} else if slices.Contains(prompts, service.OIDCPromptNone) {
values.OIDCPrompt = service.OIDCPromptNone
}
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
// If no prompt is already set, we can check if we can/should skip it based on the cookie
if values.OIDCPrompt == "" {
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
showConsent = false
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
values.OIDCPrompt = service.OIDCPromptNone
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
}
queries, err := query.Values(AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket,
OIDCScope: req.Scope,
OIDCName: client.Name,
OIDCShowConsent: showConsent,
})
if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Invalid max_age",
reasonPublic: "The max_age parameter is invalid",
callback: req.RedirectURI,
callbackError: "invalid_request",
state: req.State,
})
return
}
if userContext.Authenticated {
authTime := time.Unix(userContext.AuthTime, 0)
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
values.OIDCPrompt = service.OIDCPromptLogin
}
}
}
queries, err := query.Values(values)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
err: err,
reason: "Failed to compile authorize queries",
reasonPublic: "An internal error occured while processing your request",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
return
}
@@ -231,16 +298,12 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
json: true,
})
return
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
}
}
if !userContext.Authenticated {
if err != nil || !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
@@ -475,7 +538,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
+40 -9
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"context"
@@ -15,7 +15,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -37,11 +36,17 @@ func TestOIDCController(t *testing.T) {
store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err)
// Middleware that injects an authenticated local user into the gin context,
// mimicking the context middleware that runs before the OIDC controller.
// mimicking the context middleware that runs before the OIDC
authedUser := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
@@ -206,10 +211,30 @@ func TestOIDCController(t *testing.T) {
},
// --- authorize-complete ---
{
description: "Should fail if oidc is disabled",
oidcDisabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
var res map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
redirectURI, ok := res["redirect_uri"].(string)
require.True(t, ok)
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
},
},
{
description: "Authorize complete returns a JSON error when the user context is missing",
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -239,7 +264,7 @@ func TestOIDCController(t *testing.T) {
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "some-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -259,7 +284,7 @@ func TestOIDCController(t *testing.T) {
description: "Authorize complete returns a JSON error when the ticket is invalid",
middlewares: []gin.HandlerFunc{authedUser},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "nonexistent-ticket"})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -287,7 +312,7 @@ func TestOIDCController(t *testing.T) {
State: "state-123",
})
body, err := json.Marshal(controller.AuthorizeCompleteRequest{Ticket: ticket})
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: ticket})
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
@@ -833,7 +858,13 @@ func TestOIDCController(t *testing.T) {
svc = nil
}
controller.NewOIDCController(log, svc, runtime, helpers, cfg, group, &router.RouterGroup)
NewOIDCController(OIDCControllerInput{
Log: log,
OIDCService: svc,
RuntimeConfig: &runtime,
RouterGroup: group,
MainRouter: &router.RouterGroup,
})
recorder := httptest.NewRecorder()
+25 -20
View File
@@ -13,6 +13,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -53,29 +54,33 @@ type ProxyContext struct {
type ProxyController struct {
log *logger.Logger
runtime model.RuntimeConfig
runtime *model.RuntimeConfig
acls *service.AccessControlsService
auth *service.AuthService
policyEngine *service.PolicyEngine
}
func NewProxyController(
log *logger.Logger,
runtime model.RuntimeConfig,
router *gin.RouterGroup,
acls *service.AccessControlsService,
auth *service.AuthService,
policyEngine *service.PolicyEngine,
) *ProxyController {
type ProxyControllerInput struct {
dig.In
Log *logger.Logger
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
ACLsService *service.AccessControlsService
AuthService *service.AuthService
PolicyEngine *service.PolicyEngine
}
func NewProxyController(i ProxyControllerInput) *ProxyController {
controller := &ProxyController{
log: log,
runtime: runtime,
acls: acls,
auth: auth,
policyEngine: policyEngine,
log: i.Log,
runtime: i.RuntimeConfig,
acls: i.ACLsService,
auth: i.AuthService,
policyEngine: i.PolicyEngine,
}
proxyGroup := router.Group("/auth")
proxyGroup := i.RouterGroup.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
@@ -153,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -202,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
@@ -246,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
return
}
}
@@ -295,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
@@ -331,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
return
}
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
c.Redirect(http.StatusFound, redirectURL)
}
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
+358 -27
View File
@@ -1,7 +1,10 @@
package controller_test
package controller
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
@@ -10,7 +13,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
@@ -66,6 +68,17 @@ func TestProxyController(t *testing.T) {
}
tests := []testCase{
{
description: "Should get bad request on invalid proxy",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad request")
},
},
{
description: "Default forward auth should be detected and used for traefik",
middlewares: []gin.HandlerFunc{},
@@ -77,7 +90,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -92,7 +105,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -108,7 +121,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
assert.Contains(t, location, "login_for=app")
@@ -126,7 +139,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -143,7 +156,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
location := recorder.Header().Get("x-tinyauth-location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -161,7 +174,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
assert.Contains(t, location, "login_for=app")
@@ -178,7 +191,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -193,7 +206,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -208,7 +221,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/hello")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.Contains(t, recorder.Body.String(), `"status":401`)
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
},
@@ -225,7 +238,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -241,7 +254,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://test.example.com/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -258,7 +271,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
@@ -273,7 +286,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/allowed")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -283,7 +296,7 @@ func TestProxyController(t *testing.T) {
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -294,7 +307,7 @@ func TestProxyController(t *testing.T) {
req.Host = "path-allow.example.com"
req.Header.Set("x-forwarded-proto", "https")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -307,7 +320,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -318,7 +331,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -330,7 +343,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -344,7 +357,7 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, http.StatusOK, recorder.Code)
},
},
{
@@ -358,12 +371,301 @@ func TestProxyController(t *testing.T) {
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
},
},
{
description: "Test IP block rule, with non browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
},
},
{
description: "Test IP block rule, with browser user agent",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ip-block.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("x-forwarded-for", "10.10.10.10")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
assert.Contains(t, location, url.QueryEscape("ip-block"))
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "OAuth allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "OAuth not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
},
},
{
description: "OAuth not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "oauth-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "LDAP allowed group",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group1"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
},
},
{
description: "LDAP not in required groups and non browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.Equal(t, "", recorder.Header().Get("remote-user"))
assert.Equal(t, "", recorder.Header().Get("remote-name"))
assert.Equal(t, "", recorder.Header().Get("remote-email"))
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
},
},
{
description: "LDAP not in required groups and browser",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
Groups: []string{"group3"},
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("user-agent", browserUserAgent)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.Contains(t, location, "groupErr=true")
assert.Contains(t, location, "ldap-group")
assert.Contains(t, location, runtime.AppURL)
},
},
{
description: "Should add basic auth if it's in ACLs",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "foo") // should be overridden by basic auth
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
},
},
{
description: "Authorization header should be preserved when not basic auth acls",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "test.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
req.Header.Set("authorization", "Bearer mytoken")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
authorizationHeader := recorder.Header().Get("Authorization")
assert.NotEmpty(t, authorizationHeader)
assert.Equal(t, "Bearer mytoken", authorizationHeader)
},
},
{
description: "Should add response headers if present",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("x-forwarded-host", "response-headers.example.com")
req.Header.Set("x-forwarded-proto", "https")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
},
},
}
store := memory.New()
@@ -371,10 +673,21 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
aclsService := service.NewAccessControlsService(log, cfg, nil)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
Log: log,
Runtime: &runtime,
Ctx: ctx,
})
aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{
Log: log,
Config: &cfg,
LabelProvider: nil,
})
policyEngine, err := service.NewPolicyEngine(cfg, log)
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err)
policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{
@@ -397,7 +710,18 @@ func TestProxyController(t *testing.T) {
Log: log,
})
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -412,7 +736,14 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder()
controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine)
NewProxyController(ProxyControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
ACLsService: aclsService,
AuthService: authService,
PolicyEngine: policyEngine,
})
test.run(t, router, recorder)
})
+13 -8
View File
@@ -5,25 +5,30 @@ import (
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/model"
"go.uber.org/dig"
)
type ResourcesController struct {
config model.Config
config *model.Config
fileServer http.Handler
}
func NewResourcesController(
config model.Config,
router *gin.RouterGroup,
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
type ResourcesControllerInput struct {
dig.In
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
Config *model.Config
}
func NewResourcesController(i ResourcesControllerInput) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(i.Config.Resources.Path)))
controller := &ResourcesController{
config: config,
config: i.Config,
fileServer: fileServer,
}
router.GET("/resources/*resource", controller.resourcesHandler)
i.RouterGroup.GET("/resources/*resource", controller.resourcesHandler)
return controller
}
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"net/http/httptest"
@@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
)
@@ -19,8 +19,12 @@ func TestResourcesController(t *testing.T) {
err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err)
// create a "backup" of the original configuration to restore after each test
originalCfg := cfg.Resources
type testCase struct {
description string
customCfg *model.ResourcesConfig
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
@@ -53,6 +57,32 @@ func TestResourcesController(t *testing.T) {
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 404 when resources path is empty",
customCfg: &model.ResourcesConfig{
Path: "",
Enabled: true,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure resources controller returns 403 when resources are disabled",
customCfg: &model.ResourcesConfig{
Path: cfg.Resources.Path,
Enabled: false,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 403, recorder.Code)
},
},
}
testFilePath := cfg.Resources.Path + "/testfile.txt"
@@ -69,7 +99,18 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/")
gin.SetMode(gin.TestMode)
controller.NewResourcesController(cfg, group)
// if custom configuration is provided, override the default config
if test.customCfg != nil {
cfg.Resources = *test.customCfg
} else {
// Reset to default configuration for each test
cfg.Resources = originalCfg
}
NewResourcesController(ResourcesControllerInput{
RouterGroup: group,
Config: &cfg,
})
recorder := httptest.NewRecorder()
test.run(t, router, recorder)
+16 -11
View File
@@ -11,6 +11,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
@@ -27,23 +28,27 @@ type TotpRequest struct {
type UserController struct {
log *logger.Logger
runtime model.RuntimeConfig
runtime *model.RuntimeConfig
auth *service.AuthService
}
func NewUserController(
log *logger.Logger,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
auth *service.AuthService,
) *UserController {
type UserControllerInput struct {
dig.In
Log *logger.Logger
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
func NewUserController(i UserControllerInput) *UserController {
controller := &UserController{
log: log,
runtime: runtimeConfig,
auth: auth,
log: i.Log,
runtime: i.RuntimeConfig,
auth: i.AuthService,
}
userGroup := router.Group("/user")
userGroup := i.RouterGroup.Group("/user")
userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler)
+156 -16
View File
@@ -1,4 +1,4 @@
package controller_test
package controller
import (
"context"
@@ -14,7 +14,6 @@ import (
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
@@ -44,6 +43,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
c.Next()
}
totpAttrCtx := func(c *gin.Context) {
@@ -59,6 +59,7 @@ func TestUserController(t *testing.T) {
TOTPPending: true,
},
})
c.Next()
}
simpleCtx := func(c *gin.Context) {
@@ -73,6 +74,7 @@ func TestUserController(t *testing.T) {
},
},
})
c.Next()
}
store := memory.New()
@@ -84,11 +86,45 @@ func TestUserController(t *testing.T) {
}
tests := []testCase{
{
description: "Login should fail gracefully on invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "Should fail on missing user",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := LoginRequest{
Username: "nonexistentuser",
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 0)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "Should be able to login with valid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -116,7 +152,7 @@ func TestUserController(t *testing.T) {
description: "Should reject login with invalid credentials",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -137,7 +173,7 @@ func TestUserController(t *testing.T) {
description: "Should rate limit on 3 invalid attempts",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "wrongpassword",
}
@@ -172,7 +208,7 @@ func TestUserController(t *testing.T) {
description: "Should not allow full login with totp",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "totpuser",
Password: "password",
}
@@ -209,7 +245,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
loginReq := controller.LoginRequest{
loginReq := LoginRequest{
Username: "testuser",
Password: "password",
}
@@ -245,6 +281,87 @@ func TestUserController(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
{
description: "Logout should be treated as valid without a session cookie",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/logout", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "TOTP should gracefully reject invalid json",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Bad Request")
},
},
{
description: "TOTP should fail on non-totp context",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "TOTP should fail when user in context doesn't exist",
middlewares: []gin.HandlerFunc{
func(ctx *gin.Context) {
ctx.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "idontexist",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
ctx.Next()
},
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
totpReq := TotpRequest{
Code: "123456",
}
totpReqBody, err := json.Marshal(totpReq)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Unauthorized")
},
},
{
description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{
@@ -266,7 +383,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
totpReq := controller.TotpRequest{
totpReq := TotpRequest{
Code: code,
}
@@ -304,7 +421,7 @@ func TestUserController(t *testing.T) {
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
totpReq := controller.TotpRequest{
totpReq := TotpRequest{
Code: "000000", // invalid code
}
@@ -336,7 +453,7 @@ func TestUserController(t *testing.T) {
description: "Login uses name and email from user attributes",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attruser", Password: "password"}
loginReq := LoginRequest{Username: "attruser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -354,7 +471,7 @@ func TestUserController(t *testing.T) {
description: "Login with TOTP uses name and email from user attributes in pending session",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
loginReq := controller.LoginRequest{Username: "attrtotpuser", Password: "password"}
loginReq := LoginRequest{Username: "attrtotpuser", Password: "password"}
body, err := json.Marshal(loginReq)
require.NoError(t, err)
@@ -390,7 +507,7 @@ func TestUserController(t *testing.T) {
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
totpReq := controller.TotpRequest{Code: code}
totpReq := TotpRequest{Code: code}
body, err := json.Marshal(totpReq)
require.NoError(t, err)
@@ -416,11 +533,29 @@ func TestUserController(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
policyEngine, err := service.NewPolicyEngine(cfg, log)
policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{
Log: log,
Config: &cfg,
})
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{
Log: log,
Runtime: &runtime,
Ctx: ctx,
})
authService := service.NewAuthService(service.AuthServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Ctx: ctx,
Ding: dg,
LDAP: nil,
Queries: store,
OAuthBroker: broker,
Tailscale: nil,
PolicyEngine: policyEngine,
})
beforeEach := func() {
// Clear failed login attempts before each test
@@ -439,7 +574,12 @@ func TestUserController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
controller.NewUserController(log, runtime, group, authService)
NewUserController(UserControllerInput{
Log: log,
RuntimeConfig: &runtime,
RouterGroup: group,
AuthService: authService,
})
recorder := httptest.NewRecorder()
+87 -4
View File
@@ -3,11 +3,27 @@ package controller
import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/service"
"go.uber.org/dig"
)
const OpenIDConnectRel = "http://openid.net/specs/connect/1.0/issuer"
type WebfingerResponseLink struct {
Rel string `json:"rel,omitempty"`
Href string `json:"href"`
}
type WebfingerResponse struct {
Subject string `json:"subject"`
Links []WebfingerResponseLink `json:"links"`
}
type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
@@ -30,13 +46,21 @@ type WellKnownController struct {
oidc *service.OIDCService
}
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
type WellKnownControllerInput struct {
dig.In
OIDCService *service.OIDCService
RouterGroup *gin.RouterGroup `name:"mainRouterGroup"`
}
func NewWellKnownController(i WellKnownControllerInput) *WellKnownController {
controller := &WellKnownController{
oidc: oidc,
oidc: i.OIDCService,
}
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
router.GET("/.well-known/jwks.json", controller.JWKS)
i.RouterGroup.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
i.RouterGroup.GET("/.well-known/jwks.json", controller.JWKS)
i.RouterGroup.GET("/.well-known/webfinger", controller.WebFinger)
return controller
}
@@ -97,3 +121,62 @@ func (controller *WellKnownController) JWKS(c *gin.Context) {
c.Status(http.StatusOK)
}
func (controller *WellKnownController) WebFinger(c *gin.Context) {
c.Header("Content-Type", "application/jrd+json")
c.Header("Access-Control-Allow-Origin", "*")
resource := c.Query("resource")
if !controller.validateWebFingerResource(resource) {
c.JSON(400, gin.H{
"status": 400,
"message": "invalid resource",
})
return
}
res := WebfingerResponse{
Subject: resource,
Links: []WebfingerResponseLink{},
}
rel := c.Request.URL.Query()["rel"]
if controller.oidc != nil && (len(rel) == 0 || slices.Contains(rel, OpenIDConnectRel)) {
res.Links = append(res.Links, WebfingerResponseLink{Rel: OpenIDConnectRel, Href: controller.oidc.GetIssuer()})
}
c.JSON(200, res)
}
func (controller *WellKnownController) validateWebFingerResource(resource string) bool {
prefix, suffix, found := strings.Cut(resource, ":")
if !found {
return false
}
switch prefix {
case "acct":
if strings.Count(suffix, "@") != 1 {
return false
}
username, domain, found := strings.Cut(suffix, "@")
if !found || username == "" || domain == "" {
return false
}
case "https", "http":
u, err := url.Parse(resource)
if err != nil {
return false
}
if u.Host == "" {
return false
}
default:
return false
}
return true
}
+213 -11
View File
@@ -1,17 +1,17 @@
package controller_test
package controller
import (
"context"
"encoding/json"
"fmt"
"net/http/httptest"
"net/url"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
@@ -26,23 +26,25 @@ func TestWellKnownController(t *testing.T) {
type testCase struct {
description string
oidcEnabled bool
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
tests := []testCase{
{
description: "Ensure well-known endpoint returns correct OIDC configuration",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
res := controller.OpenIDConnectConfiguration{}
res := OpenIDConnectConfiguration{}
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{
expected := OpenIDConnectConfiguration{
Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
@@ -56,8 +58,8 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: true,
}
assert.Equal(t, expected, res)
@@ -65,6 +67,7 @@ func TestWellKnownController(t *testing.T) {
},
{
description: "Ensure well-known endpoint returns correct JWKS",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
@@ -73,19 +76,204 @@ func TestWellKnownController(t *testing.T) {
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err)
require.NoError(t, err)
keys, ok := decodedBody["keys"].([]any)
assert.True(t, ok)
require.True(t, ok)
assert.Len(t, keys, 1)
keyData, ok := keys[0].(map[string]any)
assert.True(t, ok)
require.True(t, ok)
assert.Equal(t, "RSA", keyData["kty"])
assert.Equal(t, "sig", keyData["use"])
assert.Equal(t, "RS256", keyData["alg"])
},
},
{
description: "Ensure openid configuration returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure jwks endpoint returns 500 on nil oidc service",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 500, recorder.Code)
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
},
},
{
description: "Ensure webfinger returns 400 on invalid resource",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
assert.Equal(t, "invalid resource", decodedBody["message"])
},
},
{
description: "Ensure webfinger resource validator allows acct",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows https",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "https://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Ensure webfinger resource validator allows http",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "http://example.com/testuser"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
},
},
{
description: "Webfinger should return no links when oidc is nil",
oidcEnabled: false,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
{
description: "Webfinger should return links when oidc is configured and no rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return links when oidc is configured and rel is provided",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
rel := "http://openid.net/specs/connect/1.0/issuer"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 1)
linkData, ok := links[0].(map[string]any)
require.True(t, ok)
assert.Equal(t, rel, linkData["rel"])
assert.Equal(t, runtime.AppURL, linkData["href"])
},
},
{
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
oidcEnabled: true,
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
resource := "acct:testuser@example.com"
rel := "http://example.com/does-not-exist"
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
decodedBody := make(map[string]any)
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
require.NoError(t, err)
links, ok := decodedBody["links"].([]any)
require.True(t, ok)
assert.Len(t, links, 0)
},
},
}
ctx := context.TODO()
@@ -93,7 +281,13 @@ func TestWellKnownController(t *testing.T) {
store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
oidcService, err := service.NewOIDCService(service.OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err)
for _, test := range tests {
@@ -103,7 +297,15 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder()
controller.NewWellKnownController(oidcService, &router.RouterGroup)
wellKnownControllerInput := WellKnownControllerInput{
RouterGroup: &router.RouterGroup,
}
if test.oidcEnabled {
wellKnownControllerInput.OIDCService = oidcService
}
NewWellKnownController(wellKnownControllerInput)
test.run(t, router, recorder)
})