mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-16 04:12:29 +00:00
Compare commits
14 Commits
refactor/e
...
74cb8067a8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74cb8067a8 | ||
|
|
ba46493a7b | ||
|
|
bb0373758a | ||
|
|
f8836fc964 | ||
|
|
53856e0a70 | ||
|
|
9b7dcfd86f | ||
|
|
7afea8b3fc | ||
|
|
f5ac7eff99 | ||
|
|
b024d5ffda | ||
|
|
773cd6d171 | ||
|
|
f3eb7f69b4 | ||
|
|
f0d2da281a | ||
|
|
9ce16c9652 | ||
|
|
ad4fc7ef5f |
@@ -112,6 +112,7 @@ func init() {
|
||||
{"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."},
|
||||
{"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."},
|
||||
{"database-path", "/data/tinyauth.db", "Path to the Sqlite database file."},
|
||||
{"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses) for correct client IP detection and for header ACLs."},
|
||||
}
|
||||
|
||||
for _, opt := range configOptions {
|
||||
|
||||
12
go.mod
12
go.mod
@@ -5,6 +5,7 @@ go 1.23.2
|
||||
require (
|
||||
github.com/cenkalti/backoff/v5 v5.0.3
|
||||
github.com/gin-gonic/gin v1.10.1
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/go-playground/validator/v10 v10.27.0
|
||||
github.com/golang-migrate/migrate/v4 v4.18.3
|
||||
github.com/google/go-querystring v1.1.0
|
||||
@@ -15,9 +16,9 @@ require (
|
||||
github.com/spf13/viper v1.20.1
|
||||
github.com/traefik/paerser v0.2.2
|
||||
golang.org/x/crypto v0.41.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
|
||||
gorm.io/gorm v1.30.1
|
||||
modernc.org/sqlite v1.38.2
|
||||
gotest.tools/v3 v3.5.2
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -28,9 +29,9 @@ require (
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/glebarez/sqlite v1.11.0 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
@@ -44,12 +45,11 @@ require (
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.34.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
||||
golang.org/x/term v0.34.0 // indirect
|
||||
gotest.tools/v3 v3.5.2 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.2 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
)
|
||||
|
||||
@@ -86,8 +86,6 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/goccy/go-json v0.10.4 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/gorilla/sessions v1.4.0
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -132,16 +132,10 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
@@ -380,8 +374,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4=
|
||||
gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
|
||||
@@ -146,6 +146,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
// Create engine
|
||||
engine := gin.New()
|
||||
engine.SetTrustedProxies(strings.Split(app.Config.TrustedProxies, ","))
|
||||
|
||||
if config.Version != "development" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
@@ -53,6 +53,7 @@ type Config struct {
|
||||
LdapSearchFilter string `mapstructure:"ldap-search-filter"`
|
||||
ResourcesDir string `mapstructure:"resources-dir"`
|
||||
DatabasePath string `mapstructure:"database-path" validate:"required"`
|
||||
TrustedProxies string `mapstructure:"trusted-proxies"`
|
||||
}
|
||||
|
||||
// OAuth/OIDC config
|
||||
@@ -125,51 +126,51 @@ type RedirectQuery struct {
|
||||
|
||||
// Labels
|
||||
|
||||
type Labels struct {
|
||||
Apps map[string]AppLabels
|
||||
type Apps struct {
|
||||
Apps map[string]App
|
||||
}
|
||||
|
||||
type AppLabels struct {
|
||||
Config ConfigLabels
|
||||
Users UsersLabels
|
||||
OAuth OAuthLabels
|
||||
IP IPLabels
|
||||
Response ResponseLabels
|
||||
Path PathLabels
|
||||
type App struct {
|
||||
Config AppConfig
|
||||
Users AppUsers
|
||||
OAuth AppOAuth
|
||||
IP AppIP
|
||||
Response AppResponse
|
||||
Path AppPath
|
||||
}
|
||||
|
||||
type ConfigLabels struct {
|
||||
type AppConfig struct {
|
||||
Domain string
|
||||
}
|
||||
|
||||
type UsersLabels struct {
|
||||
type AppUsers struct {
|
||||
Allow string
|
||||
Block string
|
||||
}
|
||||
|
||||
type OAuthLabels struct {
|
||||
type AppOAuth struct {
|
||||
Whitelist string
|
||||
Groups string
|
||||
}
|
||||
|
||||
type IPLabels struct {
|
||||
type AppIP struct {
|
||||
Allow []string
|
||||
Block []string
|
||||
Bypass []string
|
||||
}
|
||||
|
||||
type ResponseLabels struct {
|
||||
type AppResponse struct {
|
||||
Headers []string
|
||||
BasicAuth BasicAuthLabels
|
||||
BasicAuth AppBasicAuth
|
||||
}
|
||||
|
||||
type BasicAuthLabels struct {
|
||||
type AppBasicAuth struct {
|
||||
Username string
|
||||
Password string
|
||||
PasswordFile string
|
||||
}
|
||||
|
||||
type PathLabels struct {
|
||||
type AppPath struct {
|
||||
Allow string
|
||||
Block string
|
||||
}
|
||||
|
||||
135
internal/controller/context_controller_test.go
Normal file
135
internal/controller/context_controller_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/controller"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
var controllerCfg = controller.ContextControllerConfig{
|
||||
ConfiguredProviders: []string{"github", "google", "generic"},
|
||||
Title: "Test App",
|
||||
GenericName: "Generic",
|
||||
AppURL: "http://localhost:8080",
|
||||
RootDomain: "localhost",
|
||||
ForgotPasswordMessage: "Contact admin to reset your password.",
|
||||
BackgroundImage: "/assets/bg.jpg",
|
||||
OAuthAutoRedirect: "google",
|
||||
}
|
||||
|
||||
var userContext = config.UserContext{
|
||||
Username: "testuser",
|
||||
Name: "testuser",
|
||||
Email: "test@example.com",
|
||||
IsLoggedIn: true,
|
||||
OAuth: false,
|
||||
Provider: "username",
|
||||
TotpPending: false,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: false,
|
||||
}
|
||||
|
||||
func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.Default()
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
if middlewares != nil {
|
||||
for _, m := range *middlewares {
|
||||
router.Use(m)
|
||||
}
|
||||
}
|
||||
|
||||
group := router.Group("/api")
|
||||
|
||||
ctrl := controller.NewContextController(controllerCfg, group)
|
||||
ctrl.SetupRoutes()
|
||||
|
||||
return router, recorder
|
||||
}
|
||||
|
||||
func TestAppContextHandler(t *testing.T) {
|
||||
expectedRes := controller.AppContextResponse{
|
||||
Status: 200,
|
||||
Message: "Success",
|
||||
ConfiguredProviders: controllerCfg.ConfiguredProviders,
|
||||
Title: controllerCfg.Title,
|
||||
GenericName: controllerCfg.GenericName,
|
||||
AppURL: controllerCfg.AppURL,
|
||||
RootDomain: controllerCfg.RootDomain,
|
||||
ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage,
|
||||
BackgroundImage: controllerCfg.BackgroundImage,
|
||||
OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect,
|
||||
}
|
||||
|
||||
router, recorder := setupContextController(nil)
|
||||
req := httptest.NewRequest("GET", "/api/context/app", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
var ctrlRes controller.AppContextResponse
|
||||
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expectedRes, ctrlRes)
|
||||
}
|
||||
|
||||
func TestUserContextHandler(t *testing.T) {
|
||||
expectedRes := controller.UserContextResponse{
|
||||
Status: 200,
|
||||
Message: "Success",
|
||||
IsLoggedIn: userContext.IsLoggedIn,
|
||||
Username: userContext.Username,
|
||||
Name: userContext.Name,
|
||||
Email: userContext.Email,
|
||||
Provider: userContext.Provider,
|
||||
OAuth: userContext.OAuth,
|
||||
TotpPending: userContext.TotpPending,
|
||||
}
|
||||
|
||||
// Test with context
|
||||
router, recorder := setupContextController(&[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &userContext)
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/context/user", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
var ctrlRes controller.UserContextResponse
|
||||
|
||||
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expectedRes, ctrlRes)
|
||||
|
||||
// Test no context
|
||||
expectedRes = controller.UserContextResponse{
|
||||
Status: 401,
|
||||
Message: "Unauthorized",
|
||||
IsLoggedIn: false,
|
||||
}
|
||||
|
||||
router, recorder = setupContextController(nil)
|
||||
req = httptest.NewRequest("GET", "/api/context/user", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
err = json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expectedRes, ctrlRes)
|
||||
}
|
||||
@@ -108,6 +108,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil || state != csrfCookie {
|
||||
log.Warn().Err(err).Msg("CSRF token mismatch or cookie missing")
|
||||
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true)
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -55,6 +55,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Proxy != "nginx" && req.Proxy != "traefik" && req.Proxy != "caddy" {
|
||||
log.Warn().Str("proxy", req.Proxy).Msg("Invalid proxy")
|
||||
c.JSON(400, gin.H{
|
||||
"status": 400,
|
||||
"message": "Bad Request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html")
|
||||
|
||||
if isBrowser {
|
||||
@@ -251,7 +260,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()))
|
||||
}
|
||||
|
||||
func (controller *ProxyController) setHeaders(c *gin.Context, labels config.AppLabels) {
|
||||
func (controller *ProxyController) setHeaders(c *gin.Context, labels config.App) {
|
||||
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
||||
|
||||
headers := utils.ParseHeaders(labels.Response.Headers)
|
||||
|
||||
164
internal/controller/proxy_controller_test.go
Normal file
164
internal/controller/proxy_controller_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/controller"
|
||||
"tinyauth/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder, *service.AuthService) {
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.Default()
|
||||
|
||||
if middlewares != nil {
|
||||
for _, m := range *middlewares {
|
||||
router.Use(m)
|
||||
}
|
||||
}
|
||||
|
||||
group := router.Group("/api")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
// Database
|
||||
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
|
||||
DatabasePath: "/tmp/tinyauth_test.db",
|
||||
})
|
||||
|
||||
assert.NilError(t, databaseService.Init())
|
||||
|
||||
database := databaseService.GetDatabase()
|
||||
|
||||
// Docker
|
||||
dockerService := service.NewDockerService()
|
||||
|
||||
assert.NilError(t, dockerService.Init())
|
||||
|
||||
// Auth service
|
||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||
Users: []config.User{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
||||
},
|
||||
},
|
||||
OauthWhitelist: "",
|
||||
SessionExpiry: 3600,
|
||||
SecureCookie: false,
|
||||
RootDomain: "localhost",
|
||||
LoginTimeout: 300,
|
||||
LoginMaxRetries: 3,
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}, dockerService, nil, database)
|
||||
|
||||
// Controller
|
||||
ctrl := controller.NewProxyController(controller.ProxyControllerConfig{
|
||||
AppURL: "http://localhost:8080",
|
||||
}, group, dockerService, authService)
|
||||
ctrl.SetupRoutes()
|
||||
|
||||
return router, recorder, authService
|
||||
}
|
||||
|
||||
func TestProxyHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder, authService := setupProxyController(t, nil)
|
||||
|
||||
// Test invalid proxy
|
||||
req := httptest.NewRequest("GET", "/api/auth/invalidproxy", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
|
||||
// Test logged out user (traefik/caddy)
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||
req.Header.Set("X-Forwarded-Uri", "/somepath")
|
||||
req.Header.Set("Accept", "text/html")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location"))
|
||||
|
||||
// Test logged out user (nginx)
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
|
||||
// Test logged in user
|
||||
c := gin.CreateTestContextOnly(recorder, router)
|
||||
|
||||
err := authService.CreateSessionCookie(c, &config.SessionCookie{
|
||||
Username: "testuser",
|
||||
Name: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
Provider: "username",
|
||||
TotpPending: false,
|
||||
OAuthGroups: "",
|
||||
})
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
cookie := c.Writer.Header().Get("Set-Cookie")
|
||||
|
||||
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "testuser",
|
||||
Name: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
IsLoggedIn: true,
|
||||
OAuth: false,
|
||||
Provider: "username",
|
||||
TotpPending: false,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: false,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.Header.Set("Cookie", cookie)
|
||||
req.Header.Set("Accept", "text/html")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("Remote-User"))
|
||||
assert.Equal(t, "testuser", recorder.Header().Get("Remote-Name"))
|
||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("Remote-Email"))
|
||||
|
||||
// Ensure basic auth is disabled for TOTP enabled users
|
||||
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "testuser",
|
||||
Name: "testuser",
|
||||
Email: "testuser@example.com",
|
||||
IsLoggedIn: true,
|
||||
OAuth: false,
|
||||
Provider: "basic",
|
||||
TotpPending: false,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
req.SetBasicAuth("testuser", "test")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
}
|
||||
56
internal/controller/resources_controller_test.go
Normal file
56
internal/controller/resources_controller_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"tinyauth/internal/controller"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestResourcesHandler(t *testing.T) {
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
group := router.Group("/")
|
||||
|
||||
ctrl := controller.NewResourcesController(controller.ResourcesControllerConfig{
|
||||
ResourcesDir: "/tmp/tinyauth",
|
||||
}, group)
|
||||
ctrl.SetupRoutes()
|
||||
|
||||
// Create test data
|
||||
err := os.Mkdir("/tmp/tinyauth", 0755)
|
||||
assert.NilError(t, err)
|
||||
|
||||
file, err := os.Create("/tmp/tinyauth/test.txt")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString("This is a test file.")
|
||||
assert.NilError(t, err)
|
||||
file.Close()
|
||||
|
||||
// Test existing file
|
||||
req := httptest.NewRequest("GET", "/resources/test.txt", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
assert.Equal(t, "This is a test file.", recorder.Body.String())
|
||||
|
||||
// Test non-existing file
|
||||
req = httptest.NewRequest("GET", "/resources/nonexistent.txt", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
|
||||
// Test directory traversal attack
|
||||
req = httptest.NewRequest("GET", "/resources/../etc/passwd", nil)
|
||||
recorder = httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 404, recorder.Code)
|
||||
}
|
||||
297
internal/controller/user_controller_test.go
Normal file
297
internal/controller/user_controller_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/controller"
|
||||
"tinyauth/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
var cookieValue string
|
||||
var totpSecret = "6WFZXPEZRK5MZHHYAFW4DAOUYQMCASBJ"
|
||||
|
||||
func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.Default()
|
||||
|
||||
if middlewares != nil {
|
||||
for _, m := range *middlewares {
|
||||
router.Use(m)
|
||||
}
|
||||
}
|
||||
|
||||
group := router.Group("/api")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
// Database
|
||||
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
|
||||
DatabasePath: "/tmp/tinyauth_test.db",
|
||||
})
|
||||
|
||||
assert.NilError(t, databaseService.Init())
|
||||
|
||||
database := databaseService.GetDatabase()
|
||||
|
||||
// Auth service
|
||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||
Users: []config.User{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
|
||||
TotpSecret: totpSecret,
|
||||
},
|
||||
},
|
||||
OauthWhitelist: "",
|
||||
SessionExpiry: 3600,
|
||||
SecureCookie: false,
|
||||
RootDomain: "localhost",
|
||||
LoginTimeout: 300,
|
||||
LoginMaxRetries: 3,
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}, nil, nil, database)
|
||||
|
||||
// Controller
|
||||
ctrl := controller.NewUserController(controller.UserControllerConfig{
|
||||
RootDomain: "localhost",
|
||||
}, group, authService)
|
||||
ctrl.SetupRoutes()
|
||||
|
||||
return router, recorder
|
||||
}
|
||||
|
||||
func TestLoginHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder := setupUserController(t, nil)
|
||||
|
||||
loginReq := controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "test",
|
||||
}
|
||||
|
||||
loginReqJson, err := json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
// Test
|
||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Assert(t, cookie.Value != "")
|
||||
|
||||
cookieValue = cookie.Value
|
||||
|
||||
// Test invalid credentials
|
||||
loginReq = controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "invalid",
|
||||
}
|
||||
|
||||
loginReqJson, err = json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
|
||||
// Test totp required
|
||||
loginReq = controller.LoginRequest{
|
||||
Username: "totpuser",
|
||||
Password: "test",
|
||||
}
|
||||
|
||||
loginReqJson, err = json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
loginResJson, err := json.Marshal(map[string]any{
|
||||
"message": "TOTP required",
|
||||
"status": 200,
|
||||
"totpPending": true,
|
||||
})
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, string(loginResJson), recorder.Body.String())
|
||||
|
||||
// Test invalid json
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader("{invalid json}"))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
|
||||
// Test rate limiting
|
||||
loginReq = controller.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "invalid",
|
||||
}
|
||||
|
||||
loginReqJson, err = json.Marshal(loginReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
for range 5 {
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
assert.Equal(t, 429, recorder.Code)
|
||||
}
|
||||
|
||||
func TestLogoutHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder := setupUserController(t, nil)
|
||||
|
||||
// Test
|
||||
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
||||
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "tinyauth-session",
|
||||
Value: cookieValue,
|
||||
})
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Equal(t, "", cookie.Value)
|
||||
assert.Equal(t, -1, cookie.MaxAge)
|
||||
}
|
||||
|
||||
func TestTotpHandler(t *testing.T) {
|
||||
// Setup
|
||||
router, recorder := setupUserController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "username",
|
||||
TotpPending: true,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
// Test
|
||||
code, err := totp.GenerateCode(totpSecret, time.Now())
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
totpReq := controller.TotpRequest{
|
||||
Code: code,
|
||||
}
|
||||
|
||||
totpReqJson, err := json.Marshal(totpReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 200, recorder.Code)
|
||||
|
||||
cookie := recorder.Result().Cookies()[0]
|
||||
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.Assert(t, cookie.Value != "")
|
||||
|
||||
// Test invalid json
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader("{invalid json}"))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
|
||||
// Test rate limiting
|
||||
totpReq = controller.TotpRequest{
|
||||
Code: "000000",
|
||||
}
|
||||
|
||||
totpReqJson, err = json.Marshal(totpReq)
|
||||
assert.NilError(t, err)
|
||||
|
||||
for range 5 {
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
assert.Equal(t, 429, recorder.Code)
|
||||
|
||||
// Test invalid code
|
||||
router, recorder = setupUserController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "username",
|
||||
TotpPending: true,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
|
||||
// Test no totp pending
|
||||
router, recorder = setupUserController(t, &[]gin.HandlerFunc{
|
||||
func(c *gin.Context) {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: "totpuser",
|
||||
Name: "totpuser",
|
||||
Email: "totpuser@example.com",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "username",
|
||||
TotpPending: false,
|
||||
OAuthGroups: "",
|
||||
TotpEnabled: false,
|
||||
})
|
||||
c.Next()
|
||||
},
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqJson)))
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 401, recorder.Code)
|
||||
}
|
||||
@@ -285,7 +285,7 @@ func (auth *AuthService) UserAuthConfigured() bool {
|
||||
return len(auth.config.Users) > 0 || auth.ldap != nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.AppLabels) bool {
|
||||
func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.App) bool {
|
||||
if context.OAuth {
|
||||
log.Debug().Msg("Checking OAuth whitelist")
|
||||
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
|
||||
@@ -322,7 +322,7 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, path config.PathLabels) (bool, error) {
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
|
||||
// Check for block list
|
||||
if path.Block != "" {
|
||||
regex, err := regexp.Compile(path.Block)
|
||||
@@ -364,7 +364,7 @@ func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckIP(labels config.IPLabels, ip string) bool {
|
||||
func (auth *AuthService) CheckIP(labels config.AppIP, ip string) bool {
|
||||
for _, blocked := range labels.Block {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
if err != nil {
|
||||
@@ -398,7 +398,7 @@ func (auth *AuthService) CheckIP(labels config.IPLabels, ip string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsBypassedIP(labels config.IPLabels, ip string) bool {
|
||||
func (auth *AuthService) IsBypassedIP(labels config.AppIP, ip string) bool {
|
||||
for _, bypassed := range labels.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils"
|
||||
"tinyauth/internal/utils/decoders"
|
||||
|
||||
container "github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -55,17 +55,17 @@ func (docker *DockerService) DockerConnected() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (docker *DockerService) GetLabels(appDomain string) (config.AppLabels, error) {
|
||||
func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
|
||||
isConnected := docker.DockerConnected()
|
||||
|
||||
if !isConnected {
|
||||
log.Debug().Msg("Docker not connected, returning empty labels")
|
||||
return config.AppLabels{}, nil
|
||||
return config.App{}, nil
|
||||
}
|
||||
|
||||
containers, err := docker.GetContainers()
|
||||
if err != nil {
|
||||
return config.AppLabels{}, err
|
||||
return config.App{}, err
|
||||
}
|
||||
|
||||
for _, ctr := range containers {
|
||||
@@ -75,7 +75,7 @@ func (docker *DockerService) GetLabels(appDomain string) (config.AppLabels, erro
|
||||
continue
|
||||
}
|
||||
|
||||
labels, err := utils.GetLabels(inspect.Config.Labels)
|
||||
labels, err := decoders.DecodeLabels(inspect.Config.Labels)
|
||||
if err != nil {
|
||||
log.Warn().Str("id", ctr.ID).Err(err).Msg("Error getting container labels, skipping")
|
||||
continue
|
||||
@@ -95,5 +95,5 @@ func (docker *DockerService) GetLabels(appDomain string) (config.AppLabels, erro
|
||||
}
|
||||
|
||||
log.Debug().Msg("No matching container found, returning empty labels")
|
||||
return config.AppLabels{}, nil
|
||||
return config.App{}, nil
|
||||
}
|
||||
|
||||
@@ -13,13 +13,13 @@ import (
|
||||
)
|
||||
|
||||
// Get root domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||
func GetRootDomain(appUrl string) (string, error) {
|
||||
appUrlParsed, err := url.Parse(appUrl)
|
||||
func GetRootDomain(u string) (string, error) {
|
||||
appUrl, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host := appUrlParsed.Hostname()
|
||||
host := appUrl.Hostname()
|
||||
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("IP addresses are not allowed")
|
||||
@@ -27,7 +27,7 @@ func GetRootDomain(appUrl string) (string, error) {
|
||||
|
||||
urlParts := strings.Split(host, ".")
|
||||
|
||||
if len(urlParts) < 2 {
|
||||
if len(urlParts) < 3 {
|
||||
return "", errors.New("invalid domain, must be at least second level domain")
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ func ParseFileToLine(content string) string {
|
||||
}
|
||||
|
||||
func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
||||
res = make([]T, 0)
|
||||
for _, value := range slice {
|
||||
if test(value) {
|
||||
res = append(res, value)
|
||||
|
||||
197
internal/utils/app_utils_test.go
Normal file
197
internal/utils/app_utils_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetRootDomain(t *testing.T) {
|
||||
// Normal case
|
||||
domain := "http://sub.example.com"
|
||||
expected := "example.com"
|
||||
result, err := utils.GetRootDomain(domain)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain with multiple subdomains
|
||||
domain = "http://b.c.example.com"
|
||||
expected = "c.example.com"
|
||||
result, err = utils.GetRootDomain(domain)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Domain with no subdomain
|
||||
domain = "http://example.com"
|
||||
expected = "example.com"
|
||||
_, err = utils.GetRootDomain(domain)
|
||||
assert.Error(t, err, "invalid domain, must be at least second level domain")
|
||||
|
||||
// Invalid domain (only TLD)
|
||||
domain = "com"
|
||||
_, err = utils.GetRootDomain(domain)
|
||||
assert.ErrorContains(t, err, "invalid domain")
|
||||
|
||||
// IP address
|
||||
domain = "http://10.10.10.10"
|
||||
_, err = utils.GetRootDomain(domain)
|
||||
assert.ErrorContains(t, err, "IP addresses are not allowed")
|
||||
|
||||
// Invalid URL
|
||||
domain = "http://[::1]:namedport"
|
||||
_, err = utils.GetRootDomain(domain)
|
||||
assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host")
|
||||
|
||||
// URL with scheme and path
|
||||
domain = "https://sub.example.com/path"
|
||||
expected = "example.com"
|
||||
result, err = utils.GetRootDomain(domain)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// URL with port
|
||||
domain = "http://sub.example.com:8080"
|
||||
expected = "example.com"
|
||||
result, err = utils.GetRootDomain(domain)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestParseFileToLine(t *testing.T) {
|
||||
// Normal case
|
||||
content := "user1\nuser2\nuser3"
|
||||
expected := "user1,user2,user3"
|
||||
result := utils.ParseFileToLine(content)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Case with empty lines and spaces
|
||||
content = " user1 \n\n user2 \n user3 \n"
|
||||
expected = "user1,user2,user3"
|
||||
result = utils.ParseFileToLine(content)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Case with only empty lines
|
||||
content = "\n\n\n"
|
||||
expected = ""
|
||||
result = utils.ParseFileToLine(content)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Case with single user
|
||||
content = "singleuser"
|
||||
expected = "singleuser"
|
||||
result = utils.ParseFileToLine(content)
|
||||
assert.Equal(t, expected, result)
|
||||
|
||||
// Case with trailing newline
|
||||
content = "user1\nuser2\n"
|
||||
expected = "user1,user2"
|
||||
result = utils.ParseFileToLine(content)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
// Normal case
|
||||
slice := []int{1, 2, 3, 4, 5}
|
||||
testFunc := func(n int) bool { return n%2 == 0 }
|
||||
expected := []int{2, 4}
|
||||
result := utils.Filter(slice, testFunc)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with no matches
|
||||
slice = []int{1, 3, 5}
|
||||
testFunc = func(n int) bool { return n%2 == 0 }
|
||||
expected = []int{}
|
||||
result = utils.Filter(slice, testFunc)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with all matches
|
||||
slice = []int{2, 4, 6}
|
||||
testFunc = func(n int) bool { return n%2 == 0 }
|
||||
expected = []int{2, 4, 6}
|
||||
result = utils.Filter(slice, testFunc)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with empty slice
|
||||
slice = []int{}
|
||||
testFunc = func(n int) bool { return n%2 == 0 }
|
||||
expected = []int{}
|
||||
result = utils.Filter(slice, testFunc)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with different type (string)
|
||||
sliceStr := []string{"apple", "banana", "cherry"}
|
||||
testFuncStr := func(s string) bool { return len(s) > 5 }
|
||||
expectedStr := []string{"banana", "cherry"}
|
||||
resultStr := utils.Filter(sliceStr, testFuncStr)
|
||||
assert.DeepEqual(t, expectedStr, resultStr)
|
||||
}
|
||||
|
||||
func TestGetContext(t *testing.T) {
|
||||
// Setup
|
||||
gin.SetMode(gin.TestMode)
|
||||
c, _ := gin.CreateTestContext(nil)
|
||||
|
||||
// Normal case
|
||||
c.Set("context", &config.UserContext{Username: "testuser"})
|
||||
result, err := utils.GetContext(c)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, "testuser", result.Username)
|
||||
|
||||
// Case with no context
|
||||
c.Set("context", nil)
|
||||
_, err = utils.GetContext(c)
|
||||
assert.Error(t, err, "invalid user context in request")
|
||||
|
||||
// Case with invalid context type
|
||||
c.Set("context", "invalid type")
|
||||
_, err = utils.GetContext(c)
|
||||
assert.Error(t, err, "invalid user context in request")
|
||||
}
|
||||
|
||||
func TestIsRedirectSafe(t *testing.T) {
|
||||
// Setup
|
||||
domain := "example.com"
|
||||
|
||||
// Case with no subdomain
|
||||
redirectURL := "http://example.com/welcome"
|
||||
result := utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with different domain
|
||||
redirectURL = "http://malicious.com/phishing"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with subdomain
|
||||
redirectURL = "http://sub.example.com/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with empty redirect URL
|
||||
redirectURL = ""
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with invalid URL
|
||||
redirectURL = "http://[::1]:namedport"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Case with URL having port
|
||||
redirectURL = "http://sub.example.com:8080/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with URL having different subdomain
|
||||
redirectURL = "http://another.example.com/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Case with URL having different TLD
|
||||
redirectURL = "http://example.org/page"
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, false, result)
|
||||
}
|
||||
19
internal/utils/decoders/label_decoder.go
Normal file
19
internal/utils/decoders/label_decoder.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package decoders
|
||||
|
||||
import (
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"github.com/traefik/paerser/parser"
|
||||
)
|
||||
|
||||
func DecodeLabels(labels map[string]string) (config.Apps, error) {
|
||||
var appLabels config.Apps
|
||||
|
||||
err := parser.Decode(labels, &appLabels, "tinyauth", "tinyauth.apps")
|
||||
|
||||
if err != nil {
|
||||
return config.Apps{}, err
|
||||
}
|
||||
|
||||
return appLabels, nil
|
||||
}
|
||||
73
internal/utils/decoders/label_decoder_test.go
Normal file
73
internal/utils/decoders/label_decoder_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package decoders_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils/decoders"
|
||||
)
|
||||
|
||||
func TestDecodeLabels(t *testing.T) {
|
||||
// Variables
|
||||
expected := config.Apps{
|
||||
Apps: map[string]config.App{
|
||||
"foo": {
|
||||
Config: config.AppConfig{
|
||||
Domain: "example.com",
|
||||
},
|
||||
Users: config.AppUsers{
|
||||
Allow: "user1,user2",
|
||||
Block: "user3",
|
||||
},
|
||||
OAuth: config.AppOAuth{
|
||||
Whitelist: "somebody@example.com",
|
||||
Groups: "group3",
|
||||
},
|
||||
IP: config.AppIP{
|
||||
Allow: []string{"10.71.0.1/24", "10.71.0.2"},
|
||||
Block: []string{"10.10.10.10", "10.0.0.0/24"},
|
||||
Bypass: []string{"192.168.1.1"},
|
||||
},
|
||||
Response: config.AppResponse{
|
||||
Headers: []string{"X-Foo=Bar", "X-Baz=Qux"},
|
||||
BasicAuth: config.AppBasicAuth{
|
||||
Username: "admin",
|
||||
Password: "password",
|
||||
PasswordFile: "/path/to/passwordfile",
|
||||
},
|
||||
},
|
||||
Path: config.AppPath{
|
||||
Allow: "/public",
|
||||
Block: "/private",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
test := map[string]string{
|
||||
"tinyauth.apps.foo.config.domain": "example.com",
|
||||
"tinyauth.apps.foo.users.allow": "user1,user2",
|
||||
"tinyauth.apps.foo.users.block": "user3",
|
||||
"tinyauth.apps.foo.oauth.whitelist": "somebody@example.com",
|
||||
"tinyauth.apps.foo.oauth.groups": "group3",
|
||||
"tinyauth.apps.foo.ip.allow": "10.71.0.1/24,10.71.0.2",
|
||||
"tinyauth.apps.foo.ip.block": "10.10.10.10,10.0.0.0/24",
|
||||
"tinyauth.apps.foo.ip.bypass": "192.168.1.1",
|
||||
"tinyauth.apps.foo.response.headers": "X-Foo=Bar,X-Baz=Qux",
|
||||
"tinyauth.apps.foo.response.basicauth.username": "admin",
|
||||
"tinyauth.apps.foo.response.basicauth.password": "password",
|
||||
"tinyauth.apps.foo.response.basicauth.passwordfile": "/path/to/passwordfile",
|
||||
"tinyauth.apps.foo.path.allow": "/public",
|
||||
"tinyauth.apps.foo.path.block": "/private",
|
||||
}
|
||||
|
||||
// Test
|
||||
result, err := decoders.DecodeLabels(test)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(expected, result) == false {
|
||||
t.Fatalf("Expected %v but got %v", expected, result)
|
||||
}
|
||||
}
|
||||
31
internal/utils/fs_utils_test.go
Normal file
31
internal/utils/fs_utils_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestReadFile(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString("file content\n")
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_file")
|
||||
|
||||
// Normal case
|
||||
content, err := ReadFile("/tmp/tinyauth_test_file")
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, "file content\n", content)
|
||||
|
||||
// Non-existing file
|
||||
content, err = ReadFile("/tmp/non_existing_file")
|
||||
assert.ErrorContains(t, err, "no such file or directory")
|
||||
assert.Equal(t, "", content)
|
||||
}
|
||||
@@ -3,22 +3,8 @@ package utils
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"github.com/traefik/paerser/parser"
|
||||
)
|
||||
|
||||
func GetLabels(labels map[string]string) (config.Labels, error) {
|
||||
var labelsParsed config.Labels
|
||||
|
||||
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.apps")
|
||||
if err != nil {
|
||||
return config.Labels{}, err
|
||||
}
|
||||
|
||||
return labelsParsed, nil
|
||||
}
|
||||
|
||||
func ParseHeaders(headers []string) map[string]string {
|
||||
headerMap := make(map[string]string)
|
||||
for _, header := range headers {
|
||||
|
||||
87
internal/utils/label_utils_test.go
Normal file
87
internal/utils/label_utils_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestParseHeaders(t *testing.T) {
|
||||
// Normal case
|
||||
headers := []string{
|
||||
"X-Custom-Header=Value",
|
||||
"Another-Header=AnotherValue",
|
||||
}
|
||||
expected := map[string]string{
|
||||
"X-Custom-Header": "Value",
|
||||
"Another-Header": "AnotherValue",
|
||||
}
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Case insensitivity and trimming
|
||||
headers = []string{
|
||||
" x-custom-header = Value ",
|
||||
"ANOTHER-HEADER=AnotherValue",
|
||||
}
|
||||
expected = map[string]string{
|
||||
"X-Custom-Header": "Value",
|
||||
"Another-Header": "AnotherValue",
|
||||
}
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Invalid headers (missing '=', empty key/value)
|
||||
headers = []string{
|
||||
"InvalidHeader",
|
||||
"=NoKey",
|
||||
"NoValue=",
|
||||
" = ",
|
||||
}
|
||||
expected = map[string]string{}
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Headers with unsafe characters
|
||||
headers = []string{
|
||||
"X-Custom-Header=Val\x00ue", // Null byte
|
||||
"Another-Header=Anoth\x7FerValue", // DEL character
|
||||
"Good-Header=GoodValue",
|
||||
}
|
||||
expected = map[string]string{
|
||||
"X-Custom-Header": "Value",
|
||||
"Another-Header": "AnotherValue",
|
||||
"Good-Header": "GoodValue",
|
||||
}
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
|
||||
// Header with spaces in key (should be ignored)
|
||||
headers = []string{
|
||||
"X Custom Header=Value",
|
||||
"Valid-Header=ValidValue",
|
||||
}
|
||||
expected = map[string]string{
|
||||
"Valid-Header": "ValidValue",
|
||||
}
|
||||
assert.DeepEqual(t, expected, utils.ParseHeaders(headers))
|
||||
}
|
||||
|
||||
func TestSanitizeHeader(t *testing.T) {
|
||||
// Normal case
|
||||
header := "X-Custom-Header"
|
||||
expected := "X-Custom-Header"
|
||||
assert.Equal(t, expected, utils.SanitizeHeader(header))
|
||||
|
||||
// Header with unsafe characters
|
||||
header = "X-Cust\x00om-Hea\x7Fder" // Null byte and DEL character
|
||||
expected = "X-Custom-Header"
|
||||
assert.Equal(t, expected, utils.SanitizeHeader(header))
|
||||
|
||||
// Header with only unsafe characters
|
||||
header = "\x00\x01\x02\x7F"
|
||||
expected = ""
|
||||
assert.Equal(t, expected, utils.SanitizeHeader(header))
|
||||
|
||||
// Header with spaces and tabs (should be preserved)
|
||||
header = "X Custom\tHeader"
|
||||
expected = "X Custom\tHeader"
|
||||
assert.Equal(t, expected, utils.SanitizeHeader(header))
|
||||
}
|
||||
@@ -48,6 +48,12 @@ func GetBasicAuth(username string, password string) string {
|
||||
func FilterIP(filter string, ip string) (bool, error) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
||||
if ipAddr == nil {
|
||||
return false, errors.New("invalid IP address")
|
||||
}
|
||||
|
||||
filter = strings.Replace(filter, "-", "/", -1)
|
||||
|
||||
if strings.Contains(filter, "/") {
|
||||
_, cidr, err := net.ParseCIDR(filter)
|
||||
if err != nil {
|
||||
|
||||
151
internal/utils/security_utils_test.go
Normal file
151
internal/utils/security_utils_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetSecret(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_secret")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString(" secret \n")
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||
|
||||
// Get from config
|
||||
assert.Equal(t, "mysecret", utils.GetSecret("mysecret", ""))
|
||||
|
||||
// Get from file
|
||||
assert.Equal(t, "secret", utils.GetSecret("", "/tmp/tinyauth_test_secret"))
|
||||
|
||||
// Get from both (config should take precedence)
|
||||
assert.Equal(t, "mysecret", utils.GetSecret("mysecret", "/tmp/tinyauth_test_secret"))
|
||||
|
||||
// Get from none
|
||||
assert.Equal(t, "", utils.GetSecret("", ""))
|
||||
|
||||
// Get from non-existing file
|
||||
assert.Equal(t, "", utils.GetSecret("", "/tmp/non_existing_file"))
|
||||
}
|
||||
|
||||
func TestParseSecretFile(t *testing.T) {
|
||||
// Normal case
|
||||
content := " mysecret \n"
|
||||
assert.Equal(t, "mysecret", utils.ParseSecretFile(content))
|
||||
|
||||
// Multiple lines (should take the first non-empty line)
|
||||
content = "\n\n firstsecret \nsecondsecret\n"
|
||||
assert.Equal(t, "firstsecret", utils.ParseSecretFile(content))
|
||||
|
||||
// All empty lines
|
||||
content = "\n \n \n"
|
||||
assert.Equal(t, "", utils.ParseSecretFile(content))
|
||||
|
||||
// Empty content
|
||||
content = ""
|
||||
assert.Equal(t, "", utils.ParseSecretFile(content))
|
||||
}
|
||||
|
||||
func TestGetBasicAuth(t *testing.T) {
|
||||
// Normal case
|
||||
username := "user"
|
||||
password := "pass"
|
||||
expected := "dXNlcjpwYXNz" // base64 of "user:pass"
|
||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
||||
|
||||
// Empty username
|
||||
username = ""
|
||||
password = "pass"
|
||||
expected = "OnBhc3M=" // base64 of ":pass"
|
||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
||||
|
||||
// Empty password
|
||||
username = "user"
|
||||
password = ""
|
||||
expected = "dXNlcjo=" // base64 of "user:"
|
||||
assert.Equal(t, expected, utils.GetBasicAuth(username, password))
|
||||
}
|
||||
|
||||
func TestFilterIP(t *testing.T) {
|
||||
// Exact match IPv4
|
||||
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// Non-match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// CIDR match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR match IPv4 with '-' instead of '/'
|
||||
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR non-match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid CIDR
|
||||
ok, err = utils.FilterIP("10.10.0.0/222", "10.0.0.1")
|
||||
assert.ErrorContains(t, err, "invalid CIDR address")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid IP in filter
|
||||
ok, err = utils.FilterIP("invalid_ip", "10.5.5.5")
|
||||
assert.ErrorContains(t, err, "invalid IP address in filter")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid IP to check
|
||||
ok, err = utils.FilterIP("10.10.10.10", "invalid_ip")
|
||||
assert.ErrorContains(t, err, "invalid IP address")
|
||||
assert.Equal(t, false, ok)
|
||||
}
|
||||
|
||||
func TestCheckFilter(t *testing.T) {
|
||||
// Empty filter
|
||||
assert.Equal(t, true, utils.CheckFilter("", "anystring"))
|
||||
|
||||
// Exact match
|
||||
assert.Equal(t, true, utils.CheckFilter("hello", "hello"))
|
||||
|
||||
// Regex match
|
||||
assert.Equal(t, true, utils.CheckFilter("/^h.*o$/", "hello"))
|
||||
|
||||
// Invalid regex
|
||||
assert.Equal(t, false, utils.CheckFilter("/[unclosed", "test"))
|
||||
|
||||
// Comma-separated values
|
||||
assert.Equal(t, true, utils.CheckFilter("apple, banana, cherry", "banana"))
|
||||
|
||||
// No match
|
||||
assert.Equal(t, false, utils.CheckFilter("apple, banana, cherry", "grape"))
|
||||
}
|
||||
|
||||
func TestGenerateIdentifier(t *testing.T) {
|
||||
// Consistent output for same input
|
||||
id1 := utils.GenerateIdentifier("teststring")
|
||||
id2 := utils.GenerateIdentifier("teststring")
|
||||
assert.Equal(t, id1, id2)
|
||||
|
||||
// Different output for different input
|
||||
id3 := utils.GenerateIdentifier("differentstring")
|
||||
assert.Assert(t, id1 != id3)
|
||||
|
||||
// Check length (should be 8 characters from first segment of UUID)
|
||||
assert.Equal(t, 8, len(id1))
|
||||
}
|
||||
50
internal/utils/string_utils_test.go
Normal file
50
internal/utils/string_utils_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestCapitalize(t *testing.T) {
|
||||
// Test empty string
|
||||
assert.Equal(t, "", utils.Capitalize(""))
|
||||
|
||||
// Test single character
|
||||
assert.Equal(t, "A", utils.Capitalize("a"))
|
||||
|
||||
// Test multiple characters
|
||||
assert.Equal(t, "Hello", utils.Capitalize("hello"))
|
||||
|
||||
// Test already capitalized
|
||||
assert.Equal(t, "World", utils.Capitalize("World"))
|
||||
|
||||
// Test non-alphabetic first character
|
||||
assert.Equal(t, "1number", utils.Capitalize("1number"))
|
||||
|
||||
// Test Unicode characters
|
||||
assert.Equal(t, "Γειά", utils.Capitalize("γειά"))
|
||||
assert.Equal(t, "Привет", utils.Capitalize("привет"))
|
||||
|
||||
}
|
||||
|
||||
func TestCoalesceToString(t *testing.T) {
|
||||
// Test with []any containing strings
|
||||
assert.Equal(t, "a,b,c", utils.CoalesceToString([]any{"a", "b", "c"}))
|
||||
|
||||
// Test with []any containing mixed types
|
||||
assert.Equal(t, "a,c", utils.CoalesceToString([]any{"a", 1, "c", true}))
|
||||
|
||||
// Test with []any containing no strings
|
||||
assert.Equal(t, "", utils.CoalesceToString([]any{1, 2, 3}))
|
||||
|
||||
// Test with string input
|
||||
assert.Equal(t, "hello", utils.CoalesceToString("hello"))
|
||||
|
||||
// Test with non-string, non-[]any input
|
||||
assert.Equal(t, "", utils.CoalesceToString(123))
|
||||
|
||||
// Test with nil input
|
||||
assert.Equal(t, "", utils.CoalesceToString(nil))
|
||||
}
|
||||
163
internal/utils/user_utils_test.go
Normal file
163
internal/utils/user_utils_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_users_test.txt")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString(" user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G \n user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G ") // Spacing is on purpose
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_users_test.txt")
|
||||
|
||||
// Test file
|
||||
users, err := utils.GetUsers("", "/tmp/tinyauth_users_test.txt")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
|
||||
assert.Equal(t, "user1", users[0].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password)
|
||||
assert.Equal(t, "user2", users[1].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
|
||||
|
||||
// Test config
|
||||
users, err = utils.GetUsers("user3:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G,user4:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
|
||||
assert.Equal(t, "user3", users[0].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password)
|
||||
assert.Equal(t, "user4", users[1].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
|
||||
|
||||
// Test both
|
||||
users, err = utils.GetUsers("user5:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "/tmp/tinyauth_users_test.txt")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 3, len(users))
|
||||
|
||||
assert.Equal(t, "user5", users[0].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password)
|
||||
assert.Equal(t, "user1", users[1].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
|
||||
assert.Equal(t, "user2", users[2].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[2].Password)
|
||||
|
||||
// Test empty
|
||||
users, err = utils.GetUsers("", "")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 0, len(users))
|
||||
|
||||
// Test non-existent file
|
||||
users, err = utils.GetUsers("", "/tmp/non_existent_file.txt")
|
||||
|
||||
assert.ErrorContains(t, err, "no such file or directory")
|
||||
|
||||
assert.Equal(t, 0, len(users))
|
||||
}
|
||||
|
||||
func TestParseUsers(t *testing.T) {
|
||||
// Valid users
|
||||
users, err := utils.ParseUsers("user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G,user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF") // user2 has TOTP
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
|
||||
assert.Equal(t, "user1", users[0].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password)
|
||||
assert.Equal(t, "", users[0].TotpSecret)
|
||||
assert.Equal(t, "user2", users[1].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
|
||||
assert.Equal(t, "ABCDEF", users[1].TotpSecret)
|
||||
|
||||
// Valid weirdly spaced users
|
||||
users, err = utils.ParseUsers(" user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G , user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF ") // Spacing is on purpose
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(users))
|
||||
|
||||
assert.Equal(t, "user1", users[0].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password)
|
||||
assert.Equal(t, "", users[0].TotpSecret)
|
||||
assert.Equal(t, "user2", users[1].Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
|
||||
assert.Equal(t, "ABCDEF", users[1].TotpSecret)
|
||||
}
|
||||
|
||||
func TestParseUser(t *testing.T) {
|
||||
// Valid user without TOTP
|
||||
user, err := utils.ParseUser("user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user1", user.Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", user.Password)
|
||||
assert.Equal(t, "", user.TotpSecret)
|
||||
|
||||
// Valid user with TOTP
|
||||
user, err = utils.ParseUser("user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user2", user.Username)
|
||||
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", user.Password)
|
||||
assert.Equal(t, "ABCDEF", user.TotpSecret)
|
||||
|
||||
// Valid user with $$ in password
|
||||
user, err = utils.ParseUser("user3:pa$$word123")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user3", user.Username)
|
||||
assert.Equal(t, "pa$word123", user.Password)
|
||||
assert.Equal(t, "", user.TotpSecret)
|
||||
|
||||
// User with spaces
|
||||
user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ")
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, "user4", user.Username)
|
||||
assert.Equal(t, "password123", user.Password)
|
||||
assert.Equal(t, "TOTPSECRET", user.TotpSecret)
|
||||
|
||||
// Invalid users
|
||||
_, err = utils.ParseUser("user1") // Missing password
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
|
||||
_, err = utils.ParseUser("user1:")
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
|
||||
_, err = utils.ParseUser(":password123")
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
|
||||
_, err = utils.ParseUser("user1:password123:ABC:EXTRA") // Too many parts
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
|
||||
_, err = utils.ParseUser("user1::ABC")
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
|
||||
_, err = utils.ParseUser(":password123:ABC")
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
|
||||
_, err = utils.ParseUser(" : : ")
|
||||
assert.ErrorContains(t, err, "invalid user format")
|
||||
}
|
||||
Reference in New Issue
Block a user