Compare commits

...

10 Commits

Author SHA1 Message Date
Stavros 548d97fa62 tests: fix don't try to test logger with char size 2026-05-09 14:00:58 +03:00
Stavros 3d9c81d7a0 fix: assign public key correctly in oidc server 2026-05-09 13:56:28 +03:00
Stavros 4e760e8397 feat: add option to enable or disable concurrent listeners 2026-05-09 13:52:49 +03:00
Stavros 02b48aa165 fix: fix typos 2026-05-09 13:42:44 +03:00
Stavros 886f9a84d6 tests: fix context tests 2026-05-09 13:38:04 +03:00
Stavros 74aca0f521 tests: fix service tests 2026-05-09 13:34:34 +03:00
Stavros a76141a99d tests: fix middleware tests 2026-05-09 13:32:08 +03:00
Stavros c7e9fade03 tests: use require instead of assert where previous step is required 2026-05-09 13:28:22 +03:00
Stavros 9fccb63097 tests: fix controller tests 2026-05-09 13:17:35 +03:00
Stavros 8c8d56f87c refactor: simplify middleware, controller and service init 2026-05-09 12:24:10 +03:00
41 changed files with 758 additions and 790 deletions
+36 -15
View File
@@ -214,7 +214,7 @@ func (app *BootstrapApp) Setup() error {
return errors.New("no authentication providers configured")
}
for _, provider := range app.runtime.ConfiguredProviders {
for _, provider := range configuredProviders {
app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider")
}
@@ -238,21 +238,42 @@ func (app *BootstrapApp) Setup() error {
}
// create err channel to listen for server errors
errChan := make(chan error, 1)
errChanLen := 0
runUnix := app.config.Server.SocketPath != ""
runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled
if runUnix {
errChanLen++
}
if runHTTP {
errChanLen++
}
errChan := make(chan error, errChanLen)
if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
}
// serve unix
app.wg.Go(func() {
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
if runUnix {
app.wg.Go(func() {
if err := app.serveUnix(); err != nil {
errChan <- err
}
})
}
// serve to http
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
if runHTTP {
app.wg.Go(func() {
if err := app.serveHTTP(); err != nil {
errChan <- err
}
})
}
// monitor cancellation and server errors
for {
@@ -317,7 +338,7 @@ func (app *BootstrapApp) serveUnix() error {
listener, err := net.Listen("unix", app.config.Server.SocketPath)
if err != nil {
return fmt.Errorf("failed to create unix socket listner: %w", err)
return fmt.Errorf("failed to create unix socket listener: %w", err)
}
server := &http.Server{
@@ -330,7 +351,7 @@ func (app *BootstrapApp) serveUnix() error {
go func() {
<-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix sokcet listener")
app.log.App.Debug().Msg("Shutting down unix socket listener")
server.Close()
listener.Close()
os.Remove(app.config.Server.SocketPath)
@@ -338,7 +359,7 @@ func (app *BootstrapApp) serveUnix() error {
err = server.Serve(listener)
if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) {
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start unix socket listener: %w", err)
}
+4
View File
@@ -56,3 +56,7 @@ func (app *BootstrapApp) SetupDatabase() error {
app.db = db
return nil
}
func (app *BootstrapApp) GetDB() *sql.DB {
return app.db
}
+9 -47
View File
@@ -25,18 +25,9 @@ func (app *BootstrapApp) setupRouter() error {
}
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init()
if err != nil {
return fmt.Errorf("failed to initialize context middleware: %w", err)
}
engine.Use(contextMiddleware.Middleware())
uiMiddleware := middleware.NewUIMiddleware()
err = uiMiddleware.Init()
uiMiddleware, err := middleware.NewUIMiddleware()
if err != nil {
return fmt.Errorf("failed to initialize UI middleware: %w", err)
@@ -46,47 +37,18 @@ func (app *BootstrapApp) setupRouter() error {
zerologMiddleware := middleware.NewZerologMiddleware(app.log)
err = zerologMiddleware.Init()
if err != nil {
return fmt.Errorf("failed to initialize zerolog middleware: %w", err)
}
engine.Use(zerologMiddleware.Middleware())
apiRouter := engine.Group("/api")
contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
contextController.SetupRoutes()
oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
proxyController.SetupRoutes()
userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
userController.SetupRoutes()
resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup)
resourcesController.SetupRoutes()
healthController := controller.NewHealthController(apiRouter)
healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
wellknownController.SetupRoutes()
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup)
controller.NewHealthController(apiRouter)
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup)
app.router = engine
return nil
+7 -37
View File
@@ -8,13 +8,10 @@ import (
)
func (app *BootstrapApp) setupServices() error {
ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
err := ldapService.Init()
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
ldapService.Unconfigure()
}
app.services.ldapService = ldapService
@@ -22,14 +19,12 @@ func (app *BootstrapApp) setupServices() error {
useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
var labelProvider service.LabelProviderImpl
var labelProvider service.LabelProvider
if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg)
err = kubernetesService.Init()
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
if err != nil {
return fmt.Errorf("failed to initialize kubernetes service: %w", err)
@@ -40,9 +35,7 @@ func (app *BootstrapApp) setupServices() error {
} else {
app.log.App.Debug().Msg("Using Docker label provider")
dockerService := service.NewDockerService(app.log, app.ctx, &app.wg)
err = dockerService.Init()
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
if err != nil {
return fmt.Errorf("failed to initialize docker service: %w", err)
@@ -52,39 +45,16 @@ func (app *BootstrapApp) setupServices() error {
labelProvider = dockerService
}
accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps)
err = accessControlsService.Init()
if err != nil {
return fmt.Errorf("failed to initialize access controls service: %w", err)
}
accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps)
app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders)
err = oauthBrokerService.Init()
if err != nil {
return fmt.Errorf("failed to initialize oauth broker service: %w", err)
}
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService)
err = authService.Init()
if err != nil {
return fmt.Errorf("failed to initialize auth service: %w", err)
}
app.services.authService = authService
oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
err = oidcService.Init()
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err)
+10 -12
View File
@@ -40,7 +40,6 @@ type ContextController struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
router *gin.RouterGroup
}
func NewContextController(
@@ -49,22 +48,21 @@ func NewContextController(
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup,
) *ContextController {
controller := &ContextController{
log: log,
config: config,
runtime: runtimeConfig,
}
if !config.UI.WarningsEnabled {
log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.")
}
return &ContextController{
log: log,
config: config,
runtime: runtimeConfig,
router: router,
}
}
func (controller *ContextController) SetupRoutes() {
contextGroup := controller.router.Group("/context")
contextGroup := router.Group("/context")
contextGroup.GET("/user", controller.userContextHandler)
contextGroup.GET("/app", controller.appContextHandler)
return controller
}
func (controller *ContextController) userContextHandler(c *gin.Context) {
@@ -97,7 +95,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
}
func (controller *ContextController) appContextHandler(c *gin.Context) {
appUrl, err := url.Parse(controller.config.AppURL)
appUrl, err := url.Parse(controller.runtime.AppURL)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to parse app URL")
+21 -34
View File
@@ -8,30 +8,18 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
Title: "Tinyauth",
AppURL: "https://tinyauth.example.com",
CookieDomain: "example.com",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
OAuthAutoRedirect: "none",
WarningsEnabled: true,
}
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, runtime := createTestConfigs(t)
tests := []struct {
description string
@@ -47,17 +35,17 @@ func TestContextController(t *testing.T) {
expectedAppContextResponse := controller.AppContextResponse{
Status: 200,
Message: "Success",
Providers: controllerConfig.Providers,
Title: controllerConfig.Title,
AppURL: controllerConfig.AppURL,
CookieDomain: controllerConfig.CookieDomain,
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
BackgroundImage: controllerConfig.BackgroundImage,
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
WarningsEnabled: controllerConfig.WarningsEnabled,
Providers: runtime.ConfiguredProviders,
Title: cfg.UI.Title,
AppURL: runtime.AppURL,
CookieDomain: runtime.CookieDomain,
ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage,
BackgroundImage: cfg.UI.BackgroundImage,
OAuthAutoRedirect: cfg.OAuth.AutoRedirect,
WarningsEnabled: cfg.UI.WarningsEnabled,
}
bytes, err := json.Marshal(expectedAppContextResponse)
assert.NoError(t, err)
require.NoError(t, err)
return string(bytes)
}(),
},
@@ -71,7 +59,7 @@ func TestContextController(t *testing.T) {
Message: "Unauthorized",
}
bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err)
require.NoError(t, err)
return string(bytes)
}(),
},
@@ -86,7 +74,7 @@ func TestContextController(t *testing.T) {
BaseContext: model.BaseContext{
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
},
},
})
@@ -100,11 +88,11 @@ func TestContextController(t *testing.T) {
IsLoggedIn: true,
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain),
Provider: "local",
}
bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err)
require.NoError(t, err)
return string(bytes)
}(),
},
@@ -121,13 +109,12 @@ func TestContextController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
contextController := controller.NewContextController(controllerConfig, group)
contextController.SetupRoutes()
controller.NewContextController(log, cfg, runtime, group)
recorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.path, nil)
assert.NoError(t, err)
require.NoError(t, err)
router.ServeHTTP(recorder, request)
+106
View File
@@ -0,0 +1,106 @@
package controller_test
import (
"path"
"testing"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/crypto/bcrypt"
)
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
tempDir := t.TempDir()
config := model.Config{
UI: model.UIConfig{
Title: "Tinyauth Test",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
OAuth: model.OAuthConfig{
AutoRedirect: "none",
},
OIDC: model.OIDCConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
},
Auth: model.AuthConfig{
SessionExpiry: 10,
LoginTimeout: 10,
LoginMaxRetries: 3,
},
Database: model.DatabaseConfig{
Path: path.Join(tempDir, "test.db"),
},
Resources: model.ResourcesConfig{
Enabled: true,
Path: path.Join(tempDir, "resources"),
},
}
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
require.NoError(t, err)
runtime := model.RuntimeConfig{
ConfiguredProviders: []model.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
LocalUsers: []model.LocalUser{
{
Username: "testuser",
Password: string(passwd),
},
{
Username: "totpuser",
Password: string(passwd),
TOTPSecret: testingTOTPSecret,
},
{
Username: "attruser",
Password: string(passwd),
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: string(passwd),
TOTPSecret: testingTOTPSecret,
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig {
var clients []model.OIDCClientConfig
for id, client := range config.OIDC.Clients {
client.ID = id
clients = append(clients, client)
}
return clients
}(),
}
return config, runtime
}
+5 -8
View File
@@ -3,18 +3,15 @@ package controller
import "github.com/gin-gonic/gin"
type HealthController struct {
router *gin.RouterGroup
}
func NewHealthController(router *gin.RouterGroup) *HealthController {
return &HealthController{
router: router,
}
}
controller := &HealthController{}
func (controller *HealthController) SetupRoutes() {
controller.router.GET("/healthz", controller.healthHandler)
controller.router.HEAD("/healthz", controller.healthHandler)
router.GET("/healthz", controller.healthHandler)
router.HEAD("/healthz", controller.healthHandler)
return controller
}
func (controller *HealthController) healthHandler(c *gin.Context) {
@@ -7,13 +7,12 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct {
description string
path string
@@ -30,7 +29,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err)
require.NoError(t, err)
return string(bytes)
}(),
},
@@ -44,7 +43,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err)
require.NoError(t, err)
return string(bytes)
}(),
},
@@ -56,13 +55,12 @@ func TestHealthController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
healthController := controller.NewHealthController(group)
healthController.SetupRoutes()
controller.NewHealthController(group)
recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil)
assert.NoError(t, err)
require.NoError(t, err)
router.ServeHTTP(recorder, request)
+4 -6
View File
@@ -24,7 +24,6 @@ type OAuthController struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
router *gin.RouterGroup
auth *service.AuthService
}
@@ -35,19 +34,18 @@ func NewOAuthController(
router *gin.RouterGroup,
auth *service.AuthService,
) *OAuthController {
return &OAuthController{
controller := &OAuthController{
log: log,
config: config,
runtime: runtimeConfig,
router: router,
auth: auth,
}
}
func (controller *OAuthController) SetupRoutes() {
oauthGroup := controller.router.Group("/oauth")
oauthGroup := router.Group("/oauth")
oauthGroup.GET("/url/:provider", controller.oauthURLHandler)
oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler)
return controller
}
func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
+11 -13
View File
@@ -17,9 +17,8 @@ import (
)
type OIDCController struct {
log *logger.Logger
router *gin.RouterGroup
oidc *service.OIDCService
log *logger.Logger
oidc *service.OIDCService
}
type AuthorizeCallback struct {
@@ -60,20 +59,19 @@ func NewOIDCController(
log *logger.Logger,
oidcService *service.OIDCService,
router *gin.RouterGroup) *OIDCController {
return &OIDCController{
log: log,
oidc: oidcService,
router: router,
controller := &OIDCController{
log: log,
oidc: oidcService,
}
}
func (controller *OIDCController) SetupRoutes() {
oidcGroup := controller.router.Group("/oidc")
oidcGroup := router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo)
return controller
}
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
@@ -108,7 +106,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
}
func (controller *OIDCController) Authorize(c *gin.Context) {
if !controller.oidc.IsConfigured() {
if controller.oidc == nil {
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return
}
@@ -198,7 +196,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
}
func (controller *OIDCController) Token(c *gin.Context) {
if !controller.oidc.IsConfigured() {
if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
c.JSON(404, gin.H{
"error": "not_found",
@@ -374,7 +372,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
if !controller.oidc.IsConfigured() {
if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
c.JSON(404, gin.H{
"error": "not_found",
+65 -79
View File
@@ -1,13 +1,14 @@
package controller_test
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/http/httptest"
"net/url"
"path"
"strings"
"sync"
"testing"
"github.com/gin-gonic/gin"
@@ -19,29 +20,14 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
log := logger.NewLogger().WithTestConfig()
log.Init()
oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
controllerCfg := controller.OIDCControllerConfig{}
cfg, runtime := createTestConfigs(t)
simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
@@ -103,7 +89,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
},
@@ -123,7 +109,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -131,7 +117,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
},
@@ -151,7 +137,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -160,11 +146,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -183,7 +169,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -191,7 +177,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, res["error"], "unsupported_grant_type")
},
@@ -206,7 +192,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -244,7 +230,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -267,11 +253,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
@@ -283,7 +269,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -306,7 +292,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
require.NoError(t, err)
_, ok := tokenRes["refresh_token"]
assert.True(t, ok, "Expected refresh token in response")
@@ -320,7 +306,7 @@ func TestOIDCController(t *testing.T) {
ClientSecret: "some-client-secret",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -332,7 +318,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code)
var refreshRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
assert.NoError(t, err)
require.NoError(t, err)
_, ok = refreshRes["access_token"]
assert.True(t, ok, "Expected access token in refresh response")
@@ -353,11 +339,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
@@ -369,7 +355,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -389,7 +375,7 @@ func TestOIDCController(t *testing.T) {
var secondRes map[string]any
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_grant", secondRes["error"])
},
@@ -417,7 +403,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
require.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
@@ -429,7 +415,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)
require.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -449,7 +435,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -464,7 +450,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -479,7 +465,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -494,7 +480,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
@@ -509,7 +495,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -524,7 +510,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
@@ -541,7 +527,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
require.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
@@ -555,7 +541,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)
require.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -579,7 +565,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -588,11 +574,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -609,7 +595,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -640,7 +626,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -649,11 +635,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -670,7 +656,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -701,7 +687,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -710,11 +696,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -731,7 +717,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge-1",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -762,7 +748,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "foo",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
@@ -771,11 +757,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
error := queryParams.Get("error")
@@ -794,11 +780,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
require.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
@@ -810,7 +796,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -821,7 +807,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken)
@@ -846,20 +832,22 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
}
app := bootstrap.NewBootstrapApp(model.Config{})
app := bootstrap.NewBootstrapApp(cfg)
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
queries := repository.New(app.GetDB())
wg := &sync.WaitGroup{}
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg)
require.NoError(t, err)
for _, test := range tests {
@@ -873,8 +861,7 @@ func TestOIDCController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
oidcController := controller.NewOIDCController(controllerCfg, oidcService, group)
oidcController.SetupRoutes()
controller.NewOIDCController(log, oidcService, group)
recorder := httptest.NewRecorder()
@@ -883,7 +870,6 @@ func TestOIDCController(t *testing.T) {
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
app.GetDB().Close()
})
}
+5 -7
View File
@@ -53,7 +53,6 @@ type ProxyContext struct {
type ProxyController struct {
log *logger.Logger
runtime model.RuntimeConfig
router *gin.RouterGroup
acls *service.AccessControlsService
auth *service.AuthService
}
@@ -65,18 +64,17 @@ func NewProxyController(
acls *service.AccessControlsService,
auth *service.AuthService,
) *ProxyController {
return &ProxyController{
controller := &ProxyController{
log: log,
runtime: runtime,
router: router,
acls: acls,
auth: auth,
}
}
func (controller *ProxyController) SetupRoutes() {
proxyGroup := controller.router.Group("/auth")
proxyGroup := router.Group("/auth")
proxyGroup.Any("/:proxy", controller.proxyHandler)
return controller
}
func (controller *ProxyController) proxyHandler(c *gin.Context) {
@@ -160,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
userContext = &model.UserContext{
Authenticated: false,
}
+16 -51
View File
@@ -1,8 +1,9 @@
package controller_test
import (
"context"
"net/http/httptest"
"path"
"sync"
"testing"
"github.com/gin-gonic/gin"
@@ -13,35 +14,14 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
log := logger.NewLogger().WithTestConfig()
log.Init()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
controllerCfg := controller.ProxyControllerConfig{
AppURL: "https://tinyauth.example.com",
}
cfg, runtime := createTestConfigs(t)
acls := map[string]model.App{
"app_path_allow": {
@@ -398,32 +378,19 @@ func TestProxyController(t *testing.T) {
},
}
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(cfg)
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(db)
queries := repository.New(app.GetDB())
docker := service.NewDockerService()
err = docker.Init()
require.NoError(t, err)
wg := &sync.WaitGroup{}
ctx := context.TODO()
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
aclsService := service.NewAccessControlsService(docker, acls)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
aclsService := service.NewAccessControlsService(log, nil, acls)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -438,15 +405,13 @@ func TestProxyController(t *testing.T) {
recorder := httptest.NewRecorder()
proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService)
proxyController.SetupRoutes()
controller.NewProxyController(log, runtime, group, aclsService, authService)
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
app.GetDB().Close()
})
}
+4 -6
View File
@@ -9,7 +9,6 @@ import (
type ResourcesController struct {
config model.Config
router *gin.RouterGroup
fileServer http.Handler
}
@@ -19,15 +18,14 @@ func NewResourcesController(
) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path)))
return &ResourcesController{
controller := &ResourcesController{
config: config,
router: router,
fileServer: fileServer,
}
}
func (controller *ResourcesController) SetupRoutes() {
controller.router.GET("/resources/*resource", controller.resourcesHandler)
router.GET("/resources/*resource", controller.resourcesHandler)
return controller
}
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
@@ -3,26 +3,19 @@ package controller_test
import (
"net/http/httptest"
"os"
"path"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
)
func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
cfg, _ := createTestConfigs(t)
resourcesControllerCfg := controller.ResourcesControllerConfig{
Path: path.Join(tempDir, "resources"),
Enabled: true,
}
err := os.Mkdir(resourcesControllerCfg.Path, 0777)
err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err)
type testCase struct {
@@ -61,11 +54,11 @@ func TestResourcesController(t *testing.T) {
},
}
testFilePath := resourcesControllerCfg.Path + "/testfile.txt"
testFilePath := cfg.Resources.Path + "/testfile.txt"
err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777)
require.NoError(t, err)
testFilePathParent := tempDir + "/somefile.txt"
testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt"
err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777)
require.NoError(t, err)
@@ -75,8 +68,7 @@ func TestResourcesController(t *testing.T) {
group := router.Group("/")
gin.SetMode(gin.TestMode)
resourcesController := controller.NewResourcesController(resourcesControllerCfg, group)
resourcesController.SetupRoutes()
controller.NewResourcesController(cfg, group)
recorder := httptest.NewRecorder()
test.run(t, router, recorder)
+5 -7
View File
@@ -28,7 +28,6 @@ type TotpRequest struct {
type UserController struct {
log *logger.Logger
runtime model.RuntimeConfig
router *gin.RouterGroup
auth *service.AuthService
}
@@ -38,19 +37,18 @@ func NewUserController(
router *gin.RouterGroup,
auth *service.AuthService,
) *UserController {
return &UserController{
controller := &UserController{
log: log,
runtime: runtimeConfig,
router: router,
auth: auth,
}
}
func (controller *UserController) SetupRoutes() {
userGroup := controller.router.Group("/user")
userGroup := router.Group("/user")
userGroup.POST("/login", controller.loginHandler)
userGroup.POST("/logout", controller.logoutHandler)
userGroup.POST("/totp", controller.totpHandler)
return controller
}
func (controller *UserController) loginHandler(c *gin.Context) {
@@ -88,7 +86,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if errors.Is(err, service.ErrUserNotFound) {
controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt")
controller.auth.RecordLoginAttempt(req.Username, false)
controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found")
controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
+28 -81
View File
@@ -5,8 +5,8 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"path"
"strings"
"sync"
"testing"
"time"
@@ -19,53 +19,14 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
log := logger.NewLogger().WithTestConfig()
log.Init()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
{
Username: "attruser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
userControllerCfg := controller.UserControllerConfig{
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
cfg, runtime := createTestConfigs(t)
totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
@@ -111,14 +72,12 @@ func TestUserController(t *testing.T) {
})
}
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(cfg)
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(db)
queries := repository.New(app.GetDB())
type testCase struct {
description string
@@ -136,7 +95,7 @@ func TestUserController(t *testing.T) {
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -144,7 +103,7 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1)
require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -164,7 +123,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword",
}
loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -185,7 +144,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword",
}
loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err)
require.NoError(t, err)
for range 3 {
recorder := httptest.NewRecorder()
@@ -220,7 +179,7 @@ func TestUserController(t *testing.T) {
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -231,12 +190,12 @@ func TestUserController(t *testing.T) {
decodedBody := make(map[string]any)
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, decodedBody["totpPending"], true)
// should set the session cookie
assert.Len(t, recorder.Result().Cookies(), 1)
require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
@@ -257,7 +216,7 @@ func TestUserController(t *testing.T) {
Password: "password",
}
loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json")
@@ -266,7 +225,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, 200, recorder.Code)
cookies := recorder.Result().Cookies()
assert.Len(t, cookies, 1)
require.Len(t, cookies, 1)
cookie := cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -280,7 +239,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, 200, recorder.Code)
cookies = recorder.Result().Cookies()
assert.Len(t, cookies, 1)
require.Len(t, cookies, 1)
cookie = cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -307,14 +266,14 @@ func TestUserController(t *testing.T) {
require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
assert.NoError(t, err)
require.NoError(t, err)
totpReq := controller.TotpRequest{
Code: code,
}
totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
@@ -329,7 +288,7 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1)
require.Len(t, recorder.Result().Cookies(), 1)
// should set a new session cookie with totp pending removed
totpCookie := recorder.Result().Cookies()[0]
@@ -352,7 +311,7 @@ func TestUserController(t *testing.T) {
}
totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err)
require.NoError(t, err)
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
@@ -456,21 +415,11 @@ func TestUserController(t *testing.T) {
},
}
docker := service.NewDockerService()
err = docker.Init()
require.NoError(t, err)
ctx := context.TODO()
wg := &sync.WaitGroup{}
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
beforeEach := func() {
// Clear failed login attempts before each test
@@ -489,8 +438,7 @@ func TestUserController(t *testing.T) {
group := router.Group("/api")
gin.SetMode(gin.TestMode)
userController := controller.NewUserController(userControllerCfg, group, authService)
userController.SetupRoutes()
controller.NewUserController(log, runtime, group, authService)
recorder := httptest.NewRecorder()
@@ -499,7 +447,6 @@ func TestUserController(t *testing.T) {
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
app.GetDB().Close()
})
}
+24 -10
View File
@@ -27,23 +27,29 @@ type OpenIDConnectConfiguration struct {
}
type WellKnownController struct {
router *gin.RouterGroup
oidc *service.OIDCService
oidc *service.OIDCService
}
func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController {
return &WellKnownController{
oidc: oidc,
router: router,
controller := &WellKnownController{
oidc: oidc,
}
}
func (controller *WellKnownController) SetupRoutes() {
controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.router.GET("/.well-known/jwks.json", controller.JWKS)
router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
router.GET("/.well-known/jwks.json", controller.JWKS)
return controller
}
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC service not configured",
})
return
}
issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer,
@@ -65,11 +71,19 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
}
func (controller *WellKnownController) JWKS(c *gin.Context) {
if controller.oidc == nil {
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC service not configured",
})
return
}
jwks, err := controller.oidc.GetJWK()
if err != nil {
c.JSON(500, gin.H{
"status": "500",
"status": 500,
"message": "failed to get JWK",
})
return
@@ -1,10 +1,11 @@
package controller_test
import (
"context"
"encoding/json"
"fmt"
"net/http/httptest"
"path"
"sync"
"testing"
"github.com/gin-gonic/gin"
@@ -12,30 +13,16 @@ import (
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
log := logger.NewLogger().WithTestConfig()
log.Init()
oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
cfg, runtime := createTestConfigs(t)
type testCase struct {
description string
@@ -56,11 +43,11 @@ func TestWellKnownController(t *testing.T) {
assert.NoError(t, err)
expected := controller.OpenIDConnectConfiguration{
Issuer: oidcServiceCfg.Issuer,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer),
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer),
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer),
Issuer: runtime.AppURL,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL),
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL),
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL),
ScopesSupported: service.SupportedScopes,
ResponseTypesSupported: service.SupportedResponseTypes,
GrantTypesSupported: service.SupportedGrantTypes,
@@ -101,16 +88,17 @@ func TestWellKnownController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(model.Config{})
ctx := context.TODO()
wg := &sync.WaitGroup{}
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
app := bootstrap.NewBootstrapApp(cfg)
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(db)
queries := repository.New(app.GetDB())
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
err = oidcService.Init()
require.NoError(t, err)
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -119,15 +107,13 @@ func TestWellKnownController(t *testing.T) {
recorder := httptest.NewRecorder()
wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router)
wellKnownController.SetupRoutes()
controller.NewWellKnownController(oidcService, &router.RouterGroup)
test.run(t, router, recorder)
})
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
app.GetDB().Close()
})
}
+1 -5
View File
@@ -56,10 +56,6 @@ func NewContextMiddleware(
}
}
func (m *ContextMiddleware) Init() error {
return nil
}
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
@@ -82,7 +78,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Next()
return
} else {
m.log.App.Error().Msgf("Error authenticating session cookie: %v", err)
m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err)
}
}
+14 -47
View File
@@ -5,7 +5,7 @@ import (
"encoding/base64"
"net/http"
"net/http/httptest"
"path"
"sync"
"testing"
"time"
@@ -17,36 +17,14 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestContextMiddleware(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()
log := logger.NewLogger().WithTestConfig()
log.Init()
authServiceCfg := service.AuthServiceConfig{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
},
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
CookieDomain: "example.com",
LoginTimeout: 10, // 10 seconds, useful for testing
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}
middlewareCfg := middleware.ContextMiddlewareConfig{
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
cfg, runtime := createTestConfigs(t)
basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
@@ -270,30 +248,20 @@ func TestContextMiddleware(t *testing.T) {
},
}
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
ctx := context.TODO()
wg := &sync.WaitGroup{}
app := bootstrap.NewBootstrapApp(model.Config{})
app := bootstrap.NewBootstrapApp(cfg)
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
err := app.SetupDatabase()
require.NoError(t, err)
queries := repository.New(db)
queries := repository.New(app.GetDB())
ldap := service.NewLdapService(service.LdapServiceConfig{})
err = ldap.Init()
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker)
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
err = broker.Init()
require.NoError(t, err)
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
err = authService.Init()
require.NoError(t, err)
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
err = contextMiddleware.Init()
require.NoError(t, err)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker)
for _, test := range tests {
authService.ClearRateLimitsTestingOnly()
@@ -322,7 +290,6 @@ func TestContextMiddleware(t *testing.T) {
}
t.Cleanup(func() {
err = db.Close()
require.NoError(t, err)
app.GetDB().Close()
})
}
+108
View File
@@ -0,0 +1,108 @@
package middleware_test
import (
"path"
"testing"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/crypto/bcrypt"
)
// Note: This code is duplicated from controller_test.go
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
tempDir := t.TempDir()
config := model.Config{
UI: model.UIConfig{
Title: "Tinyauth Test",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
OAuth: model.OAuthConfig{
AutoRedirect: "none",
},
OIDC: model.OIDCConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
},
Auth: model.AuthConfig{
SessionExpiry: 10,
LoginTimeout: 10,
LoginMaxRetries: 3,
},
Database: model.DatabaseConfig{
Path: path.Join(tempDir, "test.db"),
},
Resources: model.ResourcesConfig{
Enabled: true,
Path: path.Join(tempDir, "resources"),
},
}
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
require.NoError(t, err)
runtime := model.RuntimeConfig{
ConfiguredProviders: []model.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
LocalUsers: []model.LocalUser{
{
Username: "testuser",
Password: string(passwd),
},
{
Username: "totpuser",
Password: string(passwd),
TOTPSecret: testingTOTPSecret,
},
{
Username: "attruser",
Password: string(passwd),
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: string(passwd),
TOTPSecret: testingTOTPSecret,
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig {
var clients []model.OIDCClientConfig
for id, client := range config.OIDC.Clients {
client.ID = id
clients = append(clients, client)
}
return clients
}(),
}
return config, runtime
}
+4 -6
View File
@@ -18,21 +18,19 @@ type UIMiddleware struct {
uiFileServer http.Handler
}
func NewUIMiddleware() *UIMiddleware {
return &UIMiddleware{}
}
func NewUIMiddleware() (*UIMiddleware, error) {
m := &UIMiddleware{}
func (m *UIMiddleware) Init() error {
ui, err := fs.Sub(assets.FrontendAssets, "dist")
if err != nil {
return err
return nil, fmt.Errorf("failed to load ui assets: %w", err)
}
m.uiFs = ui
m.uiFileServer = http.FileServerFS(ui)
return nil
return m, nil
}
func (m *UIMiddleware) Middleware() gin.HandlerFunc {
@@ -27,10 +27,6 @@ func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware {
}
}
func (m *ZerologMiddleware) Init() error {
return nil
}
func (m *ZerologMiddleware) logPath(path string) bool {
for _, prefix := range loggerSkipPathsPrefix {
if strings.HasPrefix(path, prefix) {
+11 -9
View File
@@ -14,8 +14,9 @@ func NewDefaultConfiguration() *Config {
Path: "./resources",
},
Server: ServerConfig{
Port: 3000,
Address: "0.0.0.0",
Port: 3000,
Address: "0.0.0.0",
ConcurrentListenersEnabled: false,
},
Auth: AuthConfig{
SubdomainsEnabled: true,
@@ -95,9 +96,10 @@ type ResourcesConfig struct {
}
type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"`
}
type AuthConfig struct {
@@ -147,10 +149,10 @@ type IPConfig struct {
}
type OAuthConfig struct {
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
}
type OIDCConfig struct {
+6 -2
View File
@@ -8,6 +8,10 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
)
var (
ErrUserContextNotFound = errors.New("user context not found")
)
type ProviderType int
const (
@@ -74,7 +78,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContextValue, exists := ginctx.Get("context")
if !exists {
return nil, errors.New("failed to get user context")
return nil, ErrUserContextNotFound
}
userContext, ok := userContextValue.(*UserContext)
@@ -117,7 +121,7 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
Email: session.Email,
},
}
// By default we assume an unkown name which is oauth
// By default we assume an unknown name which is oauth
default:
c.Provider = ProviderOAuth
c.OAuth = &OAuthContext{
+1 -1
View File
@@ -238,7 +238,7 @@ func TestContext(t *testing.T) {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
expected: "failed to get user context",
expected: model.ErrUserContextNotFound.Error(),
},
{
description: "NewFromGin returns error when context value has wrong type",
+10 -10
View File
@@ -7,19 +7,19 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LabelProviderImpl interface {
type LabelProvider interface {
GetLabels(appDomain string) (*model.App, error)
}
type AccessControlsService struct {
log *logger.Logger
labelProvider LabelProviderImpl
labelProvider *LabelProvider
static map[string]model.App
}
func NewAccessControlsService(
log *logger.Logger,
labelProvider LabelProviderImpl,
labelProvider *LabelProvider,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
log: log,
@@ -28,10 +28,6 @@ func NewAccessControlsService(
}
}
func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
@@ -59,7 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
return app, nil
}
// Fallback to label provider
acls.log.App.Debug().Msg("Using label provider for app")
return acls.labelProvider.GetLabels(domain)
// If we have a label provider configured, try to get ACLs from it
if acls.labelProvider != nil {
return (*acls.labelProvider).GetLabels(domain)
}
// no labels
return nil, nil
}
+10 -19
View File
@@ -77,7 +77,6 @@ type AuthService struct {
config model.Config
runtime model.RuntimeConfig
context context.Context
wg *sync.WaitGroup
ldap *LdapService
queries *repository.Queries
@@ -98,17 +97,16 @@ func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
return &AuthService{
service := &AuthService{
log: log,
runtime: runtime,
context: context,
wg: wg,
context: ctx,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -117,11 +115,10 @@ func NewAuthService(
queries: queries,
oauthBroker: oauthBroker,
}
}
func (auth *AuthService) Init() error {
auth.wg.Go(auth.CleanupOAuthSessionsRoutine)
return nil
wg.Go(service.CleanupOAuthSessionsRoutine)
return service
}
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
@@ -132,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}, nil
}
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
@@ -157,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err)
@@ -189,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
}
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() {
if auth.ldap == nil {
return nil, errors.New("ldap service not configured")
}
@@ -402,12 +399,6 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return nil, err
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: "",
@@ -459,7 +450,7 @@ func (auth *AuthService) LocalAuthConfigured() bool {
}
func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured()
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
+17 -24
View File
@@ -17,49 +17,42 @@ type DockerService struct {
log *logger.Logger
client *client.Client
context context.Context
wg *sync.WaitGroup
isConnected bool
}
func NewDockerService(
log *logger.Logger,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
) *DockerService {
return &DockerService{
log: log,
context: context,
wg: wg,
}
}
) (*DockerService, error) {
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return err
return nil, err
}
client.NegotiateAPIVersion(docker.context)
client.NegotiateAPIVersion(ctx)
docker.client = client
_, err = docker.client.Ping(docker.context)
_, err = client.Ping(ctx)
if err != nil {
docker.log.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false
docker.client = nil
docker.context = nil
return nil
log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
docker.isConnected = true
docker.log.App.Debug().Msg("Docker connected successfully")
service := &DockerService{
log: log,
client: client,
context: ctx,
}
docker.wg.Go(docker.watchAndClose)
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
return nil
wg.Go(service.watchAndClose)
return service, nil
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
+42 -48
View File
@@ -38,7 +38,6 @@ type ingressApp struct {
type KubernetesService struct {
log *logger.Logger
ctx context.Context
wg *sync.WaitGroup
client dynamic.Interface
started bool
@@ -50,17 +49,53 @@ type KubernetesService struct {
func NewKubernetesService(
log *logger.Logger,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
) *KubernetesService {
return &KubernetesService{
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
ctx: context,
wg: wg,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
@@ -226,7 +261,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
for {
select {
case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher")
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
@@ -251,47 +286,6 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
}
}
func (k *KubernetesService) Init() error {
var cfg *rest.Config
var err error
cfg, err = rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create Kubernetes client: %w", err)
}
k.client = client
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second)
defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
k.started = false
return nil
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
k.wg.Go(func() {
k.watchGVR(gvr)
})
k.started = true
k.log.App.Debug().Msg("Kubernetes label provider started successfully")
return nil
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started {
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
@@ -8,9 +8,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestKubernetesService(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
type testCase struct {
description string
run func(t *testing.T, svc *KubernetesService)
@@ -179,6 +183,7 @@ func TestKubernetesService(t *testing.T) {
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
log: log,
}
test.run(t, svc)
})
+23 -45
View File
@@ -17,63 +17,39 @@ type LdapService struct {
log *logger.Logger
config model.Config
context context.Context
wg *sync.WaitGroup
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
isConfigured bool
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
}
func NewLdapService(
log *logger.Logger,
config model.Config,
context context.Context,
ctx context.Context,
wg *sync.WaitGroup,
) *LdapService {
return &LdapService{
) (*LdapService, error) {
if config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
log: log,
config: config,
context: context,
wg: wg,
context: ctx,
}
}
func (ldap *LdapService) IsConfigured() bool {
return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
}
if ldap.conn != nil {
if err := ldap.conn.Close(); err != nil {
return fmt.Errorf("failed to close LDAP connection: %w", err)
}
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.LDAP.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible
if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey)
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/*
@@ -86,12 +62,14 @@ func (ldap *LdapService) Init() error {
}
*/
}
_, err := ldap.connect()
if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err)
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
ldap.wg.Go(func() {
wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -116,7 +94,7 @@ func (ldap *LdapService) Init() error {
}
})
return nil
return ldap, nil
}
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
+12 -10
View File
@@ -1,6 +1,8 @@
package service
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -25,7 +27,7 @@ type OAuthBrokerService struct {
configs map[string]model.OAuthServiceConfig
}
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
@@ -33,25 +35,25 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
return &OAuthBrokerService{
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
for name, cfg := range configs {
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
service.services[name] = presetFunc(cfg, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg, name)
broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
service.services[name] = NewOAuthService(cfg, name, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
return nil
return service
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+6 -4
View File
@@ -1,23 +1,25 @@
package service
import (
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google")
return NewOAuthService(config, "google", ctx)
}
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
}
+3 -4
View File
@@ -20,7 +20,7 @@ type OAuthService struct {
id string
}
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
TokenURL: config.TokenURL,
},
},
ctx: ctx,
ctx: vctx,
userinfoExtractor: defaultExtractor,
id: id,
}
+63 -72
View File
@@ -118,13 +118,11 @@ type OIDCService struct {
runtime model.RuntimeConfig
queries *repository.Queries
context context.Context
wg *sync.WaitGroup
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
isConfigured bool
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(
@@ -132,162 +130,155 @@ func NewOIDCService(
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
context context.Context,
wg *sync.WaitGroup) *OIDCService {
return &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: context,
wg: wg,
}
}
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
if len(service.runtime.OIDCClients) == 0 {
service.isConfigured = false
return nil
if len(runtime.OIDCClients) == 0 {
return nil, nil
}
service.isConfigured = true
// Ensure issuer is https
uissuer, err := url.Parse(service.runtime.AppURL)
uissuer, err := url.Parse(runtime.AppURL)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse app url: %w", err)
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
return nil, errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath)
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
return errors.New("failed to marshal private key")
return nil, errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600)
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
if block == nil {
return errors.New("failed to decode private key")
return nil, errors.New("failed to decode private key")
}
service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath)
var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, fmt.Errorf("failed to read public key: %w", err)
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return errors.New("failed to marshal public key")
return nil, errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644)
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return err
return nil, err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
return errors.New("failed to decode public key")
return nil, errors.New("failed to decode public key")
}
service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey
case "PUBLIC KEY":
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey.(crypto.PublicKey)
default:
return fmt.Errorf("unsupported public key type: %s", block.Type)
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
}
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig)
clients := make(map[string]model.OIDCClientConfig)
for id, client := range service.config.OIDC.Clients {
for id, client := range config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
}
service.clients[client.ClientID] = client
clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}
// Start cleanup routine
service.wg.Go(service.cleanupRoutine)
wg.Go(service.cleanupRoutine)
return nil
return service, nil
}
func (service *OIDCService) GetIssuer() string {
+26 -7
View File
@@ -1,7 +1,9 @@
package service_test
import (
"context"
"encoding/json"
"sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -10,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() repository.OidcUserinfo {
@@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo {
func TestCompileUserinfo(t *testing.T) {
dir := t.TempDir()
svc := service.NewOIDCService(service.OIDCServiceConfig{
PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub",
Issuer: "https://tinyauth.example.com",
SessionExpiry: 3600,
}, nil)
require.NoError(t, svc.Init())
cfg := model.Config{
OIDC: model.OIDCConfig{
PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub",
},
Auth: model.AuthConfig{
SessionExpiry: 3600,
},
}
runtime := model.RuntimeConfig{
AppURL: "https://tinyauth.example.com",
}
log := logger.NewLogger().WithTestConfig()
log.Init()
ctx := context.TODO()
wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err)
type testCase struct {
description string
+1 -1
View File
@@ -33,7 +33,7 @@ func NewLogger() *Logger {
App: model.LogStreamConfig{
Enabled: true,
},
// No reason to enabled audit by default since it will be surpressed by the log level
// No reason to enabled audit by default since it will be suppressed by the log level
},
},
}
+2 -2
View File
@@ -159,10 +159,10 @@ func TestLogger(t *testing.T) {
l.App.Info().Msg("test")
l.AuditLoginFailure("test", "test", "test", "test")
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
assert.NotEmpty(t, buf.String())
assert.Equal(t, 81, buf.Len()) // it's the length of the test log entry
assert.NotContains(t, "test_nop", buf.String())
},
},
}