mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 04:35:40 +00:00
refactor: rework file structure (#325)
* wip: add middlewares * refactor: use context fom middleware in handlers * refactor: use controller approach in handlers * refactor: move oauth providers into services (non-working) * feat: create oauth broker service * refactor: use a boostrap service to bootstrap the app * refactor: split utils into smaller files * refactor: use more clear name for frontend assets * feat: allow customizability of resources dir * fix: fix typo in ui middleware * fix: validate resource file paths in ui middleware * refactor: move resource handling to a controller * feat: add some logging * fix: configure middlewares before groups * fix: use correct api path in login mutation * fix: coderabbit suggestions * fix: further coderabbit suggestions
This commit is contained in:
123
internal/utils/app_utils.go
Normal file
123
internal/utils/app_utils.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||
func GetUpperDomain(appUrl string) (string, error) {
|
||||
appUrlParsed, err := url.Parse(appUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
host := appUrlParsed.Hostname()
|
||||
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("IP addresses are not allowed")
|
||||
}
|
||||
|
||||
urlParts := strings.Split(host, ".")
|
||||
|
||||
if len(urlParts) < 2 {
|
||||
return "", errors.New("invalid domain, must be at least second level domain")
|
||||
}
|
||||
|
||||
return strings.Join(urlParts[1:], "."), nil
|
||||
}
|
||||
|
||||
func ParseFileToLine(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
users := make([]string, 0)
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
users = append(users, strings.TrimSpace(line))
|
||||
}
|
||||
|
||||
return strings.Join(users, ",")
|
||||
}
|
||||
|
||||
func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
||||
for _, value := range slice {
|
||||
if test(value) {
|
||||
res = append(res, value)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func GetContext(c *gin.Context) (config.UserContext, error) {
|
||||
userContextValue, exists := c.Get("context")
|
||||
|
||||
if !exists {
|
||||
return config.UserContext{}, errors.New("no user context in request")
|
||||
}
|
||||
|
||||
userContext, ok := userContextValue.(*config.UserContext)
|
||||
|
||||
if !ok {
|
||||
return config.UserContext{}, errors.New("invalid user context in request")
|
||||
}
|
||||
|
||||
return *userContext, nil
|
||||
}
|
||||
|
||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||
if redirectURL == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !parsedURL.IsAbs() {
|
||||
return false
|
||||
}
|
||||
|
||||
upper, err := GetUpperDomain(redirectURL)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if upper != domain {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func GetLogLevel(level string) zerolog.Level {
|
||||
switch strings.ToLower(level) {
|
||||
case "trace":
|
||||
return zerolog.TraceLevel
|
||||
case "debug":
|
||||
return zerolog.DebugLevel
|
||||
case "info":
|
||||
return zerolog.InfoLevel
|
||||
case "warn":
|
||||
return zerolog.WarnLevel
|
||||
case "error":
|
||||
return zerolog.ErrorLevel
|
||||
case "fatal":
|
||||
return zerolog.FatalLevel
|
||||
case "panic":
|
||||
return zerolog.PanicLevel
|
||||
default:
|
||||
return zerolog.InfoLevel
|
||||
}
|
||||
}
|
||||
17
internal/utils/fs_utils.go
Normal file
17
internal/utils/fs_utils.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package utils
|
||||
|
||||
import "os"
|
||||
|
||||
func ReadFile(file string) (string, error) {
|
||||
_, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
48
internal/utils/label_utils.go
Normal file
48
internal/utils/label_utils.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"github.com/traefik/paerser/parser"
|
||||
)
|
||||
|
||||
func GetLabels(labels map[string]string) (config.Labels, error) {
|
||||
var labelsParsed config.Labels
|
||||
|
||||
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
|
||||
if err != nil {
|
||||
return config.Labels{}, err
|
||||
}
|
||||
|
||||
return labelsParsed, nil
|
||||
}
|
||||
|
||||
func ParseHeaders(headers []string) map[string]string {
|
||||
headerMap := make(map[string]string)
|
||||
for _, header := range headers {
|
||||
split := strings.SplitN(header, "=", 2)
|
||||
if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" {
|
||||
continue
|
||||
}
|
||||
key := SanitizeHeader(strings.TrimSpace(split[0]))
|
||||
if strings.ContainsAny(key, " \t") {
|
||||
continue
|
||||
}
|
||||
key = http.CanonicalHeaderKey(key)
|
||||
value := SanitizeHeader(strings.TrimSpace(split[1]))
|
||||
headerMap[key] = value
|
||||
}
|
||||
return headerMap
|
||||
}
|
||||
|
||||
func SanitizeHeader(header string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab)
|
||||
if r == ' ' || r == '\t' || (r >= 32 && r <= 126) {
|
||||
return r
|
||||
}
|
||||
return -1
|
||||
}, header)
|
||||
}
|
||||
124
internal/utils/security_utils.go
Normal file
124
internal/utils/security_utils.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
func GetSecret(conf string, file string) string {
|
||||
if conf == "" && file == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if conf != "" {
|
||||
return conf
|
||||
}
|
||||
|
||||
contents, err := ReadFile(file)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return ParseSecretFile(contents)
|
||||
}
|
||||
|
||||
func ParseSecretFile(contents string) string {
|
||||
lines := strings.Split(contents, "\n")
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
return strings.TrimSpace(line)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func GetBasicAuth(username string, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
func DeriveKey(secret string, info string) (string, error) {
|
||||
hash := sha256.New
|
||||
hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice
|
||||
key := make([]byte, 24)
|
||||
|
||||
_, err := io.ReadFull(hkdf, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if bytes.Equal(key, make([]byte, 24)) {
|
||||
return "", errors.New("derived key is empty")
|
||||
}
|
||||
|
||||
encodedKey := base64.StdEncoding.EncodeToString(key)
|
||||
return encodedKey, nil
|
||||
}
|
||||
|
||||
func FilterIP(filter string, ip string) (bool, error) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
||||
if strings.Contains(filter, "/") {
|
||||
_, cidr, err := net.ParseCIDR(filter)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cidr.Contains(ipAddr), nil
|
||||
}
|
||||
|
||||
ipFilter := net.ParseIP(filter)
|
||||
if ipFilter == nil {
|
||||
return false, errors.New("invalid IP address in filter")
|
||||
}
|
||||
|
||||
if ipFilter.Equal(ipAddr) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func CheckFilter(filter string, str string) bool {
|
||||
if len(strings.TrimSpace(filter)) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||
re, err := regexp.Compile(filter[1 : len(filter)-1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if re.MatchString(strings.TrimSpace(str)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
filterSplit := strings.Split(filter, ",")
|
||||
|
||||
for _, item := range filterSplit {
|
||||
if strings.TrimSpace(item) == strings.TrimSpace(str) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func GenerateIdentifier(str string) string {
|
||||
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
|
||||
uuidString := uuid.String()
|
||||
return strings.Split(uuidString, "-")[0]
|
||||
}
|
||||
30
internal/utils/string_utils.go
Normal file
30
internal/utils/string_utils.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Capitalize(str string) string {
|
||||
if len(str) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:])
|
||||
}
|
||||
|
||||
func CoalesceToString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case []any:
|
||||
strs := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
strs = append(strs, str)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return strings.Join(strs, ",")
|
||||
case string:
|
||||
return v
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
92
internal/utils/user_utils.go
Normal file
92
internal/utils/user_utils.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
)
|
||||
|
||||
func ParseUsers(users string) ([]config.User, error) {
|
||||
var usersParsed []config.User
|
||||
|
||||
users = strings.TrimSpace(users)
|
||||
|
||||
if users == "" {
|
||||
return []config.User{}, nil
|
||||
}
|
||||
|
||||
userList := strings.Split(users, ",")
|
||||
|
||||
if len(userList) == 0 {
|
||||
return []config.User{}, errors.New("invalid user format")
|
||||
}
|
||||
|
||||
for _, user := range userList {
|
||||
if strings.TrimSpace(user) == "" {
|
||||
continue
|
||||
}
|
||||
parsed, err := ParseUser(strings.TrimSpace(user))
|
||||
if err != nil {
|
||||
return []config.User{}, err
|
||||
}
|
||||
usersParsed = append(usersParsed, parsed)
|
||||
}
|
||||
|
||||
return usersParsed, nil
|
||||
}
|
||||
|
||||
func GetUsers(conf string, file string) ([]config.User, error) {
|
||||
var users string
|
||||
|
||||
if conf == "" && file == "" {
|
||||
return []config.User{}, nil
|
||||
}
|
||||
|
||||
if conf != "" {
|
||||
users += conf
|
||||
}
|
||||
|
||||
if file != "" {
|
||||
contents, err := ReadFile(file)
|
||||
if err != nil {
|
||||
return []config.User{}, err
|
||||
}
|
||||
if users != "" {
|
||||
users += ","
|
||||
}
|
||||
users += ParseFileToLine(contents)
|
||||
}
|
||||
|
||||
return ParseUsers(users)
|
||||
}
|
||||
|
||||
func ParseUser(user string) (config.User, error) {
|
||||
if strings.Contains(user, "$$") {
|
||||
user = strings.ReplaceAll(user, "$$", "$")
|
||||
}
|
||||
|
||||
userSplit := strings.Split(user, ":")
|
||||
|
||||
if len(userSplit) < 2 || len(userSplit) > 3 {
|
||||
return config.User{}, errors.New("invalid user format")
|
||||
}
|
||||
|
||||
for _, userPart := range userSplit {
|
||||
if strings.TrimSpace(userPart) == "" {
|
||||
return config.User{}, errors.New("invalid user format")
|
||||
}
|
||||
}
|
||||
|
||||
if len(userSplit) == 2 {
|
||||
return config.User{
|
||||
Username: strings.TrimSpace(userSplit[0]),
|
||||
Password: strings.TrimSpace(userSplit[1]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return config.User{
|
||||
Username: strings.TrimSpace(userSplit[0]),
|
||||
Password: strings.TrimSpace(userSplit[1]),
|
||||
TotpSecret: strings.TrimSpace(userSplit[2]),
|
||||
}, nil
|
||||
}
|
||||
@@ -1,350 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"tinyauth/internal/types"
|
||||
|
||||
"github.com/traefik/paerser/parser"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Parses a list of comma separated users in a struct
|
||||
func ParseUsers(users string) (types.Users, error) {
|
||||
log.Debug().Msg("Parsing users")
|
||||
|
||||
var usersParsed types.Users
|
||||
|
||||
userList := strings.Split(users, ",")
|
||||
|
||||
if len(userList) == 0 {
|
||||
return types.Users{}, errors.New("invalid user format")
|
||||
}
|
||||
|
||||
for _, user := range userList {
|
||||
parsed, err := ParseUser(user)
|
||||
if err != nil {
|
||||
return types.Users{}, err
|
||||
}
|
||||
usersParsed = append(usersParsed, parsed)
|
||||
}
|
||||
|
||||
log.Debug().Msg("Parsed users")
|
||||
return usersParsed, nil
|
||||
}
|
||||
|
||||
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||
func GetUpperDomain(urlSrc string) (string, error) {
|
||||
urlParsed, err := url.Parse(urlSrc)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
urlSplitted := strings.Split(urlParsed.Hostname(), ".")
|
||||
urlFinal := strings.Join(urlSplitted[1:], ".")
|
||||
|
||||
return urlFinal, nil
|
||||
}
|
||||
|
||||
// Reads a file and returns the contents
|
||||
func ReadFile(file string) (string, error) {
|
||||
_, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Parses a file into a comma separated list of users
|
||||
func ParseFileToLine(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
users := make([]string, 0)
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
users = append(users, strings.TrimSpace(line))
|
||||
}
|
||||
|
||||
return strings.Join(users, ",")
|
||||
}
|
||||
|
||||
// Get the secret from the config or file
|
||||
func GetSecret(conf string, file string) string {
|
||||
if conf == "" && file == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if conf != "" {
|
||||
return conf
|
||||
}
|
||||
|
||||
contents, err := ReadFile(file)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return ParseSecretFile(contents)
|
||||
}
|
||||
|
||||
// Get the users from the config or file
|
||||
func GetUsers(conf string, file string) (types.Users, error) {
|
||||
var users string
|
||||
|
||||
if conf == "" && file == "" {
|
||||
return types.Users{}, nil
|
||||
}
|
||||
|
||||
if conf != "" {
|
||||
log.Debug().Msg("Using users from config")
|
||||
users += conf
|
||||
}
|
||||
|
||||
if file != "" {
|
||||
contents, err := ReadFile(file)
|
||||
if err == nil {
|
||||
log.Debug().Msg("Using users from file")
|
||||
if users != "" {
|
||||
users += ","
|
||||
}
|
||||
users += ParseFileToLine(contents)
|
||||
}
|
||||
}
|
||||
|
||||
return ParseUsers(users)
|
||||
}
|
||||
|
||||
// Parse the headers in a map[string]string format
|
||||
func ParseHeaders(headers []string) map[string]string {
|
||||
headerMap := make(map[string]string)
|
||||
|
||||
for _, header := range headers {
|
||||
split := strings.SplitN(header, "=", 2)
|
||||
if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" {
|
||||
log.Warn().Str("header", header).Msg("Invalid header format, skipping")
|
||||
continue
|
||||
}
|
||||
key := SanitizeHeader(strings.TrimSpace(split[0]))
|
||||
value := SanitizeHeader(strings.TrimSpace(split[1]))
|
||||
headerMap[key] = value
|
||||
}
|
||||
|
||||
return headerMap
|
||||
}
|
||||
|
||||
// Get labels parses a map of labels into a struct with only the needed labels
|
||||
func GetLabels(labels map[string]string) (types.Labels, error) {
|
||||
var labelsParsed types.Labels
|
||||
|
||||
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error parsing labels")
|
||||
return types.Labels{}, err
|
||||
}
|
||||
|
||||
return labelsParsed, nil
|
||||
}
|
||||
|
||||
// Check if any of the OAuth providers are configured based on the client id and secret
|
||||
func OAuthConfigured(config types.Config) bool {
|
||||
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "")
|
||||
}
|
||||
|
||||
// Filter helper function
|
||||
func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
||||
for _, value := range slice {
|
||||
if test(value) {
|
||||
res = append(res, value)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Parse user
|
||||
func ParseUser(user string) (types.User, error) {
|
||||
if strings.Contains(user, "$$") {
|
||||
user = strings.ReplaceAll(user, "$$", "$")
|
||||
}
|
||||
|
||||
userSplit := strings.Split(user, ":")
|
||||
|
||||
if len(userSplit) < 2 || len(userSplit) > 3 {
|
||||
return types.User{}, errors.New("invalid user format")
|
||||
}
|
||||
|
||||
for _, userPart := range userSplit {
|
||||
if strings.TrimSpace(userPart) == "" {
|
||||
return types.User{}, errors.New("invalid user format")
|
||||
}
|
||||
}
|
||||
|
||||
if len(userSplit) == 2 {
|
||||
return types.User{
|
||||
Username: strings.TrimSpace(userSplit[0]),
|
||||
Password: strings.TrimSpace(userSplit[1]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return types.User{
|
||||
Username: strings.TrimSpace(userSplit[0]),
|
||||
Password: strings.TrimSpace(userSplit[1]),
|
||||
TotpSecret: strings.TrimSpace(userSplit[2]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse secret file
|
||||
func ParseSecretFile(contents string) string {
|
||||
lines := strings.Split(contents, "\n")
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
return strings.TrimSpace(line)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if a string matches a regex or if it is included in a comma separated list
|
||||
func CheckFilter(filter string, str string) bool {
|
||||
if len(strings.TrimSpace(filter)) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||
re, err := regexp.Compile(filter[1 : len(filter)-1])
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error compiling regex")
|
||||
return false
|
||||
}
|
||||
|
||||
if re.MatchString(str) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
filterSplit := strings.Split(filter, ",")
|
||||
|
||||
for _, item := range filterSplit {
|
||||
if strings.TrimSpace(item) == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Capitalize just the first letter of a string
|
||||
func Capitalize(str string) string {
|
||||
if len(str) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:])
|
||||
}
|
||||
|
||||
// Sanitize header removes all control characters from a string
|
||||
func SanitizeHeader(header string) string {
|
||||
return strings.Map(func(r rune) rune {
|
||||
// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab)
|
||||
if r == ' ' || r == '\t' || (r >= 32 && r <= 126) {
|
||||
return r
|
||||
}
|
||||
return -1
|
||||
}, header)
|
||||
}
|
||||
|
||||
// Generate a static identifier from a string
|
||||
func GenerateIdentifier(str string) string {
|
||||
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
|
||||
uuidString := uuid.String()
|
||||
log.Debug().Str("uuid", uuidString).Msg("Generated UUID")
|
||||
return strings.Split(uuidString, "-")[0]
|
||||
}
|
||||
|
||||
// Get a basic auth header from a username and password
|
||||
func GetBasicAuth(username string, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
// Check if an IP is contained in a CIDR range/matches a single IP
|
||||
func FilterIP(filter string, ip string) (bool, error) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
|
||||
if strings.Contains(filter, "/") {
|
||||
_, cidr, err := net.ParseCIDR(filter)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cidr.Contains(ipAddr), nil
|
||||
}
|
||||
|
||||
ipFilter := net.ParseIP(filter)
|
||||
if ipFilter == nil {
|
||||
return false, errors.New("invalid IP address in filter")
|
||||
}
|
||||
|
||||
if ipFilter.Equal(ipAddr) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func DeriveKey(secret string, info string) (string, error) {
|
||||
hash := sha256.New
|
||||
hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice
|
||||
key := make([]byte, 24)
|
||||
|
||||
_, err := io.ReadFull(hkdf, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if bytes.Equal(key, make([]byte, 24)) {
|
||||
return "", errors.New("derived key is empty")
|
||||
}
|
||||
|
||||
encodedKey := base64.StdEncoding.EncodeToString(key)
|
||||
return encodedKey, nil
|
||||
}
|
||||
|
||||
func CoalesceToString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case []any:
|
||||
log.Debug().Msg("Coalescing []any to string")
|
||||
strs := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
strs = append(strs, str)
|
||||
continue
|
||||
}
|
||||
log.Warn().Interface("item", item).Msg("Item in []any is not a string, skipping")
|
||||
}
|
||||
return strings.Join(strs, ",")
|
||||
case string:
|
||||
return v
|
||||
default:
|
||||
log.Warn().Interface("value", value).Interface("type", v).Msg("Unsupported type, returning empty string")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -1,548 +0,0 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"tinyauth/internal/types"
|
||||
"tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
func TestParseUsers(t *testing.T) {
|
||||
t.Log("Testing parse users with a valid string")
|
||||
|
||||
users := "user1:pass1,user2:pass2"
|
||||
expected := types.Users{
|
||||
{
|
||||
Username: "user1",
|
||||
Password: "pass1",
|
||||
},
|
||||
{
|
||||
Username: "user2",
|
||||
Password: "pass2",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := utils.ParseUsers(users)
|
||||
if err != nil {
|
||||
t.Fatalf("Error parsing users: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUpperDomain(t *testing.T) {
|
||||
t.Log("Testing get upper domain with a valid url")
|
||||
|
||||
url := "https://sub1.sub2.domain.com:8080"
|
||||
expected := "sub2.domain.com"
|
||||
|
||||
result, err := utils.GetUpperDomain(url)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting root url: %v", err)
|
||||
}
|
||||
|
||||
if expected != result {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFile(t *testing.T) {
|
||||
t.Log("Creating a test file")
|
||||
|
||||
err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating test file: %v", err)
|
||||
}
|
||||
|
||||
t.Log("Testing read file with a valid file")
|
||||
|
||||
data, err := utils.ReadFile("/tmp/test.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading file: %v", err)
|
||||
}
|
||||
|
||||
if data != "test" {
|
||||
t.Fatalf("Expected test, got %v", data)
|
||||
}
|
||||
|
||||
t.Log("Cleaning up test file")
|
||||
|
||||
err = os.Remove("/tmp/test.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Error cleaning up test file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileToLine(t *testing.T) {
|
||||
t.Log("Testing parse file to line with a valid string")
|
||||
|
||||
content := "\nuser1:pass1\nuser2:pass2\n"
|
||||
expected := "user1:pass1,user2:pass2"
|
||||
|
||||
result := utils.ParseFileToLine(content)
|
||||
|
||||
if expected != result {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSecret(t *testing.T) {
|
||||
t.Log("Testing get secret with an empty config and file")
|
||||
|
||||
conf := ""
|
||||
file := "/tmp/test.txt"
|
||||
expected := "test"
|
||||
|
||||
err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating test file: %v", err)
|
||||
}
|
||||
|
||||
result := utils.GetSecret(conf, file)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing get secret with an empty file and a valid config")
|
||||
|
||||
result = utils.GetSecret(expected, "")
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing get secret with both a valid config and file")
|
||||
|
||||
result = utils.GetSecret(expected, file)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Cleaning up test file")
|
||||
|
||||
err = os.Remove(file)
|
||||
if err != nil {
|
||||
t.Fatalf("Error cleaning up test file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
t.Log("Testing get users with a config and no file")
|
||||
|
||||
conf := "user1:pass1,user2:pass2"
|
||||
file := ""
|
||||
expected := types.Users{
|
||||
{
|
||||
Username: "user1",
|
||||
Password: "pass1",
|
||||
},
|
||||
{
|
||||
Username: "user2",
|
||||
Password: "pass2",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := utils.GetUsers(conf, file)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting users: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing get users with a file and no config")
|
||||
|
||||
conf = ""
|
||||
file = "/tmp/test.txt"
|
||||
expected = types.Users{
|
||||
{
|
||||
Username: "user1",
|
||||
Password: "pass1",
|
||||
},
|
||||
{
|
||||
Username: "user2",
|
||||
Password: "pass2",
|
||||
},
|
||||
}
|
||||
|
||||
err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating test file: %v", err)
|
||||
}
|
||||
|
||||
result, err = utils.GetUsers(conf, file)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting users: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing get users with both a config and file")
|
||||
|
||||
conf = "user3:pass3"
|
||||
expected = types.Users{
|
||||
{
|
||||
Username: "user3",
|
||||
Password: "pass3",
|
||||
},
|
||||
{
|
||||
Username: "user1",
|
||||
Password: "pass1",
|
||||
},
|
||||
{
|
||||
Username: "user2",
|
||||
Password: "pass2",
|
||||
},
|
||||
}
|
||||
|
||||
result, err = utils.GetUsers(conf, file)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting users: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Cleaning up test file")
|
||||
|
||||
err = os.Remove(file)
|
||||
if err != nil {
|
||||
t.Fatalf("Error cleaning up test file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLabels(t *testing.T) {
|
||||
t.Log("Testing get labels with a valid map")
|
||||
|
||||
labels := map[string]string{
|
||||
"tinyauth.users": "user1,user2",
|
||||
"tinyauth.oauth.whitelist": "/regex/",
|
||||
"tinyauth.allowed": "random",
|
||||
"tinyauth.headers": "X-Header=value",
|
||||
"tinyauth.oauth.groups": "group1,group2",
|
||||
}
|
||||
|
||||
expected := types.Labels{
|
||||
Users: "user1,user2",
|
||||
Allowed: "random",
|
||||
Headers: []string{"X-Header=value"},
|
||||
OAuth: types.OAuthLabels{
|
||||
Whitelist: "/regex/",
|
||||
Groups: "group1,group2",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := utils.GetLabels(labels)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting labels: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUser(t *testing.T) {
|
||||
t.Log("Testing parse user with a valid user")
|
||||
|
||||
user := "user:pass:secret"
|
||||
expected := types.User{
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
TotpSecret: "secret",
|
||||
}
|
||||
|
||||
result, err := utils.ParseUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Error parsing user: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing parse user with an escaped user")
|
||||
|
||||
user = "user:p$$ass$$:secret"
|
||||
expected = types.User{
|
||||
Username: "user",
|
||||
Password: "p$ass$",
|
||||
TotpSecret: "secret",
|
||||
}
|
||||
|
||||
result, err = utils.ParseUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Error parsing user: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing parse user with an invalid user")
|
||||
|
||||
user = "user::pass"
|
||||
|
||||
_, err = utils.ParseUser(user)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error parsing user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFilter(t *testing.T) {
|
||||
t.Log("Testing check filter with a comma separated list")
|
||||
|
||||
filter := "user1,user2,user3"
|
||||
str := "user1"
|
||||
expected := true
|
||||
|
||||
result := utils.CheckFilter(filter, str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing check filter with a regex filter")
|
||||
|
||||
filter = "/^user[0-9]+$/"
|
||||
str = "user1"
|
||||
expected = true
|
||||
|
||||
result = utils.CheckFilter(filter, str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing check filter with an empty filter")
|
||||
|
||||
filter = ""
|
||||
str = "user1"
|
||||
expected = true
|
||||
|
||||
result = utils.CheckFilter(filter, str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing check filter with an invalid regex filter")
|
||||
|
||||
filter = "/^user[0-9+$/"
|
||||
str = "user1"
|
||||
expected = false
|
||||
|
||||
result = utils.CheckFilter(filter, str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing check filter with a non matching list")
|
||||
|
||||
filter = "user1,user2,user3"
|
||||
str = "user4"
|
||||
expected = false
|
||||
|
||||
result = utils.CheckFilter(filter, str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeHeader(t *testing.T) {
|
||||
t.Log("Testing sanitize header with a valid string")
|
||||
|
||||
str := "X-Header=value"
|
||||
expected := "X-Header=value"
|
||||
|
||||
result := utils.SanitizeHeader(str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing sanitize header with an invalid string")
|
||||
|
||||
str = "X-Header=val\nue"
|
||||
expected = "X-Header=value"
|
||||
|
||||
result = utils.SanitizeHeader(str)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHeaders(t *testing.T) {
|
||||
t.Log("Testing parse headers with a valid string")
|
||||
|
||||
headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"}
|
||||
expected := map[string]string{
|
||||
"X-Header1": "value1",
|
||||
"X-Header2": "value2",
|
||||
}
|
||||
|
||||
result := utils.ParseHeaders(headers)
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing parse headers with an invalid string")
|
||||
|
||||
headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"}
|
||||
expected = map[string]string{"X-Header3": "value3"}
|
||||
|
||||
result = utils.ParseHeaders(headers)
|
||||
|
||||
if !reflect.DeepEqual(expected, result) {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSecretFile(t *testing.T) {
|
||||
t.Log("Testing parse secret file with a valid file")
|
||||
|
||||
content := "\n\n \n\n\n secret \n\n \n "
|
||||
expected := "secret"
|
||||
|
||||
result := utils.ParseSecretFile(content)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterIP(t *testing.T) {
|
||||
t.Log("Testing filter IP with an IP and a valid CIDR")
|
||||
|
||||
ip := "10.10.10.10"
|
||||
filter := "10.10.10.0/24"
|
||||
expected := true
|
||||
|
||||
result, err := utils.FilterIP(filter, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Error filtering IP: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing filter IP with an IP and a valid IP")
|
||||
|
||||
filter = "10.10.10.10"
|
||||
expected = true
|
||||
|
||||
result, err = utils.FilterIP(filter, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Error filtering IP: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing filter IP with an IP and an non matching CIDR")
|
||||
|
||||
filter = "10.10.15.0/24"
|
||||
expected = false
|
||||
|
||||
result, err = utils.FilterIP(filter, ip)
|
||||
if err != nil {
|
||||
t.Fatalf("Error filtering IP: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing filter IP with a non matching IP and a valid CIDR")
|
||||
|
||||
filter = "10.10.10.11"
|
||||
expected = false
|
||||
|
||||
result, err = utils.FilterIP(filter, ip)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error filtering IP: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing filter IP with an IP and an invalid CIDR")
|
||||
|
||||
filter = "10.../83"
|
||||
|
||||
_, err = utils.FilterIP(filter, ip)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error filtering IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKey(t *testing.T) {
|
||||
t.Log("Testing the derive key function")
|
||||
|
||||
master := "master"
|
||||
info := "info"
|
||||
expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl"
|
||||
|
||||
result, err := utils.DeriveKey(master, info)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error deriving key: %v", err)
|
||||
}
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCoalesceToString(t *testing.T) {
|
||||
t.Log("Testing coalesce to string with a string")
|
||||
|
||||
value := any("test")
|
||||
expected := "test"
|
||||
|
||||
result := utils.CoalesceToString(value)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing coalesce to string with a slice of strings")
|
||||
|
||||
value = []any{any("test1"), any("test2"), any(123)}
|
||||
expected = "test1,test2"
|
||||
|
||||
result = utils.CoalesceToString(value)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
t.Log("Testing coalesce to string with an unsupported type")
|
||||
|
||||
value = 12345
|
||||
expected = ""
|
||||
|
||||
result = utils.CoalesceToString(value)
|
||||
|
||||
if result != expected {
|
||||
t.Fatalf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user