Merge branch 'main' into feat/tailscale

This commit is contained in:
Stavros
2026-05-19 18:52:56 +03:00
96 changed files with 8936 additions and 1760 deletions
+249
View File
@@ -0,0 +1,249 @@
package service
import (
"regexp"
"strings"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type RuleName string
const (
RuleUserAllowed RuleName = "rule-user-allowed"
RuleOAuthGroup RuleName = "rule-oauth-group"
RuleLDAPGroup RuleName = "rule-ldap-group"
RuleAuthEnabled RuleName = "rule-auth-enabled"
RuleIPAllowed RuleName = "rule-ip-allowed"
RuleIPBypassed RuleName = "rule-ip-bypassed"
)
type UserAllowedRule struct {
Log *logger.Logger
}
func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
if ctx.ACLs == nil || ctx.UserContext == nil {
return EffectAbstain
}
if ctx.UserContext.Provider == model.ProviderOAuth {
rule.Log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Whitelist, ctx.UserContext.OAuth.Email)
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.OAuth.Email).Msg("Invalid entry in OAuth whitelist")
return EffectAbstain
}
if match {
rule.Log.App.Debug().Str("email", ctx.UserContext.OAuth.Email).Msg("User is in OAuth whitelist, allowing access")
return EffectAllow
}
return EffectDeny
}
if ctx.ACLs.Users.Block != "" {
rule.Log.App.Debug().Msg("Checking users block list")
match, err := utils.CheckFilter(ctx.ACLs.Users.Block, ctx.UserContext.GetUsername())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users block list")
return EffectAbstain
}
if match {
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users block list, denying access")
return EffectDeny
}
return EffectAllow
}
rule.Log.App.Debug().Msg("Checking users allow list")
match, err := utils.CheckFilter(ctx.ACLs.Users.Allow, ctx.UserContext.GetUsername())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users allow list")
return EffectAbstain
}
if match {
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users allow list, allowing access")
return EffectAllow
}
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is not in users allow list, denying access")
return EffectDeny
}
type OAuthGroupRule struct {
Log *logger.Logger
}
func (rule *OAuthGroupRule) Evaluate(ctx *ACLContext) Effect {
if ctx.ACLs == nil || ctx.UserContext == nil {
return EffectAbstain
}
if !ctx.UserContext.IsOAuth() {
rule.Log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return EffectAbstain
}
if _, ok := model.OverrideProviders[ctx.UserContext.OAuth.ID]; ok {
rule.Log.App.Debug().Str("provider", ctx.UserContext.OAuth.ID).Msg("Provider override detected, skipping group check")
return EffectAllow
}
for _, group := range ctx.UserContext.OAuth.Groups {
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Groups, strings.TrimSpace(group))
if err != nil {
return EffectAbstain
}
if match {
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.OAuth.Groups).Msg("User group matched, allowing access")
return EffectAllow
}
}
rule.Log.App.Debug().Msg("No groups matched")
return EffectDeny
}
type LDAPGroupRule struct {
Log *logger.Logger
}
func (rule *LDAPGroupRule) Evaluate(ctx *ACLContext) Effect {
if ctx == nil || ctx.UserContext == nil {
return EffectAbstain
}
if !ctx.UserContext.IsLDAP() {
rule.Log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return EffectAbstain
}
for _, group := range ctx.UserContext.LDAP.Groups {
match, err := utils.CheckFilter(ctx.ACLs.LDAP.Groups, strings.TrimSpace(group))
if err != nil {
return EffectAbstain
}
if match {
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.LDAP.Groups).Msg("User group matched, allowing access")
return EffectAllow
}
}
rule.Log.App.Debug().Msg("No groups matched")
return EffectDeny
}
type AuthEnabledRule struct {
Log *logger.Logger
}
func (rule *AuthEnabledRule) Evaluate(ctx *ACLContext) Effect {
if ctx.ACLs == nil {
return EffectDeny
}
if ctx.ACLs.Path.Block != "" {
regex, err := regexp.Compile(ctx.ACLs.Path.Block)
if err != nil {
rule.Log.App.Error().Err(err).Msg("Failed to compile block regex")
return EffectDeny
}
if !regex.MatchString(ctx.Path) {
return EffectAllow
}
}
if ctx.ACLs.Path.Allow != "" {
regex, err := regexp.Compile(ctx.ACLs.Path.Allow)
if err != nil {
rule.Log.App.Error().Err(err).Msg("Failed to compile allow regex")
return EffectDeny
}
if regex.MatchString(ctx.Path) {
return EffectAllow
}
}
return EffectDeny
}
type IPAllowedRule struct {
Log *logger.Logger
Config model.Config
}
func (rule *IPAllowedRule) Evaluate(ctx *ACLContext) Effect {
if ctx.ACLs == nil {
return EffectAbstain
}
// Merge the global and app IP filter
blockedIps := append(ctx.ACLs.IP.Block, rule.Config.Auth.IP.Block...)
allowedIPs := append(ctx.ACLs.IP.Allow, rule.Config.Auth.IP.Allow...)
for _, blocked := range blockedIps {
match, err := utils.CheckIPFilter(blocked, ctx.IP.String())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if match {
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", blocked).Msg("IP is in block list, denying access")
return EffectDeny
}
}
for _, allowed := range allowedIPs {
match, err := utils.CheckIPFilter(allowed, ctx.IP.String())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if match {
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", allowed).Msg("IP is in allow list, allowing access")
return EffectAllow
}
}
if len(allowedIPs) > 0 {
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in allow list, denying access")
return EffectDeny
}
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in block or allow list, allowing access")
return EffectAllow
}
type IPBypassedRule struct {
Log *logger.Logger
}
func (rule *IPBypassedRule) Evaluate(ctx *ACLContext) Effect {
if ctx.ACLs == nil {
return EffectDeny
}
for _, bypassed := range ctx.ACLs.IP.Bypass {
match, err := utils.CheckIPFilter(bypassed, ctx.IP.String())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if match {
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return EffectAllow
}
}
rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in bypass list, proceeding with authentication")
return EffectDeny
}
@@ -0,0 +1,732 @@
package service
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestUserAllowedRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
rule := &UserAllowedRule{Log: log}
tests := []struct {
name string
ctx *ACLContext
expected Effect
}{
{
name: "abstains when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAbstain,
},
{
name: "abstains when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectAbstain,
},
{
name: "allows OAuth user when email matches whitelist",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "allowed@example.com"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{
Username: "different-username",
Email: "allowed@example.com",
},
},
},
},
expected: EffectAllow,
},
{
name: "denies OAuth user when email does not match whitelist",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "allowed@example.com"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{Email: "denied@example.com"},
},
},
},
expected: EffectDeny,
},
{
name: "abstains for OAuth user when whitelist filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "/[/"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{Email: "allowed@example.com"},
},
},
},
expected: EffectAbstain,
},
{
name: "denies local user when username matches block list",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Block: "alice,bob"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectDeny,
},
{
name: "allows local user when username does not match block list",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Block: "alice,bob"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "charlie"},
},
},
},
expected: EffectAllow,
},
{
name: "abstains when block list filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Block: "/[/"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAbstain,
},
{
name: "allows local user when username matches allow list",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Allow: "alice,bob"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAllow,
},
{
name: "denies local user when username does not match allow list",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Allow: "alice,bob"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "charlie"},
},
},
},
expected: EffectDeny,
},
{
name: "abstains when allow list filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Allow: "/[/"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAbstain,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
func TestOAuthGroupRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
rule := &OAuthGroupRule{Log: log}
tests := []struct {
name string
ctx *ACLContext
expected Effect
}{
{
name: "abstains when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
Groups: []string{"admins"},
},
},
},
expected: EffectAbstain,
},
{
name: "abstains when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectAbstain,
},
{
name: "abstains when user is not OAuth",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAbstain,
},
{
name: "allows when provider is an override provider regardless of groups",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
ID: "google",
Groups: []string{"unrelated"},
},
},
},
expected: EffectAllow,
},
{
name: "allows OAuth user when a group matches",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "admins,users"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
ID: "custom",
Groups: []string{"users"},
},
},
},
expected: EffectAllow,
},
{
name: "denies OAuth user when no group matches",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
ID: "custom",
Groups: []string{"users", "guests"},
},
},
},
expected: EffectDeny,
},
{
name: "denies OAuth user when user has no groups",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
ID: "custom",
Groups: nil,
},
},
},
expected: EffectDeny,
},
{
name: "abstains when groups filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "/[/"},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
ID: "custom",
Groups: []string{"admins"},
},
},
},
expected: EffectAbstain,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
func TestLDAPGroupRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
rule := &LDAPGroupRule{Log: log}
tests := []struct {
name string
ctx *ACLContext
expected Effect
}{
{
name: "abstains when context is nil",
ctx: nil,
expected: EffectAbstain,
},
{
name: "abstains when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectAbstain,
},
{
name: "abstains when user is not LDAP",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAbstain,
},
{
name: "allows LDAP user when a group matches",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "admins,users"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
Groups: []string{"users"},
},
},
},
expected: EffectAllow,
},
{
name: "denies LDAP user when no group matches",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
Groups: []string{"users", "guests"},
},
},
},
expected: EffectDeny,
},
{
name: "denies LDAP user when user has no groups",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "admins"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
Groups: nil,
},
},
},
expected: EffectDeny,
},
{
name: "abstains when groups filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "/[/"},
},
UserContext: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
Groups: []string{"admins"},
},
},
},
expected: EffectAbstain,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
func TestAuthEnabledRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
rule := &AuthEnabledRule{Log: log}
tests := []struct {
name string
ctx *ACLContext
expected Effect
}{
{
name: "deny when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
Path: "/anything",
},
expected: EffectDeny,
},
{
name: "allows when path does not match block regex",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{Block: "^/admin"},
},
Path: "/public",
},
expected: EffectAllow,
},
{
name: "denies when path matches block regex and no allow regex",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{Block: "^/admin"},
},
Path: "/admin/users",
},
expected: EffectDeny,
},
{
name: "allows when path matches allow regex",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{Allow: "^/public"},
},
Path: "/public/index",
},
expected: EffectAllow,
},
{
name: "denies when path does not match allow regex",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{Allow: "^/public"},
},
Path: "/private",
},
expected: EffectDeny,
},
{
name: "allows when blocked path is also explicitly allowed",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{
Block: "^/admin",
Allow: "^/admin/public",
},
},
Path: "/admin/public/page",
},
expected: EffectAllow,
},
{
name: "denies when block regex fails to compile",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{Block: "[invalid"},
},
Path: "/anything",
},
expected: EffectDeny,
},
{
name: "denies when allow regex fails to compile",
ctx: &ACLContext{
ACLs: &model.App{
Path: model.AppPath{Allow: "[invalid"},
},
Path: "/anything",
},
expected: EffectDeny,
},
{
name: "denies when no path rules are configured",
ctx: &ACLContext{
ACLs: &model.App{},
Path: "/anything",
},
expected: EffectDeny,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
func TestIPAllowedRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
tests := []struct {
name string
config model.Config
ctx *ACLContext
expected Effect
}{
{
name: "abstains when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectAbstain,
},
{
name: "denies when IP matches app block list",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Block: []string{"10.0.0.1"}},
},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectDeny,
},
{
name: "denies when IP matches global block list",
config: model.Config{
Auth: model.AuthConfig{
IP: model.IPConfig{Block: []string{"10.0.0.0/24"}},
},
},
ctx: &ACLContext{
ACLs: &model.App{},
IP: net.ParseIP("10.0.0.5"),
},
expected: EffectDeny,
},
{
name: "allows when IP matches app allow list",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Allow: []string{"192.168.1.0/24"}},
},
IP: net.ParseIP("192.168.1.10"),
},
expected: EffectAllow,
},
{
name: "allows when IP matches global allow list",
config: model.Config{
Auth: model.AuthConfig{
IP: model.IPConfig{Allow: []string{"192.168.1.10"}},
},
},
ctx: &ACLContext{
ACLs: &model.App{},
IP: net.ParseIP("192.168.1.10"),
},
expected: EffectAllow,
},
{
name: "denies when allow list is set and IP does not match",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Allow: []string{"192.168.1.0/24"}},
},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectDeny,
},
{
name: "allows when no block or allow lists are configured",
ctx: &ACLContext{
ACLs: &model.App{},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectAllow,
},
{
name: "block list takes precedence over allow list",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{
Block: []string{"10.0.0.1"},
Allow: []string{"10.0.0.1"},
},
},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectDeny,
},
{
name: "skips invalid block entries and continues evaluation",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{
Block: []string{"not-an-ip"},
Allow: []string{"10.0.0.1"},
},
},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &IPAllowedRule{Log: log, Config: tt.config}
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
func TestIPBypassedRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
rule := &IPBypassedRule{Log: log}
tests := []struct {
name string
ctx *ACLContext
expected Effect
}{
{
name: "deny when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectDeny,
},
{
name: "allows when IP matches bypass list",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
},
IP: net.ParseIP("10.0.0.5"),
},
expected: EffectAllow,
},
{
name: "denies when IP does not match bypass list",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
},
IP: net.ParseIP("192.168.1.1"),
},
expected: EffectDeny,
},
{
name: "denies when bypass list is empty",
ctx: &ACLContext{
ACLs: &model.App{},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectDeny,
},
{
name: "skips invalid bypass entries and allows on later match",
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"not-an-ip", "10.0.0.1"}},
},
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
+21 -20
View File
@@ -13,51 +13,52 @@ type LabelProvider interface {
type AccessControlsService struct {
log *logger.Logger
config model.Config
labelProvider *LabelProvider
static map[string]model.App
}
func NewAccessControlsService(
log *logger.Logger,
labelProvider *LabelProvider,
static map[string]model.App) *AccessControlsService {
config model.Config,
labelProvider *LabelProvider) *AccessControlsService {
return &AccessControlsService{
log: log,
config: config,
labelProvider: labelProvider,
static: static,
}
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
func (service *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var nameMatch *model.App
// First try to find a matching app by domain, then fallback to matching by app name (subdomain)
for app, config := range service.config.Apps {
if config.Config.Domain == domain {
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config
break // If we find a match by domain, we can stop searching
service.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
return &config
}
if strings.SplitN(domain, ".", 2)[0] == app {
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config
break // If we find a match by app name, we can stop searching
service.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
nameMatch = &config
}
}
return appAcls
return nameMatch
}
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
func (service *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
// First check in the static config
app := acls.lookupStaticACLs(domain)
app := service.lookupStaticACLs(domain)
if app != nil {
acls.log.App.Debug().Msg("Using static ACLs for app")
service.log.App.Debug().Msg("Using static ACLs for app")
return app, nil
}
// If we have a label provider configured, try to get ACLs from it
if acls.labelProvider != nil {
return (*acls.labelProvider).GetLabels(domain)
if service.labelProvider != nil && *service.labelProvider != nil {
return (*service.labelProvider).GetLabels(domain)
}
// no labels
@@ -0,0 +1,199 @@
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type mockLabelProvider struct {
getLabelsFn func(appDomain string) (*model.App, error)
calledWith string
callCount int
}
func (m *mockLabelProvider) GetLabels(appDomain string) (*model.App, error) {
m.calledWith = appDomain
m.callCount++
if m.getLabelsFn != nil {
return m.getLabelsFn(appDomain)
}
return nil, nil
}
func TestLookupStaticACLs(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
tests := []struct {
name string
apps map[string]model.App
domain string
expectNil bool
expectedDomain string
}{
{
name: "returns nil when no apps are configured",
apps: nil,
domain: "foo.example.com",
expectNil: true,
},
{
name: "returns nil when no app matches",
apps: map[string]model.App{
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
domain: "bar.example.com",
expectNil: true,
},
{
name: "matches by exact domain",
apps: map[string]model.App{
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
domain: "foo.example.com",
expectedDomain: "foo.example.com",
},
{
name: "matches by app name when domain does not match any app",
apps: map[string]model.App{
"foo": {Config: model.AppConfig{Domain: "configured.example.com"}},
},
domain: "foo.example.com",
expectedDomain: "configured.example.com",
},
{
name: "matches by app name for nested subdomains",
apps: map[string]model.App{
"foo": {Config: model.AppConfig{Domain: "configured.example.com"}},
},
domain: "foo.sub.example.com",
expectedDomain: "configured.example.com",
},
{
name: "selects the app matching by domain among multiple apps",
apps: map[string]model.App{
"unrelated": {Config: model.AppConfig{Domain: "other.example.com"}},
"target": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
domain: "foo.example.com",
expectedDomain: "foo.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
got := svc.lookupStaticACLs(tt.domain)
if tt.expectNil {
assert.Nil(t, got)
return
}
require.NotNil(t, got)
assert.Equal(t, tt.expectedDomain, got.Config.Domain)
})
}
}
func TestGetAccessControls(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
t.Run("returns static ACLs when domain matches", func(t *testing.T) {
config := model.Config{
Apps: map[string]model.App{
"foo": {
Config: model.AppConfig{Domain: "foo.example.com"},
Users: model.AppUsers{Allow: "alice"},
},
},
}
svc := NewAccessControlsService(log, config, nil)
got, err := svc.GetAccessControls("foo.example.com")
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, "foo.example.com", got.Config.Domain)
assert.Equal(t, "alice", got.Users.Allow)
})
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{}, nil)
got, err := svc.GetAccessControls("unknown.example.com")
require.NoError(t, err)
assert.Nil(t, got)
})
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
var provider LabelProvider
svc := NewAccessControlsService(log, model.Config{}, &provider)
got, err := svc.GetAccessControls("unknown.example.com")
require.NoError(t, err)
assert.Nil(t, got)
})
t.Run("falls back to label provider when no static match", func(t *testing.T) {
expected := &model.App{
Config: model.AppConfig{Domain: "dynamic.example.com"},
Users: model.AppUsers{Allow: "bob"},
}
mock := &mockLabelProvider{
getLabelsFn: func(appDomain string) (*model.App, error) {
return expected, nil
},
}
var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider)
got, err := svc.GetAccessControls("dynamic.example.com")
require.NoError(t, err)
assert.Same(t, expected, got)
assert.Equal(t, "dynamic.example.com", mock.calledWith)
assert.Equal(t, 1, mock.callCount)
})
t.Run("does not call label provider when static match found", func(t *testing.T) {
mock := &mockLabelProvider{}
var provider LabelProvider = mock
config := model.Config{
Apps: map[string]model.App{
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
}
svc := NewAccessControlsService(log, config, &provider)
got, err := svc.GetAccessControls("foo.example.com")
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, "foo.example.com", got.Config.Domain)
assert.Equal(t, 0, mock.callCount)
})
t.Run("propagates label provider errors", func(t *testing.T) {
providerErr := errors.New("provider boom")
mock := &mockLabelProvider{
getLabelsFn: func(appDomain string) (*model.App, error) {
return nil, providerErr
},
}
var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider)
got, err := svc.GetAccessControls("dynamic.example.com")
assert.Nil(t, got)
assert.ErrorIs(t, err, providerErr)
assert.Equal(t, 1, mock.callCount)
})
}
+44 -201
View File
@@ -2,11 +2,9 @@ package service
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
@@ -18,7 +16,6 @@ import (
"slices"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
@@ -79,7 +76,7 @@ type AuthService struct {
context context.Context
ldap *LdapService
queries *repository.Queries
queries repository.Store
oauthBroker *OAuthBrokerService
tailscale *TailscaleService
@@ -101,7 +98,7 @@ func NewAuthService(
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
queries repository.Store,
oauthBroker *OAuthBrokerService,
tailscale *TailscaleService,
) *AuthService {
@@ -133,7 +130,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}
if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username)
userDN, email, err := auth.ldap.GetUserInfo(username)
if err != nil {
return nil, fmt.Errorf("failed to get ldap user: %w", err)
@@ -141,6 +138,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
return &model.UserSearch{
Username: userDN,
Email: email,
Type: model.UserLDAP,
}, nil
}
@@ -288,7 +286,12 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
}
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
if err != nil {
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
return false
}
return match
}
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -445,7 +448,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return nil, errors.New("session not found")
}
return nil, err
@@ -482,171 +485,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
if context.Provider == model.ProviderOAuth {
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
if !context.IsOAuth() {
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
auth.log.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
if !context.IsLDAP() {
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
auth.log.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
if acls == nil {
return true, nil
}
// Check for block list
if acls.Path.Block != "" {
regex, err := regexp.Compile(acls.Path.Block)
if err != nil {
return true, err
}
if !regex.MatchString(uri) {
return false, nil
}
}
// Check for allow list
if acls.Path.Allow != "" {
regex, err := regexp.Compile(acls.Path.Allow)
if err != nil {
return true, err
}
if regex.MatchString(uri) {
return false, nil
}
}
return true, nil
}
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
if acls == nil {
return true
}
// Merge the global and app IP filter
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
if err != nil {
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if res {
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false
}
}
for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip)
if err != nil {
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if res {
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true
}
}
if len(allowedIPs) > 0 {
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false
}
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true
}
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
if acls == nil {
return false
}
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if res {
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true
}
}
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false
}
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
@@ -801,46 +639,49 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
return
}
cleanupIds := make([]string, 0, OAuthCleanupCount)
type entry struct {
id string
expiresAt int64
}
for range OAuthCleanupCount {
oldestId := ""
oldestTime := int64(0)
entries := make([]entry, 0, len(auth.oauthPendingSessions))
for id, session := range auth.oauthPendingSessions {
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
}
for id, session := range auth.oauthPendingSessions {
if oldestTime == 0 {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
continue
}
if slices.Contains(cleanupIds, id) {
continue
}
if session.ExpiresAt.Unix() < oldestTime {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
}
}
cleanupIds = append(cleanupIds, oldestId)
slices.SortFunc(entries, func(a, b entry) int {
if a.expiresAt < b.expiresAt {
return -1
}
for _, id := range cleanupIds {
delete(auth.oauthPendingSessions, id)
if a.expiresAt > b.expiresAt {
return 1
}
return 0
})
for _, e := range entries[:OAuthCleanupCount] {
delete(auth.oauthPendingSessions, e.id)
}
}
func (auth *AuthService) lockdownMode() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel
auth.loginMutex.Lock()
if auth.lockdown != nil && auth.lockdown.Active {
auth.loginMutex.Unlock()
cancel()
return
}
auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
@@ -853,10 +694,12 @@ func (auth *AuthService) lockdownMode() {
auth.loginAttempts = make(map[string]*LoginAttempt)
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
defer timer.Stop()
auth.loginMutex.Unlock()
defer cancel()
defer timer.Stop()
select {
case <-timer.C:
// Timer expired, end lockdown
+8 -2
View File
@@ -85,17 +85,23 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
return nil, err
}
var nameMatch *model.App
// First try to find a matching app by domain, then fallback to matching by app name (subdomain)
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain {
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil
}
if strings.SplitN(appDomain, ".", 2)[0] == appName {
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil
nameMatch = &appLabels
}
}
if nameMatch != nil {
return nameMatch, nil
}
}
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
+6 -7
View File
@@ -134,8 +134,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
return ldap.conn, nil
}
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
func (ldap *LdapService) GetUserInfo(username string) (dn string, email string, err error) {
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
@@ -143,7 +142,7 @@ func (ldap *LdapService) GetUserDN(username string) (string, error) {
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
[]string{"dn", "mail"},
nil,
)
@@ -152,15 +151,15 @@ func (ldap *LdapService) GetUserDN(username string) (string, error) {
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return "", err
return "", "", err
}
if len(searchResult.Entries) != 1 {
return "", fmt.Errorf("multiple or no entries found for user %s", username)
return "", "", fmt.Errorf("multiple or no entries found for user %s", username)
}
userDN := searchResult.Entries[0].DN
return userDN, nil
entry := searchResult.Entries[0]
return entry.DN, entry.GetAttributeValue("mail"), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
+1
View File
@@ -26,6 +26,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Insecure,
MinVersion: tls.VersionTLS12,
},
},
}
+33 -22
View File
@@ -7,7 +7,6 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/json"
"encoding/pem"
@@ -116,12 +115,12 @@ type OIDCService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
queries *repository.Queries
queries repository.Store
context context.Context
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
publicKey *rsa.PublicKey
issuer string
}
@@ -129,7 +128,7 @@ func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
queries repository.Store,
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
@@ -239,6 +238,16 @@ func NewOIDCService(
}
}
rPublicKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return nil, fmt.Errorf("public key is not an rsa public key")
}
if rPublicKey.N.Cmp(privateKey.N) != 0 || rPublicKey.E != privateKey.E {
return nil, fmt.Errorf("public key does not pair with private key")
}
// We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig)
@@ -271,7 +280,7 @@ func NewOIDCService(
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
publicKey: rPublicKey,
issuer: issuer,
}
@@ -297,6 +306,11 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("access_denied")
}
// Redirect URI to verify that it's trusted
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
return errors.New("invalid_request_uri")
}
// Scopes
scopes := strings.Split(req.Scope, " ")
@@ -318,11 +332,6 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("unsupported_response_type")
}
// Redirect URI
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
return errors.New("invalid_request_uri")
}
// PKCE code challenge method if set
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" {
@@ -424,7 +433,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
@@ -455,7 +464,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
hasher := sha256.New()
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
der := x509.MarshalPKCS1PublicKey(service.publicKey)
if der == nil {
return "", errors.New("failed to marshal public key")
@@ -568,7 +577,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
@@ -647,7 +656,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
@@ -735,15 +744,15 @@ func (service *OIDCService) Hash(token string) string {
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
return nil
@@ -783,14 +792,16 @@ func (service *OIDCService) cleanupRoutine() {
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
if !errors.Is(err, repository.ErrNotFound) {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
continue
}
@@ -813,7 +824,7 @@ func (service *OIDCService) cleanupRoutine() {
func (service *OIDCService) GetJWK() ([]byte, error) {
hasher := sha256.New()
der := x509.MarshalPKCS1PublicKey(&service.privateKey.PublicKey)
der := x509.MarshalPKCS1PublicKey(service.publicKey)
if der == nil {
return nil, errors.New("failed to marshal public key")
@@ -822,13 +833,13 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
hasher.Write(der)
jwk := jose.JSONWebKey{
Key: service.privateKey,
Key: service.publicKey,
Algorithm: string(jose.RS256),
Use: "sig",
KeyID: base64.URLEncoding.EncodeToString(hasher.Sum(nil)),
}
return jwk.Public().MarshalJSON()
return jwk.MarshalJSON()
}
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool {
+110
View File
@@ -0,0 +1,110 @@
package service
import (
"fmt"
"net"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type Policy string
const (
PolicyAllow Policy = "allow"
PolicyDeny Policy = "deny"
)
type Effect int
const (
EffectAbstain Effect = iota
EffectAllow
EffectDeny
)
type Rule interface {
Evaluate(ctx *ACLContext) Effect
}
type ACLContext struct {
ACLs *model.App
UserContext *model.UserContext
IP net.IP
Path string
}
type PolicyEngine struct {
log *logger.Logger
rules map[RuleName]Rule
policy Policy
}
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
engine := PolicyEngine{
log: log,
rules: make(map[RuleName]Rule),
}
switch config.Auth.ACLs.Policy {
case string(PolicyAllow):
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow
case string(PolicyDeny):
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny
default:
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
}
return &engine, nil
}
func (engine *PolicyEngine) RegisterRule(name RuleName, rule Rule) {
engine.log.App.Debug().Str("rule", string(name)).Msg("Registering ACL rule in policy engine")
engine.rules[name] = rule
}
func (engine *PolicyEngine) evaluateRuleByName(name RuleName, ctx *ACLContext) Effect {
rule, exists := engine.rules[name]
if !exists {
engine.log.App.Warn().Str("rule", string(name)).Msg("Rule not found in policy engine, defaulting to deny")
return EffectDeny
}
return rule.Evaluate(ctx)
}
func (engine *PolicyEngine) effectToAccess(effect Effect) bool {
switch effect {
case EffectAllow:
return true
case EffectDeny:
return false
default:
// If the effect is abstain, we fall back to the default policy
return engine.policy == PolicyAllow
}
}
func (engine *PolicyEngine) Evaluate(name RuleName, ctx *ACLContext) bool {
effect := engine.evaluateRuleByName(name, ctx)
access := engine.effectToAccess(effect)
engine.log.App.Debug().
Str("rule", string(name)).
Int("effect", int(effect)).
Bool("access", access).
Msg("Evaluated ACL rule")
return access
}
func (engine *PolicyEngine) Policy() Policy {
return engine.policy
}
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
return engine.rules
}
+93
View File
@@ -0,0 +1,93 @@
package service_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
// Create test rule
type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
switch ctx.Path {
case "/allowed":
return service.EffectAllow
case "/denied":
return service.EffectDeny
default:
return service.EffectAbstain
}
}
func TestPolicyEngine(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
cfg, _ := test.CreateTestConfigs(t)
testRule := &TestRule{}
// Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(cfg, log)
assert.Error(t, err)
// Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err := service.NewPolicyEngine(cfg, log)
assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log)
assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy())
// Engine should allow adding rules
engine, err = service.NewPolicyEngine(cfg, log)
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
_, ok := engine.Rules()["test-rule"]
assert.True(t, ok)
// Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err = service.NewPolicyEngine(cfg, log)
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied
ctx.Path = "/denied"
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule abstains, access should be allowed (default)
ctx.Path = "/abstain"
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log)
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
// With deny policy, if rule allows, access should be allowed
ctx.Path = "/allowed"
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With deny policy, if rule denies, access should be denied
ctx.Path = "/denied"
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
// With deny policy, if rule abstains, access should be denied (default)
ctx.Path = "/abstain"
assert.Equal(t, false, engine.Evaluate("test-rule", ctx))
}