From 04213836a1cbb85c0de088941694c553a63e4cca Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:13:53 +0300 Subject: [PATCH] refactor: split utils into smaller files --- internal/utils/fs_utils.go | 17 ++ internal/utils/header_utils.go | 29 +++ internal/utils/other_utils.go | 95 ++++++++ internal/utils/sec_utils.go | 124 +++++++++++ internal/utils/string_utils.go | 30 +++ internal/utils/user_utils.go | 82 +++++++ internal/utils/utils.go | 382 --------------------------------- 7 files changed, 377 insertions(+), 382 deletions(-) create mode 100644 internal/utils/fs_utils.go create mode 100644 internal/utils/header_utils.go create mode 100644 internal/utils/other_utils.go create mode 100644 internal/utils/sec_utils.go create mode 100644 internal/utils/string_utils.go create mode 100644 internal/utils/user_utils.go delete mode 100644 internal/utils/utils.go diff --git a/internal/utils/fs_utils.go b/internal/utils/fs_utils.go new file mode 100644 index 0000000..8b9f28b --- /dev/null +++ b/internal/utils/fs_utils.go @@ -0,0 +1,17 @@ +package utils + +import "os" + +func ReadFile(file string) (string, error) { + _, err := os.Stat(file) + if err != nil { + return "", err + } + + data, err := os.ReadFile(file) + if err != nil { + return "", err + } + + return string(data), nil +} diff --git a/internal/utils/header_utils.go b/internal/utils/header_utils.go new file mode 100644 index 0000000..1192de5 --- /dev/null +++ b/internal/utils/header_utils.go @@ -0,0 +1,29 @@ +package utils + +import ( + "strings" +) + +func ParseHeaders(headers []string) map[string]string { + headerMap := make(map[string]string) + for _, header := range headers { + split := strings.SplitN(header, "=", 2) + if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { + continue + } + key := SanitizeHeader(strings.TrimSpace(split[0])) + value := SanitizeHeader(strings.TrimSpace(split[1])) + headerMap[key] = value + } + return headerMap +} + +func SanitizeHeader(header string) string { + return strings.Map(func(r rune) rune { + // Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) + if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { + return r + } + return -1 + }, header) +} diff --git a/internal/utils/other_utils.go b/internal/utils/other_utils.go new file mode 100644 index 0000000..1716725 --- /dev/null +++ b/internal/utils/other_utils.go @@ -0,0 +1,95 @@ +package utils + +import ( + "errors" + "net/url" + "strings" + "tinyauth/internal/config" + + "github.com/gin-gonic/gin" + "github.com/traefik/paerser/parser" + + "github.com/rs/zerolog" +) + +// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) +func GetUpperDomain(urlSrc string) (string, error) { + urlParsed, err := url.Parse(urlSrc) + if err != nil { + return "", err + } + + urlSplitted := strings.Split(urlParsed.Hostname(), ".") + urlFinal := strings.Join(urlSplitted[1:], ".") + + return urlFinal, nil +} + +func ParseFileToLine(content string) string { + lines := strings.Split(content, "\n") + users := make([]string, 0) + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + users = append(users, strings.TrimSpace(line)) + } + + return strings.Join(users, ",") +} + +func GetLabels(labels map[string]string) (config.Labels, error) { + var labelsParsed config.Labels + + err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") + if err != nil { + return config.Labels{}, err + } + + return labelsParsed, nil +} + +func Filter[T any](slice []T, test func(T) bool) (res []T) { + for _, value := range slice { + if test(value) { + res = append(res, value) + } + } + return res +} + +func GetContext(c *gin.Context) (config.UserContext, error) { + userContextValue, exists := c.Get("context") + + if !exists { + return config.UserContext{}, errors.New("no user context in request") + } + + userContext, ok := userContextValue.(*config.UserContext) + + if !ok { + return config.UserContext{}, errors.New("invalid user context in request") + } + + return *userContext, nil +} + +func GetLogLevel(level string) zerolog.Level { + switch strings.ToLower(level) { + case "debug": + return zerolog.DebugLevel + case "info": + return zerolog.InfoLevel + case "warn": + return zerolog.WarnLevel + case "error": + return zerolog.ErrorLevel + case "fatal": + return zerolog.FatalLevel + case "panic": + return zerolog.PanicLevel + default: + return zerolog.InfoLevel + } +} diff --git a/internal/utils/sec_utils.go b/internal/utils/sec_utils.go new file mode 100644 index 0000000..4e9e187 --- /dev/null +++ b/internal/utils/sec_utils.go @@ -0,0 +1,124 @@ +package utils + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "net" + "regexp" + "strings" + + "github.com/google/uuid" + "golang.org/x/crypto/hkdf" +) + +func GetSecret(conf string, file string) string { + if conf == "" && file == "" { + return "" + } + + if conf != "" { + return conf + } + + contents, err := ReadFile(file) + if err != nil { + return "" + } + + return ParseSecretFile(contents) +} + +func ParseSecretFile(contents string) string { + lines := strings.Split(contents, "\n") + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + return strings.TrimSpace(line) + } + + return "" +} + +func GetBasicAuth(username string, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func DeriveKey(secret string, info string) (string, error) { + hash := sha256.New + hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice + key := make([]byte, 24) + + _, err := io.ReadFull(hkdf, key) + if err != nil { + return "", err + } + + if bytes.Equal(key, make([]byte, 24)) { + return "", errors.New("derived key is empty") + } + + encodedKey := base64.StdEncoding.EncodeToString(key) + return encodedKey, nil +} + +func FilterIP(filter string, ip string) (bool, error) { + ipAddr := net.ParseIP(ip) + + if strings.Contains(filter, "/") { + _, cidr, err := net.ParseCIDR(filter) + if err != nil { + return false, err + } + return cidr.Contains(ipAddr), nil + } + + ipFilter := net.ParseIP(filter) + if ipFilter == nil { + return false, errors.New("invalid IP address in filter") + } + + if ipFilter.Equal(ipAddr) { + return true, nil + } + + return false, nil +} + +func CheckFilter(filter string, str string) bool { + if len(strings.TrimSpace(filter)) == 0 { + return true + } + + if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { + re, err := regexp.Compile(filter[1 : len(filter)-1]) + if err != nil { + return false + } + + if re.MatchString(str) { + return true + } + } + + filterSplit := strings.Split(filter, ",") + + for _, item := range filterSplit { + if strings.TrimSpace(item) == str { + return true + } + } + + return false +} + +func GenerateIdentifier(str string) string { + uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) + uuidString := uuid.String() + return strings.Split(uuidString, "-")[0] +} diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go new file mode 100644 index 0000000..8a629ad --- /dev/null +++ b/internal/utils/string_utils.go @@ -0,0 +1,30 @@ +package utils + +import ( + "strings" +) + +func Capitalize(str string) string { + if len(str) == 0 { + return "" + } + return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) +} + +func CoalesceToString(value any) string { + switch v := value.(type) { + case []any: + strs := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + strs = append(strs, str) + continue + } + } + return strings.Join(strs, ",") + case string: + return v + default: + return "" + } +} diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go new file mode 100644 index 0000000..bfcec49 --- /dev/null +++ b/internal/utils/user_utils.go @@ -0,0 +1,82 @@ +package utils + +import ( + "errors" + "strings" + "tinyauth/internal/config" +) + +func ParseUsers(users string) ([]config.User, error) { + var usersParsed []config.User + + userList := strings.Split(users, ",") + + if len(userList) == 0 { + return []config.User{}, errors.New("invalid user format") + } + + for _, user := range userList { + parsed, err := ParseUser(user) + if err != nil { + return []config.User{}, err + } + usersParsed = append(usersParsed, parsed) + } + + return usersParsed, nil +} + +func GetUsers(conf string, file string) ([]config.User, error) { + var users string + + if conf == "" && file == "" { + return []config.User{}, nil + } + + if conf != "" { + users += conf + } + + if file != "" { + contents, err := ReadFile(file) + if err == nil { + if users != "" { + users += "," + } + users += ParseFileToLine(contents) + } + } + + return ParseUsers(users) +} + +func ParseUser(user string) (config.User, error) { + if strings.Contains(user, "$$") { + user = strings.ReplaceAll(user, "$$", "$") + } + + userSplit := strings.Split(user, ":") + + if len(userSplit) < 2 || len(userSplit) > 3 { + return config.User{}, errors.New("invalid user format") + } + + for _, userPart := range userSplit { + if strings.TrimSpace(userPart) == "" { + return config.User{}, errors.New("invalid user format") + } + } + + if len(userSplit) == 2 { + return config.User{ + Username: strings.TrimSpace(userSplit[0]), + Password: strings.TrimSpace(userSplit[1]), + }, nil + } + + return config.User{ + Username: strings.TrimSpace(userSplit[0]), + Password: strings.TrimSpace(userSplit[1]), + TotpSecret: strings.TrimSpace(userSplit[2]), + }, nil +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go deleted file mode 100644 index 7181a26..0000000 --- a/internal/utils/utils.go +++ /dev/null @@ -1,382 +0,0 @@ -package utils - -import ( - "bytes" - "crypto/sha256" - "encoding/base64" - "errors" - "io" - "net" - "net/url" - "os" - "regexp" - "strings" - "tinyauth/internal/config" - - "github.com/gin-gonic/gin" - "github.com/traefik/paerser/parser" - "golang.org/x/crypto/hkdf" - - "github.com/google/uuid" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" -) - -// Parses a list of comma separated users in a struct -func ParseUsers(users string) ([]config.User, error) { - log.Debug().Msg("Parsing users") - - var usersParsed []config.User - - userList := strings.Split(users, ",") - - if len(userList) == 0 { - return []config.User{}, errors.New("invalid user format") - } - - for _, user := range userList { - parsed, err := ParseUser(user) - if err != nil { - return []config.User{}, err - } - usersParsed = append(usersParsed, parsed) - } - - log.Debug().Msg("Parsed users") - return usersParsed, nil -} - -// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetUpperDomain(urlSrc string) (string, error) { - urlParsed, err := url.Parse(urlSrc) - if err != nil { - return "", err - } - - urlSplitted := strings.Split(urlParsed.Hostname(), ".") - urlFinal := strings.Join(urlSplitted[1:], ".") - - return urlFinal, nil -} - -// Reads a file and returns the contents -func ReadFile(file string) (string, error) { - _, err := os.Stat(file) - if err != nil { - return "", err - } - - data, err := os.ReadFile(file) - if err != nil { - return "", err - } - - return string(data), nil -} - -// Parses a file into a comma separated list of users -func ParseFileToLine(content string) string { - lines := strings.Split(content, "\n") - users := make([]string, 0) - - for _, line := range lines { - if strings.TrimSpace(line) == "" { - continue - } - users = append(users, strings.TrimSpace(line)) - } - - return strings.Join(users, ",") -} - -// Get the secret from the config or file -func GetSecret(conf string, file string) string { - if conf == "" && file == "" { - return "" - } - - if conf != "" { - return conf - } - - contents, err := ReadFile(file) - if err != nil { - return "" - } - - return ParseSecretFile(contents) -} - -// Get the users from the config or file -func GetUsers(conf string, file string) ([]config.User, error) { - var users string - - if conf == "" && file == "" { - return []config.User{}, nil - } - - if conf != "" { - log.Debug().Msg("Using users from config") - users += conf - } - - if file != "" { - contents, err := ReadFile(file) - if err == nil { - log.Debug().Msg("Using users from file") - if users != "" { - users += "," - } - users += ParseFileToLine(contents) - } - } - - return ParseUsers(users) -} - -// Parse the headers in a map[string]string format -func ParseHeaders(headers []string) map[string]string { - headerMap := make(map[string]string) - - for _, header := range headers { - split := strings.SplitN(header, "=", 2) - if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { - log.Warn().Str("header", header).Msg("Invalid header format, skipping") - continue - } - key := SanitizeHeader(strings.TrimSpace(split[0])) - value := SanitizeHeader(strings.TrimSpace(split[1])) - headerMap[key] = value - } - - return headerMap -} - -// Get labels parses a map of labels into a struct with only the needed labels -func GetLabels(labels map[string]string) (config.Labels, error) { - var labelsParsed config.Labels - - err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") - if err != nil { - log.Error().Err(err).Msg("Error parsing labels") - return config.Labels{}, err - } - - return labelsParsed, nil -} - -// Filter helper function -func Filter[T any](slice []T, test func(T) bool) (res []T) { - for _, value := range slice { - if test(value) { - res = append(res, value) - } - } - return res -} - -// Parse user -func ParseUser(user string) (config.User, error) { - if strings.Contains(user, "$$") { - user = strings.ReplaceAll(user, "$$", "$") - } - - userSplit := strings.Split(user, ":") - - if len(userSplit) < 2 || len(userSplit) > 3 { - return config.User{}, errors.New("invalid user format") - } - - for _, userPart := range userSplit { - if strings.TrimSpace(userPart) == "" { - return config.User{}, errors.New("invalid user format") - } - } - - if len(userSplit) == 2 { - return config.User{ - Username: strings.TrimSpace(userSplit[0]), - Password: strings.TrimSpace(userSplit[1]), - }, nil - } - - return config.User{ - Username: strings.TrimSpace(userSplit[0]), - Password: strings.TrimSpace(userSplit[1]), - TotpSecret: strings.TrimSpace(userSplit[2]), - }, nil -} - -// Parse secret file -func ParseSecretFile(contents string) string { - lines := strings.Split(contents, "\n") - - for _, line := range lines { - if strings.TrimSpace(line) == "" { - continue - } - return strings.TrimSpace(line) - } - - return "" -} - -// Check if a string matches a regex or if it is included in a comma separated list -func CheckFilter(filter string, str string) bool { - if len(strings.TrimSpace(filter)) == 0 { - return true - } - - if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { - re, err := regexp.Compile(filter[1 : len(filter)-1]) - if err != nil { - log.Error().Err(err).Msg("Error compiling regex") - return false - } - - if re.MatchString(str) { - return true - } - } - - filterSplit := strings.Split(filter, ",") - - for _, item := range filterSplit { - if strings.TrimSpace(item) == str { - return true - } - } - - return false -} - -// Capitalize just the first letter of a string -func Capitalize(str string) string { - if len(str) == 0 { - return "" - } - return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) -} - -// Sanitize header removes all control characters from a string -func SanitizeHeader(header string) string { - return strings.Map(func(r rune) rune { - // Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) - if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { - return r - } - return -1 - }, header) -} - -// Generate a static identifier from a string -func GenerateIdentifier(str string) string { - uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) - uuidString := uuid.String() - log.Debug().Str("uuid", uuidString).Msg("Generated UUID") - return strings.Split(uuidString, "-")[0] -} - -// Get a basic auth header from a username and password -func GetBasicAuth(username string, password string) string { - auth := username + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) -} - -// Check if an IP is contained in a CIDR range/matches a single IP -func FilterIP(filter string, ip string) (bool, error) { - ipAddr := net.ParseIP(ip) - - if strings.Contains(filter, "/") { - _, cidr, err := net.ParseCIDR(filter) - if err != nil { - return false, err - } - return cidr.Contains(ipAddr), nil - } - - ipFilter := net.ParseIP(filter) - if ipFilter == nil { - return false, errors.New("invalid IP address in filter") - } - - if ipFilter.Equal(ipAddr) { - return true, nil - } - - return false, nil -} - -func DeriveKey(secret string, info string) (string, error) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice - key := make([]byte, 24) - - _, err := io.ReadFull(hkdf, key) - if err != nil { - return "", err - } - - if bytes.Equal(key, make([]byte, 24)) { - return "", errors.New("derived key is empty") - } - - encodedKey := base64.StdEncoding.EncodeToString(key) - return encodedKey, nil -} - -func CoalesceToString(value any) string { - switch v := value.(type) { - case []any: - log.Debug().Msg("Coalescing []any to string") - strs := make([]string, 0, len(v)) - for _, item := range v { - if str, ok := item.(string); ok { - strs = append(strs, str) - continue - } - log.Warn().Interface("item", item).Msg("Item in []any is not a string, skipping") - } - return strings.Join(strs, ",") - case string: - return v - default: - log.Warn().Interface("value", value).Interface("type", v).Msg("Unsupported type, returning empty string") - return "" - } -} - -func GetContext(c *gin.Context) (config.UserContext, error) { - userContextValue, exists := c.Get("context") - - if !exists { - return config.UserContext{}, errors.New("no user context in request") - } - - userContext, ok := userContextValue.(*config.UserContext) - - if !ok { - return config.UserContext{}, errors.New("invalid user context in request") - } - - return *userContext, nil -} - -func GetLogLevel(level string) zerolog.Level { - switch strings.ToLower(level) { - case "debug": - return zerolog.DebugLevel - case "info": - return zerolog.InfoLevel - case "warn": - return zerolog.WarnLevel - case "error": - return zerolog.ErrorLevel - case "fatal": - return zerolog.FatalLevel - case "panic": - return zerolog.PanicLevel - default: - return zerolog.InfoLevel - } -}