mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-26 16:37:56 +00:00
Compare commits
3 Commits
refactor/t
...
feat/lockd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb86fff0e4 | ||
|
|
6b0d804ba3 | ||
|
|
f9b1aeb23e |
3
go.mod
3
go.mod
@@ -17,7 +17,6 @@ require (
|
|||||||
github.com/mdp/qrterminal/v3 v3.2.1
|
github.com/mdp/qrterminal/v3 v3.2.1
|
||||||
github.com/pquerna/otp v1.5.0
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/rs/zerolog v1.34.0
|
github.com/rs/zerolog v1.34.0
|
||||||
github.com/stretchr/testify v1.11.1
|
|
||||||
github.com/traefik/paerser v0.2.2
|
github.com/traefik/paerser v0.2.2
|
||||||
github.com/weppos/publicsuffix-go v0.50.3
|
github.com/weppos/publicsuffix-go v0.50.3
|
||||||
golang.org/x/crypto v0.49.0
|
golang.org/x/crypto v0.49.0
|
||||||
@@ -53,7 +52,6 @@ require (
|
|||||||
github.com/containerd/errdefs v1.0.0 // indirect
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/go-connections v0.5.0 // indirect
|
github.com/docker/go-connections v0.5.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
@@ -98,7 +96,6 @@ require (
|
|||||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
|
||||||
github.com/quic-go/qpack v0.6.0 // indirect
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
|||||||
@@ -2,131 +2,152 @@ package controller_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"gotest.tools/v3/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestContextController(t *testing.T) {
|
var contextControllerCfg = controller.ContextControllerConfig{
|
||||||
controllerConfig := controller.ContextControllerConfig{
|
|
||||||
Providers: []controller.Provider{
|
Providers: []controller.Provider{
|
||||||
{
|
{
|
||||||
Name: "Local",
|
Name: "Local",
|
||||||
ID: "local",
|
ID: "local",
|
||||||
OAuth: false,
|
OAuth: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Google",
|
||||||
|
ID: "google",
|
||||||
|
OAuth: true,
|
||||||
},
|
},
|
||||||
Title: "Tinyauth",
|
},
|
||||||
AppURL: "https://tinyauth.example.com",
|
Title: "Test App",
|
||||||
CookieDomain: "example.com",
|
AppURL: "http://localhost:8080",
|
||||||
ForgotPasswordMessage: "foo",
|
CookieDomain: "localhost",
|
||||||
BackgroundImage: "/background.jpg",
|
ForgotPasswordMessage: "Contact admin to reset your password.",
|
||||||
OAuthAutoRedirect: "none",
|
BackgroundImage: "/assets/bg.jpg",
|
||||||
|
OAuthAutoRedirect: "google",
|
||||||
WarningsEnabled: true,
|
WarningsEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
var contextCtrlTestContext = config.UserContext{
|
||||||
description string
|
Username: "testuser",
|
||||||
middlewares []gin.HandlerFunc
|
Name: "testuser",
|
||||||
expected string
|
Email: "test@example.com",
|
||||||
path string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
description: "Ensure context controller returns app context",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
path: "/api/context/app",
|
|
||||||
expected: func() string {
|
|
||||||
expectedAppContextResponse := controller.AppContextResponse{
|
|
||||||
Status: 200,
|
|
||||||
Message: "Success",
|
|
||||||
Providers: controllerConfig.Providers,
|
|
||||||
Title: controllerConfig.Title,
|
|
||||||
AppURL: controllerConfig.AppURL,
|
|
||||||
CookieDomain: controllerConfig.CookieDomain,
|
|
||||||
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
|
|
||||||
BackgroundImage: controllerConfig.BackgroundImage,
|
|
||||||
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
|
|
||||||
WarningsEnabled: controllerConfig.WarningsEnabled,
|
|
||||||
}
|
|
||||||
bytes, err := json.Marshal(expectedAppContextResponse)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
return string(bytes)
|
|
||||||
}(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure user context returns 401 when unauthorized",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
path: "/api/context/user",
|
|
||||||
expected: func() string {
|
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
|
||||||
Status: 401,
|
|
||||||
Message: "Unauthorized",
|
|
||||||
}
|
|
||||||
bytes, err := json.Marshal(expectedUserContextResponse)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
return string(bytes)
|
|
||||||
}(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure user context returns when authorized",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(c *gin.Context) {
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: "johndoe",
|
|
||||||
Name: "John Doe",
|
|
||||||
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
|
||||||
Provider: "local",
|
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
})
|
IsBasicAuth: false,
|
||||||
},
|
OAuth: false,
|
||||||
},
|
|
||||||
path: "/api/context/user",
|
|
||||||
expected: func() string {
|
|
||||||
expectedUserContextResponse := controller.UserContextResponse{
|
|
||||||
Status: 200,
|
|
||||||
Message: "Success",
|
|
||||||
IsLoggedIn: true,
|
|
||||||
Username: "johndoe",
|
|
||||||
Name: "John Doe",
|
|
||||||
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
|
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
}
|
TotpPending: false,
|
||||||
bytes, err := json.Marshal(expectedUserContextResponse)
|
OAuthGroups: "",
|
||||||
assert.NoError(t, err)
|
TotpEnabled: false,
|
||||||
return string(bytes)
|
OAuthSub: "",
|
||||||
}(),
|
}
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
tlog.NewSimpleLogger().Init()
|
||||||
|
|
||||||
|
// Setup
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
for _, middleware := range test.middlewares {
|
if middlewares != nil {
|
||||||
router.Use(middleware)
|
for _, m := range *middlewares {
|
||||||
|
router.Use(m)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
contextController := controller.NewContextController(controllerConfig, group)
|
ctrl := controller.NewContextController(contextControllerCfg, group)
|
||||||
contextController.SetupRoutes()
|
ctrl.SetupRoutes()
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
return router, recorder
|
||||||
|
}
|
||||||
request, err := http.NewRequest("GET", test.path, nil)
|
|
||||||
assert.NoError(t, err)
|
func TestAppContextHandler(t *testing.T) {
|
||||||
|
expectedRes := controller.AppContextResponse{
|
||||||
router.ServeHTTP(recorder, request)
|
Status: 200,
|
||||||
|
Message: "Success",
|
||||||
assert.Equal(t, recorder.Result().StatusCode, http.StatusOK)
|
Providers: contextControllerCfg.Providers,
|
||||||
assert.Equal(t, test.expected, recorder.Body.String())
|
Title: contextControllerCfg.Title,
|
||||||
})
|
AppURL: contextControllerCfg.AppURL,
|
||||||
}
|
CookieDomain: contextControllerCfg.CookieDomain,
|
||||||
|
ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
|
||||||
|
BackgroundImage: contextControllerCfg.BackgroundImage,
|
||||||
|
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
|
||||||
|
WarningsEnabled: contextControllerCfg.WarningsEnabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
router, recorder := setupContextController(nil)
|
||||||
|
req := httptest.NewRequest("GET", "/api/context/app", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
|
var ctrlRes controller.AppContextResponse
|
||||||
|
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
assert.DeepEqual(t, expectedRes, ctrlRes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserContextHandler(t *testing.T) {
|
||||||
|
expectedRes := controller.UserContextResponse{
|
||||||
|
Status: 200,
|
||||||
|
Message: "Success",
|
||||||
|
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
|
||||||
|
Username: contextCtrlTestContext.Username,
|
||||||
|
Name: contextCtrlTestContext.Name,
|
||||||
|
Email: contextCtrlTestContext.Email,
|
||||||
|
Provider: contextCtrlTestContext.Provider,
|
||||||
|
OAuth: contextCtrlTestContext.OAuth,
|
||||||
|
TotpPending: contextCtrlTestContext.TotpPending,
|
||||||
|
OAuthName: contextCtrlTestContext.OAuthName,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with context
|
||||||
|
router, recorder := setupContextController(&[]gin.HandlerFunc{
|
||||||
|
func(c *gin.Context) {
|
||||||
|
c.Set("context", &contextCtrlTestContext)
|
||||||
|
c.Next()
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/api/context/user", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
|
var ctrlRes controller.UserContextResponse
|
||||||
|
|
||||||
|
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
assert.DeepEqual(t, expectedRes, ctrlRes)
|
||||||
|
|
||||||
|
// Test no context
|
||||||
|
expectedRes = controller.UserContextResponse{
|
||||||
|
Status: 401,
|
||||||
|
Message: "Unauthorized",
|
||||||
|
IsLoggedIn: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
router, recorder = setupContextController(nil)
|
||||||
|
req = httptest.NewRequest("GET", "/api/context/user", nil)
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 200, recorder.Code)
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
assert.DeepEqual(t, expectedRes, ctrlRes)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (controller *HealthController) SetupRoutes() {
|
|||||||
|
|
||||||
func (controller *HealthController) healthHandler(c *gin.Context) {
|
func (controller *HealthController) healthHandler(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": "ok",
|
||||||
"message": "Healthy",
|
"message": "Healthy",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,71 +0,0 @@
|
|||||||
package controller_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHealthController(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
description string
|
|
||||||
path string
|
|
||||||
method string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
description: "Ensure health endpoint returns 200 OK",
|
|
||||||
path: "/api/healthz",
|
|
||||||
method: "GET",
|
|
||||||
expected: func() string {
|
|
||||||
expectedHealthResponse := map[string]any{
|
|
||||||
"status": 200,
|
|
||||||
"message": "Healthy",
|
|
||||||
}
|
|
||||||
bytes, err := json.Marshal(expectedHealthResponse)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
return string(bytes)
|
|
||||||
}(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure health endpoint returns 200 OK for HEAD request",
|
|
||||||
path: "/api/healthz",
|
|
||||||
method: "HEAD",
|
|
||||||
expected: func() string {
|
|
||||||
expectedHealthResponse := map[string]any{
|
|
||||||
"status": 200,
|
|
||||||
"message": "Healthy",
|
|
||||||
}
|
|
||||||
bytes, err := json.Marshal(expectedHealthResponse)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
return string(bytes)
|
|
||||||
}(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
router := gin.Default()
|
|
||||||
group := router.Group("/api")
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
healthController := controller.NewHealthController(group)
|
|
||||||
healthController.SetupRoutes()
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
request, err := http.NewRequest(test.method, test.path, nil)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, request)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, test.expected, recorder.Body.String())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -235,7 +235,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Error().Msg("Missing authorization header")
|
tlog.App.Error().Msg("Missing authorization header")
|
||||||
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`)
|
c.Header("www-authenticate", "basic")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "invalid_client",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,465 +2,280 @@ package controller_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-querystring/query"
|
||||||
"github.com/steveiliop56/tinyauth/internal/bootstrap"
|
"github.com/steveiliop56/tinyauth/internal/bootstrap"
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||||
"github.com/steveiliop56/tinyauth/internal/service"
|
"github.com/steveiliop56/tinyauth/internal/service"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
|
"gotest.tools/v3/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOIDCController(t *testing.T) {
|
var oidcServiceConfig = service.OIDCServiceConfig{
|
||||||
oidcServiceCfg := service.OIDCServiceConfig{
|
|
||||||
Clients: map[string]config.OIDCClientConfig{
|
Clients: map[string]config.OIDCClientConfig{
|
||||||
"test": {
|
"client1": {
|
||||||
ClientID: "some-client-id",
|
ClientID: "some-client-id",
|
||||||
ClientSecret: "some-client-secret",
|
ClientSecret: "some-client-secret",
|
||||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
ClientSecretFile: "",
|
||||||
Name: "Test Client",
|
TrustedRedirectURIs: []string{
|
||||||
|
"https://example.com/oauth/callback",
|
||||||
|
},
|
||||||
|
Name: "Client 1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
PrivateKeyPath: "/tmp/tinyauth_testing_key.pem",
|
PrivateKeyPath: "/tmp/tinyauth_oidc_key",
|
||||||
PublicKeyPath: "/tmp/tinyauth_testing_key.pub",
|
PublicKeyPath: "/tmp/tinyauth_oidc_key.pub",
|
||||||
Issuer: "https://tinyauth.example.com",
|
Issuer: "https://example.com",
|
||||||
SessionExpiry: 500,
|
SessionExpiry: 3600,
|
||||||
}
|
}
|
||||||
|
|
||||||
controllerCfg := controller.OIDCControllerConfig{}
|
var oidcCtrlTestContext = config.UserContext{
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: "test",
|
Username: "test",
|
||||||
Name: "Test User",
|
Name: "Test",
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
Provider: "local",
|
IsBasicAuth: false,
|
||||||
})
|
OAuth: false,
|
||||||
c.Next()
|
Provider: "ldap", // ldap in order to test the groups
|
||||||
}
|
TotpPending: false,
|
||||||
|
OAuthGroups: "",
|
||||||
|
TotpEnabled: false,
|
||||||
|
OAuthName: "",
|
||||||
|
OAuthSub: "",
|
||||||
|
LdapGroups: "test1,test2",
|
||||||
|
}
|
||||||
|
|
||||||
type testCase struct {
|
// Test is not amazing, but it will confirm the OIDC server works
|
||||||
description string
|
func TestOIDCController(t *testing.T) {
|
||||||
middlewares []gin.HandlerFunc
|
tlog.NewSimpleLogger().Init()
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
|
||||||
}
|
|
||||||
|
|
||||||
var tests []testCase
|
|
||||||
|
|
||||||
getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) {
|
|
||||||
for _, test := range tests {
|
|
||||||
if test.description == description {
|
|
||||||
return test.run, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
tests = []testCase{
|
|
||||||
{
|
|
||||||
description: "Ensure we can fetch the client",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure API fails on non-existent client ID",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, 404, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure authorize fails with empty context",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
var res map[string]any
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure authorize fails with an invalid param",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
reqBody := service.AuthorizeRequest{
|
|
||||||
Scope: "openid",
|
|
||||||
ResponseType: "some_unsupported_response_type",
|
|
||||||
ClientID: "some-client-id",
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
State: "some-state",
|
|
||||||
Nonce: "some-nonce",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
var res map[string]any
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure authorize succeeds with valid params",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
reqBody := service.AuthorizeRequest{
|
|
||||||
Scope: "openid",
|
|
||||||
ResponseType: "code",
|
|
||||||
ClientID: "some-client-id",
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
State: "some-state",
|
|
||||||
Nonce: "some-nonce",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
|
|
||||||
var res map[string]any
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
redirectURI := res["redirect_uri"].(string)
|
|
||||||
url, err := url.Parse(redirectURI)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
queryParams := url.Query()
|
|
||||||
assert.Equal(t, queryParams.Get("state"), "some-state")
|
|
||||||
|
|
||||||
code := queryParams.Get("code")
|
|
||||||
assert.NotEmpty(t, code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure token request fails with invalid grant",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
reqBody := controller.TokenRequest{
|
|
||||||
GrantType: "invalid_grant",
|
|
||||||
Code: "",
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
var res map[string]any
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &res)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, res["error"], "unsupported_grant_type")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure token endpoint accepts basic auth",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
reqBody := controller.TokenRequest{
|
|
||||||
GrantType: "authorization_code",
|
|
||||||
Code: "some-code",
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.SetBasicAuth("some-client-id", "some-client-secret")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Empty(t, recorder.Header().Get("www-authenticate"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure token endpoint accepts form auth",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
form := url.Values{}
|
|
||||||
form.Set("grant_type", "authorization_code")
|
|
||||||
form.Set("code", "some-code")
|
|
||||||
form.Set("redirect_uri", "https://test.example.com/callback")
|
|
||||||
form.Set("client_id", "some-client-id")
|
|
||||||
form.Set("client_secret", "some-client-secret")
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Empty(t, recorder.Header().Get("www-authenticate"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure token endpoint sets authenticate header when no auth is available",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
reqBody := controller.TokenRequest{
|
|
||||||
GrantType: "authorization_code",
|
|
||||||
Code: "some-code",
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
authHeader := recorder.Header().Get("www-authenticate")
|
|
||||||
assert.Contains(t, authHeader, "Basic")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure we can get a token with a valid request",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
|
|
||||||
assert.True(t, found, "Authorize test not found")
|
|
||||||
authorizeTestRecorder := httptest.NewRecorder()
|
|
||||||
authorizeCodeTest(t, router, authorizeTestRecorder)
|
|
||||||
|
|
||||||
var authorizeRes map[string]any
|
|
||||||
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
redirectURI := authorizeRes["redirect_uri"].(string)
|
|
||||||
url, err := url.Parse(redirectURI)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
queryParams := url.Query()
|
|
||||||
code := queryParams.Get("code")
|
|
||||||
assert.NotEmpty(t, code)
|
|
||||||
|
|
||||||
reqBody := controller.TokenRequest{
|
|
||||||
GrantType: "authorization_code",
|
|
||||||
Code: code,
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.SetBasicAuth("some-client-id", "some-client-secret")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure we can renew the access token with the refresh token",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
|
|
||||||
assert.True(t, found, "Token test not found")
|
|
||||||
tokenRecorder := httptest.NewRecorder()
|
|
||||||
tokenTest(t, router, tokenRecorder)
|
|
||||||
|
|
||||||
var tokenRes map[string]any
|
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, ok := tokenRes["refresh_token"]
|
|
||||||
assert.True(t, ok, "Expected refresh token in response")
|
|
||||||
refreshToken := tokenRes["refresh_token"].(string)
|
|
||||||
assert.NotEmpty(t, refreshToken)
|
|
||||||
|
|
||||||
reqBody := controller.TokenRequest{
|
|
||||||
GrantType: "refresh_token",
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
ClientID: "some-client-id",
|
|
||||||
ClientSecret: "some-client-secret",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.NotEmpty(t, recorder.Header().Get("cache-control"))
|
|
||||||
assert.NotEmpty(t, recorder.Header().Get("pragma"))
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
var refreshRes map[string]any
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, ok = refreshRes["access_token"]
|
|
||||||
assert.True(t, ok, "Expected access token in refresh response")
|
|
||||||
assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string))
|
|
||||||
assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure token endpoint deletes code afer use",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
|
|
||||||
assert.True(t, found, "Authorize test not found")
|
|
||||||
authorizeTestRecorder := httptest.NewRecorder()
|
|
||||||
authorizeCodeTest(t, router, authorizeTestRecorder)
|
|
||||||
|
|
||||||
var authorizeRes map[string]any
|
|
||||||
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
redirectURI := authorizeRes["redirect_uri"].(string)
|
|
||||||
url, err := url.Parse(redirectURI)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
queryParams := url.Query()
|
|
||||||
code := queryParams.Get("code")
|
|
||||||
assert.NotEmpty(t, code)
|
|
||||||
|
|
||||||
reqBody := controller.TokenRequest{
|
|
||||||
GrantType: "authorization_code",
|
|
||||||
Code: code,
|
|
||||||
RedirectURI: "https://test.example.com/callback",
|
|
||||||
}
|
|
||||||
reqBodyBytes, err := json.Marshal(reqBody)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.SetBasicAuth("some-client-id", "some-client-secret")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
|
|
||||||
// Try to use the same code again
|
|
||||||
secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
|
|
||||||
secondReq.Header.Set("Content-Type", "application/json")
|
|
||||||
secondReq.SetBasicAuth("some-client-id", "some-client-secret")
|
|
||||||
secondRecorder := httptest.NewRecorder()
|
|
||||||
router.ServeHTTP(secondRecorder, secondReq)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, secondRecorder.Code)
|
|
||||||
|
|
||||||
var secondRes map[string]any
|
|
||||||
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, secondRes["error"], "invalid_grant")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure userinfo forbids access with invalid access token",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
|
|
||||||
req.Header.Set("Authorization", "Bearer invalid-access-token")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure access token can be used to access protected resources",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
|
|
||||||
assert.True(t, found, "Token test not found")
|
|
||||||
tokenRecorder := httptest.NewRecorder()
|
|
||||||
tokenTest(t, router, tokenRecorder)
|
|
||||||
|
|
||||||
var tokenRes map[string]any
|
|
||||||
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
accessToken := tokenRes["access_token"].(string)
|
|
||||||
assert.NotEmpty(t, accessToken)
|
|
||||||
|
|
||||||
protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
|
|
||||||
protectedReq.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
router.ServeHTTP(recorder, protectedReq)
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
|
|
||||||
var userInfoRes map[string]any
|
|
||||||
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, ok := userInfoRes["sub"]
|
|
||||||
assert.True(t, ok, "Expected sub claim in userinfo response")
|
|
||||||
|
|
||||||
// We should not have an email claim since we didn't request it in the scope
|
|
||||||
_, ok = userInfoRes["email"]
|
|
||||||
assert.False(t, ok, "Did not expect email claim in userinfo response")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Create an app instance
|
||||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||||
|
|
||||||
db, err := app.SetupDatabase("/tmp/tinyauth_test.db")
|
// Get db
|
||||||
|
db, err := app.SetupDatabase("/tmp/tinyauth.db")
|
||||||
if err != nil {
|
assert.NilError(t, err)
|
||||||
t.Fatalf("Failed to set up database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Create queries
|
||||||
queries := repository.New(db)
|
queries := repository.New(db)
|
||||||
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
|
||||||
|
// Create a new OIDC Servicee
|
||||||
|
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
|
||||||
err = oidcService.Init()
|
err = oidcService.Init()
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
if err != nil {
|
// Create test router
|
||||||
t.Fatalf("Failed to initialize OIDC service: %v", err)
|
gin.SetMode(gin.TestMode)
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
|
||||||
for _, middleware := range test.middlewares {
|
router.Use(func(c *gin.Context) {
|
||||||
router.Use(middleware)
|
c.Set("context", &oidcCtrlTestContext)
|
||||||
}
|
c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
oidcController := controller.NewOIDCController(controllerCfg, oidcService, group)
|
// Register oidc controller
|
||||||
|
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group)
|
||||||
oidcController.SetupRoutes()
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
|
// Get redirect URL test
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
marshalled, err := json.Marshal(service.AuthorizeRequest{
|
||||||
|
Scope: "openid profile email groups",
|
||||||
|
ResponseType: "code",
|
||||||
|
ClientID: "some-client-id",
|
||||||
|
RedirectURI: "https://example.com/oauth/callback",
|
||||||
|
State: "some-state",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled)))
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson := map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
redirect_uri, ok := resJson["redirect_uri"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
u, err := url.Parse(redirect_uri)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
m, err := url.ParseQuery(u.RawQuery)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
assert.Equal(t, m["state"][0], "some-state")
|
||||||
|
|
||||||
|
code := m["code"][0]
|
||||||
|
|
||||||
|
// Exchange code for token
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
params, err := query.Values(controller.TokenRequest{
|
||||||
|
GrantType: "authorization_code",
|
||||||
|
Code: code,
|
||||||
|
RedirectURI: "https://example.com/oauth/callback",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson = map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
accessToken, ok := resJson["access_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
_, ok = resJson["id_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
refreshToken, ok := resJson["refresh_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
expires_in, ok := resJson["expires_in"].(float64)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
|
||||||
|
|
||||||
|
// Ensure code is expired
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
params, err = query.Values(controller.TokenRequest{
|
||||||
|
GrantType: "authorization_code",
|
||||||
|
Code: code,
|
||||||
|
RedirectURI: "https://example.com/oauth/callback",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||||
|
|
||||||
|
// Test userinfo
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson = map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
_, ok = resJson["sub"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
name, ok := resJson["name"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, name, oidcCtrlTestContext.Name)
|
||||||
|
|
||||||
|
email, ok := resJson["email"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, email, oidcCtrlTestContext.Email)
|
||||||
|
|
||||||
|
preferred_username, ok := resJson["preferred_username"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
|
||||||
|
|
||||||
|
// Not sure why this is failing, will look into it later
|
||||||
|
igroups, ok := resJson["groups"].([]any)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
|
||||||
|
groups := make([]string, len(igroups))
|
||||||
|
for i, group := range igroups {
|
||||||
|
groups[i], ok = group.(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
|
||||||
|
|
||||||
|
// Test refresh token
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
|
||||||
|
params, err = query.Values(controller.TokenRequest{
|
||||||
|
GrantType: "refresh_token",
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth("some-client-id", "some-client-secret")
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
|
|
||||||
|
resJson = map[string]any{}
|
||||||
|
|
||||||
|
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
newToken, ok := resJson["access_token"].(string)
|
||||||
|
assert.Assert(t, ok)
|
||||||
|
assert.Assert(t, newToken != accessToken)
|
||||||
|
|
||||||
|
// Ensure old token is invalid
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||||
|
|
||||||
|
// Test new token
|
||||||
|
recorder = httptest.NewRecorder()
|
||||||
|
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
|
||||||
|
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,11 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// hard-defaults, may make configurable in the future if needed,
|
||||||
|
// but for now these are just safety limits to prevent unbounded memory usage
|
||||||
const MaxOAuthPendingSessions = 256
|
const MaxOAuthPendingSessions = 256
|
||||||
const OAuthCleanupCount = 16
|
const OAuthCleanupCount = 16
|
||||||
|
const MaxLoginAttemptRecords = 256
|
||||||
|
|
||||||
type OAuthPendingSession struct {
|
type OAuthPendingSession struct {
|
||||||
State string
|
State string
|
||||||
@@ -43,6 +46,11 @@ type LoginAttempt struct {
|
|||||||
LockedUntil time.Time
|
LockedUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Lockdown struct {
|
||||||
|
Active bool
|
||||||
|
ActiveUntil time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type AuthServiceConfig struct {
|
type AuthServiceConfig struct {
|
||||||
Users []config.User
|
Users []config.User
|
||||||
OauthWhitelist []string
|
OauthWhitelist []string
|
||||||
@@ -69,6 +77,7 @@ type AuthService struct {
|
|||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries *repository.Queries
|
queries *repository.Queries
|
||||||
oauthBroker *OAuthBrokerService
|
oauthBroker *OAuthBrokerService
|
||||||
|
lockdown *Lockdown
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||||
@@ -202,6 +211,11 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
auth.loginMutex.RLock()
|
auth.loginMutex.RLock()
|
||||||
defer auth.loginMutex.RUnlock()
|
defer auth.loginMutex.RUnlock()
|
||||||
|
|
||||||
|
if auth.lockdown != nil && auth.lockdown.Active {
|
||||||
|
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
|
||||||
|
return true, remaining
|
||||||
|
}
|
||||||
|
|
||||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
@@ -227,6 +241,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
auth.loginMutex.Lock()
|
auth.loginMutex.Lock()
|
||||||
defer auth.loginMutex.Unlock()
|
defer auth.loginMutex.Unlock()
|
||||||
|
|
||||||
|
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
|
||||||
|
if auth.lockdown != nil && auth.lockdown.Active {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go auth.lockdownMode()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
attempt, exists := auth.loginAttempts[identifier]
|
attempt, exists := auth.loginAttempts[identifier]
|
||||||
if !exists {
|
if !exists {
|
||||||
attempt = &LoginAttempt{}
|
attempt = &LoginAttempt{}
|
||||||
@@ -746,3 +768,31 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) lockdownMode() {
|
||||||
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
|
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
||||||
|
|
||||||
|
auth.lockdown = &Lockdown{
|
||||||
|
Active: true,
|
||||||
|
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point all login attemps will also expire so,
|
||||||
|
// we might as well clear them to free up memory
|
||||||
|
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||||
|
|
||||||
|
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
auth.loginMutex.Unlock()
|
||||||
|
|
||||||
|
<-timer.C
|
||||||
|
|
||||||
|
auth.loginMutex.Lock()
|
||||||
|
|
||||||
|
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
||||||
|
auth.lockdown = nil
|
||||||
|
auth.loginMutex.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user