mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-04-18 19:56:08 +00:00
Compare commits
5 Commits
b2ab3d0f37
...
c7bb6d61af
| Author | SHA1 | Date | |
|---|---|---|---|
| c7bb6d61af | |||
| fa25740546 | |||
| d3cda06a75 | |||
| 23e0da96a6 | |||
| 39beed706b |
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -431,18 +432,12 @@ func TestOIDCController(t *testing.T) {
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase("/tmp/tinyauth_test.db")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set up database: %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||
err = oidcService.Init()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize OIDC service: %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
@@ -463,4 +458,16 @@ func TestOIDCController(t *testing.T) {
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove("/tmp/tinyauth_test.db")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove(oidcServiceCfg.PrivateKeyPath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove(oidcServiceCfg.PublicKeyPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -5,55 +5,84 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gotest.tools/v3/assert"
|
||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResourcesHandler(t *testing.T) {
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
group := router.Group("/")
|
||||
|
||||
ctrl := controller.NewResourcesController(controller.ResourcesControllerConfig{
|
||||
Path: "/tmp/tinyauth",
|
||||
func TestResourcesController(t *testing.T) {
|
||||
resourcesControllerCfg := controller.ResourcesControllerConfig{
|
||||
Path: "/tmp/testfiles",
|
||||
Enabled: true,
|
||||
}, group)
|
||||
ctrl.SetupRoutes()
|
||||
}
|
||||
|
||||
// Create test data
|
||||
err := os.Mkdir("/tmp/tinyauth", 0755)
|
||||
assert.NilError(t, err)
|
||||
defer os.RemoveAll("/tmp/tinyauth")
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
file, err := os.Create("/tmp/tinyauth/test.txt")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString("This is a test file.")
|
||||
assert.NilError(t, err)
|
||||
file.Close()
|
||||
|
||||
// Test existing file
|
||||
req := httptest.NewRequest("GET", "/resources/test.txt", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Ensure resources endpoint returns 200 OK for existing file",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "This is a test file.", recorder.Body.String())
|
||||
|
||||
// Test non-existing file
|
||||
req = httptest.NewRequest("GET", "/resources/nonexistent.txt", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure resources endpoint returns 404 Not Found for non-existing file",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/nonexistent.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
|
||||
// Test directory traversal attack
|
||||
req = httptest.NewRequest("GET", "/resources/../etc/passwd", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure resources controller denies path traversal",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/resources/../somefile.txt", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := os.MkdirAll(resourcesControllerCfg.Path, 0777)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testFilePath := resourcesControllerCfg.Path + "/testfile.txt"
|
||||
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testFilePathParent := resourcesControllerCfg.Path + "/../somefile.txt"
|
||||
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
router := gin.Default()
|
||||
group := router.Group("/")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
resourcesController := controller.NewResourcesController(resourcesControllerCfg, group)
|
||||
resourcesController.SetupRoutes()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
err = os.Remove(testFilePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove(testFilePathParent)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove(resourcesControllerCfg.Path)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -2,305 +2,353 @@ package controller_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/steveiliop56/tinyauth/internal/bootstrap"
|
||||
"github.com/steveiliop56/tinyauth/internal/config"
|
||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||
"github.com/steveiliop56/tinyauth/internal/service"
|
||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"gotest.tools/v3/assert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var cookieValue string
|
||||
var totpSecret = "6WFZXPEZRK5MZHHYAFW4DAOUYQMCASBJ"
|
||||
|
||||
func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
|
||||
tlog.NewSimpleLogger().Init()
|
||||
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.Default()
|
||||
|
||||
if middlewares != nil {
|
||||
for _, m := range *middlewares {
|
||||
router.Use(m)
|
||||
}
|
||||
}
|
||||
|
||||
group := router.Group("/api")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
// Mock app
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
// Database
|
||||
db, err := app.SetupDatabase(":memory:")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
// Queries
|
||||
queries := repository.New(db)
|
||||
|
||||
// Auth service
|
||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||
func TestUserController(t *testing.T) {
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
Users: []config.User{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa",
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
||||
TotpSecret: totpSecret,
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa",
|
||||
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
},
|
||||
},
|
||||
OauthWhitelist: []string{},
|
||||
SessionExpiry: 3600,
|
||||
SessionMaxLifetime: 0,
|
||||
SecureCookie: false,
|
||||
CookieDomain: "localhost",
|
||||
LoginTimeout: 300,
|
||||
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||
CookieDomain: "example.com",
|
||||
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||
LoginMaxRetries: 3,
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}, nil, nil, queries, &service.OAuthBrokerService{})
|
||||
|
||||
// Controller
|
||||
ctrl := controller.NewUserController(controller.UserControllerConfig{
|
||||
CookieDomain: "localhost",
|
||||
}, group, authService)
|
||||
ctrl.SetupRoutes()
|
||||
|
||||
return router, recorder
|
||||
}
|
||||
|
||||
func TestLoginHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder := setupUserController(t, nil)
|
||||
userControllerCfg := controller.UserControllerConfig{
|
||||
CookieDomain: "example.com",
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
middlewares []gin.HandlerFunc
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Should be able to login with valid credentials",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "test",
|
||||
Password: "password",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
loginReqJson, err := json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Test
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Assert(t, cookie.Value != "")
|
||||
|
||||
cookieValue = cookie.Value
|
||||
|
||||
// Test invalid credentials
|
||||
loginReq = controller.LoginRequest{
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.Equal(t, cookie.MaxAge, 10)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should reject login with invalid credentials",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "invalid",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
loginReqJson, err = json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
|
||||
// Test totp required
|
||||
loginReq = controller.LoginRequest{
|
||||
Username: "totpuser",
|
||||
Password: "test",
|
||||
}
|
||||
|
||||
loginReqJson, err = json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
loginResJson, err := json.Marshal(map[string]any{
|
||||
"message": "TOTP required",
|
||||
"status": 200,
|
||||
"totpPending": true,
|
||||
})
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, string(loginResJson), recorder.Body.String())
|
||||
|
||||
// Test invalid json
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader("{invalid json}"))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
|
||||
// Test rate limiting
|
||||
loginReq = controller.LoginRequest{
|
||||
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should rate limit on 3 invalid attempts",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "invalid",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
loginReqJson, err = json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
for range 3 {
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
for range 5 {
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 0)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
}
|
||||
|
||||
// 4th attempt should be rate limited
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 429, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Too many failed login attempts.")
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should not allow full login with totp",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "totpuser",
|
||||
Password: "password",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
func TestLogoutHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder := setupUserController(t, nil)
|
||||
|
||||
// Test
|
||||
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "tinyauth-session",
|
||||
Value: cookieValue,
|
||||
})
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, decodedBody["totpPending"], true)
|
||||
|
||||
// should set the session cookie
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Equal(t, "", cookie.Value)
|
||||
assert.Equal(t, -1, cookie.MaxAge)
|
||||
}
|
||||
|
||||
func TestTotpHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder := setupUserController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.Equal(t, cookie.MaxAge, 3600) // 1 hour, default for totp pending sessions
|
||||
},
|
||||
})
|
||||
},
|
||||
{
|
||||
description: "Should be able to logout",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
// First login to get a session cookie
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
}
|
||||
loginReqBody, err := json.Marshal(loginReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test
|
||||
code, err := totp.GenerateCode(totpSecret, time.Now())
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
assert.NilError(t, err)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
|
||||
// Now logout using the session cookie
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||
req.AddCookie(cookie)
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
|
||||
logoutCookie := recorder.Result().Cookies()[0]
|
||||
assert.Equal(t, "tinyauth-session", logoutCookie.Name)
|
||||
assert.Equal(t, "", logoutCookie.Value)
|
||||
assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Should be able to login with totp",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
|
||||
assert.NoError(t, err)
|
||||
|
||||
totpReq := controller.TotpRequest{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
totpReqJson, err := json.Marshal(totpReq)
|
||||
assert.NilError(t, err)
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Len(t, recorder.Result().Cookies(), 1)
|
||||
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Assert(t, cookie.Value != "")
|
||||
|
||||
// Test invalid json
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader("{invalid json}"))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
|
||||
// Test rate limiting
|
||||
totpReq = controller.TotpRequest{
|
||||
Code: "000000",
|
||||
// should set a new session cookie with totp pending removed
|
||||
totpCookie := recorder.Result().Cookies()[0]
|
||||
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
||||
assert.True(t, totpCookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", totpCookie.Domain)
|
||||
assert.Equal(t, totpCookie.MaxAge, 10) // should use the regular session expiry time
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Totp should rate limit on multiple invalid attempts",
|
||||
middlewares: []gin.HandlerFunc{},
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
for range 3 {
|
||||
totpReq := controller.TotpRequest{
|
||||
Code: "000000", // invalid code
|
||||
}
|
||||
|
||||
totpReqJson, err = json.Marshal(totpReq)
|
||||
assert.NilError(t, err)
|
||||
totpReqBody, err := json.Marshal(totpReq)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for range 5 {
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
||||
}
|
||||
|
||||
// 4th attempt should be rate limited
|
||||
recorder = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(`{"code":"000000"}`)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 429, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Too many failed TOTP attempts.")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test invalid code
|
||||
router, recorder = setupUserController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
tlog.NewSimpleLogger().Init()
|
||||
|
||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase("/tmp/tinyauth_test.db")
|
||||
assert.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
docker := service.NewDockerService()
|
||||
err = docker.Init()
|
||||
assert.NoError(t, err)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
err = ldap.Init()
|
||||
assert.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||
err = broker.Init()
|
||||
assert.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, docker, ldap, queries, broker)
|
||||
err = authService.Init()
|
||||
assert.NoError(t, err)
|
||||
|
||||
beforeEach := func() {
|
||||
// Clear failed login attempts before each test
|
||||
authService.ClearRateLimitsTestingOnly()
|
||||
}
|
||||
|
||||
setTotpMiddlewareOverrides := []string{
|
||||
"Should be able to login with totp",
|
||||
"Totp should rate limit on multiple invalid attempts",
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
beforeEach()
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
router := gin.Default()
|
||||
|
||||
for _, middleware := range test.middlewares {
|
||||
router.Use(middleware)
|
||||
}
|
||||
|
||||
// Gin is stupid and doesn't allow setting a middleware after the groups
|
||||
// so we need to do some stupid overrides here
|
||||
if slices.Contains(setTotpMiddlewareOverrides, test.description) {
|
||||
// Assuming the cookie is set, it should be picked up by the
|
||||
// context middleware
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "totpuser",
|
||||
Name: "Totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
|
||||
// Test no totp pending
|
||||
router, recorder = setupUserController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "local",
|
||||
TotpPending: false,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: false,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
}
|
||||
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
userController := controller.NewUserController(userControllerCfg, group, authService)
|
||||
userController.SetupRoutes()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove("/tmp/tinyauth_test.db")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/steveiliop56/tinyauth/internal/bootstrap"
|
||||
"github.com/steveiliop56/tinyauth/internal/config"
|
||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||
"github.com/steveiliop56/tinyauth/internal/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWellKnownController(t *testing.T) {
|
||||
oidcServiceCfg := service.OIDCServiceConfig{
|
||||
Clients: map[string]config.OIDCClientConfig{
|
||||
"test": {
|
||||
ClientID: "some-client-id",
|
||||
ClientSecret: "some-client-secret",
|
||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||
Name: "Test Client",
|
||||
},
|
||||
},
|
||||
PrivateKeyPath: "/tmp/tinyauth_testing_key.pem",
|
||||
PublicKeyPath: "/tmp/tinyauth_testing_key.pub",
|
||||
Issuer: "https://tinyauth.example.com",
|
||||
SessionExpiry: 500,
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
res := controller.OpenIDConnectConfiguration{}
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := controller.OpenIDConnectConfiguration{
|
||||
Issuer: oidcServiceCfg.Issuer,
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer),
|
||||
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer),
|
||||
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer),
|
||||
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer),
|
||||
ScopesSupported: service.SupportedScopes,
|
||||
ResponseTypesSupported: service.SupportedResponseTypes,
|
||||
GrantTypesSupported: service.SupportedGrantTypes,
|
||||
SubjectTypesSupported: []string{"pairwise"},
|
||||
IDTokenSigningAlgValuesSupported: []string{"RS256"},
|
||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups"},
|
||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||
}
|
||||
|
||||
assert.Equal(t, res, expected)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure well-known endpoint returns correct JWKS",
|
||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
decodedBody := make(map[string]any)
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||
assert.NoError(t, err)
|
||||
|
||||
keys, ok := decodedBody["keys"].([]any)
|
||||
assert.True(t, ok)
|
||||
assert.Len(t, keys, 1)
|
||||
|
||||
keyData, ok := keys[0].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "RSA", keyData["kty"])
|
||||
assert.Equal(t, "sig", keyData["use"])
|
||||
assert.Equal(t, "RS256", keyData["alg"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app := bootstrap.NewBootstrapApp(config.Config{})
|
||||
|
||||
db, err := app.SetupDatabase("/tmp/tinyauth_test.db")
|
||||
assert.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
|
||||
err = oidcService.Init()
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
router := gin.Default()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router)
|
||||
wellKnownController.SetupRoutes()
|
||||
|
||||
test.run(t, router, recorder)
|
||||
})
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.Remove("/tmp/tinyauth_test.db")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -21,8 +21,11 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// hard-defaults, may make configurable in the future if needed,
|
||||
// but for now these are just safety limits to prevent unbounded memory usage
|
||||
const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
const MaxLoginAttemptRecords = 256
|
||||
|
||||
type OAuthPendingSession struct {
|
||||
State string
|
||||
@@ -43,6 +46,11 @@ type LoginAttempt struct {
|
||||
LockedUntil time.Time
|
||||
}
|
||||
|
||||
type Lockdown struct {
|
||||
Active bool
|
||||
ActiveUntil time.Time
|
||||
}
|
||||
|
||||
type AuthServiceConfig struct {
|
||||
Users []config.User
|
||||
OauthWhitelist []string
|
||||
@@ -69,6 +77,7 @@ type AuthService struct {
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
oauthBroker *OAuthBrokerService
|
||||
lockdown *Lockdown
|
||||
}
|
||||
|
||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||
@@ -202,6 +211,11 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
auth.loginMutex.RLock()
|
||||
defer auth.loginMutex.RUnlock()
|
||||
|
||||
if auth.lockdown != nil && auth.lockdown.Active {
|
||||
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
|
||||
return true, remaining
|
||||
}
|
||||
|
||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||
return false, 0
|
||||
}
|
||||
@@ -227,6 +241,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
auth.loginMutex.Lock()
|
||||
defer auth.loginMutex.Unlock()
|
||||
|
||||
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
|
||||
if auth.lockdown != nil && auth.lockdown.Active {
|
||||
return
|
||||
}
|
||||
go auth.lockdownMode()
|
||||
return
|
||||
}
|
||||
|
||||
attempt, exists := auth.loginAttempts[identifier]
|
||||
if !exists {
|
||||
attempt = &LoginAttempt{}
|
||||
@@ -746,3 +768,38 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) lockdownMode() {
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
||||
|
||||
auth.lockdown = &Lockdown{
|
||||
Active: true,
|
||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
||||
}
|
||||
|
||||
// At this point all login attemps will also expire so,
|
||||
// we might as well clear them to free up memory
|
||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||
|
||||
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
||||
defer timer.Stop()
|
||||
|
||||
auth.loginMutex.Unlock()
|
||||
|
||||
<-timer.C
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
||||
auth.lockdown = nil
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
|
||||
// Function only used for testing - do not use in prod!
|
||||
func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
||||
auth.loginMutex.Lock()
|
||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user