tests: move to testify for testing in utils

This commit is contained in:
Stavros
2026-05-04 20:25:16 +03:00
parent ff3c25c09d
commit 8f337aaff8
7 changed files with 59 additions and 64 deletions
+20 -21
View File
@@ -3,9 +3,8 @@ package utils_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestGetRootDomain(t *testing.T) { func TestGetRootDomain(t *testing.T) {
@@ -13,14 +12,14 @@ func TestGetRootDomain(t *testing.T) {
domain := "http://sub.tinyauth.app" domain := "http://sub.tinyauth.app"
expected := "tinyauth.app" expected := "tinyauth.app"
result, err := utils.GetCookieDomain(domain) result, err := utils.GetCookieDomain(domain)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain with multiple subdomains // Domain with multiple subdomains
domain = "http://b.c.tinyauth.app" domain = "http://b.c.tinyauth.app"
expected = "c.tinyauth.app" expected = "c.tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Invalid domain (only TLD) // Invalid domain (only TLD)
@@ -42,14 +41,14 @@ func TestGetRootDomain(t *testing.T) {
domain = "https://sub.tinyauth.app/path" domain = "https://sub.tinyauth.app/path"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// URL with port // URL with port
domain = "http://sub.tinyauth.app:8080" domain = "http://sub.tinyauth.app:8080"
expected = "tinyauth.app" expected = "tinyauth.app"
result, err = utils.GetCookieDomain(domain) result, err = utils.GetCookieDomain(domain)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, result) assert.Equal(t, expected, result)
// Domain managed by ICANN // Domain managed by ICANN
@@ -96,35 +95,35 @@ func TestFilter(t *testing.T) {
testFunc := func(n int) bool { return n%2 == 0 } testFunc := func(n int) bool { return n%2 == 0 }
expected := []int{2, 4} expected := []int{2, 4}
result := utils.Filter(slice, testFunc) result := utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result) assert.Equal(t, expected, result)
// Case with no matches // Case with no matches
slice = []int{1, 3, 5} slice = []int{1, 3, 5}
testFunc = func(n int) bool { return n%2 == 0 } testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{} expected = []int{}
result = utils.Filter(slice, testFunc) result = utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result) assert.Equal(t, expected, result)
// Case with all matches // Case with all matches
slice = []int{2, 4, 6} slice = []int{2, 4, 6}
testFunc = func(n int) bool { return n%2 == 0 } testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{2, 4, 6} expected = []int{2, 4, 6}
result = utils.Filter(slice, testFunc) result = utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result) assert.Equal(t, expected, result)
// Case with empty slice // Case with empty slice
slice = []int{} slice = []int{}
testFunc = func(n int) bool { return n%2 == 0 } testFunc = func(n int) bool { return n%2 == 0 }
expected = []int{} expected = []int{}
result = utils.Filter(slice, testFunc) result = utils.Filter(slice, testFunc)
assert.DeepEqual(t, expected, result) assert.Equal(t, expected, result)
// Case with different type (string) // Case with different type (string)
sliceStr := []string{"apple", "banana", "cherry"} sliceStr := []string{"apple", "banana", "cherry"}
testFuncStr := func(s string) bool { return len(s) > 5 } testFuncStr := func(s string) bool { return len(s) > 5 }
expectedStr := []string{"banana", "cherry"} expectedStr := []string{"banana", "cherry"}
resultStr := utils.Filter(sliceStr, testFuncStr) resultStr := utils.Filter(sliceStr, testFuncStr)
assert.DeepEqual(t, expectedStr, resultStr) assert.Equal(t, expectedStr, resultStr)
} }
func TestIsRedirectSafe(t *testing.T) { func TestIsRedirectSafe(t *testing.T) {
@@ -134,50 +133,50 @@ func TestIsRedirectSafe(t *testing.T) {
// Case with no subdomain // Case with no subdomain
redirectURL := "http://example.com/welcome" redirectURL := "http://example.com/welcome"
result := utils.IsRedirectSafe(redirectURL, domain) result := utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result) assert.True(t, result)
// Case with different domain // Case with different domain
redirectURL = "http://malicious.com/phishing" redirectURL = "http://malicious.com/phishing"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result) assert.False(t, result)
// Case with subdomain // Case with subdomain
redirectURL = "http://sub.example.com/page" redirectURL = "http://sub.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result) assert.True(t, result)
// Case with sub-subdomain // Case with sub-subdomain
redirectURL = "http://a.b.example.com/home" redirectURL = "http://a.b.example.com/home"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result) assert.True(t, result)
// Case with empty redirect URL // Case with empty redirect URL
redirectURL = "" redirectURL = ""
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result) assert.False(t, result)
// Case with invalid URL // Case with invalid URL
redirectURL = "http://[::1]:namedport" redirectURL = "http://[::1]:namedport"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result) assert.False(t, result)
// Case with URL having port // Case with URL having port
redirectURL = "http://sub.example.com:8080/page" redirectURL = "http://sub.example.com:8080/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result) assert.True(t, result)
// Case with URL having different subdomain // Case with URL having different subdomain
redirectURL = "http://another.example.com/page" redirectURL = "http://another.example.com/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, true, result) assert.True(t, result)
// Case with URL having different TLD // Case with URL having different TLD
redirectURL = "http://example.org/page" redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result) assert.False(t, result)
// Case with malicious domain // Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo" redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain) result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result) assert.False(t, result)
} }
@@ -3,10 +3,9 @@ package decoders_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"gotest.tools/v3/assert"
) )
func TestDecodeLabels(t *testing.T) { func TestDecodeLabels(t *testing.T) {
@@ -64,6 +63,6 @@ func TestDecodeLabels(t *testing.T) {
// Test // Test
result, err := decoders.DecodeLabels[model.Apps](test, "apps") result, err := decoders.DecodeLabels[model.Apps](test, "apps")
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, expected, result) assert.Equal(t, expected, result)
} }
+5 -5
View File
@@ -4,24 +4,24 @@ import (
"os" "os"
"testing" "testing"
"gotest.tools/v3/assert" "github.com/stretchr/testify/assert"
) )
func TestReadFile(t *testing.T) { func TestReadFile(t *testing.T) {
// Setup // Setup
file, err := os.Create("/tmp/tinyauth_test_file") file, err := os.Create("/tmp/tinyauth_test_file")
assert.NilError(t, err) assert.NoError(t, err)
_, err = file.WriteString("file content\n") _, err = file.WriteString("file content\n")
assert.NilError(t, err) assert.NoError(t, err)
err = file.Close() err = file.Close()
assert.NilError(t, err) assert.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_file") defer os.Remove("/tmp/tinyauth_test_file")
// Normal case // Normal case
content, err := ReadFile("/tmp/tinyauth_test_file") content, err := ReadFile("/tmp/tinyauth_test_file")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, "file content\n", content) assert.Equal(t, "file content\n", content)
// Non-existing file // Non-existing file
+6 -7
View File
@@ -3,9 +3,8 @@ package utils_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestParseHeaders(t *testing.T) { func TestParseHeaders(t *testing.T) {
@@ -18,7 +17,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value", "X-Custom-Header": "Value",
"Another-Header": "AnotherValue", "Another-Header": "AnotherValue",
} }
assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) assert.Equal(t, expected, utils.ParseHeaders(headers))
// Case insensitivity and trimming // Case insensitivity and trimming
headers = []string{ headers = []string{
@@ -29,7 +28,7 @@ func TestParseHeaders(t *testing.T) {
"X-Custom-Header": "Value", "X-Custom-Header": "Value",
"Another-Header": "AnotherValue", "Another-Header": "AnotherValue",
} }
assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) assert.Equal(t, expected, utils.ParseHeaders(headers))
// Invalid headers (missing '=', empty key/value) // Invalid headers (missing '=', empty key/value)
headers = []string{ headers = []string{
@@ -39,7 +38,7 @@ func TestParseHeaders(t *testing.T) {
" = ", " = ",
} }
expected = map[string]string{} expected = map[string]string{}
assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) assert.Equal(t, expected, utils.ParseHeaders(headers))
// Headers with unsafe characters // Headers with unsafe characters
headers = []string{ headers = []string{
@@ -52,7 +51,7 @@ func TestParseHeaders(t *testing.T) {
"Another-Header": "AnotherValue", "Another-Header": "AnotherValue",
"Good-Header": "GoodValue", "Good-Header": "GoodValue",
} }
assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) assert.Equal(t, expected, utils.ParseHeaders(headers))
// Header with spaces in key (should be ignored) // Header with spaces in key (should be ignored)
headers = []string{ headers = []string{
@@ -62,7 +61,7 @@ func TestParseHeaders(t *testing.T) {
expected = map[string]string{ expected = map[string]string{
"Valid-Header": "ValidValue", "Valid-Header": "ValidValue",
} }
assert.DeepEqual(t, expected, utils.ParseHeaders(headers)) assert.Equal(t, expected, utils.ParseHeaders(headers))
} }
func TestSanitizeHeader(t *testing.T) { func TestSanitizeHeader(t *testing.T) {
+10 -11
View File
@@ -4,21 +4,20 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestGetSecret(t *testing.T) { func TestGetSecret(t *testing.T) {
// Setup // Setup
file, err := os.Create("/tmp/tinyauth_test_secret") file, err := os.Create("/tmp/tinyauth_test_secret")
assert.NilError(t, err) assert.NoError(t, err)
_, err = file.WriteString(" secret \n") _, err = file.WriteString(" secret \n")
assert.NilError(t, err) assert.NoError(t, err)
err = file.Close() err = file.Close()
assert.NilError(t, err) assert.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret") defer os.Remove("/tmp/tinyauth_test_secret")
// Get from config // Get from config
@@ -78,27 +77,27 @@ func TestEncodeBasicAuth(t *testing.T) {
func TestFilterIP(t *testing.T) { func TestFilterIP(t *testing.T) {
// Exact match IPv4 // Exact match IPv4
ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1") ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
// Non-match IPv4 // Non-match IPv4
ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2") ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, false, ok) assert.Equal(t, false, ok)
// CIDR match IPv4 // CIDR match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2") ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
// CIDR match IPv4 with '-' instead of '/' // CIDR match IPv4 with '-' instead of '/'
ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5") ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, true, ok) assert.Equal(t, true, ok)
// CIDR non-match IPv4 // CIDR non-match IPv4
ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1") ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, false, ok) assert.Equal(t, false, ok)
// Invalid CIDR // Invalid CIDR
@@ -145,5 +144,5 @@ func TestGenerateUUID(t *testing.T) {
// Different output for different input // Different output for different input
id3 := utils.GenerateUUID("differentstring") id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3) assert.NotEqual(t, id2, id3)
} }
+1 -2
View File
@@ -3,9 +3,8 @@ package utils_test
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"gotest.tools/v3/assert"
) )
func TestCapitalize(t *testing.T) { func TestCapitalize(t *testing.T) {
+14 -14
View File
@@ -5,11 +5,11 @@ import (
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"gotest.tools/v3/assert"
) )
func TestNewLogger(t *testing.T) { func TestNewLogger(t *testing.T) {
@@ -25,25 +25,25 @@ func TestNewLogger(t *testing.T) {
logger := tlog.NewLogger(cfg) logger := tlog.NewLogger(cfg)
assert.Assert(t, logger != nil) assert.NotNil(t, logger)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel) assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Assert(t, logger.App.GetLevel() == zerolog.DebugLevel) assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
} }
func TestNewSimpleLogger(t *testing.T) { func TestNewSimpleLogger(t *testing.T) {
logger := tlog.NewSimpleLogger() logger := tlog.NewSimpleLogger()
assert.Assert(t, logger != nil) assert.NotNil(t, logger)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.InfoLevel) assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel())
assert.Assert(t, logger.App.GetLevel() == zerolog.InfoLevel) assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel())
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
} }
func TestLoggerInit(t *testing.T) { func TestLoggerInit(t *testing.T) {
logger := tlog.NewSimpleLogger() logger := tlog.NewSimpleLogger()
logger.Init() logger.Init()
assert.Assert(t, tlog.App.GetLevel() != zerolog.Disabled) assert.Equal(t, zerolog.Disabled, tlog.App.GetLevel())
} }
func TestLoggerWithDisabledStreams(t *testing.T) { func TestLoggerWithDisabledStreams(t *testing.T) {
@@ -59,9 +59,9 @@ func TestLoggerWithDisabledStreams(t *testing.T) {
logger := tlog.NewLogger(cfg) logger := tlog.NewLogger(cfg)
assert.Assert(t, logger.HTTP.GetLevel() == zerolog.Disabled) assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel())
assert.Assert(t, logger.App.GetLevel() == zerolog.Disabled) assert.Equal(t, zerolog.Disabled, logger.App.GetLevel())
assert.Assert(t, logger.Audit.GetLevel() == zerolog.Disabled) assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel())
} }
func TestLogStreamField(t *testing.T) { func TestLogStreamField(t *testing.T) {
@@ -86,7 +86,7 @@ func TestLogStreamField(t *testing.T) {
var logEntry map[string]interface{} var logEntry map[string]interface{}
err := json.Unmarshal(buf.Bytes(), &logEntry) err := json.Unmarshal(buf.Bytes(), &logEntry)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, "http", logEntry["log_stream"]) assert.Equal(t, "http", logEntry["log_stream"])
assert.Equal(t, "test message", logEntry["message"]) assert.Equal(t, "test message", logEntry["message"])