chore: remove meaningless comments

This commit is contained in:
Stavros
2025-07-12 13:17:06 +03:00
parent e742603c15
commit 8ebed0ac9a
24 changed files with 81 additions and 876 deletions

View File

@@ -3,9 +3,7 @@ package cmd
import ( import (
"errors" "errors"
"fmt" "fmt"
"os"
"strings" "strings"
"time"
totpCmd "tinyauth/cmd/totp" totpCmd "tinyauth/cmd/totp"
userCmd "tinyauth/cmd/user" userCmd "tinyauth/cmd/user"
"tinyauth/internal/auth" "tinyauth/internal/auth"
@@ -31,47 +29,37 @@ var rootCmd = &cobra.Command{
Short: "The simplest way to protect your apps with a login screen.", Short: "The simplest way to protect your apps with a login screen.",
Long: `Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.`, Long: `Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.FatalLevel)
// Get config
var config types.Config var config types.Config
err := viper.Unmarshal(&config) err := viper.Unmarshal(&config)
HandleError(err, "Failed to parse config") HandleError(err, "Failed to parse config")
// Secrets // Check if secrets have a file associated with them
config.Secret = utils.GetSecret(config.Secret, config.SecretFile) config.Secret = utils.GetSecret(config.Secret, config.SecretFile)
config.GithubClientSecret = utils.GetSecret(config.GithubClientSecret, config.GithubClientSecretFile) config.GithubClientSecret = utils.GetSecret(config.GithubClientSecret, config.GithubClientSecretFile)
config.GoogleClientSecret = utils.GetSecret(config.GoogleClientSecret, config.GoogleClientSecretFile) config.GoogleClientSecret = utils.GetSecret(config.GoogleClientSecret, config.GoogleClientSecretFile)
config.GenericClientSecret = utils.GetSecret(config.GenericClientSecret, config.GenericClientSecretFile) config.GenericClientSecret = utils.GetSecret(config.GenericClientSecret, config.GenericClientSecretFile)
// Validate config
validator := validator.New() validator := validator.New()
err = validator.Struct(config) err = validator.Struct(config)
HandleError(err, "Failed to validate config") HandleError(err, "Failed to validate config")
// Logger
log.Logger = log.Level(zerolog.Level(config.LogLevel)) log.Logger = log.Level(zerolog.Level(config.LogLevel))
log.Info().Str("version", strings.TrimSpace(constants.Version)).Msg("Starting tinyauth") log.Info().Str("version", strings.TrimSpace(constants.Version)).Msg("Starting tinyauth")
// Users
log.Info().Msg("Parsing users") log.Info().Msg("Parsing users")
users, err := utils.GetUsers(config.Users, config.UsersFile) users, err := utils.GetUsers(config.Users, config.UsersFile)
HandleError(err, "Failed to parse users") HandleError(err, "Failed to parse users")
// Get domain
log.Debug().Msg("Getting domain") log.Debug().Msg("Getting domain")
domain, err := utils.GetUpperDomain(config.AppURL) domain, err := utils.GetUpperDomain(config.AppURL)
HandleError(err, "Failed to get upper domain") HandleError(err, "Failed to get upper domain")
log.Info().Str("domain", domain).Msg("Using domain for cookie store") log.Info().Str("domain", domain).Msg("Using domain for cookie store")
// Generate cookie name
cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0])
sessionCookieName := fmt.Sprintf("%s-%s", constants.SessionCookieName, cookieId) sessionCookieName := fmt.Sprintf("%s-%s", constants.SessionCookieName, cookieId)
csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId) csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId)
redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId) redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId)
// Generate HMAC and encryption secrets
log.Debug().Msg("Deriving HMAC and encryption secrets") log.Debug().Msg("Deriving HMAC and encryption secrets")
hmacSecret, err := utils.DeriveKey(config.Secret, "hmac") hmacSecret, err := utils.DeriveKey(config.Secret, "hmac")
@@ -80,7 +68,7 @@ var rootCmd = &cobra.Command{
encryptionSecret, err := utils.DeriveKey(config.Secret, "encryption") encryptionSecret, err := utils.DeriveKey(config.Secret, "encryption")
HandleError(err, "Failed to derive encryption secret") HandleError(err, "Failed to derive encryption secret")
// Create OAuth config // Split the config into service-specific sub-configs
oauthConfig := types.OAuthConfig{ oauthConfig := types.OAuthConfig{
GithubClientId: config.GithubClientId, GithubClientId: config.GithubClientId,
GithubClientSecret: config.GithubClientSecret, GithubClientSecret: config.GithubClientSecret,
@@ -96,7 +84,6 @@ var rootCmd = &cobra.Command{
AppURL: config.AppURL, AppURL: config.AppURL,
} }
// Create handlers config
handlersConfig := types.HandlersConfig{ handlersConfig := types.HandlersConfig{
AppURL: config.AppURL, AppURL: config.AppURL,
DisableContinue: config.DisableContinue, DisableContinue: config.DisableContinue,
@@ -111,13 +98,11 @@ var rootCmd = &cobra.Command{
RedirectCookieName: redirectCookieName, RedirectCookieName: redirectCookieName,
} }
// Create server config
serverConfig := types.ServerConfig{ serverConfig := types.ServerConfig{
Port: config.Port, Port: config.Port,
Address: config.Address, Address: config.Address,
} }
// Create auth config
authConfig := types.AuthConfig{ authConfig := types.AuthConfig{
Users: users, Users: users,
OauthWhitelist: config.OAuthWhitelist, OauthWhitelist: config.OAuthWhitelist,
@@ -131,21 +116,14 @@ var rootCmd = &cobra.Command{
EncryptionSecret: encryptionSecret, EncryptionSecret: encryptionSecret,
} }
// Create hooks config
hooksConfig := types.HooksConfig{ hooksConfig := types.HooksConfig{
Domain: domain, Domain: domain,
} }
// Create docker service
docker, err := docker.NewDocker()
HandleError(err, "Failed to initialize docker")
// Create LDAP service if configured
var ldapService *ldap.LDAP var ldapService *ldap.LDAP
if config.LdapAddress != "" { if config.LdapAddress != "" {
log.Info().Msg("Using LDAP for authentication") log.Info().Msg("Using LDAP for authentication")
ldapConfig := types.LdapConfig{ ldapConfig := types.LdapConfig{
Address: config.LdapAddress, Address: config.LdapAddress,
BindDN: config.LdapBindDN, BindDN: config.LdapBindDN,
@@ -154,36 +132,28 @@ var rootCmd = &cobra.Command{
Insecure: config.LdapInsecure, Insecure: config.LdapInsecure,
SearchFilter: config.LdapSearchFilter, SearchFilter: config.LdapSearchFilter,
} }
// Create LDAP service
ldapService, err = ldap.NewLDAP(ldapConfig) ldapService, err = ldap.NewLDAP(ldapConfig)
HandleError(err, "Failed to create LDAP service") HandleError(err, "Failed to create LDAP service")
} else { } else {
log.Info().Msg("LDAP not configured, using local users or OAuth") log.Info().Msg("LDAP not configured, using local users or OAuth")
} }
// Check if we have any users configured // Check if we have a source of users
if len(users) == 0 && !utils.OAuthConfigured(config) && ldapService == nil { if len(users) == 0 && !utils.OAuthConfigured(config) && ldapService == nil {
HandleError(errors.New("err no users"), "Unable to find a source of users") HandleError(errors.New("err no users"), "Unable to find a source of users")
} }
// Create auth service // Setup the services
docker, err := docker.NewDocker()
HandleError(err, "Failed to initialize docker")
auth := auth.NewAuth(authConfig, docker, ldapService) auth := auth.NewAuth(authConfig, docker, ldapService)
// Create OAuth providers service
providers := providers.NewProviders(oauthConfig) providers := providers.NewProviders(oauthConfig)
// Create hooks service
hooks := hooks.NewHooks(hooksConfig, auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
// Create server
srv, err := server.NewServer(serverConfig, handlers) srv, err := server.NewServer(serverConfig, handlers)
HandleError(err, "Failed to create server") HandleError(err, "Failed to create server")
// Start server // Start up
err = srv.Start() err = srv.Start()
HandleError(err, "Failed to start server") HandleError(err, "Failed to start server")
}, },
@@ -195,23 +165,17 @@ func Execute() {
} }
func HandleError(err error, msg string) { func HandleError(err error, msg string) {
// If error, log it and exit
if err != nil { if err != nil {
log.Fatal().Err(err).Msg(msg) log.Fatal().Err(err).Msg(msg)
} }
} }
func init() { func init() {
// Add user command
rootCmd.AddCommand(userCmd.UserCmd()) rootCmd.AddCommand(userCmd.UserCmd())
// Add totp command
rootCmd.AddCommand(totpCmd.TotpCmd()) rootCmd.AddCommand(totpCmd.TotpCmd())
// Read environment variables
viper.AutomaticEnv() viper.AutomaticEnv()
// Flags
rootCmd.Flags().Int("port", 3000, "Port to run the server on.") rootCmd.Flags().Int("port", 3000, "Port to run the server on.")
rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.")
rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") rootCmd.Flags().String("secret", "", "Secret to use for the cookie.")
@@ -252,7 +216,6 @@ func init() {
rootCmd.Flags().Bool("ldap-insecure", false, "Skip certificate verification for the LDAP server.") rootCmd.Flags().Bool("ldap-insecure", false, "Skip certificate verification for the LDAP server.")
rootCmd.Flags().String("ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup.") rootCmd.Flags().String("ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup.")
// Bind flags to environment
viper.BindEnv("port", "PORT") viper.BindEnv("port", "PORT")
viper.BindEnv("address", "ADDRESS") viper.BindEnv("address", "ADDRESS")
viper.BindEnv("secret", "SECRET") viper.BindEnv("secret", "SECRET")
@@ -293,6 +256,5 @@ func init() {
viper.BindEnv("ldap-insecure", "LDAP_INSECURE") viper.BindEnv("ldap-insecure", "LDAP_INSECURE")
viper.BindEnv("ldap-search-filter", "LDAP_SEARCH_FILTER") viper.BindEnv("ldap-search-filter", "LDAP_SEARCH_FILTER")
// Bind flags to viper
viper.BindPFlags(rootCmd.Flags()) viper.BindPFlags(rootCmd.Flags())
} }

View File

@@ -15,7 +15,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// Interactive flag
var interactive bool var interactive bool
// Input user // Input user
@@ -25,15 +24,9 @@ var GenerateCmd = &cobra.Command{
Use: "generate", Use: "generate",
Short: "Generate a totp secret", Short: "Generate a totp secret",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Setup logger
log.Logger = log.Level(zerolog.InfoLevel) log.Logger = log.Level(zerolog.InfoLevel)
// Use simple theme
var baseTheme *huh.Theme = huh.ThemeBase()
// Interactive
if interactive { if interactive {
// Create huh form
form := huh.NewForm( form := huh.NewForm(
huh.NewGroup( huh.NewGroup(
huh.NewInput().Title("Current username:hash").Value(&iUser).Validate((func(s string) error { huh.NewInput().Title("Current username:hash").Value(&iUser).Validate((func(s string) error {
@@ -44,51 +37,39 @@ var GenerateCmd = &cobra.Command{
})), })),
), ),
) )
var baseTheme *huh.Theme = huh.ThemeBase()
// Run form
err := form.WithTheme(baseTheme).Run() err := form.WithTheme(baseTheme).Run()
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Form failed") log.Fatal().Err(err).Msg("Form failed")
} }
} }
// Parse user
user, err := utils.ParseUser(iUser) user, err := utils.ParseUser(iUser)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to parse user") log.Fatal().Err(err).Msg("Failed to parse user")
} }
// Check if user was using docker escape
dockerEscape := false dockerEscape := false
if strings.Contains(iUser, "$$") { if strings.Contains(iUser, "$$") {
dockerEscape = true dockerEscape = true
} }
// Check it has totp
if user.TotpSecret != "" { if user.TotpSecret != "" {
log.Fatal().Msg("User already has a totp secret") log.Fatal().Msg("User already has a totp secret")
} }
// Generate totp secret
key, err := totp.Generate(totp.GenerateOpts{ key, err := totp.Generate(totp.GenerateOpts{
Issuer: "Tinyauth", Issuer: "Tinyauth",
AccountName: user.Username, AccountName: user.Username,
}) })
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to generate totp secret") log.Fatal().Err(err).Msg("Failed to generate totp secret")
} }
// Create secret
secret := key.Secret() secret := key.Secret()
// Print secret and image
log.Info().Str("secret", secret).Msg("Generated totp secret") log.Info().Str("secret", secret).Msg("Generated totp secret")
// Print QR code
log.Info().Msg("Generated QR code") log.Info().Msg("Generated QR code")
config := qrterminal.Config{ config := qrterminal.Config{
@@ -101,7 +82,6 @@ var GenerateCmd = &cobra.Command{
qrterminal.GenerateWithConfig(key.URL(), config) qrterminal.GenerateWithConfig(key.URL(), config)
// Add the secret to the user
user.TotpSecret = secret user.TotpSecret = secret
// If using docker escape re-escape it // If using docker escape re-escape it
@@ -109,13 +89,11 @@ var GenerateCmd = &cobra.Command{
user.Password = strings.ReplaceAll(user.Password, "$", "$$") user.Password = strings.ReplaceAll(user.Password, "$", "$$")
} }
// Print success
log.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") log.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
}, },
} }
func init() { func init() {
// Add interactive flag
GenerateCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Run in interactive mode") GenerateCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Run in interactive mode")
GenerateCmd.Flags().StringVar(&iUser, "user", "", "Your current username:hash") GenerateCmd.Flags().StringVar(&iUser, "user", "", "Your current username:hash")
} }

View File

@@ -7,16 +7,11 @@ import (
) )
func TotpCmd() *cobra.Command { func TotpCmd() *cobra.Command {
// Create the totp command
totpCmd := &cobra.Command{ totpCmd := &cobra.Command{
Use: "totp", Use: "totp",
Short: "Totp utilities", Short: "Totp utilities",
Long: `Utilities for creating and verifying totp codes.`, Long: `Utilities for creating and verifying totp codes.`,
} }
// Add the generate command
totpCmd.AddCommand(generate.GenerateCmd) totpCmd.AddCommand(generate.GenerateCmd)
// Return the totp command
return totpCmd return totpCmd
} }

View File

@@ -12,10 +12,7 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// Interactive flag
var interactive bool var interactive bool
// Docker flag
var docker bool var docker bool
// i stands for input // i stands for input
@@ -27,12 +24,9 @@ var CreateCmd = &cobra.Command{
Short: "Create a user", Short: "Create a user",
Long: `Create a user either interactively or by passing flags.`, Long: `Create a user either interactively or by passing flags.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Setup logger
log.Logger = log.Level(zerolog.InfoLevel) log.Logger = log.Level(zerolog.InfoLevel)
// Check if interactive
if interactive { if interactive {
// Create huh form
form := huh.NewForm( form := huh.NewForm(
huh.NewGroup( huh.NewGroup(
huh.NewInput().Title("Username").Value(&iUsername).Validate((func(s string) error { huh.NewInput().Title("Username").Value(&iUsername).Validate((func(s string) error {
@@ -50,46 +44,35 @@ var CreateCmd = &cobra.Command{
huh.NewSelect[bool]().Title("Format the output for docker?").Options(huh.NewOption("Yes", true), huh.NewOption("No", false)).Value(&docker), huh.NewSelect[bool]().Title("Format the output for docker?").Options(huh.NewOption("Yes", true), huh.NewOption("No", false)).Value(&docker),
), ),
) )
// Use simple theme
var baseTheme *huh.Theme = huh.ThemeBase() var baseTheme *huh.Theme = huh.ThemeBase()
err := form.WithTheme(baseTheme).Run() err := form.WithTheme(baseTheme).Run()
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Form failed") log.Fatal().Err(err).Msg("Form failed")
} }
} }
// Do we have username and password?
if iUsername == "" || iPassword == "" { if iUsername == "" || iPassword == "" {
log.Fatal().Err(errors.New("error invalid input")).Msg("Username and password cannot be empty") log.Fatal().Err(errors.New("error invalid input")).Msg("Username and password cannot be empty")
} }
log.Info().Str("username", iUsername).Str("password", iPassword).Bool("docker", docker).Msg("Creating user") log.Info().Str("username", iUsername).Str("password", iPassword).Bool("docker", docker).Msg("Creating user")
// Hash password
password, err := bcrypt.GenerateFromPassword([]byte(iPassword), bcrypt.DefaultCost) password, err := bcrypt.GenerateFromPassword([]byte(iPassword), bcrypt.DefaultCost)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to hash password") log.Fatal().Err(err).Msg("Failed to hash password")
} }
// Convert password to string // If docker format is enabled, escape the dollar sign
passwordString := string(password) passwordString := string(password)
// Escape $ for docker
if docker { if docker {
passwordString = strings.ReplaceAll(passwordString, "$", "$$") passwordString = strings.ReplaceAll(passwordString, "$", "$$")
} }
// Log user created
log.Info().Str("user", fmt.Sprintf("%s:%s", iUsername, passwordString)).Msg("User created") log.Info().Str("user", fmt.Sprintf("%s:%s", iUsername, passwordString)).Msg("User created")
}, },
} }
func init() { func init() {
// Flags
CreateCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively") CreateCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker") CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker")
CreateCmd.Flags().StringVar(&iUsername, "username", "", "Username") CreateCmd.Flags().StringVar(&iUsername, "username", "", "Username")

View File

@@ -8,17 +8,12 @@ import (
) )
func UserCmd() *cobra.Command { func UserCmd() *cobra.Command {
// Create the user command
userCmd := &cobra.Command{ userCmd := &cobra.Command{
Use: "user", Use: "user",
Short: "User utilities", Short: "User utilities",
Long: `Utilities for creating and verifying tinyauth compatible users.`, Long: `Utilities for creating and verifying tinyauth compatible users.`,
} }
// Add subcommands
userCmd.AddCommand(create.CreateCmd) userCmd.AddCommand(create.CreateCmd)
userCmd.AddCommand(verify.VerifyCmd) userCmd.AddCommand(verify.VerifyCmd)
// Return the user command
return userCmd return userCmd
} }

View File

@@ -12,10 +12,7 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
// Interactive flag
var interactive bool var interactive bool
// Docker flag
var docker bool var docker bool
// i stands for input // i stands for input
@@ -29,15 +26,9 @@ var VerifyCmd = &cobra.Command{
Short: "Verify a user is set up correctly", Short: "Verify a user is set up correctly",
Long: `Verify a user is set up correctly meaning that it has a correct username, password and totp code.`, Long: `Verify a user is set up correctly meaning that it has a correct username, password and totp code.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Setup logger
log.Logger = log.Level(zerolog.InfoLevel) log.Logger = log.Level(zerolog.InfoLevel)
// Use simple theme
var baseTheme *huh.Theme = huh.ThemeBase()
// Check if interactive
if interactive { if interactive {
// Create huh form
form := huh.NewForm( form := huh.NewForm(
huh.NewGroup( huh.NewGroup(
huh.NewInput().Title("User (username:hash:totp)").Value(&iUser).Validate((func(s string) error { huh.NewInput().Title("User (username:hash:totp)").Value(&iUser).Validate((func(s string) error {
@@ -61,35 +52,27 @@ var VerifyCmd = &cobra.Command{
huh.NewInput().Title("Totp Code (if setup)").Value(&iTotp), huh.NewInput().Title("Totp Code (if setup)").Value(&iTotp),
), ),
) )
var baseTheme *huh.Theme = huh.ThemeBase()
// Run form
err := form.WithTheme(baseTheme).Run() err := form.WithTheme(baseTheme).Run()
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Form failed") log.Fatal().Err(err).Msg("Form failed")
} }
} }
// Parse user
user, err := utils.ParseUser(iUser) user, err := utils.ParseUser(iUser)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to parse user") log.Fatal().Err(err).Msg("Failed to parse user")
} }
// Compare username
if user.Username != iUsername { if user.Username != iUsername {
log.Fatal().Msg("Username is incorrect") log.Fatal().Msg("Username is incorrect")
} }
// Compare password
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(iPassword)) err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(iPassword))
if err != nil { if err != nil {
log.Fatal().Msg("Ppassword is incorrect") log.Fatal().Msg("Ppassword is incorrect")
} }
// Check if user has 2fa code
if user.TotpSecret == "" { if user.TotpSecret == "" {
if iTotp != "" { if iTotp != "" {
log.Warn().Msg("User does not have 2fa secret") log.Warn().Msg("User does not have 2fa secret")
@@ -98,21 +81,17 @@ var VerifyCmd = &cobra.Command{
return return
} }
// Check totp code
ok := totp.Validate(iTotp, user.TotpSecret) ok := totp.Validate(iTotp, user.TotpSecret)
if !ok { if !ok {
log.Fatal().Msg("Totp code incorrect") log.Fatal().Msg("Totp code incorrect")
} }
// Done
log.Info().Msg("User verified") log.Info().Msg("User verified")
}, },
} }
func init() { func init() {
// Flags
VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively") VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?") VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?")
VerifyCmd.Flags().StringVar(&iUsername, "username", "", "Username") VerifyCmd.Flags().StringVar(&iUsername, "username", "", "Username")

View File

@@ -7,7 +7,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// Create the version command
var versionCmd = &cobra.Command{ var versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "Print the version number of Tinyauth", Short: "Print the version number of Tinyauth",

View File

@@ -27,10 +27,8 @@ type Auth struct {
} }
func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth { func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth {
// Create cookie store // Setup cookie store and create the auth service
store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret)) store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret))
// Configure cookie store
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: "/",
MaxAge: config.SessionExpiry, MaxAge: config.SessionExpiry,
@@ -38,7 +36,6 @@ func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *A
HttpOnly: true, HttpOnly: true,
Domain: fmt.Sprintf(".%s", config.Domain), Domain: fmt.Sprintf(".%s", config.Domain),
} }
return &Auth{ return &Auth{
Config: config, Config: config,
Docker: docker, Docker: docker,
@@ -49,20 +46,14 @@ func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *A
} }
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
// Get session
session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName)
// If there was an error getting the session, it might be invalid so let's clear it and retry
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying") log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying")
// Delete the session cookie if there is an error
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
// Try to get the session again
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
if err != nil { if err != nil {
// If we still can't get the session, log the error and return nil
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
return nil, err return nil, err
} }
@@ -72,13 +63,11 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
} }
func (auth *Auth) SearchUser(username string) types.UserSearch { func (auth *Auth) SearchUser(username string) types.UserSearch {
// Loop through users and return the user if the username matches
log.Debug().Str("username", username).Msg("Searching for user") log.Debug().Str("username", username).Msg("Searching for user")
// Check local users first
if auth.GetLocalUser(username).Username != "" { if auth.GetLocalUser(username).Username != "" {
log.Debug().Str("username", username).Msg("Found local user") log.Debug().Str("username", username).Msg("Found local user")
// If user found, return a user with the username and type "local"
return types.UserSearch{ return types.UserSearch{
Username: username, Username: username,
Type: "local", Type: "local",
@@ -88,14 +77,11 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
// If no user found, check LDAP // If no user found, check LDAP
if auth.LDAP != nil { if auth.LDAP != nil {
log.Debug().Str("username", username).Msg("Checking LDAP for user") log.Debug().Str("username", username).Msg("Checking LDAP for user")
userDN, err := auth.LDAP.Search(username) userDN, err := auth.LDAP.Search(username)
if err != nil { if err != nil {
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
return types.UserSearch{} return types.UserSearch{}
} }
// If user found in LDAP, return a user with the DN as username
return types.UserSearch{ return types.UserSearch{
Username: userDN, Username: userDN,
Type: "ldap", Type: "ldap",
@@ -109,34 +95,28 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
// Authenticate the user based on the type // Authenticate the user based on the type
switch search.Type { switch search.Type {
case "local": case "local":
// Get local user // If local user, get the user and check the password
user := auth.GetLocalUser(search.Username) user := auth.GetLocalUser(search.Username)
// Check if password is correct
return auth.CheckPassword(user, password) return auth.CheckPassword(user, password)
case "ldap": case "ldap":
// If LDAP is configured, bind to the LDAP server with the user DN and password // If LDAP is configured, bind to the LDAP server with the user DN and password
if auth.LDAP != nil { if auth.LDAP != nil {
log.Debug().Str("username", search.Username).Msg("Binding to LDAP for user authentication") log.Debug().Str("username", search.Username).Msg("Binding to LDAP for user authentication")
// Bind to the LDAP server
err := auth.LDAP.Bind(search.Username, password) err := auth.LDAP.Bind(search.Username, password)
if err != nil { if err != nil {
log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
return false return false
} }
// If bind is successful, rebind with the LDAP bind user // Rebind with the service account to reset the connection
err = auth.LDAP.Bind(auth.LDAP.Config.BindDN, auth.LDAP.Config.BindPassword) err = auth.LDAP.Bind(auth.LDAP.Config.BindDN, auth.LDAP.Config.BindPassword)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to rebind with service account after user authentication") log.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
// Consider closing the connection or creating a new one
return false return false
} }
log.Debug().Str("username", search.Username).Msg("LDAP authentication successful") log.Debug().Str("username", search.Username).Msg("LDAP authentication successful")
// Return true if the bind was successful
return true return true
} }
default: default:
@@ -165,11 +145,9 @@ func (auth *Auth) GetLocalUser(username string) types.User {
} }
func (auth *Auth) CheckPassword(user types.User, password string) bool { func (auth *Auth) CheckPassword(user types.User, password string) bool {
// Compare the hashed password with the password provided
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
} }
// IsAccountLocked checks if a username or IP is locked due to too many failed login attempts
func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { func (auth *Auth) IsAccountLocked(identifier string) (bool, int) {
auth.LoginMutex.RLock() auth.LoginMutex.RLock()
defer auth.LoginMutex.RUnlock() defer auth.LoginMutex.RUnlock()
@@ -196,7 +174,6 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) {
return false, 0 return false, 0
} }
// RecordLoginAttempt records a login attempt for rate limiting
func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { func (auth *Auth) RecordLoginAttempt(identifier string, success bool) {
// Skip if rate limiting is not configured // Skip if rate limiting is not configured
if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 {
@@ -240,7 +217,6 @@ func (auth *Auth) EmailWhitelisted(email string) bool {
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
log.Debug().Msg("Creating session cookie") log.Debug().Msg("Creating session cookie")
// Get session
session, err := auth.GetSession(c) session, err := auth.GetSession(c)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
@@ -249,7 +225,6 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
log.Debug().Msg("Setting session cookie") log.Debug().Msg("Setting session cookie")
// Calculate expiry
var sessionExpiry int var sessionExpiry int
if data.TotpPending { if data.TotpPending {
@@ -258,7 +233,6 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
sessionExpiry = auth.Config.SessionExpiry sessionExpiry = auth.Config.SessionExpiry
} }
// Set data
session.Values["username"] = data.Username session.Values["username"] = data.Username
session.Values["name"] = data.Name session.Values["name"] = data.Name
session.Values["email"] = data.Email session.Values["email"] = data.Email
@@ -267,21 +241,18 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
session.Values["totpPending"] = data.TotpPending session.Values["totpPending"] = data.TotpPending
session.Values["oauthGroups"] = data.OAuthGroups session.Values["oauthGroups"] = data.OAuthGroups
// Save session
err = session.Save(c.Request, c.Writer) err = session.Save(c.Request, c.Writer)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to save session") log.Error().Err(err).Msg("Failed to save session")
return err return err
} }
// Return nil
return nil return nil
} }
func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { func (auth *Auth) DeleteSessionCookie(c *gin.Context) error {
log.Debug().Msg("Deleting session cookie") log.Debug().Msg("Deleting session cookie")
// Get session
session, err := auth.GetSession(c) session, err := auth.GetSession(c)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
@@ -293,21 +264,18 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error {
delete(session.Values, key) delete(session.Values, key)
} }
// Save session
err = session.Save(c.Request, c.Writer) err = session.Save(c.Request, c.Writer)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to save session") log.Error().Err(err).Msg("Failed to save session")
return err return err
} }
// Return nil
return nil return nil
} }
func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
log.Debug().Msg("Getting session cookie") log.Debug().Msg("Getting session cookie")
// Get session
session, err := auth.GetSession(c) session, err := auth.GetSession(c)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
@@ -316,7 +284,6 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
log.Debug().Msg("Got session") log.Debug().Msg("Got session")
// Get data from session
username, usernameOk := session.Values["username"].(string) username, usernameOk := session.Values["username"].(string)
email, emailOk := session.Values["email"].(string) email, emailOk := session.Values["email"].(string)
name, nameOk := session.Values["name"].(string) name, nameOk := session.Values["name"].(string)
@@ -325,30 +292,21 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
totpPending, totpPendingOk := session.Values["totpPending"].(bool) totpPending, totpPendingOk := session.Values["totpPending"].(bool)
oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string) oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string)
// If any data is missing, delete the session cookie
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
log.Warn().Msg("Session cookie is invalid") log.Warn().Msg("Session cookie is invalid")
// If any data is missing, delete the session cookie
auth.DeleteSessionCookie(c) auth.DeleteSessionCookie(c)
// Return empty cookie
return types.SessionCookie{}, nil return types.SessionCookie{}, nil
} }
// Check if the cookie has expired // If the session cookie has expired, delete it
if time.Now().Unix() > expiry { if time.Now().Unix() > expiry {
log.Warn().Msg("Session cookie expired") log.Warn().Msg("Session cookie expired")
// If it has, delete it
auth.DeleteSessionCookie(c) auth.DeleteSessionCookie(c)
// Return empty cookie
return types.SessionCookie{}, nil return types.SessionCookie{}, nil
} }
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
// Return the cookie
return types.SessionCookie{ return types.SessionCookie{
Username: username, Username: username,
Name: name, Name: name,
@@ -360,25 +318,21 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
} }
func (auth *Auth) UserAuthConfigured() bool { func (auth *Auth) UserAuthConfigured() bool {
// If there are users, return true // If there are users or LDAP is configured, return true
return len(auth.Config.Users) > 0 || auth.LDAP != nil return len(auth.Config.Users) > 0 || auth.LDAP != nil
} }
func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
// Check if oauth is allowed
if context.OAuth { if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist") log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
} }
// Check users
log.Debug().Msg("Checking users") log.Debug().Msg("Checking users")
return utils.CheckFilter(labels.Users, context.Username) return utils.CheckFilter(labels.Users, context.Username)
} }
func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
// Check if groups are required
if labels.OAuth.Groups == "" { if labels.OAuth.Groups == "" {
return true return true
} }
@@ -402,18 +356,12 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t
// No groups matched // No groups matched
log.Debug().Msg("No groups matched") log.Debug().Msg("No groups matched")
// Return false
return false return false
} }
func (auth *Auth) AuthEnabled(c *gin.Context, labels types.Labels) (bool, error) { func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
// Get headers // If the label is empty, auth is enabled
uri := c.Request.Header.Get("X-Forwarded-Uri")
// Check if the allowed label is empty
if labels.Allowed == "" { if labels.Allowed == "" {
// Auth enabled
return true, nil return true, nil
} }
@@ -426,9 +374,8 @@ func (auth *Auth) AuthEnabled(c *gin.Context, labels types.Labels) (bool, error)
return true, err return true, err
} }
// Check if the uri matches the regex // If the regex matches the URI, auth is not enabled
if regex.MatchString(uri) { if regex.MatchString(uri) {
// Auth disabled
return false, nil return false, nil
} }
@@ -437,15 +384,10 @@ func (auth *Auth) AuthEnabled(c *gin.Context, labels types.Labels) (bool, error)
} }
func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User {
// Get the Authorization header
username, password, ok := c.Request.BasicAuth() username, password, ok := c.Request.BasicAuth()
// If not ok, return an empty user
if !ok { if !ok {
return nil return nil
} }
// Return the user
return &types.User{ return &types.User{
Username: username, Username: username,
Password: password, Password: password,
@@ -486,7 +428,6 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
} }
log.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") log.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
return true return true
} }
@@ -505,6 +446,5 @@ func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool {
} }
log.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") log.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
return false return false
} }

View File

@@ -4,7 +4,6 @@ import (
"testing" "testing"
"time" "time"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/docker"
"tinyauth/internal/types" "tinyauth/internal/types"
) )
@@ -18,7 +17,7 @@ func TestLoginRateLimiting(t *testing.T) {
// Initialize a new auth service with 3 max retries and 5 seconds timeout // Initialize a new auth service with 3 max retries and 5 seconds timeout
config.LoginMaxRetries = 3 config.LoginMaxRetries = 3
config.LoginTimeout = 5 config.LoginTimeout = 5
authService := auth.NewAuth(config, &docker.Docker{}, nil) authService := auth.NewAuth(config, nil, nil)
// Test identifier // Test identifier
identifier := "test_user" identifier := "test_user"
@@ -62,7 +61,7 @@ func TestLoginRateLimiting(t *testing.T) {
// Reinitialize auth service with a shorter timeout for testing // Reinitialize auth service with a shorter timeout for testing
config.LoginTimeout = 1 config.LoginTimeout = 1
config.LoginMaxRetries = 3 config.LoginMaxRetries = 3
authService = auth.NewAuth(config, &docker.Docker{}, nil) authService = auth.NewAuth(config, nil, nil)
// Add enough failed attempts to lock the account // Add enough failed attempts to lock the account
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
@@ -87,7 +86,7 @@ func TestLoginRateLimiting(t *testing.T) {
t.Log("Testing disabled rate limiting") t.Log("Testing disabled rate limiting")
config.LoginMaxRetries = 0 config.LoginMaxRetries = 0
config.LoginTimeout = 0 config.LoginTimeout = 0
authService = auth.NewAuth(config, &docker.Docker{}, nil) authService = auth.NewAuth(config, nil, nil)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
authService.RecordLoginAttempt(identifier, false) authService.RecordLoginAttempt(identifier, false)
@@ -103,7 +102,7 @@ func TestConcurrentLoginAttempts(t *testing.T) {
// Initialize a new auth service with 2 max retries and 5 seconds timeout // Initialize a new auth service with 2 max retries and 5 seconds timeout
config.LoginMaxRetries = 2 config.LoginMaxRetries = 2
config.LoginTimeout = 5 config.LoginTimeout = 5
authService := auth.NewAuth(config, &docker.Docker{}, nil) authService := auth.NewAuth(config, nil, nil)
// Test multiple identifiers // Test multiple identifiers
identifiers := []string{"user1", "user2", "user3"} identifiers := []string{"user1", "user2", "user3"}

View File

@@ -1,6 +1,6 @@
package constants package constants
// Claims are the OIDC supported claims (including preferd username for some reason) // Claims are the OIDC supported claims (prefered username is included for convinience)
type Claims struct { type Claims struct {
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
@@ -13,7 +13,7 @@ var Version = "development"
var CommitHash = "n/a" var CommitHash = "n/a"
var BuildTimestamp = "n/a" var BuildTimestamp = "n/a"
// Cookie names // Base cookie names
var SessionCookieName = "tinyauth-session" var SessionCookieName = "tinyauth-session"
var CsrfCookieName = "tinyauth-csrf" var CsrfCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect" var RedirectCookieName = "tinyauth-redirect"

View File

@@ -17,18 +17,12 @@ type Docker struct {
} }
func NewDocker() (*Docker, error) { func NewDocker() (*Docker, error) {
// Create a new docker client
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
// Check if there was an error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Create the context
ctx := context.Background() ctx := context.Background()
// Negotiate API version
client.NegotiateAPIVersion(ctx) client.NegotiateAPIVersion(ctx)
return &Docker{ return &Docker{
@@ -38,75 +32,52 @@ func NewDocker() (*Docker, error) {
} }
func (docker *Docker) GetContainers() ([]container.Summary, error) { func (docker *Docker) GetContainers() ([]container.Summary, error) {
// Get the list of containers
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
// Check if there was an error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Return the containers
return containers, nil return containers, nil
} }
func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) {
// Inspect the container
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
// Check if there was an error
if err != nil { if err != nil {
return container.InspectResponse{}, err return container.InspectResponse{}, err
} }
// Return the inspect
return inspect, nil return inspect, nil
} }
func (docker *Docker) DockerConnected() bool { func (docker *Docker) DockerConnected() bool {
// Ping the docker client if there is an error it is not connected
_, err := docker.Client.Ping(docker.Context) _, err := docker.Client.Ping(docker.Context)
return err == nil return err == nil
} }
func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) { func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) {
// Check if we have access to the Docker API
isConnected := docker.DockerConnected() isConnected := docker.DockerConnected()
// If we don't have access, return an empty struct
if !isConnected { if !isConnected {
log.Debug().Msg("Docker not connected, returning empty labels") log.Debug().Msg("Docker not connected, returning empty labels")
return types.Labels{}, nil return types.Labels{}, nil
} }
// Get the containers
log.Debug().Msg("Getting containers") log.Debug().Msg("Getting containers")
containers, err := docker.GetContainers() containers, err := docker.GetContainers()
// If there is an error, return false
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error getting containers") log.Error().Err(err).Msg("Error getting containers")
return types.Labels{}, err return types.Labels{}, err
} }
// Loop through the containers
for _, container := range containers { for _, container := range containers {
// Inspect the container
inspect, err := docker.InspectContainer(container.ID) inspect, err := docker.InspectContainer(container.ID)
// Check if there was an error
if err != nil { if err != nil {
log.Warn().Str("id", container.ID).Err(err).Msg("Error inspecting container, skipping") log.Warn().Str("id", container.ID).Err(err).Msg("Error inspecting container, skipping")
continue continue
} }
// Get the labels
log.Debug().Str("id", inspect.ID).Msg("Getting labels for container") log.Debug().Str("id", inspect.ID).Msg("Getting labels for container")
labels, err := utils.GetLabels(inspect.Config.Labels) labels, err := utils.GetLabels(inspect.Config.Labels)
// Check if there was an error
if err != nil { if err != nil {
log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping") log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping")
continue continue
@@ -127,7 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error)
} }
log.Debug().Msg("No matching container found, returning empty labels") log.Debug().Msg("No matching container found, returning empty labels")
// If no matching container is found, return empty labels
return types.Labels{}, nil return types.Labels{}, nil
} }

View File

@@ -37,13 +37,9 @@ func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hook
} }
func (h *Handlers) AuthHandler(c *gin.Context) { func (h *Handlers) AuthHandler(c *gin.Context) {
// Create struct for proxy
var proxy types.Proxy var proxy types.Proxy
// Bind URI
err := c.BindUri(&proxy) err := c.BindUri(&proxy)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind URI") log.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -64,7 +60,6 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy")
// Get headers
uri := c.Request.Header.Get("X-Forwarded-Uri") uri := c.Request.Header.Get("X-Forwarded-Uri")
proto := c.Request.Header.Get("X-Forwarded-Proto") proto := c.Request.Header.Get("X-Forwarded-Proto")
host := c.Request.Header.Get("X-Forwarded-Host") host := c.Request.Header.Get("X-Forwarded-Host")
@@ -75,12 +70,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
// Get the id // Get the id
id := strings.Split(hostPortless, ".")[0] id := strings.Split(hostPortless, ".")[0]
// Get the container labels
labels, err := h.Docker.GetLabels(id, hostPortless) labels, err := h.Docker.GetLabels(id, hostPortless)
log.Debug().Interface("labels", labels).Msg("Got labels")
// Check if there was an error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get container labels") log.Error().Err(err).Msg("Failed to get container labels")
@@ -96,20 +86,24 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return return
} }
// Get client IP log.Debug().Interface("labels", labels).Msg("Got labels")
ip := c.ClientIP() ip := c.ClientIP()
// Check if the IP is in bypass list // Check if the IP is in bypass list
if h.Auth.BypassedIP(labels, ip) { if h.Auth.BypassedIP(labels, ip) {
headersParsed := utils.ParseHeaders(labels.Headers) headersParsed := utils.ParseHeaders(labels.Headers)
for key, value := range headersParsed { for key, value := range headersParsed {
log.Debug().Str("key", key).Msg("Setting header") log.Debug().Str("key", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
@@ -132,10 +126,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
IP: ip, IP: ip,
} }
// Build query
queries, err := query.Values(values) queries, err := query.Values(values)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to build queries") log.Error().Err(err).Msg("Failed to build queries")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
@@ -147,12 +138,9 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
} }
// Check if auth is enabled // Check if auth is enabled
authEnabled, err := h.Auth.AuthEnabled(c, labels) authEnabled, err := h.Auth.AuthEnabled(uri, labels)
// Check if there was an error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to check if app is allowed") log.Error().Err(err).Msg("Failed to check if app is allowed")
if proxy.Proxy == "nginx" || !isBrowser { if proxy.Proxy == "nginx" || !isBrowser {
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
@@ -172,14 +160,17 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
log.Debug().Str("key", key).Msg("Setting header") log.Debug().Str("key", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
}) })
return return
} }
@@ -201,7 +192,6 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
// The user is not allowed to access the app
if !appAllowed { if !appAllowed {
log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed")
@@ -213,44 +203,35 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return return
} }
// Values
values := types.UnauthorizedQuery{ values := types.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0], Resource: strings.Split(host, ".")[0],
} }
// Use either username or email
if userContext.OAuth { if userContext.OAuth {
values.Username = userContext.Email values.Username = userContext.Email
} else { } else {
values.Username = userContext.Username values.Username = userContext.Username
} }
// Build query
queries, err := query.Values(values) queries, err := query.Values(values)
// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to build queries") log.Error().Err(err).Msg("Failed to build queries")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return return
} }
// We are using caddy/traefik so redirect
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode()))
return return
} }
// Check groups if using OAuth // Check groups if using OAuth
if userContext.OAuth { if userContext.OAuth {
// Check if user is in required groups
groupOk := h.Auth.OAuthGroup(c, userContext, labels) groupOk := h.Auth.OAuthGroup(c, userContext, labels)
log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups")
// The user is not allowed to access the app
if !groupOk { if !groupOk {
log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups")
if proxy.Proxy == "nginx" || !isBrowser { if proxy.Proxy == "nginx" || !isBrowser {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
@@ -259,30 +240,24 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return return
} }
// Values
values := types.UnauthorizedQuery{ values := types.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0], Resource: strings.Split(host, ".")[0],
GroupErr: true, GroupErr: true,
} }
// Use either username or email
if userContext.OAuth { if userContext.OAuth {
values.Username = userContext.Email values.Username = userContext.Email
} else { } else {
values.Username = userContext.Username values.Username = userContext.Username
} }
// Build query
queries, err := query.Values(values) queries, err := query.Values(values)
// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to build queries") log.Error().Err(err).Msg("Failed to build queries")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return return
} }
// We are using caddy/traefik so redirect
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode()))
return return
} }
@@ -306,7 +281,6 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
} }
// The user is allowed to access the app
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
@@ -336,19 +310,13 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
} }
log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
// Redirect to login
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode()))
} }
func (h *Handlers) LoginHandler(c *gin.Context) { func (h *Handlers) LoginHandler(c *gin.Context) {
// Create login struct
var login types.LoginRequest var login types.LoginRequest
// Bind JSON
err := c.BindJSON(&login) err := c.BindJSON(&login)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind JSON") log.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -360,7 +328,6 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
log.Debug().Msg("Got login request") log.Debug().Msg("Got login request")
// Get client IP for rate limiting
clientIP := c.ClientIP() clientIP := c.ClientIP()
// Create an identifier for rate limiting (username or IP if username doesn't exist yet) // Create an identifier for rate limiting (username or IP if username doesn't exist yet)
@@ -381,9 +348,9 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
} }
// Search for a user based on username // Search for a user based on username
userSearch := h.Auth.SearchUser(login.Username) log.Debug().Interface("username", login.Username).Msg("Searching for user")
log.Debug().Interface("userSearch", userSearch).Msg("Searching for user") userSearch := h.Auth.SearchUser(login.Username)
// User does not exist // User does not exist
if userSearch.Type == "" { if userSearch.Type == "" {
@@ -440,8 +407,6 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
"message": "Waiting for totp", "message": "Waiting for totp",
"totpPending": true, "totpPending": true,
}) })
// Stop further processing
return return
} }
} }
@@ -463,13 +428,9 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
} }
func (h *Handlers) TotpHandler(c *gin.Context) { func (h *Handlers) TotpHandler(c *gin.Context) {
// Create totp struct
var totpReq types.TotpRequest var totpReq types.TotpRequest
// Bind JSON
err := c.BindJSON(&totpReq) err := c.BindJSON(&totpReq)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind JSON") log.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -500,7 +461,6 @@ func (h *Handlers) TotpHandler(c *gin.Context) {
// Check if totp is correct // Check if totp is correct
ok := totp.Validate(totpReq.Code, user.TotpSecret) ok := totp.Validate(totpReq.Code, user.TotpSecret)
// TOTP is incorrect
if !ok { if !ok {
log.Debug().Msg("Totp incorrect") log.Debug().Msg("Totp incorrect")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
@@ -528,14 +488,10 @@ func (h *Handlers) TotpHandler(c *gin.Context) {
} }
func (h *Handlers) LogoutHandler(c *gin.Context) { func (h *Handlers) LogoutHandler(c *gin.Context) {
log.Debug().Msg("Logging out")
// Delete session cookie
h.Auth.DeleteSessionCookie(c)
log.Debug().Msg("Cleaning up redirect cookie") log.Debug().Msg("Cleaning up redirect cookie")
// Return logged out h.Auth.DeleteSessionCookie(c)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logged out", "message": "Logged out",
@@ -553,7 +509,7 @@ func (h *Handlers) AppHandler(c *gin.Context) {
configuredProviders = append(configuredProviders, "username") configuredProviders = append(configuredProviders, "username")
} }
// Create app context struct // Return app context
appContext := types.AppContext{ appContext := types.AppContext{
Status: 200, Status: 200,
Message: "OK", Message: "OK",
@@ -566,18 +522,15 @@ func (h *Handlers) AppHandler(c *gin.Context) {
BackgroundImage: h.Config.BackgroundImage, BackgroundImage: h.Config.BackgroundImage,
OAuthAutoRedirect: h.Config.OAuthAutoRedirect, OAuthAutoRedirect: h.Config.OAuthAutoRedirect,
} }
// Return app context
c.JSON(200, appContext) c.JSON(200, appContext)
} }
func (h *Handlers) UserHandler(c *gin.Context) { func (h *Handlers) UserHandler(c *gin.Context) {
log.Debug().Msg("Getting user context") log.Debug().Msg("Getting user context")
// Get user context // Create user context using hooks
userContext := h.Hooks.UseUserContext(c) userContext := h.Hooks.UseUserContext(c)
// Create user context response
userContextResponse := types.UserContextResponse{ userContextResponse := types.UserContextResponse{
Status: 200, Status: 200,
IsLoggedIn: userContext.IsLoggedIn, IsLoggedIn: userContext.IsLoggedIn,
@@ -598,18 +551,13 @@ func (h *Handlers) UserHandler(c *gin.Context) {
userContextResponse.Message = "Authenticated" userContextResponse.Message = "Authenticated"
} }
// Return user context
c.JSON(200, userContextResponse) c.JSON(200, userContextResponse)
} }
func (h *Handlers) OauthUrlHandler(c *gin.Context) { func (h *Handlers) OauthUrlHandler(c *gin.Context) {
// Create struct for OAuth request
var request types.OAuthRequest var request types.OAuthRequest
// Bind URI
err := c.BindUri(&request) err := c.BindUri(&request)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind URI") log.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -624,7 +572,6 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
// Check if provider exists // Check if provider exists
provider := h.Providers.GetProvider(request.Provider) provider := h.Providers.GetProvider(request.Provider)
// Provider does not exist
if provider == nil { if provider == nil {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
@@ -664,13 +611,9 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
} }
func (h *Handlers) OauthCallbackHandler(c *gin.Context) { func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
// Create struct for OAuth request
var providerName types.OAuthRequest var providerName types.OAuthRequest
// Bind URI
err := c.BindUri(&providerName) err := c.BindUri(&providerName)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind URI") log.Error().Err(err).Msg("Failed to bind URI")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
@@ -711,30 +654,25 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
// Get provider // Get provider
provider := h.Providers.GetProvider(providerName.Provider) provider := h.Providers.GetProvider(providerName.Provider)
log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
// Provider does not exist
if provider == nil { if provider == nil {
c.Redirect(http.StatusTemporaryRedirect, "/not-found") c.Redirect(http.StatusTemporaryRedirect, "/not-found")
return return
} }
log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
// Exchange token (authenticates user) // Exchange token (authenticates user)
_, err = provider.ExchangeToken(code) _, err = provider.ExchangeToken(code)
log.Debug().Msg("Got token")
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to exchange token") log.Error().Err(err).Msg("Failed to exchange token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return return
} }
log.Debug().Msg("Got token")
// Get user // Get user
user, err := h.Providers.GetUser(providerName.Provider) user, err := h.Providers.GetUser(providerName.Provider)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get user") log.Error().Err(err).Msg("Failed to get user")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
@@ -753,20 +691,16 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
// Email is not whitelisted // Email is not whitelisted
if !h.Auth.EmailWhitelisted(user.Email) { if !h.Auth.EmailWhitelisted(user.Email) {
log.Warn().Str("email", user.Email).Msg("Email not whitelisted") log.Warn().Str("email", user.Email).Msg("Email not whitelisted")
// Build query
queries, err := query.Values(types.UnauthorizedQuery{ queries, err := query.Values(types.UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to build queries") log.Error().Err(err).Msg("Failed to build queries")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return return
} }
// Redirect to unauthorized
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode()))
} }
@@ -790,7 +724,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
} }
// Create session cookie (also cleans up redirect cookie) // Create session cookie
h.Auth.CreateSessionCookie(c, &types.SessionCookie{ h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: username, Username: username,
Name: name, Name: name,
@@ -810,20 +744,18 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI")
// Build query
queries, err := query.Values(types.LoginQuery{ queries, err := query.Values(types.LoginQuery{
RedirectURI: redirectCookie, RedirectURI: redirectCookie,
}) })
log.Debug().Msg("Got redirect query")
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to build queries") log.Error().Err(err).Msg("Failed to build queries")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return return
} }
log.Debug().Msg("Got redirect query")
// Clean up redirect cookie // Clean up redirect cookie
c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true) c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true)

View File

@@ -35,21 +35,16 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if basic != nil { if basic != nil {
log.Debug().Msg("Got basic auth") log.Debug().Msg("Got basic auth")
// Search for a user based on username
userSearch := hooks.Auth.SearchUser(basic.Username) userSearch := hooks.Auth.SearchUser(basic.Username)
if userSearch.Type == "" { if userSearch.Type == "" {
log.Error().Str("username", basic.Username).Msg("User does not exist") log.Error().Str("username", basic.Username).Msg("User does not exist")
// Return empty context
return types.UserContext{} return types.UserContext{}
} }
// Verify the user // Verify the user
if !hooks.Auth.VerifyUser(userSearch, basic.Password) { if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
log.Error().Str("username", basic.Username).Msg("Password incorrect") log.Error().Str("username", basic.Username).Msg("Password incorrect")
// Return empty context
return types.UserContext{} return types.UserContext{}
} }
@@ -83,14 +78,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Check cookie error after basic auth // Check cookie error after basic auth
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session cookie") log.Error().Err(err).Msg("Failed to get session cookie")
// Return empty context
return types.UserContext{} return types.UserContext{}
} }
// Check if session cookie has totp pending
if cookie.TotpPending { if cookie.TotpPending {
log.Debug().Msg("Totp pending") log.Debug().Msg("Totp pending")
// Return empty context since we are pending totp
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
Name: cookie.Name, Name: cookie.Name,
@@ -104,19 +96,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if cookie.Provider == "username" { if cookie.Provider == "username" {
log.Debug().Msg("Provider is username") log.Debug().Msg("Provider is username")
// Search for the user with the username
userSearch := hooks.Auth.SearchUser(cookie.Username) userSearch := hooks.Auth.SearchUser(cookie.Username)
if userSearch.Type == "" { if userSearch.Type == "" {
log.Error().Str("username", cookie.Username).Msg("User does not exist") log.Error().Str("username", cookie.Username).Msg("User does not exist")
// Return empty context
return types.UserContext{} return types.UserContext{}
} }
log.Debug().Str("type", userSearch.Type).Msg("User exists") log.Debug().Str("type", userSearch.Type).Msg("User exists")
// It exists so we are logged in
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
Name: cookie.Name, Name: cookie.Name,
@@ -135,20 +123,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if provider != nil { if provider != nil {
log.Debug().Msg("Provider exists") log.Debug().Msg("Provider exists")
// Check if the oauth email is whitelisted // If the email is not whitelisted we delete the cookie and return an empty context
if !hooks.Auth.EmailWhitelisted(cookie.Email) { if !hooks.Auth.EmailWhitelisted(cookie.Email) {
log.Error().Str("email", cookie.Email).Msg("Email is not whitelisted") log.Error().Str("email", cookie.Email).Msg("Email is not whitelisted")
// It isn't so we delete the cookie and return an empty context
hooks.Auth.DeleteSessionCookie(c) hooks.Auth.DeleteSessionCookie(c)
// Return empty context
return types.UserContext{} return types.UserContext{}
} }
log.Debug().Msg("Email is whitelisted") log.Debug().Msg("Email is whitelisted")
// Return user context since we are logged in with oauth
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
Name: cookie.Name, Name: cookie.Name,
@@ -160,6 +143,5 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
} }
} }
// Neither basic auth or oauth is set so we return an empty context
return types.UserContext{} return types.UserContext{}
} }

View File

@@ -16,17 +16,15 @@ type LDAP struct {
} }
func NewLDAP(config types.LdapConfig) (*LDAP, error) { func NewLDAP(config types.LdapConfig) (*LDAP, error) {
// Create a new LDAP instance with the provided configuration
ldap := &LDAP{ ldap := &LDAP{
Config: config, Config: config,
} }
// Connect to the LDAP server _, err := ldap.connect()
if err := ldap.Connect(); err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP server: %w", err) return nil, fmt.Errorf("failed to connect to LDAP server: %w", err)
} }
// Start heartbeat goroutine
go func() { go func() {
for range time.Tick(time.Duration(5) * time.Minute) { for range time.Tick(time.Duration(5) * time.Minute) {
err := ldap.heartbeat() err := ldap.heartbeat()
@@ -39,25 +37,23 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
return ldap, nil return ldap, nil
} }
func (l *LDAP) Connect() error { func (l *LDAP) connect() (*ldapgo.Conn, error) {
// Connect to the LDAP server
conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: l.Config.Insecure, InsecureSkipVerify: l.Config.Insecure,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
})) }))
if err != nil { if err != nil {
return err return nil, err
} }
// Bind to the LDAP server with the provided credentials
err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
if err != nil { if err != nil {
return err return nil, err
} }
// Store the connection in the LDAP struct // Set and return the connection
l.Conn = conn l.Conn = conn
return nil return conn, nil
} }
func (l *LDAP) Search(username string) (string, error) { func (l *LDAP) Search(username string) (string, error) {
@@ -65,7 +61,6 @@ func (l *LDAP) Search(username string) (string, error) {
escapedUsername := ldapgo.EscapeFilter(username) escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername) filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername)
// Create a search request to find the user by username
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
l.Config.BaseDN, l.Config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
@@ -74,7 +69,6 @@ func (l *LDAP) Search(username string) (string, error) {
nil, nil,
) )
// Perform the search
searchResult, err := l.Conn.Search(searchRequest) searchResult, err := l.Conn.Search(searchRequest)
if err != nil { if err != nil {
return "", err return "", err
@@ -84,14 +78,11 @@ func (l *LDAP) Search(username string) (string, error) {
return "", fmt.Errorf("err multiple or no entries found for user %s", username) return "", fmt.Errorf("err multiple or no entries found for user %s", username)
} }
// User found, return the distinguished name (DN)
userDN := searchResult.Entries[0].DN userDN := searchResult.Entries[0].DN
return userDN, nil return userDN, nil
} }
func (l *LDAP) Bind(userDN string, password string) error { func (l *LDAP) Bind(userDN string, password string) error {
// Bind to the LDAP server with the user's DN and password
err := l.Conn.Bind(userDN, password) err := l.Conn.Bind(userDN, password)
if err != nil { if err != nil {
return err return err
@@ -100,10 +91,8 @@ func (l *LDAP) Bind(userDN string, password string) error {
} }
func (l *LDAP) heartbeat() error { func (l *LDAP) heartbeat() error {
// Perform a simple search to check if the connection is alive
log.Info().Msg("Performing LDAP connection heartbeat") log.Info().Msg("Performing LDAP connection heartbeat")
// Create a search request to find the user by username
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
"", "",
ldapgo.ScopeBaseObject, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeBaseObject, ldapgo.NeverDerefAliases, 0, 0, false,
@@ -112,11 +101,11 @@ func (l *LDAP) heartbeat() error {
nil, nil,
) )
// Perform the search
_, err := l.Conn.Search(searchRequest) _, err := l.Conn.Search(searchRequest)
if err != nil { if err != nil {
return err return err
} }
// No error means the connection is alive // No error means the connection is alive
return nil return nil
} }

View File

@@ -18,7 +18,6 @@ type OAuth struct {
} }
func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
// Create transport with TLS
transport := &http.Transport{ transport := &http.Transport{
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecureSkipVerify, InsecureSkipVerify: insecureSkipVerify,
@@ -26,18 +25,15 @@ func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
}, },
} }
// Create a new context
ctx := context.Background()
// Create the HTTP client with the transport
httpClient := &http.Client{ httpClient := &http.Client{
Transport: transport, Transport: transport,
} }
ctx := context.Background()
// Set the HTTP client in the context // Set the HTTP client in the context
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// Create the verifier
verifier := oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier()
return &OAuth{ return &OAuth{
@@ -48,40 +44,28 @@ func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
} }
func (oauth *OAuth) GetAuthURL(state string) string { func (oauth *OAuth) GetAuthURL(state string) string {
// Return the auth url
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
} }
func (oauth *OAuth) ExchangeToken(code string) (string, error) { func (oauth *OAuth) ExchangeToken(code string) (string, error) {
// Exchange the code for a token
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
// Check if there was an error
if err != nil { if err != nil {
return "", err return "", err
} }
// Set the token // Set and return the token
oauth.Token = token oauth.Token = token
// Return the access token
return oauth.Token.AccessToken, nil return oauth.Token.AccessToken, nil
} }
func (oauth *OAuth) GetClient() *http.Client { func (oauth *OAuth) GetClient() *http.Client {
// Return the http client with the token set
return oauth.Config.Client(oauth.Context, oauth.Token) return oauth.Config.Client(oauth.Context, oauth.Token)
} }
func (oauth *OAuth) GenerateState() string { func (oauth *OAuth) GenerateState() string {
// Generate a random state string
b := make([]byte, 128) b := make([]byte, 128)
// Fill the byte slice with random data
rand.Read(b) rand.Read(b)
// Encode the byte slice to a base64 string
state := base64.URLEncoding.EncodeToString(b) state := base64.URLEncoding.EncodeToString(b)
return state return state
} }

View File

@@ -10,41 +10,28 @@ import (
) )
func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { func GetGenericUser(client *http.Client, url string) (constants.Claims, error) {
// Create user struct
var user constants.Claims var user constants.Claims
// Using the oauth client get the user info url
res, err := client.Get(url) res, err := client.Get(url)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
defer res.Body.Close() defer res.Body.Close()
log.Debug().Msg("Got response from generic provider") log.Debug().Msg("Got response from generic provider")
// Read the body of the response
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Read body from generic provider") log.Debug().Msg("Read body from generic provider")
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &user) err = json.Unmarshal(body, &user)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Parsed user from generic provider") log.Debug().Msg("Parsed user from generic provider")
// Return the user
return user, nil return user, nil
} }

View File

@@ -28,71 +28,48 @@ func GithubScopes() []string {
} }
func GetGithubUser(client *http.Client) (constants.Claims, error) { func GetGithubUser(client *http.Client) (constants.Claims, error) {
// Create user struct
var user constants.Claims var user constants.Claims
// Get the user info from github using the oauth http client
res, err := client.Get("https://api.github.com/user") res, err := client.Get("https://api.github.com/user")
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
defer res.Body.Close() defer res.Body.Close()
log.Debug().Msg("Got user response from github") log.Debug().Msg("Got user response from github")
// Read the body of the response
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Read user body from github") log.Debug().Msg("Read user body from github")
// Parse the body into a user struct
var userInfo GithubUserInfoResponse var userInfo GithubUserInfoResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &userInfo) err = json.Unmarshal(body, &userInfo)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
// Get the user emails from github using the oauth http client
res, err = client.Get("https://api.github.com/user/emails") res, err = client.Get("https://api.github.com/user/emails")
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
defer res.Body.Close() defer res.Body.Close()
log.Debug().Msg("Got email response from github") log.Debug().Msg("Got email response from github")
// Read the body of the response
body, err = io.ReadAll(res.Body) body, err = io.ReadAll(res.Body)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Read email body from github") log.Debug().Msg("Read email body from github")
// Parse the body into a user struct
var emails GithubEmailResponse var emails GithubEmailResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &emails) err = json.Unmarshal(body, &emails)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
@@ -102,28 +79,24 @@ func GetGithubUser(client *http.Client) (constants.Claims, error) {
// Find and return the primary email // Find and return the primary email
for _, email := range emails { for _, email := range emails {
if email.Primary { if email.Primary {
// Set the email then exit
log.Debug().Str("email", email.Email).Msg("Found primary email") log.Debug().Str("email", email.Email).Msg("Found primary email")
user.Email = email.Email user.Email = email.Email
break break
} }
} }
// If no primary email was found, use the first available email
if len(emails) == 0 { if len(emails) == 0 {
return user, errors.New("no emails found") return user, errors.New("no emails found")
} }
// Set the email if it is not set picking the first one // Use first available email if no primary email was found
if user.Email == "" { if user.Email == "" {
log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email") log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email")
user.Email = emails[0].Email user.Email = emails[0].Email
} }
// Set the username and name
user.PreferredUsername = userInfo.Login user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name user.Name = userInfo.Name
// Return
return user, nil return user, nil
} }

View File

@@ -22,49 +22,35 @@ func GoogleScopes() []string {
} }
func GetGoogleUser(client *http.Client) (constants.Claims, error) { func GetGoogleUser(client *http.Client) (constants.Claims, error) {
// Create user struct
var user constants.Claims var user constants.Claims
// Get the user info from google using the oauth http client
res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
defer res.Body.Close() defer res.Body.Close()
log.Debug().Msg("Got response from google") log.Debug().Msg("Got response from google")
// Read the body of the response
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Read body from google") log.Debug().Msg("Read body from google")
// Create a new user info struct
var userInfo GoogleUserInfoResponse var userInfo GoogleUserInfoResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &userInfo) err = json.Unmarshal(body, &userInfo)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Parsed user from google") log.Debug().Msg("Parsed user from google")
// Map the user info to the user struct
user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
user.Name = userInfo.Name user.Name = userInfo.Name
user.Email = userInfo.Email user.Email = userInfo.Email
// Return the user
return user, nil return user, nil
} }

View File

@@ -23,11 +23,8 @@ func NewProviders(config types.OAuthConfig) *Providers {
Config: config, Config: config,
} }
// If we have a client id and secret for github, initialize the oauth provider
if config.GithubClientId != "" && config.GithubClientSecret != "" { if config.GithubClientId != "" && config.GithubClientSecret != "" {
log.Info().Msg("Initializing Github OAuth") log.Info().Msg("Initializing Github OAuth")
// Create a new oauth provider with the github config
providers.Github = oauth.NewOAuth(oauth2.Config{ providers.Github = oauth.NewOAuth(oauth2.Config{
ClientID: config.GithubClientId, ClientID: config.GithubClientId,
ClientSecret: config.GithubClientSecret, ClientSecret: config.GithubClientSecret,
@@ -37,11 +34,8 @@ func NewProviders(config types.OAuthConfig) *Providers {
}, false) }, false)
} }
// If we have a client id and secret for google, initialize the oauth provider
if config.GoogleClientId != "" && config.GoogleClientSecret != "" { if config.GoogleClientId != "" && config.GoogleClientSecret != "" {
log.Info().Msg("Initializing Google OAuth") log.Info().Msg("Initializing Google OAuth")
// Create a new oauth provider with the google config
providers.Google = oauth.NewOAuth(oauth2.Config{ providers.Google = oauth.NewOAuth(oauth2.Config{
ClientID: config.GoogleClientId, ClientID: config.GoogleClientId,
ClientSecret: config.GoogleClientSecret, ClientSecret: config.GoogleClientSecret,
@@ -51,11 +45,8 @@ func NewProviders(config types.OAuthConfig) *Providers {
}, false) }, false)
} }
// If we have a client id and secret for generic oauth, initialize the oauth provider
if config.GenericClientId != "" && config.GenericClientSecret != "" { if config.GenericClientId != "" && config.GenericClientSecret != "" {
log.Info().Msg("Initializing Generic OAuth") log.Info().Msg("Initializing Generic OAuth")
// Create a new oauth provider with the generic config
providers.Generic = oauth.NewOAuth(oauth2.Config{ providers.Generic = oauth.NewOAuth(oauth2.Config{
ClientID: config.GenericClientId, ClientID: config.GenericClientId,
ClientSecret: config.GenericClientSecret, ClientSecret: config.GenericClientSecret,
@@ -72,7 +63,6 @@ func NewProviders(config types.OAuthConfig) *Providers {
} }
func (providers *Providers) GetProvider(provider string) *oauth.OAuth { func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
// Return the provider based on the provider string
switch provider { switch provider {
case "github": case "github":
return providers.Github return providers.Github
@@ -86,82 +76,63 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
} }
func (providers *Providers) GetUser(provider string) (constants.Claims, error) { func (providers *Providers) GetUser(provider string) (constants.Claims, error) {
// Create user struct
var user constants.Claims var user constants.Claims
// Get the user from the provider // Get the user from the provider
switch provider { switch provider {
case "github": case "github":
// If the github provider is not configured, return an error
if providers.Github == nil { if providers.Github == nil {
log.Debug().Msg("Github provider not configured") log.Debug().Msg("Github provider not configured")
return user, nil return user, nil
} }
// Get the client from the github provider
client := providers.Github.GetClient() client := providers.Github.GetClient()
log.Debug().Msg("Got client from github") log.Debug().Msg("Got client from github")
// Get the user from the github provider
user, err := GetGithubUser(client) user, err := GetGithubUser(client)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Got user from github") log.Debug().Msg("Got user from github")
// Return the user
return user, nil return user, nil
case "google": case "google":
// If the google provider is not configured, return an error
if providers.Google == nil { if providers.Google == nil {
log.Debug().Msg("Google provider not configured") log.Debug().Msg("Google provider not configured")
return user, nil return user, nil
} }
// Get the client from the google provider
client := providers.Google.GetClient() client := providers.Google.GetClient()
log.Debug().Msg("Got client from google") log.Debug().Msg("Got client from google")
// Get the user from the google provider
user, err := GetGoogleUser(client) user, err := GetGoogleUser(client)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Got user from google") log.Debug().Msg("Got user from google")
// Return the user
return user, nil return user, nil
case "generic": case "generic":
// If the generic provider is not configured, return an error
if providers.Generic == nil { if providers.Generic == nil {
log.Debug().Msg("Generic provider not configured") log.Debug().Msg("Generic provider not configured")
return user, nil return user, nil
} }
// Get the client from the generic provider
client := providers.Generic.GetClient() client := providers.Generic.GetClient()
log.Debug().Msg("Got client from generic") log.Debug().Msg("Got client from generic")
// Get the user from the generic provider
user, err := GetGenericUser(client, providers.Config.GenericUserURL) user, err := GetGenericUser(client, providers.Config.GenericUserURL)
// Check if there was an error
if err != nil { if err != nil {
return user, err return user, err
} }
log.Debug().Msg("Got user from generic") log.Debug().Msg("Got user from generic")
// Return the email
return user, nil return user, nil
default: default:
return user, nil return user, nil
@@ -169,7 +140,6 @@ func (providers *Providers) GetUser(provider string) (constants.Claims, error) {
} }
func (provider *Providers) GetConfiguredProviders() []string { func (provider *Providers) GetConfiguredProviders() []string {
// Create a list of the configured providers
providers := []string{} providers := []string{}
if provider.Github != nil { if provider.Github != nil {
providers = append(providers, "github") providers = append(providers, "github")

View File

@@ -22,23 +22,18 @@ type Server struct {
} }
func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) { func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) {
// Disable gin logs
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
// Create router and use zerolog for logs
log.Debug().Msg("Setting up router") log.Debug().Msg("Setting up router")
router := gin.New() router := gin.New()
router.Use(zerolog()) router.Use(zerolog())
// Read UI assets
log.Debug().Msg("Setting up assets") log.Debug().Msg("Setting up assets")
dist, err := fs.Sub(assets.Assets, "dist") dist, err := fs.Sub(assets.Assets, "dist")
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Create file server
log.Debug().Msg("Setting up file server") log.Debug().Msg("Setting up file server")
fileServer := http.FileServer(http.FS(dist)) fileServer := http.FileServer(http.FS(dist))
@@ -46,18 +41,11 @@ func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server,
router.Use(func(c *gin.Context) { router.Use(func(c *gin.Context) {
// If not an API request, serve the UI // If not an API request, serve the UI
if !strings.HasPrefix(c.Request.URL.Path, "/api") { if !strings.HasPrefix(c.Request.URL.Path, "/api") {
// Check if the file exists
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
// If the file doesn't exist, serve the index.html
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.Request.URL.Path = "/" c.Request.URL.Path = "/"
} }
// Serve the file
fileServer.ServeHTTP(c.Writer, c.Request) fileServer.ServeHTTP(c.Writer, c.Request)
// Stop further processing
c.Abort() c.Abort()
} }
}) })
@@ -81,7 +69,6 @@ func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server,
// App routes // App routes
router.GET("/api/healthcheck", handlers.HealthcheckHandler) router.GET("/api/healthcheck", handlers.HealthcheckHandler)
// Return the server
return &Server{ return &Server{
Config: config, Config: config,
Handlers: handlers, Handlers: handlers,
@@ -90,9 +77,7 @@ func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server,
} }
func (s *Server) Start() error { func (s *Server) Start() error {
// Run server
log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server") log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server")
return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port)) return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port))
} }

View File

@@ -21,13 +21,13 @@ import (
"github.com/pquerna/otp/totp" "github.com/pquerna/otp/totp"
) )
// Simple server config for tests // Simple server config
var serverConfig = types.ServerConfig{ var serverConfig = types.ServerConfig{
Port: 8080, Port: 8080,
Address: "0.0.0.0", Address: "0.0.0.0",
} }
// Simple handlers config for tests // Simple handlers config
var handlersConfig = types.HandlersConfig{ var handlersConfig = types.HandlersConfig{
AppURL: "http://localhost:8080", AppURL: "http://localhost:8080",
Domain: "localhost", Domain: "localhost",
@@ -42,7 +42,7 @@ var handlersConfig = types.HandlersConfig{
OAuthAutoRedirect: "none", OAuthAutoRedirect: "none",
} }
// Simple auth config for tests // Simple auth config
var authConfig = types.AuthConfig{ var authConfig = types.AuthConfig{
Users: types.Users{}, Users: types.Users{},
OauthWhitelist: "", OauthWhitelist: "",
@@ -56,13 +56,13 @@ var authConfig = types.AuthConfig{
Domain: "localhost", Domain: "localhost",
} }
// Simple hooks config for tests // Simple hooks config
var hooksConfig = types.HooksConfig{ var hooksConfig = types.HooksConfig{
Domain: "localhost", Domain: "localhost",
} }
// Cookie // Cookie
var cookie = "MTc1MTkyMzM5MnxiME9aTzlGQjZMNEJMdDZMc0lHMk9zcXQyME9SR1ZnUmlaYWZNcWplek5vcVNpdkdHRTZqb09YWkVUYUN6NEt4MkEyOGEyX2hFQWZEUEYtbllDX0h5eDBCb3VyT2phQlRpZWFfRFdTMGw2WUg2VWw4RGdNbEhQclotOUJjblJGaWFQcmhyaWFna0dXRWNud2c1akg5eEpLZ3JzS0pfWktscVZyckZFR1VDX0R5QjFOT0hzMTNKb18ySEMxZlluSWNxa1ByM0VhSzNyMkRtdDNORWJXVGFYSnMzWjFGa0lrZlhSTWduRmttMHhQUXN4UFhNbHFXY0lBWjBnUWpKU0xXMHRubjlKbjV0LXBGdjk0MmpJX0xMX1ZYblVJVW9LWUJoWmpNanVXNkNjamhYWlR2V29rY0RNYWkxY2lMQnpqLUI2cHMyYTZkWWgtWnlFdGN0amh2WURUeUNGT3ZLS1FJVUFIb0NWR1RPMlRtY2c9PXwerwFtb9urOXnwA02qXbLeorMloaK_paQd0in4BAesmg==" var cookie string
// User // User
var user = types.User{ var user = types.User{
@@ -72,14 +72,7 @@ var user = types.User{
// Initialize the server for tests // Initialize the server for tests
func getServer(t *testing.T) *server.Server { func getServer(t *testing.T) *server.Server {
// Create docker service // Create services
docker, err := docker.NewDocker()
if err != nil {
t.Fatalf("Failed to initialize docker: %v", err)
}
// Create auth service
authConfig.Users = types.Users{ authConfig.Users = types.Users{
{ {
Username: user.Username, Username: user.Username,
@@ -87,69 +80,51 @@ func getServer(t *testing.T) *server.Server {
TotpSecret: user.TotpSecret, TotpSecret: user.TotpSecret,
}, },
} }
auth := auth.NewAuth(authConfig, docker, nil) docker, err := docker.NewDocker()
if err != nil {
// Create providers service t.Fatalf("Failed to create docker client: %v", err)
}
auth := auth.NewAuth(authConfig, nil, nil)
providers := providers.NewProviders(types.OAuthConfig{}) providers := providers.NewProviders(types.OAuthConfig{})
// Create hooks service
hooks := hooks.NewHooks(hooksConfig, auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers service
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
// Create server // Create server
srv, err := server.NewServer(serverConfig, handlers) srv, err := server.NewServer(serverConfig, handlers)
if err != nil { if err != nil {
t.Fatalf("Failed to create server: %v", err) t.Fatalf("Failed to create server: %v", err)
} }
// Return the server
return srv return srv
} }
// Test login
func TestLogin(t *testing.T) { func TestLogin(t *testing.T) {
t.Log("Testing login") t.Log("Testing login")
// Get server
srv := getServer(t) srv := getServer(t)
// Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
// Create request
user := types.LoginRequest{ user := types.LoginRequest{
Username: "user", Username: "user",
Password: "pass", Password: "pass",
} }
json, err := json.Marshal(user) json, err := json.Marshal(user)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error marshalling json: %v", err) t.Fatalf("Error marshalling json: %v", err)
} }
// Create request
req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(json))) req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(json)))
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
// Get the result cookie
cookies := recorder.Result().Cookies() cookies := recorder.Result().Cookies()
// Check if the cookie is set
if len(cookies) == 0 { if len(cookies) == 0 {
t.Fatalf("Cookie not set") t.Fatalf("Cookie not set")
} }
@@ -158,55 +133,42 @@ func TestLogin(t *testing.T) {
cookie = cookies[0].Value cookie = cookies[0].Value
} }
// Test app context
func TestAppContext(t *testing.T) { func TestAppContext(t *testing.T) {
// Refresh the cookie
TestLogin(t)
t.Log("Testing app context") t.Log("Testing app context")
// Get server
srv := getServer(t) srv := getServer(t)
// Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest("GET", "/api/app", nil) req, err := http.NewRequest("GET", "/api/app", nil)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Set the cookie // Set the cookie from the previous test
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth", Name: "tinyauth",
Value: cookie, Value: cookie,
}) })
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
// Read the body of the response
body, err := io.ReadAll(recorder.Body) body, err := io.ReadAll(recorder.Body)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting body: %v", err) t.Fatalf("Error getting body: %v", err)
} }
// Unmarshal the body into the user struct
var app types.AppContext var app types.AppContext
err = json.Unmarshal(body, &app) err = json.Unmarshal(body, &app)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error unmarshalling body: %v", err) t.Fatalf("Error unmarshalling body: %v", err)
} }
// Create tests values
expected := types.AppContext{ expected := types.AppContext{
Status: 200, Status: 200,
Message: "OK", Message: "OK",
@@ -226,48 +188,34 @@ func TestAppContext(t *testing.T) {
} }
} }
// Test user context
func TestUserContext(t *testing.T) { func TestUserContext(t *testing.T) {
// Refresh the cookie // Refresh the cookie
TestLogin(t) TestLogin(t)
t.Log("Testing user context") t.Log("Testing user context")
// Get server
srv := getServer(t) srv := getServer(t)
// Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest("GET", "/api/user", nil) req, err := http.NewRequest("GET", "/api/user", nil)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Set the cookie
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth-session", Name: "tinyauth-session",
Value: cookie, Value: cookie,
}) })
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
// Read the body of the response
body, err := io.ReadAll(recorder.Body) body, err := io.ReadAll(recorder.Body)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting body: %v", err) t.Fatalf("Error getting body: %v", err)
} }
// Unmarshal the body into the user struct
type User struct { type User struct {
Username string `json:"username"` Username string `json:"username"`
} }
@@ -275,49 +223,37 @@ func TestUserContext(t *testing.T) {
var user User var user User
err = json.Unmarshal(body, &user) err = json.Unmarshal(body, &user)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error unmarshalling body: %v", err) t.Fatalf("Error unmarshalling body: %v", err)
} }
// We should get the username back // We should get the user back
if user.Username != "user" { if user.Username != "user" {
t.Fatalf("Expected user, got %s", user.Username) t.Fatalf("Expected user, got %s", user.Username)
} }
} }
// Test logout
func TestLogout(t *testing.T) { func TestLogout(t *testing.T) {
// Refresh the cookie // Refresh the cookie
TestLogin(t) TestLogin(t)
t.Log("Testing logout") t.Log("Testing logout")
// Get server
srv := getServer(t) srv := getServer(t)
// Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest("POST", "/api/logout", nil) req, err := http.NewRequest("POST", "/api/logout", nil)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Set the cookie
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth-session", Name: "tinyauth-session",
Value: cookie, Value: cookie,
}) })
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
// Check if the cookie is different (means the cookie is gone) // Check if the cookie is different (means the cookie is gone)
@@ -326,196 +262,133 @@ func TestLogout(t *testing.T) {
} }
} }
// Test auth endpoint
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
// Refresh the cookie // Refresh the cookie
TestLogin(t) TestLogin(t)
t.Log("Testing auth endpoint") t.Log("Testing auth endpoint")
// Get server
srv := getServer(t) srv := getServer(t)
// Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest("GET", "/api/auth/traefik", nil) req, err := http.NewRequest("GET", "/api/auth/traefik", nil)
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the accept header
req.Header.Set("Accept", "text/html") req.Header.Set("Accept", "text/html")
// Check if there was an error
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusTemporaryRedirect) assert.Equal(t, recorder.Code, http.StatusTemporaryRedirect)
// Recreate recorder
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
// Recreate the request
req, err = http.NewRequest("GET", "/api/auth/traefik", nil) req, err = http.NewRequest("GET", "/api/auth/traefik", nil)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Test with the cookie
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth-session", Name: "tinyauth-session",
Value: cookie, Value: cookie,
}) })
// Serve the request again
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
// Recreate recorder
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
// Recreate the request
req, err = http.NewRequest("GET", "/api/auth/nginx", nil) req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Serve the request again
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusUnauthorized) assert.Equal(t, recorder.Code, http.StatusUnauthorized)
// Recreate recorder
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
// Recreate the request
req, err = http.NewRequest("GET", "/api/auth/nginx", nil) req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Test with the cookie
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth-session", Name: "tinyauth-session",
Value: cookie, Value: cookie,
}) })
// Serve the request again
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
} }
func TestTOTP(t *testing.T) { func TestTOTP(t *testing.T) {
t.Log("Testing TOTP") t.Log("Testing TOTP")
// Generate totp secret
key, err := totp.Generate(totp.GenerateOpts{ key, err := totp.Generate(totp.GenerateOpts{
Issuer: "Tinyauth", Issuer: "Tinyauth",
AccountName: user.Username, AccountName: user.Username,
}) })
if err != nil { if err != nil {
t.Fatalf("Failed to generate TOTP secret: %v", err) t.Fatalf("Failed to generate TOTP secret: %v", err)
} }
// Create secret
secret := key.Secret() secret := key.Secret()
// Set the user's TOTP secret
user.TotpSecret = secret user.TotpSecret = secret
// Get server
srv := getServer(t) srv := getServer(t)
// Create request
user := types.LoginRequest{ user := types.LoginRequest{
Username: "user", Username: "user",
Password: "pass", Password: "pass",
} }
loginJson, err := json.Marshal(user) loginJson, err := json.Marshal(user)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error marshalling json: %v", err) t.Fatalf("Error marshalling json: %v", err)
} }
// Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
// Create request
req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(loginJson))) req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(loginJson)))
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
// Set the cookie for next test // Set the cookie for next test
cookie = recorder.Result().Cookies()[0].Value cookie = recorder.Result().Cookies()[0].Value
// Create TOTP code
code, err := totp.GenerateCode(secret, time.Now()) code, err := totp.GenerateCode(secret, time.Now())
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Failed to generate TOTP code: %v", err) t.Fatalf("Failed to generate TOTP code: %v", err)
} }
// Create TOTP request
totpRequest := types.TotpRequest{ totpRequest := types.TotpRequest{
Code: code, Code: code,
} }
// Marshal the TOTP request
totpJson, err := json.Marshal(totpRequest) totpJson, err := json.Marshal(totpRequest)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error marshalling TOTP request: %v", err) t.Fatalf("Error marshalling TOTP request: %v", err)
} }
// Create recorder
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
// Create request
req, err = http.NewRequest("POST", "/api/totp", strings.NewReader(string(totpJson))) req, err = http.NewRequest("POST", "/api/totp", strings.NewReader(string(totpJson)))
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating request: %v", err) t.Fatalf("Error creating request: %v", err)
} }
// Set the cookie
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth-session", Name: "tinyauth-session",
Value: cookie, Value: cookie,
}) })
// Serve the request
srv.Router.ServeHTTP(recorder, req) srv.Router.ServeHTTP(recorder, req)
// Assert
assert.Equal(t, recorder.Code, http.StatusOK) assert.Equal(t, recorder.Code, http.StatusOK)
} }

View File

@@ -24,168 +24,118 @@ import (
func ParseUsers(users string) (types.Users, error) { func ParseUsers(users string) (types.Users, error) {
log.Debug().Msg("Parsing users") log.Debug().Msg("Parsing users")
// Create a new users struct
var usersParsed types.Users var usersParsed types.Users
// Split the users by comma
userList := strings.Split(users, ",") userList := strings.Split(users, ",")
// Check if there are any users
if len(userList) == 0 { if len(userList) == 0 {
return types.Users{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
// Loop through the users and split them by colon
for _, user := range userList { for _, user := range userList {
parsed, err := ParseUser(user) parsed, err := ParseUser(user)
// Check if there was an error
if err != nil { if err != nil {
return types.Users{}, err return types.Users{}, err
} }
// Append the user to the users struct
usersParsed = append(usersParsed, parsed) usersParsed = append(usersParsed, parsed)
} }
log.Debug().Msg("Parsed users") log.Debug().Msg("Parsed users")
// Return the users struct
return usersParsed, nil return usersParsed, nil
} }
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) // Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
func GetUpperDomain(urlSrc string) (string, error) { func GetUpperDomain(urlSrc string) (string, error) {
// Make sure the url is valid
urlParsed, err := url.Parse(urlSrc) urlParsed, err := url.Parse(urlSrc)
// Check if there was an error
if err != nil { if err != nil {
return "", err return "", err
} }
// Split the hostname by period
urlSplitted := strings.Split(urlParsed.Hostname(), ".") urlSplitted := strings.Split(urlParsed.Hostname(), ".")
// Get the last part of the url
urlFinal := strings.Join(urlSplitted[1:], ".") urlFinal := strings.Join(urlSplitted[1:], ".")
// Return the root domain
return urlFinal, nil return urlFinal, nil
} }
// Reads a file and returns the contents // Reads a file and returns the contents
func ReadFile(file string) (string, error) { func ReadFile(file string) (string, error) {
// Check if the file exists
_, err := os.Stat(file) _, err := os.Stat(file)
// Check if there was an error
if err != nil { if err != nil {
return "", err return "", err
} }
// Read the file
data, err := os.ReadFile(file) data, err := os.ReadFile(file)
// Check if there was an error
if err != nil { if err != nil {
return "", err return "", err
} }
// Return the file contents
return string(data), nil return string(data), nil
} }
// Parses a file into a comma separated list of users // Parses a file into a comma separated list of users
func ParseFileToLine(content string) string { func ParseFileToLine(content string) string {
// Split the content by newline
lines := strings.Split(content, "\n") lines := strings.Split(content, "\n")
// Create a list of users
users := make([]string, 0) users := make([]string, 0)
// Loop through the lines, trimming the whitespace and appending to the users list
for _, line := range lines { for _, line := range lines {
if strings.TrimSpace(line) == "" { if strings.TrimSpace(line) == "" {
continue continue
} }
users = append(users, strings.TrimSpace(line)) users = append(users, strings.TrimSpace(line))
} }
// Return the users as a comma separated string
return strings.Join(users, ",") return strings.Join(users, ",")
} }
// Get the secret from the config or file // Get the secret from the config or file
func GetSecret(conf string, file string) string { func GetSecret(conf string, file string) string {
// If neither the config or file is set, return an empty string
if conf == "" && file == "" { if conf == "" && file == "" {
return "" return ""
} }
// If the config is set, return the config (environment variable)
if conf != "" { if conf != "" {
return conf return conf
} }
// If the file is set, read the file
contents, err := ReadFile(file) contents, err := ReadFile(file)
// Check if there was an error
if err != nil { if err != nil {
return "" return ""
} }
// Return the contents of the file
return ParseSecretFile(contents) return ParseSecretFile(contents)
} }
// Get the users from the config or file // Get the users from the config or file
func GetUsers(conf string, file string) (types.Users, error) { func GetUsers(conf string, file string) (types.Users, error) {
// Create a string to store the users
var users string var users string
// If neither the config or file is set, return an empty users struct
if conf == "" && file == "" { if conf == "" && file == "" {
return types.Users{}, nil return types.Users{}, nil
} }
// If the config (environment) is set, append the users to the users string
if conf != "" { if conf != "" {
log.Debug().Msg("Using users from config") log.Debug().Msg("Using users from config")
users += conf users += conf
} }
// If the file is set, read the file and append the users to the users string
if file != "" { if file != "" {
// Read the file
contents, err := ReadFile(file) contents, err := ReadFile(file)
// If there isn't an error we can append the users to the users string
if err == nil { if err == nil {
log.Debug().Msg("Using users from file") log.Debug().Msg("Using users from file")
// Append the users to the users string
if users != "" { if users != "" {
users += "," users += ","
} }
// Parse the file contents into a comma separated list of users
users += ParseFileToLine(contents) users += ParseFileToLine(contents)
} }
} }
// Return the parsed users
return ParseUsers(users) return ParseUsers(users)
} }
// Parse the headers in a map[string]string format // Parse the headers in a map[string]string format
func ParseHeaders(headers []string) map[string]string { func ParseHeaders(headers []string) map[string]string {
// Create a map to store the headers
headerMap := make(map[string]string) headerMap := make(map[string]string)
// Loop through the headers
for _, header := range headers { for _, header := range headers {
split := strings.SplitN(header, "=", 2) split := strings.SplitN(header, "=", 2)
if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" {
@@ -197,25 +147,19 @@ func ParseHeaders(headers []string) map[string]string {
headerMap[key] = value headerMap[key] = value
} }
// Return the header map
return headerMap return headerMap
} }
// Get labels parses a map of labels into a struct with only the needed labels // Get labels parses a map of labels into a struct with only the needed labels
func GetLabels(labels map[string]string) (types.Labels, error) { func GetLabels(labels map[string]string) (types.Labels, error) {
// Create a new labels struct
var labelsParsed types.Labels var labelsParsed types.Labels
// Decode the labels into the labels struct
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
// Check if there was an error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error parsing labels") log.Error().Err(err).Msg("Error parsing labels")
return types.Labels{}, err return types.Labels{}, err
} }
// Return the labels struct
return labelsParsed, nil return labelsParsed, nil
} }
@@ -236,27 +180,22 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
// Parse user // Parse user
func ParseUser(user string) (types.User, error) { func ParseUser(user string) (types.User, error) {
// Check if the user is escaped
if strings.Contains(user, "$$") { if strings.Contains(user, "$$") {
user = strings.ReplaceAll(user, "$$", "$") user = strings.ReplaceAll(user, "$$", "$")
} }
// Split the user by colon
userSplit := strings.Split(user, ":") userSplit := strings.Split(user, ":")
// Check if the user is in the correct format
if len(userSplit) < 2 || len(userSplit) > 3 { if len(userSplit) < 2 || len(userSplit) > 3 {
return types.User{}, errors.New("invalid user format") return types.User{}, errors.New("invalid user format")
} }
// Check for empty strings
for _, userPart := range userSplit { for _, userPart := range userSplit {
if strings.TrimSpace(userPart) == "" { if strings.TrimSpace(userPart) == "" {
return types.User{}, errors.New("invalid user format") return types.User{}, errors.New("invalid user format")
} }
} }
// Check if the user has a totp secret
if len(userSplit) == 2 { if len(userSplit) == 2 {
return types.User{ return types.User{
Username: strings.TrimSpace(userSplit[0]), Username: strings.TrimSpace(userSplit[0]),
@@ -264,7 +203,6 @@ func ParseUser(user string) (types.User, error) {
}, nil }, nil
} }
// Return the user struct
return types.User{ return types.User{
Username: strings.TrimSpace(userSplit[0]), Username: strings.TrimSpace(userSplit[0]),
Password: strings.TrimSpace(userSplit[1]), Password: strings.TrimSpace(userSplit[1]),
@@ -274,60 +212,44 @@ func ParseUser(user string) (types.User, error) {
// Parse secret file // Parse secret file
func ParseSecretFile(contents string) string { func ParseSecretFile(contents string) string {
// Split to lines
lines := strings.Split(contents, "\n") lines := strings.Split(contents, "\n")
// Loop through the lines
for _, line := range lines { for _, line := range lines {
// Check if the line is empty
if strings.TrimSpace(line) == "" { if strings.TrimSpace(line) == "" {
continue continue
} }
// Return the line
return strings.TrimSpace(line) return strings.TrimSpace(line)
} }
// Return an empty string
return "" return ""
} }
// Check if a string matches a regex or if it is included in a comma separated list // Check if a string matches a regex or if it is included in a comma separated list
func CheckFilter(filter string, str string) bool { func CheckFilter(filter string, str string) bool {
// Check if the filter is empty
if len(strings.TrimSpace(filter)) == 0 { if len(strings.TrimSpace(filter)) == 0 {
return true return true
} }
// Check if the filter is a regex
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
// Create regex
re, err := regexp.Compile(filter[1 : len(filter)-1]) re, err := regexp.Compile(filter[1 : len(filter)-1])
// Check if there was an error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error compiling regex") log.Error().Err(err).Msg("Error compiling regex")
return false return false
} }
// Check if the string matches the regex
if re.MatchString(str) { if re.MatchString(str) {
return true return true
} }
} }
// Split the filter by comma
filterSplit := strings.Split(filter, ",") filterSplit := strings.Split(filter, ",")
// Loop through the filter items
for _, item := range filterSplit { for _, item := range filterSplit {
// Check if the item matches with the string
if strings.TrimSpace(item) == str { if strings.TrimSpace(item) == str {
return true return true
} }
} }
// Return false if no match was found
return false return false
} }
@@ -352,89 +274,56 @@ func SanitizeHeader(header string) string {
// Generate a static identifier from a string // Generate a static identifier from a string
func GenerateIdentifier(str string) string { func GenerateIdentifier(str string) string {
// Create a new UUID
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
// Convert the UUID to a string
uuidString := uuid.String() uuidString := uuid.String()
// Show the UUID
log.Debug().Str("uuid", uuidString).Msg("Generated UUID") log.Debug().Str("uuid", uuidString).Msg("Generated UUID")
// Convert the UUID to a string
return strings.Split(uuidString, "-")[0] return strings.Split(uuidString, "-")[0]
} }
// Get a basic auth header from a username and password // Get a basic auth header from a username and password
func GetBasicAuth(username string, password string) string { func GetBasicAuth(username string, password string) string {
// Create the auth string
auth := username + ":" + password auth := username + ":" + password
// Encode the auth string to base64
return base64.StdEncoding.EncodeToString([]byte(auth)) return base64.StdEncoding.EncodeToString([]byte(auth))
} }
// Check if an IP is contained in a CIDR range/matches a single IP // Check if an IP is contained in a CIDR range/matches a single IP
func FilterIP(filter string, ip string) (bool, error) { func FilterIP(filter string, ip string) (bool, error) {
// Convert the check IP to an IP instance
ipAddr := net.ParseIP(ip) ipAddr := net.ParseIP(ip)
// Check if the filter is a CIDR range
if strings.Contains(filter, "/") { if strings.Contains(filter, "/") {
// Parse the CIDR range
_, cidr, err := net.ParseCIDR(filter) _, cidr, err := net.ParseCIDR(filter)
// Check if there was an error
if err != nil { if err != nil {
return false, err return false, err
} }
// Check if the IP is in the CIDR range
return cidr.Contains(ipAddr), nil return cidr.Contains(ipAddr), nil
} }
// Parse the filter as a single IP
ipFilter := net.ParseIP(filter) ipFilter := net.ParseIP(filter)
// Check if the IP is valid
if ipFilter == nil { if ipFilter == nil {
return false, errors.New("invalid IP address in filter") return false, errors.New("invalid IP address in filter")
} }
// Check if the IP matches the filter
if ipFilter.Equal(ipAddr) { if ipFilter.Equal(ipAddr) {
return true, nil return true, nil
} }
// If the filter is not a CIDR range or a single IP, return false
return false, nil return false, nil
} }
func DeriveKey(secret string, info string) (string, error) { func DeriveKey(secret string, info string) (string, error) {
// Create hashing function
hash := sha256.New hash := sha256.New
// Create a new key using the secret and info
hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice
// Create a new key
key := make([]byte, 24) key := make([]byte, 24)
// Read the key from the HKDF
_, err := io.ReadFull(hkdf, key) _, err := io.ReadFull(hkdf, key)
if err != nil { if err != nil {
return "", err return "", err
} }
// Verify the key is not empty
if bytes.Equal(key, make([]byte, 24)) { if bytes.Equal(key, make([]byte, 24)) {
return "", errors.New("derived key is empty") return "", errors.New("derived key is empty")
} }
// Encode the key to base64
encodedKey := base64.StdEncoding.EncodeToString(key) encodedKey := base64.StdEncoding.EncodeToString(key)
// Return the key as a base64 encoded string
return encodedKey, nil return encodedKey, nil
} }

View File

@@ -9,11 +9,9 @@ import (
"tinyauth/internal/utils" "tinyauth/internal/utils"
) )
// Test the parse users function
func TestParseUsers(t *testing.T) { func TestParseUsers(t *testing.T) {
t.Log("Testing parse users with a valid string") t.Log("Testing parse users with a valid string")
// Test the parse users function with a valid string
users := "user1:pass1,user2:pass2" users := "user1:pass1,user2:pass2"
expected := types.Users{ expected := types.Users{
{ {
@@ -27,154 +25,116 @@ func TestParseUsers(t *testing.T) {
} }
result, err := utils.ParseUsers(users) result, err := utils.ParseUsers(users)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error parsing users: %v", err) t.Fatalf("Error parsing users: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the get upper domain function
func TestGetUpperDomain(t *testing.T) { func TestGetUpperDomain(t *testing.T) {
t.Log("Testing get upper domain with a valid url") t.Log("Testing get upper domain with a valid url")
// Test the get upper domain function with a valid url
url := "https://sub1.sub2.domain.com:8080" url := "https://sub1.sub2.domain.com:8080"
expected := "sub2.domain.com" expected := "sub2.domain.com"
result, err := utils.GetUpperDomain(url) result, err := utils.GetUpperDomain(url)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting root url: %v", err) t.Fatalf("Error getting root url: %v", err)
} }
// Check if the result is equal to the expected
if expected != result { if expected != result {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the read file function
func TestReadFile(t *testing.T) { func TestReadFile(t *testing.T) {
t.Log("Creating a test file") t.Log("Creating a test file")
// Create a test file
err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644) err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating test file: %v", err) t.Fatalf("Error creating test file: %v", err)
} }
// Test the read file function
t.Log("Testing read file with a valid file") t.Log("Testing read file with a valid file")
data, err := utils.ReadFile("/tmp/test.txt") data, err := utils.ReadFile("/tmp/test.txt")
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error reading file: %v", err) t.Fatalf("Error reading file: %v", err)
} }
// Check if the data is equal to the expected
if data != "test" { if data != "test" {
t.Fatalf("Expected test, got %v", data) t.Fatalf("Expected test, got %v", data)
} }
// Cleanup the test file
t.Log("Cleaning up test file") t.Log("Cleaning up test file")
err = os.Remove("/tmp/test.txt") err = os.Remove("/tmp/test.txt")
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error cleaning up test file: %v", err) t.Fatalf("Error cleaning up test file: %v", err)
} }
} }
// Test the parse file to line function
func TestParseFileToLine(t *testing.T) { func TestParseFileToLine(t *testing.T) {
t.Log("Testing parse file to line with a valid string") t.Log("Testing parse file to line with a valid string")
// Test the parse file to line function with a valid string
content := "\nuser1:pass1\nuser2:pass2\n" content := "\nuser1:pass1\nuser2:pass2\n"
expected := "user1:pass1,user2:pass2" expected := "user1:pass1,user2:pass2"
result := utils.ParseFileToLine(content) result := utils.ParseFileToLine(content)
// Check if the result is equal to the expected
if expected != result { if expected != result {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the get secret function
func TestGetSecret(t *testing.T) { func TestGetSecret(t *testing.T) {
t.Log("Testing get secret with an empty config and file") t.Log("Testing get secret with an empty config and file")
// Test the get secret function with an empty config and file
conf := "" conf := ""
file := "/tmp/test.txt" file := "/tmp/test.txt"
expected := "test" expected := "test"
// Create file
err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644) err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating test file: %v", err) t.Fatalf("Error creating test file: %v", err)
} }
// Test
result := utils.GetSecret(conf, file) result := utils.GetSecret(conf, file)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing get secret with an empty file and a valid config") t.Log("Testing get secret with an empty file and a valid config")
// Test the get secret function with an empty file and a valid config
result = utils.GetSecret(expected, "") result = utils.GetSecret(expected, "")
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing get secret with both a valid config and file") t.Log("Testing get secret with both a valid config and file")
// Test the get secret function with both a valid config and file
result = utils.GetSecret(expected, file) result = utils.GetSecret(expected, file)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
// Cleanup the test file
t.Log("Cleaning up test file") t.Log("Cleaning up test file")
err = os.Remove(file) err = os.Remove(file)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error cleaning up test file: %v", err) t.Fatalf("Error cleaning up test file: %v", err)
} }
} }
// Test the get users function
func TestGetUsers(t *testing.T) { func TestGetUsers(t *testing.T) {
t.Log("Testing get users with a config and no file") t.Log("Testing get users with a config and no file")
// Test the get users function with a config and no file
conf := "user1:pass1,user2:pass2" conf := "user1:pass1,user2:pass2"
file := "" file := ""
expected := types.Users{ expected := types.Users{
@@ -189,20 +149,16 @@ func TestGetUsers(t *testing.T) {
} }
result, err := utils.GetUsers(conf, file) result, err := utils.GetUsers(conf, file)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting users: %v", err) t.Fatalf("Error getting users: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing get users with a file and no config") t.Log("Testing get users with a file and no config")
// Test the get users function with a file and no config
conf = "" conf = ""
file = "/tmp/test.txt" file = "/tmp/test.txt"
expected = types.Users{ expected = types.Users{
@@ -216,28 +172,20 @@ func TestGetUsers(t *testing.T) {
}, },
} }
// Create file
err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644) err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error creating test file: %v", err) t.Fatalf("Error creating test file: %v", err)
} }
// Test
result, err = utils.GetUsers(conf, file) result, err = utils.GetUsers(conf, file)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting users: %v", err) t.Fatalf("Error getting users: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
// Test the get users function with both a config and file
t.Log("Testing get users with both a config and file") t.Log("Testing get users with both a config and file")
conf = "user3:pass3" conf = "user3:pass3"
@@ -257,33 +205,25 @@ func TestGetUsers(t *testing.T) {
} }
result, err = utils.GetUsers(conf, file) result, err = utils.GetUsers(conf, file)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting users: %v", err) t.Fatalf("Error getting users: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
// Cleanup the test file
t.Log("Cleaning up test file") t.Log("Cleaning up test file")
err = os.Remove(file) err = os.Remove(file)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error cleaning up test file: %v", err) t.Fatalf("Error cleaning up test file: %v", err)
} }
} }
// Test the get labels function
func TestGetLabels(t *testing.T) { func TestGetLabels(t *testing.T) {
t.Log("Testing get labels with a valid map") t.Log("Testing get labels with a valid map")
// Test the get tinyauth labels function with a valid map
labels := map[string]string{ labels := map[string]string{
"tinyauth.users": "user1,user2", "tinyauth.users": "user1,user2",
"tinyauth.oauth.whitelist": "/regex/", "tinyauth.oauth.whitelist": "/regex/",
@@ -303,23 +243,18 @@ func TestGetLabels(t *testing.T) {
} }
result, err := utils.GetLabels(labels) result, err := utils.GetLabels(labels)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error getting labels: %v", err) t.Fatalf("Error getting labels: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test parse user
func TestParseUser(t *testing.T) { func TestParseUser(t *testing.T) {
t.Log("Testing parse user with a valid user") t.Log("Testing parse user with a valid user")
// Create variables
user := "user:pass:secret" user := "user:pass:secret"
expected := types.User{ expected := types.User{
Username: "user", Username: "user",
@@ -327,22 +262,17 @@ func TestParseUser(t *testing.T) {
TotpSecret: "secret", TotpSecret: "secret",
} }
// Test the parse user function
result, err := utils.ParseUser(user) result, err := utils.ParseUser(user)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error parsing user: %v", err) t.Fatalf("Error parsing user: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing parse user with an escaped user") t.Log("Testing parse user with an escaped user")
// Create variables
user = "user:p$$ass$$:secret" user = "user:p$$ass$$:secret"
expected = types.User{ expected = types.User{
Username: "user", Username: "user",
@@ -350,304 +280,233 @@ func TestParseUser(t *testing.T) {
TotpSecret: "secret", TotpSecret: "secret",
} }
// Test the parse user function
result, err = utils.ParseUser(user) result, err = utils.ParseUser(user)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error parsing user: %v", err) t.Fatalf("Error parsing user: %v", err)
} }
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing parse user with an invalid user") t.Log("Testing parse user with an invalid user")
// Create variables
user = "user::pass" user = "user::pass"
// Test the parse user function
_, err = utils.ParseUser(user) _, err = utils.ParseUser(user)
// Check if there was an error
if err == nil { if err == nil {
t.Fatalf("Expected error parsing user") t.Fatalf("Expected error parsing user")
} }
} }
// Test the check filter function
func TestCheckFilter(t *testing.T) { func TestCheckFilter(t *testing.T) {
t.Log("Testing check filter with a comma separated list") t.Log("Testing check filter with a comma separated list")
// Create variables
filter := "user1,user2,user3" filter := "user1,user2,user3"
str := "user1" str := "user1"
expected := true expected := true
// Test the check filter function
result := utils.CheckFilter(filter, str) result := utils.CheckFilter(filter, str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing check filter with a regex filter") t.Log("Testing check filter with a regex filter")
// Create variables
filter = "/^user[0-9]+$/" filter = "/^user[0-9]+$/"
str = "user1" str = "user1"
expected = true expected = true
// Test the check filter function
result = utils.CheckFilter(filter, str) result = utils.CheckFilter(filter, str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing check filter with an empty filter") t.Log("Testing check filter with an empty filter")
// Create variables
filter = "" filter = ""
str = "user1" str = "user1"
expected = true expected = true
// Test the check filter function
result = utils.CheckFilter(filter, str) result = utils.CheckFilter(filter, str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing check filter with an invalid regex filter") t.Log("Testing check filter with an invalid regex filter")
// Create variables
filter = "/^user[0-9+$/" filter = "/^user[0-9+$/"
str = "user1" str = "user1"
expected = false expected = false
// Test the check filter function
result = utils.CheckFilter(filter, str) result = utils.CheckFilter(filter, str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing check filter with a non matching list") t.Log("Testing check filter with a non matching list")
// Create variables
filter = "user1,user2,user3" filter = "user1,user2,user3"
str = "user4" str = "user4"
expected = false expected = false
// Test the check filter function
result = utils.CheckFilter(filter, str) result = utils.CheckFilter(filter, str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the header sanitizer
func TestSanitizeHeader(t *testing.T) { func TestSanitizeHeader(t *testing.T) {
t.Log("Testing sanitize header with a valid string") t.Log("Testing sanitize header with a valid string")
// Create variables
str := "X-Header=value" str := "X-Header=value"
expected := "X-Header=value" expected := "X-Header=value"
// Test the sanitize header function
result := utils.SanitizeHeader(str) result := utils.SanitizeHeader(str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing sanitize header with an invalid string") t.Log("Testing sanitize header with an invalid string")
// Create variables
str = "X-Header=val\nue" str = "X-Header=val\nue"
expected = "X-Header=value" expected = "X-Header=value"
// Test the sanitize header function
result = utils.SanitizeHeader(str) result = utils.SanitizeHeader(str)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the parse headers function
func TestParseHeaders(t *testing.T) { func TestParseHeaders(t *testing.T) {
t.Log("Testing parse headers with a valid string") t.Log("Testing parse headers with a valid string")
// Create variables
headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"} headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"}
expected := map[string]string{ expected := map[string]string{
"X-Header1": "value1", "X-Header1": "value1",
"X-Header2": "value2", "X-Header2": "value2",
} }
// Test the parse headers function
result := utils.ParseHeaders(headers) result := utils.ParseHeaders(headers)
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing parse headers with an invalid string") t.Log("Testing parse headers with an invalid string")
// Create variables
headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"} headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"}
expected = map[string]string{"X-Header3": "value3"} expected = map[string]string{"X-Header3": "value3"}
// Test the parse headers function
result = utils.ParseHeaders(headers) result = utils.ParseHeaders(headers)
// Check if the result is equal to the expected
if !reflect.DeepEqual(expected, result) { if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the parse secret file function
func TestParseSecretFile(t *testing.T) { func TestParseSecretFile(t *testing.T) {
t.Log("Testing parse secret file with a valid file") t.Log("Testing parse secret file with a valid file")
// Create variables
content := "\n\n \n\n\n secret \n\n \n " content := "\n\n \n\n\n secret \n\n \n "
expected := "secret" expected := "secret"
// Test the parse secret file function
result := utils.ParseSecretFile(content) result := utils.ParseSecretFile(content)
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
} }
// Test the filter IP function
func TestFilterIP(t *testing.T) { func TestFilterIP(t *testing.T) {
t.Log("Testing filter IP with an IP and a valid CIDR") t.Log("Testing filter IP with an IP and a valid CIDR")
// Create variables
ip := "10.10.10.10" ip := "10.10.10.10"
filter := "10.10.10.0/24" filter := "10.10.10.0/24"
expected := true expected := true
// Test the filter IP function
result, err := utils.FilterIP(filter, ip) result, err := utils.FilterIP(filter, ip)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error filtering IP: %v", err) t.Fatalf("Error filtering IP: %v", err)
} }
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing filter IP with an IP and a valid IP") t.Log("Testing filter IP with an IP and a valid IP")
// Create variables
filter = "10.10.10.10" filter = "10.10.10.10"
expected = true expected = true
// Test the filter IP function
result, err = utils.FilterIP(filter, ip) result, err = utils.FilterIP(filter, ip)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error filtering IP: %v", err) t.Fatalf("Error filtering IP: %v", err)
} }
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing filter IP with an IP and an non matching CIDR") t.Log("Testing filter IP with an IP and an non matching CIDR")
// Create variables
filter = "10.10.15.0/24" filter = "10.10.15.0/24"
expected = false expected = false
// Test the filter IP function
result, err = utils.FilterIP(filter, ip) result, err = utils.FilterIP(filter, ip)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error filtering IP: %v", err) t.Fatalf("Error filtering IP: %v", err)
} }
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing filter IP with a non matching IP and a valid CIDR") t.Log("Testing filter IP with a non matching IP and a valid CIDR")
// Create variables
filter = "10.10.10.11" filter = "10.10.10.11"
expected = false expected = false
// Test the filter IP function
result, err = utils.FilterIP(filter, ip) result, err = utils.FilterIP(filter, ip)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error filtering IP: %v", err) t.Fatalf("Error filtering IP: %v", err)
} }
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }
t.Log("Testing filter IP with an IP and an invalid CIDR") t.Log("Testing filter IP with an IP and an invalid CIDR")
// Create variables
filter = "10.../83" filter = "10.../83"
// Test the filter IP function
_, err = utils.FilterIP(filter, ip) _, err = utils.FilterIP(filter, ip)
// Check if there was an error
if err == nil { if err == nil {
t.Fatalf("Expected error filtering IP") t.Fatalf("Expected error filtering IP")
} }
} }
// Test the derive key function
func TestDeriveKey(t *testing.T) { func TestDeriveKey(t *testing.T) {
t.Log("Testing the derive key function") t.Log("Testing the derive key function")
// Create variables
master := "master" master := "master"
info := "info" info := "info"
expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl" expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl"
// Test the derive key function
result, err := utils.DeriveKey(master, info) result, err := utils.DeriveKey(master, info)
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Error deriving key: %v", err) t.Fatalf("Error deriving key: %v", err)
} }
// Check if the result is equal to the expected
if result != expected { if result != expected {
t.Fatalf("Expected %v, got %v", expected, result) t.Fatalf("Expected %v, got %v", expected, result)
} }

View File

@@ -10,9 +10,6 @@ import (
) )
func main() { func main() {
// Logger
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.FatalLevel) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.FatalLevel)
// Run cmd
cmd.Execute() cmd.Execute()
} }