diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index cf879645..4d729911 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -121,7 +121,12 @@ func TestContextController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - controller.NewContextController(log, cfg, runtime, group) + controller.NewContextController(controller.ContextControllerInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + RouterGroup: group, + }) recorder := httptest.NewRecorder() diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go index 7576d518..9517a0d8 100644 --- a/internal/controller/health_controller_test.go +++ b/internal/controller/health_controller_test.go @@ -55,7 +55,9 @@ func TestHealthController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - controller.NewHealthController(group) + controller.NewHealthController(controller.HealthControllerInput{ + RouterGroup: group, + }) recorder := httptest.NewRecorder() diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index d4a07baa..d8bbef17 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -35,7 +35,13 @@ func TestOIDCController(t *testing.T) { store := memory.New() - oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) + oidcService, err := service.NewOIDCService(service.OIDCServiceInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + Queries: store, + Ding: dg, + }) require.NoError(t, err) // Middleware that injects an authenticated local user into the gin context, @@ -831,7 +837,13 @@ func TestOIDCController(t *testing.T) { svc = nil } - controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup) + controller.NewOIDCController(controller.OIDCControllerInput{ + Log: log, + OIDCService: svc, + RuntimeConfig: &runtime, + RouterGroup: group, + MainRouter: &router.RouterGroup, + }) recorder := httptest.NewRecorder() diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index c6a358b4..79e3e198 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -369,10 +369,21 @@ func TestProxyController(t *testing.T) { ctx := context.TODO() dg := ding.New(ctx) - broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) - aclsService := service.NewAccessControlsService(log, cfg, nil) + broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{ + Log: log, + Runtime: &runtime, + Ctx: ctx, + }) + aclsService := service.NewAccessControlsService(service.AccessControlServiceInput{ + Log: log, + Config: &cfg, + LabelProvider: nil, + }) - policyEngine, err := service.NewPolicyEngine(cfg, log) + policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) require.NoError(t, err) policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ @@ -395,7 +406,18 @@ func TestProxyController(t *testing.T) { Log: log, }) - authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) + authService := service.NewAuthService(service.AuthServiceInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + Ctx: ctx, + Ding: dg, + LDAP: nil, + Queries: store, + OAuthBroker: broker, + Tailscale: nil, + PolicyEngine: policyEngine, + }) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -410,7 +432,14 @@ func TestProxyController(t *testing.T) { recorder := httptest.NewRecorder() - controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine) + controller.NewProxyController(controller.ProxyControllerInput{ + Log: log, + RuntimeConfig: &runtime, + RouterGroup: group, + ACLsService: aclsService, + AuthService: authService, + PolicyEngine: policyEngine, + }) test.run(t, router, recorder) }) diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index 68ce463d..ef225ec8 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -69,7 +69,10 @@ func TestResourcesController(t *testing.T) { group := router.Group("/") gin.SetMode(gin.TestMode) - controller.NewResourcesController(cfg, group) + controller.NewResourcesController(controller.ResourcesControllerInput{ + RouterGroup: group, + Config: &cfg, + }) recorder := httptest.NewRecorder() test.run(t, router, recorder) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index f3c0bed2..18b8772f 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -414,11 +414,29 @@ func TestUserController(t *testing.T) { ctx := context.TODO() dg := ding.New(ctx) - policyEngine, err := service.NewPolicyEngine(cfg, log) + policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) require.NoError(t, err) - broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) - authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) + broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{ + Log: log, + Runtime: &runtime, + Ctx: ctx, + }) + authService := service.NewAuthService(service.AuthServiceInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + Ctx: ctx, + Ding: dg, + LDAP: nil, + Queries: store, + OAuthBroker: broker, + Tailscale: nil, + PolicyEngine: policyEngine, + }) beforeEach := func() { // Clear failed login attempts before each test @@ -437,7 +455,12 @@ func TestUserController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - controller.NewUserController(log, runtime, group, authService) + controller.NewUserController(controller.UserControllerInput{ + Log: log, + RuntimeConfig: &runtime, + RouterGroup: group, + AuthService: authService, + }) recorder := httptest.NewRecorder() diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index f4685723..aa421603 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -93,7 +93,13 @@ func TestWellKnownController(t *testing.T) { store := memory.New() - oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg) + oidcService, err := service.NewOIDCService(service.OIDCServiceInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + Queries: store, + Ding: dg, + }) require.NoError(t, err) for _, test := range tests { @@ -103,7 +109,10 @@ func TestWellKnownController(t *testing.T) { recorder := httptest.NewRecorder() - controller.NewWellKnownController(oidcService, &router.RouterGroup) + controller.NewWellKnownController(controller.WellKnownControllerInput{ + OIDCService: oidcService, + RouterGroup: &router.RouterGroup, + }) test.run(t, router, recorder) }) diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 50ededdb..7468dec0 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -254,13 +254,37 @@ func TestContextMiddleware(t *testing.T) { store := memory.New() - policyEngine, err := service.NewPolicyEngine(cfg, log) + policyEngine, err := service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) require.NoError(t, err) - broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) - authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine) + broker := service.NewOAuthBrokerService(service.OAuthBrokerServiceInput{ + Log: log, + Runtime: &runtime, + Ctx: ctx, + }) + authService := service.NewAuthService(service.AuthServiceInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + Ctx: ctx, + Ding: dg, + LDAP: nil, + Queries: store, + OAuthBroker: broker, + Tailscale: nil, + PolicyEngine: policyEngine, + }) - contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) + contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareInput{ + Log: log, + RuntimeConfig: &runtime, + AuthService: authService, + BrokerService: broker, + TailscaleService: nil, + }) for _, test := range tests { authService.ClearLoginAttempts() diff --git a/internal/service/access_controls_service_test.go b/internal/service/access_controls_service_test.go index e3d32eb6..f4f4d24c 100644 --- a/internal/service/access_controls_service_test.go +++ b/internal/service/access_controls_service_test.go @@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &model.Config{Apps: tt.apps}, + LabelProvider: nil, + }) got := svc.lookupStaticACLs(tt.domain) if tt.expectNil { assert.Nil(t, got) @@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) { }, }, } - svc := NewAccessControlsService(log, config, nil) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &config, + LabelProvider: nil, + }) got, err := svc.GetAccessControls("foo.example.com") @@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) { }) t.Run("returns nil when no static match and no label provider", func(t *testing.T) { - svc := NewAccessControlsService(log, model.Config{}, nil) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &model.Config{}, + LabelProvider: nil, + }) got, err := svc.GetAccessControls("unknown.example.com") @@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) { t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) { var provider LabelProvider - svc := NewAccessControlsService(log, model.Config{}, &provider) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &model.Config{}, + LabelProvider: provider, // nil provider + }) got, err := svc.GetAccessControls("unknown.example.com") @@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) { }, } var provider LabelProvider = mock - svc := NewAccessControlsService(log, model.Config{}, &provider) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &model.Config{}, + LabelProvider: provider, + }) got, err := svc.GetAccessControls("dynamic.example.com") @@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) { "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, }, } - svc := NewAccessControlsService(log, config, &provider) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &config, + LabelProvider: provider, + }) got, err := svc.GetAccessControls("foo.example.com") @@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) { }, } var provider LabelProvider = mock - svc := NewAccessControlsService(log, model.Config{}, &provider) + svc := NewAccessControlsService(AccessControlServiceInput{ + Log: log, + Config: &model.Config{}, + LabelProvider: provider, + }) got, err := svc.GetAccessControls("dynamic.example.com") diff --git a/internal/service/auth_service_test.go b/internal/service/auth_service_test.go index 3000adcc..653db8c4 100644 --- a/internal/service/auth_service_test.go +++ b/internal/service/auth_service_test.go @@ -14,7 +14,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) { auth := &AuthService{ log: log, - runtime: model.RuntimeConfig{ + runtime: &model.RuntimeConfig{ OAuthWhitelist: []string{"global@example.com"}, OAuthProviders: map[string]model.OAuthServiceConfig{ "github": { diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go index 48078a9d..5197d7c7 100644 --- a/internal/service/oidc_service_test.go +++ b/internal/service/oidc_service_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -67,7 +68,15 @@ func TestCompileUserinfo(t *testing.T) { ctx := context.TODO() dg := ding.New(ctx) - svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg) + store := memory.New() + + svc, err := service.NewOIDCService(service.OIDCServiceInput{ + Log: log, + Config: &cfg, + Runtime: &runtime, + Queries: store, + Ding: dg, + }) require.NoError(t, err) type testCase struct { diff --git a/internal/service/policy_engine_test.go b/internal/service/policy_engine_test.go index d1ef4796..1c6120a0 100644 --- a/internal/service/policy_engine_test.go +++ b/internal/service/policy_engine_test.go @@ -33,23 +33,35 @@ func TestPolicyEngine(t *testing.T) { // Engine should fail with invalid policy cfg.Auth.ACLs.Policy = "invalid_policy" - _, err := service.NewPolicyEngine(cfg, log) + _, err := service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) assert.Error(t, err) // Engine should initialize with 'allow' policy cfg.Auth.ACLs.Policy = string(service.PolicyAllow) - engine, err := service.NewPolicyEngine(cfg, log) + engine, err := service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) assert.NoError(t, err) assert.Equal(t, service.PolicyAllow, engine.Policy()) // Engine should initialize with 'deny' policy cfg.Auth.ACLs.Policy = string(service.PolicyDeny) - engine, err = service.NewPolicyEngine(cfg, log) + engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) assert.NoError(t, err) assert.Equal(t, service.PolicyDeny, engine.Policy()) // Engine should allow adding rules - engine, err = service.NewPolicyEngine(cfg, log) + engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) assert.NoError(t, err) engine.RegisterRule("test-rule", testRule) _, ok := engine.Rules()["test-rule"] @@ -57,7 +69,10 @@ func TestPolicyEngine(t *testing.T) { // Begin allow policy tests cfg.Auth.ACLs.Policy = string(service.PolicyAllow) - engine, err = service.NewPolicyEngine(cfg, log) + engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) assert.NoError(t, err) engine.RegisterRule("test-rule", testRule) @@ -75,7 +90,10 @@ func TestPolicyEngine(t *testing.T) { // Begin deny policy tests cfg.Auth.ACLs.Policy = string(service.PolicyDeny) - engine, err = service.NewPolicyEngine(cfg, log) + engine, err = service.NewPolicyEngine(service.PolicyEngineInput{ + Log: log, + Config: &cfg, + }) assert.NoError(t, err) engine.RegisterRule("test-rule", testRule)