mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-08 05:18:11 +00:00
refactor: simplify acls checking logic by passing the entire acl struct
This commit is contained in:
@@ -107,7 +107,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
clientIP := c.ClientIP()
|
clientIP := c.ClientIP()
|
||||||
|
|
||||||
if controller.auth.IsBypassedIP(&acls.IP, clientIP) {
|
if controller.auth.IsBypassedIP(clientIP, acls) {
|
||||||
controller.setHeaders(c, *acls)
|
controller.setHeaders(c, *acls)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -116,7 +116,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, &acls.Path)
|
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
|
||||||
@@ -134,7 +134,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !controller.auth.CheckIP(&acls.IP, clientIP) {
|
if !controller.auth.CheckIP(clientIP, acls) {
|
||||||
queries, err := query.Values(UnauthorizedQuery{
|
queries, err := query.Values(UnauthorizedQuery{
|
||||||
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
Resource: strings.Split(proxyCtx.Host, ".")[0],
|
||||||
IP: clientIP,
|
IP: clientIP,
|
||||||
@@ -213,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
var groupOK bool
|
var groupOK bool
|
||||||
|
|
||||||
if userContext.IsOAuth() {
|
if userContext.IsOAuth() {
|
||||||
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls.OAuth.Groups)
|
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
|
||||||
} else {
|
} else {
|
||||||
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls.LDAP.Groups)
|
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !groupOK {
|
if !groupOK {
|
||||||
|
|||||||
@@ -86,10 +86,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
basic, err := m.auth.GetBasicAuth(c.Request)
|
username, password, ok := c.Request.BasicAuth()
|
||||||
|
|
||||||
if err == nil {
|
if ok {
|
||||||
userContext, headers, err := m.basicAuth(c.Request.Context(), basic)
|
userContext, headers, err := m.basicAuth(username, password)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||||
@@ -188,39 +188,39 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model
|
|||||||
return userContext, cookie, nil
|
return userContext, cookie, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) {
|
func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
userContext := new(model.UserContext)
|
userContext := new(model.UserContext)
|
||||||
locked, remaining := m.auth.IsAccountLocked(basic.Username)
|
locked, remaining := m.auth.IsAccountLocked(username)
|
||||||
|
|
||||||
if locked {
|
if locked {
|
||||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
|
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
||||||
headers["x-tinyauth-lock-locked"] = "true"
|
headers["x-tinyauth-lock-locked"] = "true"
|
||||||
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||||
return nil, headers, nil
|
return nil, headers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
search, err := m.auth.SearchUser(basic.Username)
|
search, err := m.auth.SearchUser(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("error searching for user: %w", err)
|
return nil, nil, fmt.Errorf("error searching for user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.auth.CheckUserPassword(*search, basic.Password)
|
err = m.auth.CheckUserPassword(*search, password)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
m.auth.RecordLoginAttempt(username, false)
|
||||||
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
|
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, true)
|
m.auth.RecordLoginAttempt(username, true)
|
||||||
|
|
||||||
switch search.Type {
|
switch search.Type {
|
||||||
case model.UserLocal:
|
case model.UserLocal:
|
||||||
user := m.auth.GetLocalUser(basic.Username)
|
user := m.auth.GetLocalUser(username)
|
||||||
|
|
||||||
if user.TOTPSecret != "" {
|
if user.TOTPSecret != "" {
|
||||||
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username)
|
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext.Local = &model.LocalContext{
|
userContext.Local = &model.LocalContext{
|
||||||
@@ -233,7 +233,7 @@ func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUse
|
|||||||
}
|
}
|
||||||
userContext.Provider = model.ProviderLocal
|
userContext.Provider = model.ProviderLocal
|
||||||
case model.UserLDAP:
|
case model.UserLDAP:
|
||||||
user, err := m.auth.GetLDAPUser(basic.Username)
|
user, err := m.auth.GetLDAPUser(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||||
@@ -241,9 +241,9 @@ func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUse
|
|||||||
|
|
||||||
userContext.LDAP = &model.LDAPContext{
|
userContext.LDAP = &model.LDAPContext{
|
||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
Username: basic.Username,
|
Username: username,
|
||||||
Name: utils.Capitalize(basic.Username),
|
Name: utils.Capitalize(username),
|
||||||
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
|
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
|
||||||
},
|
},
|
||||||
Groups: user.Groups,
|
Groups: user.Groups,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
@@ -28,26 +27,26 @@ func (acls *AccessControlsService) Init() error {
|
|||||||
return nil // No initialization needed
|
return nil // No initialization needed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) {
|
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
return &config, nil
|
return &config
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
return &config, nil
|
return &config
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("no results")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||||
// First check in the static config
|
// First check in the static config
|
||||||
app, err := acls.lookupStaticACLs(domain)
|
app := acls.lookupStaticACLs(domain)
|
||||||
|
|
||||||
if err == nil {
|
if app != nil {
|
||||||
tlog.App.Debug().Msg("Using ACls from static configuration")
|
tlog.App.Debug().Msg("Using ACls from static configuration")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -464,8 +464,8 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
|
|||||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
|
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
if requiredGroups == "" {
|
if acls == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -480,8 +480,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.OAuth.Groups {
|
for _, userGroup := range context.OAuth.Groups {
|
||||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -490,8 +490,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
|
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||||
if requiredGroups == "" {
|
if acls == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -501,8 +501,8 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, userGroup := range context.LDAP.Groups {
|
for _, userGroup := range context.LDAP.Groups {
|
||||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -511,14 +511,14 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, error) {
|
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
||||||
if path == nil {
|
if acls == nil {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for block list
|
// Check for block list
|
||||||
if path.Block != "" {
|
if acls.Path.Block != "" {
|
||||||
regex, err := regexp.Compile(path.Block)
|
regex, err := regexp.Compile(acls.Path.Block)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
@@ -530,8 +530,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for allow list
|
// Check for allow list
|
||||||
if path.Allow != "" {
|
if acls.Path.Allow != "" {
|
||||||
regex, err := regexp.Compile(path.Allow)
|
regex, err := regexp.Compile(acls.Path.Allow)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, err
|
return true, err
|
||||||
@@ -545,29 +545,14 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// local user is used only as a medium to pass the basic auth credentials, user can be ldap too
|
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||||
func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) {
|
|
||||||
if req == nil {
|
|
||||||
return nil, errors.New("request is nil")
|
|
||||||
}
|
|
||||||
username, password, ok := req.BasicAuth()
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("no basic auth credentials provided")
|
|
||||||
}
|
|
||||||
return &model.LocalUser{
|
|
||||||
Username: username,
|
|
||||||
Password: password,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool {
|
|
||||||
if acls == nil {
|
if acls == nil {
|
||||||
acls = &model.AppIP{}
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge the global and app IP filter
|
// Merge the global and app IP filter
|
||||||
blockedIps := append(auth.config.IP.Block, acls.Block...)
|
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
||||||
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
|
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
||||||
|
|
||||||
for _, blocked := range blockedIps {
|
for _, blocked := range blockedIps {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
@@ -602,12 +587,12 @@ func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsBypassedIP(acls *model.AppIP, ip string) bool {
|
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||||
if acls == nil {
|
if acls == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, bypassed := range acls.Bypass {
|
for _, bypassed := range acls.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
|
|||||||
@@ -51,19 +51,11 @@ func (docker *DockerService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
func (docker *DockerService) getContainers() ([]container.Summary, error) {
|
||||||
containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
|
return docker.client.ContainerList(docker.context, container.ListOptions{})
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return containers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
|
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
|
||||||
inspect, err := docker.client.ContainerInspect(docker.context, containerId)
|
return docker.client.ContainerInspect(docker.context, containerId)
|
||||||
if err != nil {
|
|
||||||
return container.InspectResponse{}, err
|
|
||||||
}
|
|
||||||
return inspect, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
|
func (k *KubernetesService) getByDomain(domain string) *model.App {
|
||||||
k.mu.RLock()
|
k.mu.RLock()
|
||||||
defer k.mu.RUnlock()
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
@@ -97,15 +97,15 @@ func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
|
|||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for _, app := range apps {
|
||||||
if app.domain == domain && app.appName == appKey.appName {
|
if app.domain == domain && app.appName == appKey.appName {
|
||||||
return &app.app, true
|
return &app.app
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
|
func (k *KubernetesService) getByAppName(appName string) *model.App {
|
||||||
k.mu.RLock()
|
k.mu.RLock()
|
||||||
defer k.mu.RUnlock()
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
@@ -113,12 +113,12 @@ func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
|
|||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for _, app := range apps {
|
||||||
if app.appName == appName {
|
if app.appName == appName {
|
||||||
return &app.app, true
|
return &app.app
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
||||||
@@ -287,12 +287,14 @@ func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First check cache
|
// First check cache
|
||||||
if app, found := k.getByDomain(appDomain); found {
|
app := k.getByDomain(appDomain)
|
||||||
|
if app != nil {
|
||||||
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
appName := strings.SplitN(appDomain, ".", 2)[0]
|
appName := strings.SplitN(appDomain, ".", 2)[0]
|
||||||
if app, found := k.getByAppName(appName); found {
|
app = k.getByAppName(appName)
|
||||||
|
if app != nil {
|
||||||
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user