diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 62b9592..57e40f4 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -13,13 +13,13 @@ import ( ) // Get root domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetRootDomain(appUrl string) (string, error) { - appUrlParsed, err := url.Parse(appUrl) +func GetRootDomain(u string) (string, error) { + appUrl, err := url.Parse(u) if err != nil { return "", err } - host := appUrlParsed.Hostname() + host := appUrl.Hostname() if netIP := net.ParseIP(host); netIP != nil { return "", errors.New("IP addresses are not allowed") @@ -27,7 +27,7 @@ func GetRootDomain(appUrl string) (string, error) { urlParts := strings.Split(host, ".") - if len(urlParts) < 2 { + if len(urlParts) < 3 { return "", errors.New("invalid domain, must be at least second level domain") } @@ -49,6 +49,7 @@ func ParseFileToLine(content string) string { } func Filter[T any](slice []T, test func(T) bool) (res []T) { + res = make([]T, 0) for _, value := range slice { if test(value) { res = append(res, value) diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go new file mode 100644 index 0000000..1540c76 --- /dev/null +++ b/internal/utils/app_utils_test.go @@ -0,0 +1,197 @@ +package utils_test + +import ( + "testing" + "tinyauth/internal/config" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "gotest.tools/v3/assert" +) + +func TestGetRootDomain(t *testing.T) { + // Normal case + domain := "http://sub.example.com" + expected := "example.com" + result, err := utils.GetRootDomain(domain) + assert.NilError(t, err) + assert.Equal(t, expected, result) + + // Domain with multiple subdomains + domain = "http://b.c.example.com" + expected = "c.example.com" + result, err = utils.GetRootDomain(domain) + assert.NilError(t, err) + assert.Equal(t, expected, result) + + // Domain with no subdomain + domain = "http://example.com" + expected = "example.com" + _, err = utils.GetRootDomain(domain) + assert.Error(t, err, "invalid domain, must be at least second level domain") + + // Invalid domain (only TLD) + domain = "com" + _, err = utils.GetRootDomain(domain) + assert.ErrorContains(t, err, "invalid domain") + + // IP address + domain = "http://10.10.10.10" + _, err = utils.GetRootDomain(domain) + assert.ErrorContains(t, err, "IP addresses are not allowed") + + // Invalid URL + domain = "http://[::1]:namedport" + _, err = utils.GetRootDomain(domain) + assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") + + // URL with scheme and path + domain = "https://sub.example.com/path" + expected = "example.com" + result, err = utils.GetRootDomain(domain) + assert.NilError(t, err) + assert.Equal(t, expected, result) + + // URL with port + domain = "http://sub.example.com:8080" + expected = "example.com" + result, err = utils.GetRootDomain(domain) + assert.NilError(t, err) + assert.Equal(t, expected, result) +} + +func TestParseFileToLine(t *testing.T) { + // Normal case + content := "user1\nuser2\nuser3" + expected := "user1,user2,user3" + result := utils.ParseFileToLine(content) + assert.Equal(t, expected, result) + + // Case with empty lines and spaces + content = " user1 \n\n user2 \n user3 \n" + expected = "user1,user2,user3" + result = utils.ParseFileToLine(content) + assert.Equal(t, expected, result) + + // Case with only empty lines + content = "\n\n\n" + expected = "" + result = utils.ParseFileToLine(content) + assert.Equal(t, expected, result) + + // Case with single user + content = "singleuser" + expected = "singleuser" + result = utils.ParseFileToLine(content) + assert.Equal(t, expected, result) + + // Case with trailing newline + content = "user1\nuser2\n" + expected = "user1,user2" + result = utils.ParseFileToLine(content) + assert.Equal(t, expected, result) +} + +func TestFilter(t *testing.T) { + // Normal case + slice := []int{1, 2, 3, 4, 5} + testFunc := func(n int) bool { return n%2 == 0 } + expected := []int{2, 4} + result := utils.Filter(slice, testFunc) + assert.DeepEqual(t, expected, result) + + // Case with no matches + slice = []int{1, 3, 5} + testFunc = func(n int) bool { return n%2 == 0 } + expected = []int{} + result = utils.Filter(slice, testFunc) + assert.DeepEqual(t, expected, result) + + // Case with all matches + slice = []int{2, 4, 6} + testFunc = func(n int) bool { return n%2 == 0 } + expected = []int{2, 4, 6} + result = utils.Filter(slice, testFunc) + assert.DeepEqual(t, expected, result) + + // Case with empty slice + slice = []int{} + testFunc = func(n int) bool { return n%2 == 0 } + expected = []int{} + result = utils.Filter(slice, testFunc) + assert.DeepEqual(t, expected, result) + + // Case with different type (string) + sliceStr := []string{"apple", "banana", "cherry"} + testFuncStr := func(s string) bool { return len(s) > 5 } + expectedStr := []string{"banana", "cherry"} + resultStr := utils.Filter(sliceStr, testFuncStr) + assert.DeepEqual(t, expectedStr, resultStr) +} + +func TestGetContext(t *testing.T) { + // Setup + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(nil) + + // Normal case + c.Set("context", &config.UserContext{Username: "testuser"}) + result, err := utils.GetContext(c) + assert.NilError(t, err) + assert.Equal(t, "testuser", result.Username) + + // Case with no context + c.Set("context", nil) + _, err = utils.GetContext(c) + assert.Error(t, err, "invalid user context in request") + + // Case with invalid context type + c.Set("context", "invalid type") + _, err = utils.GetContext(c) + assert.Error(t, err, "invalid user context in request") +} + +func TestIsRedirectSafe(t *testing.T) { + // Setup + domain := "example.com" + + // Case with no subdomain + redirectURL := "http://example.com/welcome" + result := utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, false, result) + + // Case with different domain + redirectURL = "http://malicious.com/phishing" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, false, result) + + // Case with subdomain + redirectURL = "http://sub.example.com/page" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, true, result) + + // Case with empty redirect URL + redirectURL = "" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, false, result) + + // Case with invalid URL + redirectURL = "http://[::1]:namedport" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, false, result) + + // Case with URL having port + redirectURL = "http://sub.example.com:8080/page" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, true, result) + + // Case with URL having different subdomain + redirectURL = "http://another.example.com/page" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, true, result) + + // Case with URL having different TLD + redirectURL = "http://example.org/page" + result = utils.IsRedirectSafe(redirectURL, domain) + assert.Equal(t, false, result) +} diff --git a/internal/utils/fs_utils_test.go b/internal/utils/fs_utils_test.go new file mode 100644 index 0000000..54033ba --- /dev/null +++ b/internal/utils/fs_utils_test.go @@ -0,0 +1,31 @@ +package utils + +import ( + "os" + "testing" + + "gotest.tools/v3/assert" +) + +func TestReadFile(t *testing.T) { + // Setup + file, err := os.Create("/tmp/tinyauth_test_file") + assert.NilError(t, err) + + _, err = file.WriteString("file content\n") + assert.NilError(t, err) + + err = file.Close() + assert.NilError(t, err) + defer os.Remove("/tmp/tinyauth_test_file") + + // Normal case + content, err := ReadFile("/tmp/tinyauth_test_file") + assert.NilError(t, err) + assert.Equal(t, "file content\n", content) + + // Non-existing file + content, err = ReadFile("/tmp/non_existing_file") + assert.ErrorContains(t, err, "no such file or directory") + assert.Equal(t, "", content) +} diff --git a/internal/utils/label_utils_test.go b/internal/utils/label_utils_test.go new file mode 100644 index 0000000..f38302d --- /dev/null +++ b/internal/utils/label_utils_test.go @@ -0,0 +1,87 @@ +package utils_test + +import ( + "testing" + "tinyauth/internal/utils" + + "gotest.tools/v3/assert" +) + +func TestParseHeaders(t *testing.T) { + // Normal case + headers := []string{ + "X-Custom-Header=Value", + "Another-Header=AnotherValue", + } + expected := map[string]string{ + "X-Custom-Header": "Value", + "Another-Header": "AnotherValue", + } + assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + + // Case insensitivity and trimming + headers = []string{ + " x-custom-header = Value ", + "ANOTHER-HEADER=AnotherValue", + } + expected = map[string]string{ + "X-Custom-Header": "Value", + "Another-Header": "AnotherValue", + } + assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + + // Invalid headers (missing '=', empty key/value) + headers = []string{ + "InvalidHeader", + "=NoKey", + "NoValue=", + " = ", + } + expected = map[string]string{} + assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + + // Headers with unsafe characters + headers = []string{ + "X-Custom-Header=Val\x00ue", // Null byte + "Another-Header=Anoth\x7FerValue", // DEL character + "Good-Header=GoodValue", + } + expected = map[string]string{ + "X-Custom-Header": "Value", + "Another-Header": "AnotherValue", + "Good-Header": "GoodValue", + } + assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) + + // Header with spaces in key (should be ignored) + headers = []string{ + "X Custom Header=Value", + "Valid-Header=ValidValue", + } + expected = map[string]string{ + "Valid-Header": "ValidValue", + } + assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) +} + +func TestSanitizeHeader(t *testing.T) { + // Normal case + header := "X-Custom-Header" + expected := "X-Custom-Header" + assert.Equal(t, expected, utils.SanitizeHeader(header)) + + // Header with unsafe characters + header = "X-Cust\x00om-Hea\x7Fder" // Null byte and DEL character + expected = "X-Custom-Header" + assert.Equal(t, expected, utils.SanitizeHeader(header)) + + // Header with only unsafe characters + header = "\x00\x01\x02\x7F" + expected = "" + assert.Equal(t, expected, utils.SanitizeHeader(header)) + + // Header with spaces and tabs (should be preserved) + header = "X Custom\tHeader" + expected = "X Custom\tHeader" + assert.Equal(t, expected, utils.SanitizeHeader(header)) +} diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 85a359d..91e17ee 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -48,6 +48,12 @@ func GetBasicAuth(username string, password string) string { func FilterIP(filter string, ip string) (bool, error) { ipAddr := net.ParseIP(ip) + if ipAddr == nil { + return false, errors.New("invalid IP address") + } + + filter = strings.Replace(filter, "-", "/", -1) + if strings.Contains(filter, "/") { _, cidr, err := net.ParseCIDR(filter) if err != nil { @@ -73,8 +79,6 @@ func CheckFilter(filter string, str string) bool { return true } - filter = strings.Replace(filter, "-", "/", -1) - if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { re, err := regexp.Compile(filter[1 : len(filter)-1]) if err != nil { diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go new file mode 100644 index 0000000..941f853 --- /dev/null +++ b/internal/utils/security_utils_test.go @@ -0,0 +1,151 @@ +package utils_test + +import ( + "os" + "testing" + "tinyauth/internal/utils" + + "gotest.tools/v3/assert" +) + +func TestGetSecret(t *testing.T) { + // Setup + file, err := os.Create("/tmp/tinyauth_test_secret") + assert.NilError(t, err) + + _, err = file.WriteString(" secret \n") + assert.NilError(t, err) + + err = file.Close() + assert.NilError(t, err) + defer os.Remove("/tmp/tinyauth_test_secret") + + // Get from config + assert.Equal(t, "mysecret", utils.GetSecret("mysecret", "")) + + // Get from file + assert.Equal(t, "secret", utils.GetSecret("", "/tmp/tinyauth_test_secret")) + + // Get from both (config should take precedence) + assert.Equal(t, "mysecret", utils.GetSecret("mysecret", "/tmp/tinyauth_test_secret")) + + // Get from none + assert.Equal(t, "", utils.GetSecret("", "")) + + // Get from non-existing file + assert.Equal(t, "", utils.GetSecret("", "/tmp/non_existing_file")) +} + +func TestParseSecretFile(t *testing.T) { + // Normal case + content := " mysecret \n" + assert.Equal(t, "mysecret", utils.ParseSecretFile(content)) + + // Multiple lines (should take the first non-empty line) + content = "\n\n firstsecret \nsecondsecret\n" + assert.Equal(t, "firstsecret", utils.ParseSecretFile(content)) + + // All empty lines + content = "\n \n \n" + assert.Equal(t, "", utils.ParseSecretFile(content)) + + // Empty content + content = "" + assert.Equal(t, "", utils.ParseSecretFile(content)) +} + +func TestGetBasicAuth(t *testing.T) { + // Normal case + username := "user" + password := "pass" + expected := "dXNlcjpwYXNz" // base64 of "user:pass" + assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + + // Empty username + username = "" + password = "pass" + expected = "OnBhc3M=" // base64 of ":pass" + assert.Equal(t, expected, utils.GetBasicAuth(username, password)) + + // Empty password + username = "user" + password = "" + expected = "dXNlcjo=" // base64 of "user:" + assert.Equal(t, expected, utils.GetBasicAuth(username, password)) +} + +func TestFilterIP(t *testing.T) { + // Exact match IPv4 + ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1") + assert.NilError(t, err) + assert.Equal(t, true, ok) + + // Non-match IPv4 + ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2") + assert.NilError(t, err) + assert.Equal(t, false, ok) + + // CIDR match IPv4 + ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2") + assert.NilError(t, err) + assert.Equal(t, true, ok) + + // CIDR match IPv4 with '-' instead of '/' + ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5") + assert.NilError(t, err) + assert.Equal(t, true, ok) + + // CIDR non-match IPv4 + ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1") + assert.NilError(t, err) + assert.Equal(t, false, ok) + + // Invalid CIDR + ok, err = utils.FilterIP("10.10.0.0/222", "10.0.0.1") + assert.ErrorContains(t, err, "invalid CIDR address") + assert.Equal(t, false, ok) + + // Invalid IP in filter + ok, err = utils.FilterIP("invalid_ip", "10.5.5.5") + assert.ErrorContains(t, err, "invalid IP address in filter") + assert.Equal(t, false, ok) + + // Invalid IP to check + ok, err = utils.FilterIP("10.10.10.10", "invalid_ip") + assert.ErrorContains(t, err, "invalid IP address") + assert.Equal(t, false, ok) +} + +func TestCheckFilter(t *testing.T) { + // Empty filter + assert.Equal(t, true, utils.CheckFilter("", "anystring")) + + // Exact match + assert.Equal(t, true, utils.CheckFilter("hello", "hello")) + + // Regex match + assert.Equal(t, true, utils.CheckFilter("/^h.*o$/", "hello")) + + // Invalid regex + assert.Equal(t, false, utils.CheckFilter("/[unclosed", "test")) + + // Comma-separated values + assert.Equal(t, true, utils.CheckFilter("apple, banana, cherry", "banana")) + + // No match + assert.Equal(t, false, utils.CheckFilter("apple, banana, cherry", "grape")) +} + +func TestGenerateIdentifier(t *testing.T) { + // Consistent output for same input + id1 := utils.GenerateIdentifier("teststring") + id2 := utils.GenerateIdentifier("teststring") + assert.Equal(t, id1, id2) + + // Different output for different input + id3 := utils.GenerateIdentifier("differentstring") + assert.Assert(t, id1 != id3) + + // Check length (should be 8 characters from first segment of UUID) + assert.Equal(t, 8, len(id1)) +} diff --git a/internal/utils/string_utils_test.go b/internal/utils/string_utils_test.go new file mode 100644 index 0000000..3677eb6 --- /dev/null +++ b/internal/utils/string_utils_test.go @@ -0,0 +1,50 @@ +package utils_test + +import ( + "testing" + "tinyauth/internal/utils" + + "gotest.tools/v3/assert" +) + +func TestCapitalize(t *testing.T) { + // Test empty string + assert.Equal(t, "", utils.Capitalize("")) + + // Test single character + assert.Equal(t, "A", utils.Capitalize("a")) + + // Test multiple characters + assert.Equal(t, "Hello", utils.Capitalize("hello")) + + // Test already capitalized + assert.Equal(t, "World", utils.Capitalize("World")) + + // Test non-alphabetic first character + assert.Equal(t, "1number", utils.Capitalize("1number")) + + // Test Unicode characters + assert.Equal(t, "Γειά", utils.Capitalize("γειά")) + assert.Equal(t, "Привет", utils.Capitalize("привет")) + +} + +func TestCoalesceToString(t *testing.T) { + // Test with []any containing strings + assert.Equal(t, "a,b,c", utils.CoalesceToString([]any{"a", "b", "c"})) + + // Test with []any containing mixed types + assert.Equal(t, "a,c", utils.CoalesceToString([]any{"a", 1, "c", true})) + + // Test with []any containing no strings + assert.Equal(t, "", utils.CoalesceToString([]any{1, 2, 3})) + + // Test with string input + assert.Equal(t, "hello", utils.CoalesceToString("hello")) + + // Test with non-string, non-[]any input + assert.Equal(t, "", utils.CoalesceToString(123)) + + // Test with nil input + assert.Equal(t, "", utils.CoalesceToString(nil)) +} diff --git a/internal/utils/user_utils_test.go b/internal/utils/user_utils_test.go new file mode 100644 index 0000000..d04636a --- /dev/null +++ b/internal/utils/user_utils_test.go @@ -0,0 +1,163 @@ +package utils_test + +import ( + "os" + "testing" + "tinyauth/internal/utils" + + "gotest.tools/v3/assert" +) + +func TestGetUsers(t *testing.T) { + // Setup + file, err := os.Create("/tmp/tinyauth_users_test.txt") + assert.NilError(t, err) + + _, err = file.WriteString(" user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G \n user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G ") // Spacing is on purpose + assert.NilError(t, err) + + err = file.Close() + assert.NilError(t, err) + defer os.Remove("/tmp/tinyauth_users_test.txt") + + // Test file + users, err := utils.GetUsers("", "/tmp/tinyauth_users_test.txt") + + assert.NilError(t, err) + + assert.Equal(t, 2, len(users)) + + assert.Equal(t, "user1", users[0].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password) + assert.Equal(t, "user2", users[1].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) + + // Test config + users, err = utils.GetUsers("user3:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G,user4:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "") + + assert.NilError(t, err) + + assert.Equal(t, 2, len(users)) + + assert.Equal(t, "user3", users[0].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password) + assert.Equal(t, "user4", users[1].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) + + // Test both + users, err = utils.GetUsers("user5:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "/tmp/tinyauth_users_test.txt") + + assert.NilError(t, err) + + assert.Equal(t, 3, len(users)) + + assert.Equal(t, "user5", users[0].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password) + assert.Equal(t, "user1", users[1].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) + assert.Equal(t, "user2", users[2].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[2].Password) + + // Test empty + users, err = utils.GetUsers("", "") + + assert.NilError(t, err) + + assert.Equal(t, 0, len(users)) + + // Test non-existent file + users, err = utils.GetUsers("", "/tmp/non_existent_file.txt") + + assert.ErrorContains(t, err, "no such file or directory") + + assert.Equal(t, 0, len(users)) +} + +func TestParseUsers(t *testing.T) { + // Valid users + users, err := utils.ParseUsers("user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G,user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF") // user2 has TOTP + + assert.NilError(t, err) + + assert.Equal(t, 2, len(users)) + + assert.Equal(t, "user1", users[0].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password) + assert.Equal(t, "", users[0].TotpSecret) + assert.Equal(t, "user2", users[1].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) + assert.Equal(t, "ABCDEF", users[1].TotpSecret) + + // Valid weirdly spaced users + users, err = utils.ParseUsers(" user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G , user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF ") // Spacing is on purpose + assert.NilError(t, err) + + assert.Equal(t, 2, len(users)) + + assert.Equal(t, "user1", users[0].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[0].Password) + assert.Equal(t, "", users[0].TotpSecret) + assert.Equal(t, "user2", users[1].Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) + assert.Equal(t, "ABCDEF", users[1].TotpSecret) +} + +func TestParseUser(t *testing.T) { + // Valid user without TOTP + user, err := utils.ParseUser("user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G") + + assert.NilError(t, err) + + assert.Equal(t, "user1", user.Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", user.Password) + assert.Equal(t, "", user.TotpSecret) + + // Valid user with TOTP + user, err = utils.ParseUser("user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF") + + assert.NilError(t, err) + + assert.Equal(t, "user2", user.Username) + assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", user.Password) + assert.Equal(t, "ABCDEF", user.TotpSecret) + + // Valid user with $$ in password + user, err = utils.ParseUser("user3:pa$$word123") + + assert.NilError(t, err) + + assert.Equal(t, "user3", user.Username) + assert.Equal(t, "pa$word123", user.Password) + assert.Equal(t, "", user.TotpSecret) + + // User with spaces + user, err = utils.ParseUser(" user4 : password123 : TOTPSECRET ") + + assert.NilError(t, err) + + assert.Equal(t, "user4", user.Username) + assert.Equal(t, "password123", user.Password) + assert.Equal(t, "TOTPSECRET", user.TotpSecret) + + // Invalid users + _, err = utils.ParseUser("user1") // Missing password + assert.ErrorContains(t, err, "invalid user format") + + _, err = utils.ParseUser("user1:") + assert.ErrorContains(t, err, "invalid user format") + + _, err = utils.ParseUser(":password123") + assert.ErrorContains(t, err, "invalid user format") + + _, err = utils.ParseUser("user1:password123:ABC:EXTRA") // Too many parts + assert.ErrorContains(t, err, "invalid user format") + + _, err = utils.ParseUser("user1::ABC") + assert.ErrorContains(t, err, "invalid user format") + + _, err = utils.ParseUser(":password123:ABC") + assert.ErrorContains(t, err, "invalid user format") + + _, err = utils.ParseUser(" : : ") + assert.ErrorContains(t, err, "invalid user format") +}