refactor: simplify acls checking logic by passing the entire acl struct

This commit is contained in:
Stavros
2026-05-04 16:13:39 +03:00
parent 62ffd2fd11
commit df56708b9a
6 changed files with 59 additions and 81 deletions
+6 -7
View File
@@ -1,7 +1,6 @@
package service
import (
"errors"
"strings"
"github.com/tinyauthapp/tinyauth/internal/model"
@@ -28,26 +27,26 @@ func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) {
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
for app, config := range acls.static {
if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
return &config, nil
return &config
}
if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
return &config, nil
return &config
}
}
return nil, errors.New("no results")
return nil
}
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
// First check in the static config
app, err := acls.lookupStaticACLs(domain)
app := acls.lookupStaticACLs(domain)
if err == nil {
if app != nil {
tlog.App.Debug().Msg("Using ACls from static configuration")
return app, nil
}
+20 -35
View File
@@ -464,8 +464,8 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
@@ -480,8 +480,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
@@ -490,8 +490,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
return false
}
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
@@ -501,8 +501,8 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
@@ -511,14 +511,14 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, error) {
if path == nil {
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
if acls == nil {
return true, nil
}
// Check for block list
if path.Block != "" {
regex, err := regexp.Compile(path.Block)
if acls.Path.Block != "" {
regex, err := regexp.Compile(acls.Path.Block)
if err != nil {
return true, err
@@ -530,8 +530,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e
}
// Check for allow list
if path.Allow != "" {
regex, err := regexp.Compile(path.Allow)
if acls.Path.Allow != "" {
regex, err := regexp.Compile(acls.Path.Allow)
if err != nil {
return true, err
@@ -545,29 +545,14 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e
return true, nil
}
// local user is used only as a medium to pass the basic auth credentials, user can be ldap too
func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) {
if req == nil {
return nil, errors.New("request is nil")
}
username, password, ok := req.BasicAuth()
if !ok {
return nil, errors.New("no basic auth credentials provided")
}
return &model.LocalUser{
Username: username,
Password: password,
}, nil
}
func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool {
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
if acls == nil {
acls = &model.AppIP{}
return true
}
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
@@ -602,12 +587,12 @@ func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool {
return true
}
func (auth *AuthService) IsBypassedIP(acls *model.AppIP, ip string) bool {
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
if acls == nil {
return false
}
for _, bypassed := range acls.Bypass {
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
+2 -10
View File
@@ -51,19 +51,11 @@ func (docker *DockerService) Init() error {
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
if err != nil {
return nil, err
}
return containers, nil
return docker.client.ContainerList(docker.context, container.ListOptions{})
}
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
inspect, err := docker.client.ContainerInspect(docker.context, containerId)
if err != nil {
return container.InspectResponse{}, err
}
return inspect, nil
return docker.client.ContainerInspect(docker.context, containerId)
}
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
+10 -8
View File
@@ -89,7 +89,7 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
}
}
func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
func (k *KubernetesService) getByDomain(domain string) *model.App {
k.mu.RLock()
defer k.mu.RUnlock()
@@ -97,15 +97,15 @@ func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps {
if app.domain == domain && app.appName == appKey.appName {
return &app.app, true
return &app.app
}
}
}
}
return nil, false
return nil
}
func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
func (k *KubernetesService) getByAppName(appName string) *model.App {
k.mu.RLock()
defer k.mu.RUnlock()
@@ -113,12 +113,12 @@ func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps {
if app.appName == appName {
return &app.app, true
return &app.app
}
}
}
}
return nil, false
return nil
}
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
@@ -287,12 +287,14 @@ func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
}
// First check cache
if app, found := k.getByDomain(appDomain); found {
app := k.getByDomain(appDomain)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil
}
appName := strings.SplitN(appDomain, ".", 2)[0]
if app, found := k.getByAppName(appName); found {
app = k.getByAppName(appName)
if app != nil {
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil
}