mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-18 02:00:12 +00:00
wip: use policy engine for acls
This commit is contained in:
@@ -3,7 +3,7 @@ package utils
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -46,26 +46,27 @@ func EncodeBasicAuth(username string, password string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
func FilterIP(filter string, ip string) (bool, error) {
|
||||
func CheckIPFilter(filter string, ip string) (bool, error) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
||||
if ipAddr == nil {
|
||||
return false, errors.New("invalid IP address")
|
||||
return false, fmt.Errorf("invalid ip address")
|
||||
}
|
||||
|
||||
filter = strings.Replace(filter, "-", "/", -1)
|
||||
filter = strings.ReplaceAll(filter, "-", "/")
|
||||
|
||||
if strings.Contains(filter, "/") {
|
||||
_, cidr, err := net.ParseCIDR(filter)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("invalid cidr notation: %w", err)
|
||||
}
|
||||
return cidr.Contains(ipAddr), nil
|
||||
}
|
||||
|
||||
ipFilter := net.ParseIP(filter)
|
||||
|
||||
if ipFilter == nil {
|
||||
return false, errors.New("invalid IP address in filter")
|
||||
return false, fmt.Errorf("invalid ip address")
|
||||
}
|
||||
|
||||
if ipFilter.Equal(ipAddr) {
|
||||
@@ -75,31 +76,29 @@ func FilterIP(filter string, ip string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func CheckFilter(filter string, str string) bool {
|
||||
func CheckFilter(filter string, input string) (bool, error) {
|
||||
if len(strings.TrimSpace(filter)) == 0 {
|
||||
return true
|
||||
return false, fmt.Errorf("filter is empty")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||
re, err := regexp.Compile(filter[1 : len(filter)-1])
|
||||
if err != nil {
|
||||
return false
|
||||
return false, fmt.Errorf("invalid regex filter: %w", err)
|
||||
}
|
||||
|
||||
if re.MatchString(strings.TrimSpace(str)) {
|
||||
return true
|
||||
if re.MatchString(input) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
filterSplit := strings.Split(filter, ",")
|
||||
|
||||
for _, item := range filterSplit {
|
||||
if strings.TrimSpace(item) == strings.TrimSpace(str) {
|
||||
return true
|
||||
for item := range strings.SplitSeq(filter, ",") {
|
||||
if strings.TrimSpace(item) == input {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func GenerateUUID(str string) string {
|
||||
|
||||
@@ -75,66 +75,77 @@ func TestEncodeBasicAuth(t *testing.T) {
|
||||
assert.Equal(t, expected, utils.EncodeBasicAuth(username, password))
|
||||
}
|
||||
|
||||
func TestFilterIP(t *testing.T) {
|
||||
func TestCheckIPFilter(t *testing.T) {
|
||||
// Exact match IPv4
|
||||
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
|
||||
ok, err := utils.CheckIPFilter("10.10.0.1", "10.10.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// Non-match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
|
||||
ok, err = utils.CheckIPFilter("10.10.0.1", "10.10.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// CIDR match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
|
||||
ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.10.0.2")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR match IPv4 with '-' instead of '/'
|
||||
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
|
||||
ok, err = utils.CheckIPFilter("10.10.10.0-24", "10.10.10.5")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// CIDR non-match IPv4
|
||||
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
|
||||
ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.5.0.1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid CIDR
|
||||
ok, err = utils.FilterIP("10.10.0.0/222", "10.0.0.1")
|
||||
assert.ErrorContains(t, err, "invalid CIDR address")
|
||||
ok, err = utils.CheckIPFilter("10.10.0.0/222", "10.0.0.1")
|
||||
assert.ErrorContains(t, err, "invalid cidr notation: invalid CIDR address: 10.10.0.0/222")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid IP in filter
|
||||
ok, err = utils.FilterIP("invalid_ip", "10.5.5.5")
|
||||
assert.ErrorContains(t, err, "invalid IP address in filter")
|
||||
ok, err = utils.CheckIPFilter("invalid_ip", "10.5.5.5")
|
||||
assert.ErrorContains(t, err, "invalid ip address")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Invalid IP to check
|
||||
ok, err = utils.FilterIP("10.10.10.10", "invalid_ip")
|
||||
assert.ErrorContains(t, err, "invalid IP address")
|
||||
ok, err = utils.CheckIPFilter("10.10.10.10", "invalid_ip")
|
||||
assert.ErrorContains(t, err, "invalid ip address")
|
||||
assert.Equal(t, false, ok)
|
||||
}
|
||||
|
||||
func TestCheckFilter(t *testing.T) {
|
||||
// Empty filter
|
||||
assert.Equal(t, true, utils.CheckFilter("", "anystring"))
|
||||
_, err := utils.CheckFilter("", "anystring")
|
||||
assert.ErrorContains(t, err, "filter is empty")
|
||||
|
||||
// Exact match
|
||||
assert.Equal(t, true, utils.CheckFilter("hello", "hello"))
|
||||
ok, err := utils.CheckFilter("hello", "hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// Regex match
|
||||
assert.Equal(t, true, utils.CheckFilter("/^h.*o$/", "hello"))
|
||||
ok, err = utils.CheckFilter("/^h.*o$/", "hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// Invalid regex
|
||||
assert.Equal(t, false, utils.CheckFilter("/[unclosed", "test"))
|
||||
ok, err = utils.CheckFilter("/[unclosed/", "test")
|
||||
assert.ErrorContains(t, err, "invalid regex")
|
||||
assert.Equal(t, false, ok)
|
||||
|
||||
// Comma-separated values
|
||||
assert.Equal(t, true, utils.CheckFilter("apple, banana, cherry", "banana"))
|
||||
ok, err = utils.CheckFilter("apple, banana, cherry", "banana")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, ok)
|
||||
|
||||
// No match
|
||||
assert.Equal(t, false, utils.CheckFilter("apple, banana, cherry", "grape"))
|
||||
ok, err = utils.CheckFilter("apple, banana, cherry", "grape")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, false, ok)
|
||||
}
|
||||
|
||||
func TestGenerateUUID(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user