diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 521c3c8..9dbd105 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -3,12 +3,12 @@ package auth import ( "fmt" "regexp" - "slices" "strings" "sync" "time" "tinyauth/internal/docker" "tinyauth/internal/types" + "tinyauth/internal/utils" "github.com/gin-gonic/gin" "github.com/gorilla/sessions" @@ -278,27 +278,14 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bo // Check if oauth is allowed if context.OAuth { - if len(labels.OAuthWhitelist) == 0 { - return true, nil - } log.Debug().Msg("Checking OAuth whitelist") - if slices.Contains(labels.OAuthWhitelist, context.Username) { - return true, nil - } + return utils.CheckWhitelist(labels.OAuthWhitelist, context.Username), nil } - // Check if user is allowed - if len(labels.Users) != 0 { - log.Debug().Msg("Checking users") - if slices.Contains(labels.Users, context.Username) { - return true, nil - } - } else { - return true, nil - } + // Check users + log.Debug().Msg("Checking users") - // Not allowed - return false, nil + return utils.CheckWhitelist(labels.Users, context.Username), nil } func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { diff --git a/internal/types/types.go b/internal/types/types.go index 19d877d..dc6f9c1 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -31,8 +31,8 @@ type SessionCookie struct { // TinyauthLabels is the labels for the tinyauth container type TinyauthLabels struct { - OAuthWhitelist []string - Users []string + OAuthWhitelist string + Users string Allowed string Headers map[string]string } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index b60f8c6..772c125 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -4,6 +4,7 @@ import ( "errors" "net/url" "os" + "regexp" "slices" "strings" "tinyauth/internal/constants" @@ -188,9 +189,9 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { // Add the label value to the tinyauth labels struct switch label { case "tinyauth.oauth.whitelist": - tinyauthLabels.OAuthWhitelist = strings.Split(value, ",") + tinyauthLabels.OAuthWhitelist = value case "tinyauth.users": - tinyauthLabels.Users = strings.Split(value, ",") + tinyauthLabels.Users = value case "tinyauth.allowed": tinyauthLabels.Allowed = value case "tinyauth.headers": @@ -283,3 +284,42 @@ func ParseSecretFile(contents string) string { // Return an empty string return "" } + +// Check if a string matches a regex or a whitelist +func CheckWhitelist(whitelist string, str string) bool { + // Check if the whitelist is empty + if len(whitelist) == 0 { + return true + } + + // Check if the whitelist is a regex + if strings.HasPrefix(whitelist, "/") && strings.HasSuffix(whitelist, "/") { + // Create regex + re, err := regexp.Compile(whitelist[1 : len(whitelist)-1]) + + // Check if there was an error + if err != nil { + log.Error().Err(err).Msg("Error compiling regex") + return false + } + + // Check if the string matches the regex + if re.MatchString(str) { + return true + } + } + + // Split the whitelist by comma + whitelistSplit := strings.Split(whitelist, ",") + + // Loop through the whitelist + for _, item := range whitelistSplit { + // Check if the item matches with the string + if strings.TrimSpace(item) == str { + return true + } + } + + // Return false if no match was found + return false +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index e859f8d..041da98 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -286,15 +286,19 @@ func TestGetTinyauthLabels(t *testing.T) { // Test the get tinyauth labels function with a valid map labels := map[string]string{ "tinyauth.users": "user1,user2", - "tinyauth.oauth.whitelist": "user1,user2", + "tinyauth.oauth.whitelist": "/regex/", "tinyauth.allowed": "random", "random": "random", + "tinyauth.headers": "X-Header=value", } expected := types.TinyauthLabels{ - Users: []string{"user1", "user2"}, - OAuthWhitelist: []string{"user1", "user2"}, + Users: "user1,user2", + OAuthWhitelist: "/regex/", Allowed: "random", + Headers: map[string]string{ + "X-Header": "value", + }, } result := utils.GetTinyauthLabels(labels) @@ -385,3 +389,81 @@ func TestParseUser(t *testing.T) { t.Fatalf("Expected error parsing user") } } + +// Test the whitelist function +func TestCheckWhitelist(t *testing.T) { + t.Log("Testing check whitelist with a comma whitelist") + + // Create variables + whitelist := "user1,user2,user3" + str := "user1" + expected := true + + // Test the check whitelist function + result := utils.CheckWhitelist(whitelist, str) + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing check whitelist with a regex whitelist") + + // Create variables + whitelist = "/^user[0-9]+$/" + str = "user1" + expected = true + + // Test the check whitelist function + result = utils.CheckWhitelist(whitelist, str) + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing check whitelist with an empty whitelist") + + // Create variables + whitelist = "" + str = "user1" + expected = true + + // Test the check whitelist function + result = utils.CheckWhitelist(whitelist, str) + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing check whitelist with an invalid regex whitelist") + + // Create variables + whitelist = "/^user[0-9+$/" + str = "user1" + expected = false + + // Test the check whitelist function + result = utils.CheckWhitelist(whitelist, str) + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing check whitelist with a non matching whitelist") + + // Create variables + whitelist = "user1,user2,user3" + str = "user4" + expected = false + + // Test the check whitelist function + result = utils.CheckWhitelist(whitelist, str) + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } +}