diff --git a/internal/service/access_control_rules_test.go b/internal/service/access_controls_rules_test.go similarity index 100% rename from internal/service/access_control_rules_test.go rename to internal/service/access_controls_rules_test.go diff --git a/internal/service/access_controls_service_test.go b/internal/service/access_controls_service_test.go new file mode 100644 index 00000000..e3d32eb6 --- /dev/null +++ b/internal/service/access_controls_service_test.go @@ -0,0 +1,199 @@ +package service + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +type mockLabelProvider struct { + getLabelsFn func(appDomain string) (*model.App, error) + calledWith string + callCount int +} + +func (m *mockLabelProvider) GetLabels(appDomain string) (*model.App, error) { + m.calledWith = appDomain + m.callCount++ + if m.getLabelsFn != nil { + return m.getLabelsFn(appDomain) + } + return nil, nil +} + +func TestLookupStaticACLs(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + tests := []struct { + name string + apps map[string]model.App + domain string + expectNil bool + expectedDomain string + }{ + { + name: "returns nil when no apps are configured", + apps: nil, + domain: "foo.example.com", + expectNil: true, + }, + { + name: "returns nil when no app matches", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + domain: "bar.example.com", + expectNil: true, + }, + { + name: "matches by exact domain", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + domain: "foo.example.com", + expectedDomain: "foo.example.com", + }, + { + name: "matches by app name when domain does not match any app", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "configured.example.com"}}, + }, + domain: "foo.example.com", + expectedDomain: "configured.example.com", + }, + { + name: "matches by app name for nested subdomains", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "configured.example.com"}}, + }, + domain: "foo.sub.example.com", + expectedDomain: "configured.example.com", + }, + { + name: "selects the app matching by domain among multiple apps", + apps: map[string]model.App{ + "unrelated": {Config: model.AppConfig{Domain: "other.example.com"}}, + "target": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + domain: "foo.example.com", + expectedDomain: "foo.example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil) + got := svc.lookupStaticACLs(tt.domain) + if tt.expectNil { + assert.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tt.expectedDomain, got.Config.Domain) + }) + } +} + +func TestGetAccessControls(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + t.Run("returns static ACLs when domain matches", func(t *testing.T) { + config := model.Config{ + Apps: map[string]model.App{ + "foo": { + Config: model.AppConfig{Domain: "foo.example.com"}, + Users: model.AppUsers{Allow: "alice"}, + }, + }, + } + svc := NewAccessControlsService(log, config, nil) + + got, err := svc.GetAccessControls("foo.example.com") + + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "foo.example.com", got.Config.Domain) + assert.Equal(t, "alice", got.Users.Allow) + }) + + t.Run("returns nil when no static match and no label provider", func(t *testing.T) { + svc := NewAccessControlsService(log, model.Config{}, nil) + + got, err := svc.GetAccessControls("unknown.example.com") + + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) { + var provider LabelProvider + svc := NewAccessControlsService(log, model.Config{}, &provider) + + got, err := svc.GetAccessControls("unknown.example.com") + + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("falls back to label provider when no static match", func(t *testing.T) { + expected := &model.App{ + Config: model.AppConfig{Domain: "dynamic.example.com"}, + Users: model.AppUsers{Allow: "bob"}, + } + mock := &mockLabelProvider{ + getLabelsFn: func(appDomain string) (*model.App, error) { + return expected, nil + }, + } + var provider LabelProvider = mock + svc := NewAccessControlsService(log, model.Config{}, &provider) + + got, err := svc.GetAccessControls("dynamic.example.com") + + require.NoError(t, err) + assert.Same(t, expected, got) + assert.Equal(t, "dynamic.example.com", mock.calledWith) + assert.Equal(t, 1, mock.callCount) + }) + + t.Run("does not call label provider when static match found", func(t *testing.T) { + mock := &mockLabelProvider{} + var provider LabelProvider = mock + config := model.Config{ + Apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + } + svc := NewAccessControlsService(log, config, &provider) + + got, err := svc.GetAccessControls("foo.example.com") + + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "foo.example.com", got.Config.Domain) + assert.Equal(t, 0, mock.callCount) + }) + + t.Run("propagates label provider errors", func(t *testing.T) { + providerErr := errors.New("provider boom") + mock := &mockLabelProvider{ + getLabelsFn: func(appDomain string) (*model.App, error) { + return nil, providerErr + }, + } + var provider LabelProvider = mock + svc := NewAccessControlsService(log, model.Config{}, &provider) + + got, err := svc.GetAccessControls("dynamic.example.com") + + assert.Nil(t, got) + assert.ErrorIs(t, err, providerErr) + assert.Equal(t, 1, mock.callCount) + }) +} diff --git a/internal/service/policy_engine.go b/internal/service/policy_engine.go index 870ce4a4..4250d8a0 100644 --- a/internal/service/policy_engine.go +++ b/internal/service/policy_engine.go @@ -61,6 +61,7 @@ func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, er } func (engine *PolicyEngine) RegisterRule(name RuleName, rule Rule) { + engine.log.App.Debug().Str("rule", string(name)).Msg("Registering ACL rule in policy engine") engine.rules[name] = rule } @@ -99,3 +100,11 @@ func (engine *PolicyEngine) Evaluate(name RuleName, ctx *ACLContext) bool { return access } + +func (engine *PolicyEngine) Policy() Policy { + return engine.policy +} + +func (engine *PolicyEngine) Rules() map[RuleName]Rule { + return engine.rules +} diff --git a/internal/service/policy_engine_test.go b/internal/service/policy_engine_test.go new file mode 100644 index 00000000..d1ef4796 --- /dev/null +++ b/internal/service/policy_engine_test.go @@ -0,0 +1,93 @@ +package service_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +// Create test rule +type TestRule struct{} + +func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect { + switch ctx.Path { + case "/allowed": + return service.EffectAllow + case "/denied": + return service.EffectDeny + default: + return service.EffectAbstain + } +} + +func TestPolicyEngine(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, _ := test.CreateTestConfigs(t) + + testRule := &TestRule{} + + // Engine should fail with invalid policy + cfg.Auth.ACLs.Policy = "invalid_policy" + _, err := service.NewPolicyEngine(cfg, log) + assert.Error(t, err) + + // Engine should initialize with 'allow' policy + cfg.Auth.ACLs.Policy = string(service.PolicyAllow) + engine, err := service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + assert.Equal(t, service.PolicyAllow, engine.Policy()) + + // Engine should initialize with 'deny' policy + cfg.Auth.ACLs.Policy = string(service.PolicyDeny) + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + assert.Equal(t, service.PolicyDeny, engine.Policy()) + + // Engine should allow adding rules + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + engine.RegisterRule("test-rule", testRule) + _, ok := engine.Rules()["test-rule"] + assert.True(t, ok) + + // Begin allow policy tests + cfg.Auth.ACLs.Policy = string(service.PolicyAllow) + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + engine.RegisterRule("test-rule", testRule) + + // With allow policy, if rule allows, access should be allowed + ctx := &service.ACLContext{Path: "/allowed"} + assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) + + // With allow policy, if rule denies, access should be denied + ctx.Path = "/denied" + assert.Equal(t, false, engine.Evaluate("test-rule", ctx)) + + // With allow policy, if rule abstains, access should be allowed (default) + ctx.Path = "/abstain" + assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) + + // Begin deny policy tests + cfg.Auth.ACLs.Policy = string(service.PolicyDeny) + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + engine.RegisterRule("test-rule", testRule) + + // With deny policy, if rule allows, access should be allowed + ctx.Path = "/allowed" + assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) + + // With deny policy, if rule denies, access should be denied + ctx.Path = "/denied" + assert.Equal(t, false, engine.Evaluate("test-rule", ctx)) + + // With deny policy, if rule abstains, access should be denied (default) + ctx.Path = "/abstain" + assert.Equal(t, false, engine.Evaluate("test-rule", ctx)) +}