mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 12:45:47 +00:00
Previously IsRedirectSafe rejected redirects to the exact cookie domain when AppURL had multiple subdomain levels, because it stripped the first label twice.
211 lines
4.3 KiB
Go
211 lines
4.3 KiB
Go
package utils
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
"tinyauth/internal/config"
|
|
"tinyauth/internal/utils/decoders"
|
|
|
|
"maps"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/rs/zerolog"
|
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
|
)
|
|
|
|
// Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
|
func GetCookieDomain(u string) (string, error) {
|
|
parsed, err := url.Parse(u)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
host := parsed.Hostname()
|
|
|
|
if netIP := net.ParseIP(host); netIP != nil {
|
|
return "", errors.New("IP addresses not allowed")
|
|
}
|
|
|
|
parts := strings.Split(host, ".")
|
|
|
|
if len(parts) < 3 {
|
|
return "", errors.New("invalid app url, must be at least second level domain")
|
|
}
|
|
|
|
domain := strings.Join(parts[1:], ".")
|
|
|
|
_, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil)
|
|
|
|
if err != nil {
|
|
return "", errors.New("domain in public suffix list, cannot set cookies")
|
|
}
|
|
|
|
return domain, nil
|
|
}
|
|
|
|
func ParseFileToLine(content string) string {
|
|
lines := strings.Split(content, "\n")
|
|
users := make([]string, 0)
|
|
|
|
for _, line := range lines {
|
|
if strings.TrimSpace(line) == "" {
|
|
continue
|
|
}
|
|
users = append(users, strings.TrimSpace(line))
|
|
}
|
|
|
|
return strings.Join(users, ",")
|
|
}
|
|
|
|
func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|
res = make([]T, 0)
|
|
for _, value := range slice {
|
|
if test(value) {
|
|
res = append(res, value)
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
func GetContext(c *gin.Context) (config.UserContext, error) {
|
|
userContextValue, exists := c.Get("context")
|
|
|
|
if !exists {
|
|
return config.UserContext{}, errors.New("no user context in request")
|
|
}
|
|
|
|
userContext, ok := userContextValue.(*config.UserContext)
|
|
|
|
if !ok {
|
|
return config.UserContext{}, errors.New("invalid user context in request")
|
|
}
|
|
|
|
return *userContext, nil
|
|
}
|
|
|
|
func IsRedirectSafe(redirectURL string, domain string) bool {
|
|
if redirectURL == "" {
|
|
return false
|
|
}
|
|
|
|
parsedURL, err := url.Parse(redirectURL)
|
|
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if !parsedURL.IsAbs() {
|
|
return false
|
|
}
|
|
|
|
host := parsedURL.Hostname()
|
|
if host == domain {
|
|
return true
|
|
}
|
|
|
|
cookieDomain, err := GetCookieDomain(redirectURL)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return cookieDomain == domain
|
|
}
|
|
|
|
func GetLogLevel(level string) zerolog.Level {
|
|
switch strings.ToLower(level) {
|
|
case "trace":
|
|
return zerolog.TraceLevel
|
|
case "debug":
|
|
return zerolog.DebugLevel
|
|
case "info":
|
|
return zerolog.InfoLevel
|
|
case "warn":
|
|
return zerolog.WarnLevel
|
|
case "error":
|
|
return zerolog.ErrorLevel
|
|
case "fatal":
|
|
return zerolog.FatalLevel
|
|
case "panic":
|
|
return zerolog.PanicLevel
|
|
default:
|
|
return zerolog.InfoLevel
|
|
}
|
|
}
|
|
|
|
func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) {
|
|
providers := make(map[string]config.OAuthServiceConfig)
|
|
|
|
// Get from environment variables
|
|
envMap := make(map[string]string)
|
|
|
|
for _, e := range env {
|
|
pair := strings.SplitN(e, "=", 2)
|
|
if len(pair) == 2 {
|
|
envMap[pair[0]] = pair[1]
|
|
}
|
|
}
|
|
|
|
envProviders, err := decoders.DecodeEnv(envMap)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
maps.Copy(providers, envProviders.Providers)
|
|
|
|
// Get from flags
|
|
flagsMap := make(map[string]string)
|
|
|
|
for _, arg := range args[1:] {
|
|
if strings.HasPrefix(arg, "--") {
|
|
pair := strings.SplitN(arg[2:], "=", 2)
|
|
if len(pair) == 2 {
|
|
flagsMap[pair[0]] = pair[1]
|
|
}
|
|
}
|
|
}
|
|
|
|
flagProviders, err := decoders.DecodeFlags(flagsMap)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
maps.Copy(providers, flagProviders.Providers)
|
|
|
|
// For every provider get correct secret from file if set
|
|
for name, provider := range providers {
|
|
secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
|
provider.ClientSecret = secret
|
|
provider.ClientSecretFile = ""
|
|
providers[name] = provider
|
|
}
|
|
|
|
// If we have google/github providers and no redirect URL then set a default
|
|
for id := range config.OverrideProviders {
|
|
if provider, exists := providers[id]; exists {
|
|
if provider.RedirectURL == "" {
|
|
provider.RedirectURL = appUrl + "/api/oauth/callback/" + id
|
|
providers[id] = provider
|
|
}
|
|
}
|
|
}
|
|
|
|
// Set names
|
|
for id, provider := range providers {
|
|
if provider.Name == "" {
|
|
if name, ok := config.OverrideProviders[id]; ok {
|
|
provider.Name = name
|
|
} else {
|
|
provider.Name = Capitalize(id)
|
|
}
|
|
}
|
|
providers[id] = provider
|
|
}
|
|
|
|
// Return combined providers
|
|
return providers, nil
|
|
}
|