chore: add comments to code

This commit is contained in:
Stavros
2025-02-08 12:33:58 +02:00
parent e09f241364
commit 7a3a463489
17 changed files with 485 additions and 92 deletions

View File

@@ -125,20 +125,24 @@ var rootCmd = &cobra.Command{
func Execute() { func Execute() {
err := rootCmd.Execute() err := rootCmd.Execute()
if err != nil { HandleError(err, "Failed to execute root command")
log.Fatal().Err(err).Msg("Failed to execute command")
}
} }
func HandleError(err error, msg string) { func HandleError(err error, msg string) {
// If error log it and exit
if err != nil { if err != nil {
log.Fatal().Err(err).Msg(msg) log.Fatal().Err(err).Msg(msg)
} }
} }
func init() { func init() {
// Add user command
rootCmd.AddCommand(cmd.UserCmd()) rootCmd.AddCommand(cmd.UserCmd())
// Read environment variables
viper.AutomaticEnv() viper.AutomaticEnv()
// Flags
rootCmd.Flags().Int("port", 3000, "Port to run the server on.") rootCmd.Flags().Int("port", 3000, "Port to run the server on.")
rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.")
rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") rootCmd.Flags().String("secret", "", "Secret to use for the cookie.")
@@ -167,6 +171,8 @@ func init() {
rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.")
rootCmd.Flags().Int("log-level", 1, "Log level.") rootCmd.Flags().Int("log-level", 1, "Log level.")
// Bind flags to environment
viper.BindEnv("port", "PORT") viper.BindEnv("port", "PORT")
viper.BindEnv("address", "ADDRESS") viper.BindEnv("address", "ADDRESS")
viper.BindEnv("secret", "SECRET") viper.BindEnv("secret", "SECRET")
@@ -195,5 +201,7 @@ func init() {
viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
viper.BindEnv("session-expiry", "SESSION_EXPIRY") viper.BindEnv("session-expiry", "SESSION_EXPIRY")
viper.BindEnv("log-level", "LOG_LEVEL") viper.BindEnv("log-level", "LOG_LEVEL")
// Bind flags to viper
viper.BindPFlags(rootCmd.Flags()) viper.BindPFlags(rootCmd.Flags())
} }

View File

@@ -22,9 +22,12 @@ var CreateCmd = &cobra.Command{
Short: "Create a user", Short: "Create a user",
Long: `Create a user either interactively or by passing flags.`, Long: `Create a user either interactively or by passing flags.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Setup logger
log.Logger = log.Level(zerolog.InfoLevel) log.Logger = log.Level(zerolog.InfoLevel)
// Check if interactive
if interactive { if interactive {
// Create huh form
form := huh.NewForm( form := huh.NewForm(
huh.NewGroup( huh.NewGroup(
huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error { huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error {
@@ -43,6 +46,7 @@ var CreateCmd = &cobra.Command{
), ),
) )
// Use simple theme
var baseTheme *huh.Theme = huh.ThemeBase() var baseTheme *huh.Theme = huh.ThemeBase()
formErr := form.WithTheme(baseTheme).Run() formErr := form.WithTheme(baseTheme).Run()
@@ -52,12 +56,14 @@ var CreateCmd = &cobra.Command{
} }
} }
// Do we have username and password?
if username == "" || password == "" { if username == "" || password == "" {
log.Error().Msg("Username and password cannot be empty") log.Error().Msg("Username and password cannot be empty")
} }
log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user") log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user")
// Hash password
passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if passwordErr != nil { if passwordErr != nil {
@@ -66,15 +72,18 @@ var CreateCmd = &cobra.Command{
passwordString := string(passwordByte) passwordString := string(passwordByte)
// Escape $ for docker
if docker { if docker {
passwordString = strings.ReplaceAll(passwordString, "$", "$$") passwordString = strings.ReplaceAll(passwordString, "$", "$$")
} }
// Log user created
log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created") log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created")
}, },
} }
func init() { func init() {
// Flags
CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively") CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively")
CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker") CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker")
CreateCmd.Flags().StringVar(&username, "username", "", "Username") CreateCmd.Flags().StringVar(&username, "username", "", "Username")

View File

@@ -22,9 +22,12 @@ var VerifyCmd = &cobra.Command{
Short: "Verify a user is set up correctly", Short: "Verify a user is set up correctly",
Long: `Verify a user is set up correctly meaning that it has a correct username and password.`, Long: `Verify a user is set up correctly meaning that it has a correct username and password.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Setup logger
log.Logger = log.Level(zerolog.InfoLevel) log.Logger = log.Level(zerolog.InfoLevel)
// Check if interactive
if interactive { if interactive {
// Create huh form
form := huh.NewForm( form := huh.NewForm(
huh.NewGroup( huh.NewGroup(
huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error { huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error {
@@ -49,6 +52,7 @@ var VerifyCmd = &cobra.Command{
), ),
) )
// Use simple theme
var baseTheme *huh.Theme = huh.ThemeBase() var baseTheme *huh.Theme = huh.ThemeBase()
formErr := form.WithTheme(baseTheme).Run() formErr := form.WithTheme(baseTheme).Run()
@@ -58,22 +62,26 @@ var VerifyCmd = &cobra.Command{
} }
} }
// Do we have username, password and user?
if username == "" || password == "" || user == "" { if username == "" || password == "" || user == "" {
log.Fatal().Msg("Username, password and user cannot be empty") log.Fatal().Msg("Username, password and user cannot be empty")
} }
log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user") log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user")
// Split username and password
userSplit := strings.Split(user, ":") userSplit := strings.Split(user, ":")
if userSplit[1] == "" { if userSplit[1] == "" {
log.Fatal().Msg("User is not formatted correctly") log.Fatal().Msg("User is not formatted correctly")
} }
// Replace $$ with $ if formatted for docker
if docker { if docker {
userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$") userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$")
} }
// Compare username and password
verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password)) verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password))
if verifyErr != nil || username != userSplit[0] { if verifyErr != nil || username != userSplit[0] {
@@ -85,6 +93,7 @@ var VerifyCmd = &cobra.Command{
} }
func init() { func init() {
// Flags
VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively") VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?") VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?")
VerifyCmd.Flags().StringVar(&username, "username", "", "Username") VerifyCmd.Flags().StringVar(&username, "username", "", "Username")

View File

@@ -41,11 +41,15 @@ type API struct {
} }
func (api *API) Init() { func (api *API) Init() {
// Disable gin logs
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
// Create router and use zerolog for logs
log.Debug().Msg("Setting up router") log.Debug().Msg("Setting up router")
router := gin.New() router := gin.New()
router.Use(zerolog()) router.Use(zerolog())
// Read UI assets
log.Debug().Msg("Setting up assets") log.Debug().Msg("Setting up assets")
dist, distErr := fs.Sub(assets.Assets, "dist") dist, distErr := fs.Sub(assets.Assets, "dist")
@@ -53,11 +57,15 @@ func (api *API) Init() {
log.Fatal().Err(distErr).Msg("Failed to get UI assets") log.Fatal().Err(distErr).Msg("Failed to get UI assets")
} }
// Create file server
log.Debug().Msg("Setting up file server") log.Debug().Msg("Setting up file server")
fileServer := http.FileServer(http.FS(dist)) fileServer := http.FileServer(http.FS(dist))
// Setup cookie store
log.Debug().Msg("Setting up cookie store") log.Debug().Msg("Setting up cookie store")
store := cookie.NewStore([]byte(api.Config.Secret)) store := cookie.NewStore([]byte(api.Config.Secret))
// Get domain to use for session cookies
log.Debug().Msg("Getting domain") log.Debug().Msg("Getting domain")
domain, domainErr := utils.GetRootURL(api.Config.AppURL) domain, domainErr := utils.GetRootURL(api.Config.AppURL)
@@ -70,6 +78,7 @@ func (api *API) Init() {
api.Domain = fmt.Sprintf(".%s", domain) api.Domain = fmt.Sprintf(".%s", domain)
// Use session middleware
store.Options(sessions.Options{ store.Options(sessions.Options{
Domain: api.Domain, Domain: api.Domain,
Path: "/", Path: "/",
@@ -80,175 +89,169 @@ func (api *API) Init() {
router.Use(sessions.Sessions("tinyauth", store)) router.Use(sessions.Sessions("tinyauth", store))
// UI middleware
router.Use(func(c *gin.Context) { router.Use(func(c *gin.Context) {
// If not an API request, serve the UI
if !strings.HasPrefix(c.Request.URL.Path, "/api") { if !strings.HasPrefix(c.Request.URL.Path, "/api") {
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
// If the file doesn't exist, serve the index.html
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.Request.URL.Path = "/" c.Request.URL.Path = "/"
} }
// Serve the file
fileServer.ServeHTTP(c.Writer, c.Request) fileServer.ServeHTTP(c.Writer, c.Request)
// Stop further processing
c.Abort() c.Abort()
} }
}) })
// Set router
api.Router = router api.Router = router
} }
func (api *API) SetupRoutes() { func (api *API) SetupRoutes() {
api.Router.GET("/api/auth/:proxy", func(c *gin.Context) { api.Router.GET("/api/auth/:proxy", func(c *gin.Context) {
// Create struct for proxy
var proxy types.Proxy var proxy types.Proxy
// Bind URI
bindErr := c.BindUri(&proxy) bindErr := c.BindUri(&proxy)
// Handle error
if api.handleError(c, "Failed to bind URI", bindErr) { if api.handleError(c, "Failed to bind URI", bindErr) {
return return
} }
log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy")
// Get user context
userContext := api.Hooks.UseUserContext(c) userContext := api.Hooks.UseUserContext(c)
// Get headers
uri := c.Request.Header.Get("X-Forwarded-Uri") uri := c.Request.Header.Get("X-Forwarded-Uri")
proto := c.Request.Header.Get("X-Forwarded-Proto") proto := c.Request.Header.Get("X-Forwarded-Proto")
host := c.Request.Header.Get("X-Forwarded-Host") host := c.Request.Header.Get("X-Forwarded-Host")
// Check if user is logged in
if userContext.IsLoggedIn { if userContext.IsLoggedIn {
log.Debug().Msg("Authenticated") log.Debug().Msg("Authenticated")
// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx
appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host) appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host)
// Check if there was an error
if appAllowedErr != nil { if appAllowedErr != nil {
switch proxy.Proxy { // Return 501 if nginx is the proxy or if the request is using an Authorization header
case "nginx": if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
log.Error().Err(appAllowedErr).Msg("Failed to check if resource is allowed") log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed")
c.JSON(501, gin.H{ c.JSON(501, gin.H{
"status": 501, "status": 501,
"message": "Internal Server Error", "message": "Internal Server Error",
}) })
return return
default: }
if c.GetHeader("Authorization") != "" {
log.Error().Err(appAllowedErr).Msg("Failed to check if resource is allowed") // Return the internal server error page
c.JSON(501, gin.H{ if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) {
"status": 501, return
"message": "Internal Server Error",
})
return
}
if api.handleError(c, "Failed to check if resource is allowed", appAllowedErr) {
return
}
} }
} }
log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
// The user is not allowed to access the app
if !appAllowed { if !appAllowed {
log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed")
// Build query
queries, queryErr := query.Values(types.UnauthorizedQuery{ queries, queryErr := query.Values(types.UnauthorizedQuery{
Username: userContext.Username, Username: userContext.Username,
Resource: strings.Split(host, ".")[0], Resource: strings.Split(host, ".")[0],
}) })
// Check if there was an error
if queryErr != nil { if queryErr != nil {
switch proxy.Proxy { // Return 501 if nginx is the proxy or if the request is using an Authorization header
case "nginx": if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
log.Error().Err(queryErr).Msg("Failed to build query") log.Error().Err(queryErr).Msg("Failed to build query")
c.JSON(501, gin.H{ c.JSON(501, gin.H{
"status": 501, "status": 501,
"message": "Internal Server Error", "message": "Internal Server Error",
}) })
return return
default: }
if c.GetHeader("Authorization") != "" {
log.Error().Err(appAllowedErr).Msg("Failed to build query") // Return the internal server error page
c.JSON(501, gin.H{ if api.handleError(c, "Failed to build query", queryErr) {
"status": 501, return
"message": "Internal Server Error",
})
return
}
if api.handleError(c, "Failed to build query", queryErr) {
return
}
} }
} }
switch proxy.Proxy { // Return 401 if nginx is the proxy or if the request is using an Authorization header
case "nginx": if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
default:
if c.GetHeader("Authorization") != "" {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode()))
return
} }
// We are using caddy/traefik so redirect
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode()))
// Stop further processing
return
} }
// The user is allowed to access the app
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
}) })
// Stop further processing
return return
} }
switch proxy.Proxy { // The user is not logged in
case "nginx": log.Debug().Msg("Unauthorized")
// Return 401 if nginx is the proxy or if the request is using an Authorization header
if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
default:
if c.GetHeader("Authorization") != "" {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
queries, queryErr := query.Values(types.LoginQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
})
log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
if queryErr != nil {
switch proxy.Proxy {
case "nginx":
log.Error().Err(queryErr).Msg("Failed to build query")
c.JSON(501, gin.H{
"status": 501,
"message": "Internal Server Error",
})
return
default:
if api.handleError(c, "Failed to build query", queryErr) {
return
}
}
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
} }
// Build query
queries, queryErr := query.Values(types.LoginQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
})
log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
if api.handleError(c, "Failed to build query", queryErr) {
return
}
// Redirect to login
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
}) })
api.Router.POST("/api/login", func(c *gin.Context) { api.Router.POST("/api/login", func(c *gin.Context) {
// Create login struct
var login types.LoginRequest var login types.LoginRequest
// Bind JSON
err := c.BindJSON(&login) err := c.BindJSON(&login)
// Handle error
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind JSON") log.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -260,8 +263,10 @@ func (api *API) SetupRoutes() {
log.Debug().Msg("Got login request") log.Debug().Msg("Got login request")
// Get user based on username
user := api.Auth.GetUser(login.Username) user := api.Auth.GetUser(login.Username)
// User does not exist
if user == nil { if user == nil {
log.Debug().Str("username", login.Username).Msg("User not found") log.Debug().Str("username", login.Username).Msg("User not found")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
@@ -271,6 +276,9 @@ func (api *API) SetupRoutes() {
return return
} }
log.Debug().Msg("Got user")
// Check if password is correct
if !api.Auth.CheckPassword(*user, login.Password) { if !api.Auth.CheckPassword(*user, login.Password) {
log.Debug().Str("username", login.Username).Msg("Password incorrect") log.Debug().Str("username", login.Username).Msg("Password incorrect")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
@@ -282,11 +290,13 @@ func (api *API) SetupRoutes() {
log.Debug().Msg("Password correct, logging in") log.Debug().Msg("Password correct, logging in")
// Create session cookie with username as provider
api.Auth.CreateSessionCookie(c, &types.SessionCookie{ api.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: login.Username, Username: login.Username,
Provider: "username", Provider: "username",
}) })
// Return logged in
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logged in", "message": "Logged in",
@@ -294,12 +304,17 @@ func (api *API) SetupRoutes() {
}) })
api.Router.POST("/api/logout", func(c *gin.Context) { api.Router.POST("/api/logout", func(c *gin.Context) {
log.Debug().Msg("Logging out")
// Delete session cookie
api.Auth.DeleteSessionCookie(c) api.Auth.DeleteSessionCookie(c)
log.Debug().Msg("Cleaning up redirect cookie") log.Debug().Msg("Cleaning up redirect cookie")
// Clean up redirect cookie if it exists
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
// Return logged out
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logged out", "message": "Logged out",
@@ -308,19 +323,24 @@ func (api *API) SetupRoutes() {
api.Router.GET("/api/status", func(c *gin.Context) { api.Router.GET("/api/status", func(c *gin.Context) {
log.Debug().Msg("Checking status") log.Debug().Msg("Checking status")
// Get user context
userContext := api.Hooks.UseUserContext(c) userContext := api.Hooks.UseUserContext(c)
// Get configured providers
configuredProviders := api.Providers.GetConfiguredProviders() configuredProviders := api.Providers.GetConfiguredProviders()
// We have username/password configured so add it to our providers
if api.Auth.UserAuthConfigured() { if api.Auth.UserAuthConfigured() {
configuredProviders = append(configuredProviders, "username") configuredProviders = append(configuredProviders, "username")
} }
// We are not logged in so return unauthorized
if !userContext.IsLoggedIn { if !userContext.IsLoggedIn {
log.Debug().Msg("Unauthenticated") log.Debug().Msg("Unauthorized")
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Unauthenticated", "message": "Unauthorized",
"username": "", "username": "",
"isLoggedIn": false, "isLoggedIn": false,
"oauth": false, "oauth": false,
@@ -333,6 +353,7 @@ func (api *API) SetupRoutes() {
log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated") log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated")
// We are logged in so return our user context
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
@@ -345,18 +366,14 @@ func (api *API) SetupRoutes() {
}) })
}) })
api.Router.GET("/api/healthcheck", func(c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"message": "OK",
})
})
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
// Create struct for OAuth request
var request types.OAuthRequest var request types.OAuthRequest
// Bind URI
bindErr := c.BindUri(&request) bindErr := c.BindUri(&request)
// Handle error
if bindErr != nil { if bindErr != nil {
log.Error().Err(bindErr).Msg("Failed to bind URI") log.Error().Err(bindErr).Msg("Failed to bind URI")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -368,8 +385,10 @@ func (api *API) SetupRoutes() {
log.Debug().Msg("Got OAuth request") log.Debug().Msg("Got OAuth request")
// Check if provider exists
provider := api.Providers.GetProvider(request.Provider) provider := api.Providers.GetProvider(request.Provider)
// Provider does not exist
if provider == nil { if provider == nil {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
@@ -380,24 +399,38 @@ func (api *API) SetupRoutes() {
log.Debug().Str("provider", request.Provider).Msg("Got provider") log.Debug().Str("provider", request.Provider).Msg("Got provider")
// Get auth URL
authURL := provider.GetAuthURL() authURL := provider.GetAuthURL()
log.Debug().Msg("Got auth URL") log.Debug().Msg("Got auth URL")
// Get redirect URI
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
// Set redirect cookie if redirect URI is provided
if redirectURI != "" { if redirectURI != "" {
log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
} }
// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it
if request.Provider == "tailscale" { if request.Provider == "tailscale" {
// Build tailscale query
tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{ tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{
Code: (1000 + rand.IntN(9000)), // doesn't need to be secure, just there to avoid caching Code: (1000 + rand.IntN(9000)),
}) })
if api.handleError(c, "Failed to build query", tailscaleQueryErr) {
// Handle error
if tailscaleQueryErr != nil {
log.Error().Err(tailscaleQueryErr).Msg("Failed to build query")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return return
} }
// Return tailscale URL (immidiately redirects to the callback)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Ok", "message": "Ok",
@@ -406,6 +439,7 @@ func (api *API) SetupRoutes() {
return return
} }
// Return auth URL
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Ok", "message": "Ok",
@@ -414,18 +448,23 @@ func (api *API) SetupRoutes() {
}) })
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
// Create struct for OAuth request
var providerName types.OAuthRequest var providerName types.OAuthRequest
// Bind URI
bindErr := c.BindUri(&providerName) bindErr := c.BindUri(&providerName)
// Handle error
if api.handleError(c, "Failed to bind URI", bindErr) { if api.handleError(c, "Failed to bind URI", bindErr) {
return return
} }
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
// Get code
code := c.Query("code") code := c.Query("code")
// Code empty so redirect to error
if code == "" { if code == "" {
log.Error().Msg("No code provided") log.Error().Msg("No code provided")
c.Redirect(http.StatusPermanentRedirect, "/error") c.Redirect(http.StatusPermanentRedirect, "/error")
@@ -434,51 +473,67 @@ func (api *API) SetupRoutes() {
log.Debug().Msg("Got code") log.Debug().Msg("Got code")
// Get provider
provider := api.Providers.GetProvider(providerName.Provider) provider := api.Providers.GetProvider(providerName.Provider)
log.Debug().Str("provider", providerName.Provider).Msg("Got provider") log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
// Provider does not exist
if provider == nil { if provider == nil {
c.Redirect(http.StatusPermanentRedirect, "/not-found") c.Redirect(http.StatusPermanentRedirect, "/not-found")
return return
} }
// Exchange token (authenticates user)
_, tokenErr := provider.ExchangeToken(code) _, tokenErr := provider.ExchangeToken(code)
log.Debug().Msg("Got token") log.Debug().Msg("Got token")
// Handle error
if api.handleError(c, "Failed to exchange token", tokenErr) { if api.handleError(c, "Failed to exchange token", tokenErr) {
return return
} }
// Get email
email, emailErr := api.Providers.GetUser(providerName.Provider) email, emailErr := api.Providers.GetUser(providerName.Provider)
log.Debug().Str("email", email).Msg("Got email") log.Debug().Str("email", email).Msg("Got email")
// Handle error
if api.handleError(c, "Failed to get user", emailErr) { if api.handleError(c, "Failed to get user", emailErr) {
return return
} }
// Email is not whitelisted
if !api.Auth.EmailWhitelisted(email) { if !api.Auth.EmailWhitelisted(email) {
log.Warn().Str("email", email).Msg("Email not whitelisted") log.Warn().Str("email", email).Msg("Email not whitelisted")
// Build query
unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
Username: email, Username: email,
}) })
// Handle error
if api.handleError(c, "Failed to build query", unauthorizedQueryErr) { if api.handleError(c, "Failed to build query", unauthorizedQueryErr) {
return return
} }
// Redirect to unauthorized
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
} }
log.Debug().Msg("Email whitelisted") log.Debug().Msg("Email whitelisted")
// Create session cookie
api.Auth.CreateSessionCookie(c, &types.SessionCookie{ api.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: email, Username: email,
Provider: providerName.Provider, Provider: providerName.Provider,
}) })
// Get redirect URI
redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
// If it is empty it means that no redirect_uri was provided to the login screen so we just log in
if redirectURIErr != nil { if redirectURIErr != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -488,28 +543,44 @@ func (api *API) SetupRoutes() {
log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI")
// Clean up redirect cookie since we already have the value
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
// Build query
redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
RedirectURI: redirectURI, RedirectURI: redirectURI,
}) })
log.Debug().Msg("Got redirect query") log.Debug().Msg("Got redirect query")
// Handle error
if api.handleError(c, "Failed to build query", redirectQueryErr) { if api.handleError(c, "Failed to build query", redirectQueryErr) {
return return
} }
// Redirect to continue with the redirect URI
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
}) })
// Simple healthcheck
api.Router.GET("/api/healthcheck", func(c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"message": "OK",
})
})
} }
func (api *API) Run() { func (api *API) Run() {
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
// Run server
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
} }
// handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths)
func (api *API) handleError(c *gin.Context, msg string, err error) bool { func (api *API) handleError(c *gin.Context, msg string, err error) bool {
// If error is not nil log it and redirect to error page also return true so we can stop further processing
if err != nil { if err != nil {
log.Error().Err(err).Msg(msg) log.Error().Err(err).Msg(msg)
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL)) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL))
@@ -518,19 +589,25 @@ func (api *API) handleError(c *gin.Context, msg string, err error) bool {
return false return false
} }
// zerolog is a middleware for gin that logs requests using zerolog
func zerolog() gin.HandlerFunc { func zerolog() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Get initial time
tStart := time.Now() tStart := time.Now()
// Process request
c.Next() c.Next()
// Get status code, address, method and path
code := c.Writer.Status() code := c.Writer.Status()
address := c.Request.RemoteAddr address := c.Request.RemoteAddr
method := c.Request.Method method := c.Request.Method
path := c.Request.URL.Path path := c.Request.URL.Path
// Get latency
latency := time.Since(tStart).String() latency := time.Since(tStart).String()
// Log request
switch { switch {
case code >= 200 && code < 300: case code >= 200 && code < 300:
log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")

View File

@@ -4,8 +4,12 @@ import (
"embed" "embed"
) )
// UI assets
//
//go:embed dist //go:embed dist
var Assets embed.FS var Assets embed.FS
// Version file
//
//go:embed version //go:embed version
var Version string var Version string

View File

@@ -31,6 +31,7 @@ type Auth struct {
} }
func (auth *Auth) GetUser(username string) *types.User { func (auth *Auth) GetUser(username string) *types.User {
// Loop through users and return the user if the username matches
for _, user := range auth.Users { for _, user := range auth.Users {
if user.Username == username { if user.Username == username {
return &user return &user
@@ -40,64 +41,93 @@ func (auth *Auth) GetUser(username string) *types.User {
} }
func (auth *Auth) CheckPassword(user types.User, password string) bool { func (auth *Auth) CheckPassword(user types.User, password string) bool {
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) // Compare the hashed password with the password provided
return hashedPasswordErr == nil return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
} }
func (auth *Auth) EmailWhitelisted(emailSrc string) bool { func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
// If the whitelist is empty, allow all emails
if len(auth.OAuthWhitelist) == 0 { if len(auth.OAuthWhitelist) == 0 {
return true return true
} }
// Loop through the whitelist and return true if the email matches
for _, email := range auth.OAuthWhitelist { for _, email := range auth.OAuthWhitelist {
if email == emailSrc { if email == emailSrc {
return true return true
} }
} }
// If no emails match, return false
return false return false
} }
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) { func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) {
log.Debug().Msg("Creating session cookie") log.Debug().Msg("Creating session cookie")
// Get session
sessions := sessions.Default(c) sessions := sessions.Default(c)
log.Debug().Msg("Setting session cookie") log.Debug().Msg("Setting session cookie")
// Set data
sessions.Set("username", data.Username) sessions.Set("username", data.Username)
sessions.Set("provider", data.Provider) sessions.Set("provider", data.Provider)
sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix()) sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix())
// Save session
sessions.Save() sessions.Save()
} }
func (auth *Auth) DeleteSessionCookie(c *gin.Context) { func (auth *Auth) DeleteSessionCookie(c *gin.Context) {
log.Debug().Msg("Deleting session cookie") log.Debug().Msg("Deleting session cookie")
// Get session
sessions := sessions.Default(c) sessions := sessions.Default(c)
// Clear session
sessions.Clear() sessions.Clear()
// Save session
sessions.Save() sessions.Save()
} }
func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie {
log.Debug().Msg("Getting session cookie") log.Debug().Msg("Getting session cookie")
// Get session
sessions := sessions.Default(c) sessions := sessions.Default(c)
// Get data
cookieUsername := sessions.Get("username") cookieUsername := sessions.Get("username")
cookieProvider := sessions.Get("provider") cookieProvider := sessions.Get("provider")
cookieExpiry := sessions.Get("expiry") cookieExpiry := sessions.Get("expiry")
// Convert interfaces to correct types
username, usernameOk := cookieUsername.(string) username, usernameOk := cookieUsername.(string)
provider, providerOk := cookieProvider.(string) provider, providerOk := cookieProvider.(string)
expiry, expiryOk := cookieExpiry.(int64) expiry, expiryOk := cookieExpiry.(int64)
// Check if the cookie is invalid
if !usernameOk || !providerOk || !expiryOk { if !usernameOk || !providerOk || !expiryOk {
log.Warn().Msg("Session cookie invalid") log.Warn().Msg("Session cookie invalid")
return types.SessionCookie{} return types.SessionCookie{}
} }
// Check if the cookie has expired
if time.Now().Unix() > expiry { if time.Now().Unix() > expiry {
log.Warn().Msg("Session cookie expired") log.Warn().Msg("Session cookie expired")
// If it has, delete it
auth.DeleteSessionCookie(c) auth.DeleteSessionCookie(c)
// Return empty cookie
return types.SessionCookie{} return types.SessionCookie{}
} }
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie") log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie")
// Return the cookie
return types.SessionCookie{ return types.SessionCookie{
Username: username, Username: username,
Provider: provider, Provider: provider,
@@ -105,42 +135,56 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie {
} }
func (auth *Auth) UserAuthConfigured() bool { func (auth *Auth) UserAuthConfigured() bool {
// If there are users, return true
return len(auth.Users) > 0 return len(auth.Users) > 0
} }
func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) { func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) {
// Check if we have access to the Docker API
isConnected := auth.Docker.DockerConnected() isConnected := auth.Docker.DockerConnected()
// If we don't have access, it is assumed that the user has access
if !isConnected { if !isConnected {
log.Debug().Msg("Docker not connected, allowing access") log.Debug().Msg("Docker not connected, allowing access")
return true, nil return true, nil
} }
// Get the app ID from the host
appId := strings.Split(host, ".")[0] appId := strings.Split(host, ".")[0]
// Get the containers
containers, containersErr := auth.Docker.GetContainers() containers, containersErr := auth.Docker.GetContainers()
// If there is an error, return false
if containersErr != nil { if containersErr != nil {
return false, containersErr return false, containersErr
} }
log.Debug().Msg("Got containers") log.Debug().Msg("Got containers")
// Loop through the containers
for _, container := range containers { for _, container := range containers {
// Inspect the container
inspect, inspectErr := auth.Docker.InspectContainer(container.ID) inspect, inspectErr := auth.Docker.InspectContainer(container.ID)
// If there is an error, return false
if inspectErr != nil { if inspectErr != nil {
return false, inspectErr return false, inspectErr
} }
// Get the container name (for some reason it is /name)
containerName := strings.Split(inspect.Name, "/")[1] containerName := strings.Split(inspect.Name, "/")[1]
// There is a container with the same name as the app ID
if containerName == appId { if containerName == appId {
log.Debug().Str("container", containerName).Msg("Found container") log.Debug().Str("container", containerName).Msg("Found container")
// Get only the tinyauth labels in a struct
labels := utils.GetTinyauthLabels(inspect.Config.Labels) labels := utils.GetTinyauthLabels(inspect.Config.Labels)
log.Debug().Msg("Got labels") log.Debug().Msg("Got labels")
// If the container has an oauth whitelist, check if the user is in it
if context.OAuth && len(labels.OAuthWhitelist) != 0 { if context.OAuth && len(labels.OAuthWhitelist) != 0 {
log.Debug().Msg("Checking OAuth whitelist") log.Debug().Msg("Checking OAuth whitelist")
if slices.Contains(labels.OAuthWhitelist, context.Username) { if slices.Contains(labels.OAuthWhitelist, context.Username) {
@@ -149,6 +193,7 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool,
return false, nil return false, nil
} }
// If the container has users, check if the user is in it
if len(labels.Users) != 0 { if len(labels.Users) != 0 {
log.Debug().Msg("Checking users") log.Debug().Msg("Checking users")
if slices.Contains(labels.Users, context.Username) { if slices.Contains(labels.Users, context.Username) {
@@ -162,32 +207,40 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool,
log.Debug().Msg("No matching container found, allowing access") log.Debug().Msg("No matching container found, allowing access")
// If no matching container is found, allow access
return true, nil return true, nil
} }
func (auth *Auth) GetBasicAuth(c *gin.Context) types.User { func (auth *Auth) GetBasicAuth(c *gin.Context) types.User {
// Get the Authorization header
header := c.GetHeader("Authorization") header := c.GetHeader("Authorization")
// If the header is empty, return an empty user
if header == "" { if header == "" {
return types.User{} return types.User{}
} }
// Split the header
headerSplit := strings.Split(header, " ") headerSplit := strings.Split(header, " ")
if len(headerSplit) != 2 { if len(headerSplit) != 2 {
return types.User{} return types.User{}
} }
// Check if the header is Basic
if headerSplit[0] != "Basic" { if headerSplit[0] != "Basic" {
return types.User{} return types.User{}
} }
// Split the credentials
credentials := strings.Split(headerSplit[1], ":") credentials := strings.Split(headerSplit[1], ":")
// If the credentials are not in the correct format, return an empty user
if len(credentials) != 2 { if len(credentials) != 2 {
return types.User{} return types.User{}
} }
// Return the user
return types.User{ return types.User{
Username: credentials[0], Username: credentials[0],
Password: credentials[1], Password: credentials[1],

View File

@@ -1,5 +1,6 @@
package constants package constants
// TinyauthLabels is a list of labels that can be used in a tinyauth protected container
var TinyauthLabels = []string{ var TinyauthLabels = []string{
"tinyauth.oauth.whitelist", "tinyauth.oauth.whitelist",
"tinyauth.users", "tinyauth.users",

View File

@@ -18,39 +18,50 @@ type Docker struct {
} }
func (docker *Docker) Init() error { func (docker *Docker) Init() error {
// Create a new docker client
apiClient, err := client.NewClientWithOpts(client.FromEnv) apiClient, err := client.NewClientWithOpts(client.FromEnv)
// Check if there was an error
if err != nil { if err != nil {
return err return err
} }
// Set the context and api client
docker.Context = context.Background() docker.Context = context.Background()
docker.Client = apiClient docker.Client = apiClient
// Done
return nil return nil
} }
func (docker *Docker) GetContainers() ([]types.Container, error) { func (docker *Docker) GetContainers() ([]types.Container, error) {
// Get the list of containers
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
// Check if there was an error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Return the containers
return containers, nil return containers, nil
} }
func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) { func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) {
// Inspect the container
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
// Check if there was an error
if err != nil { if err != nil {
return types.ContainerJSON{}, err return types.ContainerJSON{}, err
} }
// Return the inspect
return inspect, nil return inspect, nil
} }
func (docker *Docker) DockerConnected() bool { func (docker *Docker) DockerConnected() bool {
// Ping the docker client if there is an error it is not connected
_, err := docker.Client.Ping(docker.Context) _, err := docker.Client.Ping(docker.Context)
return err == nil return err == nil
} }

View File

@@ -22,13 +22,19 @@ type Hooks struct {
} }
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Get session cookie and basic auth
cookie := hooks.Auth.GetSessionCookie(c) cookie := hooks.Auth.GetSessionCookie(c)
basic := hooks.Auth.GetBasicAuth(c) basic := hooks.Auth.GetBasicAuth(c)
// Check if basic auth is set
if basic.Username != "" { if basic.Username != "" {
log.Debug().Msg("Got basic auth") log.Debug().Msg("Got basic auth")
// Check if user exists and password is correct
user := hooks.Auth.GetUser(basic.Username) user := hooks.Auth.GetUser(basic.Username)
if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) {
// Return user context since we are logged in with basic auth
return types.UserContext{ return types.UserContext{
Username: basic.Username, Username: basic.Username,
IsLoggedIn: true, IsLoggedIn: true,
@@ -39,10 +45,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
} }
// Check if session cookie is username/password auth
if cookie.Provider == "username" { if cookie.Provider == "username" {
log.Debug().Msg("Provider is username") log.Debug().Msg("Provider is username")
// Check if user exists
if hooks.Auth.GetUser(cookie.Username) != nil { if hooks.Auth.GetUser(cookie.Username) != nil {
log.Debug().Msg("User exists") log.Debug().Msg("User exists")
// It exists so we are logged in
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
IsLoggedIn: true, IsLoggedIn: true,
@@ -53,13 +64,22 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
} }
log.Debug().Msg("Provider is not username") log.Debug().Msg("Provider is not username")
// The provider is not username so we need to check if it is an oauth provider
provider := hooks.Providers.GetProvider(cookie.Provider) provider := hooks.Providers.GetProvider(cookie.Provider)
// If we have a provider with this name
if provider != nil { if provider != nil {
log.Debug().Msg("Provider exists") log.Debug().Msg("Provider exists")
// Check if the oauth email is whitelisted
if !hooks.Auth.EmailWhitelisted(cookie.Username) { if !hooks.Auth.EmailWhitelisted(cookie.Username) {
log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted") log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted")
// It isn't so we delete the cookie and return an empty context
hooks.Auth.DeleteSessionCookie(c) hooks.Auth.DeleteSessionCookie(c)
// Return empty context
return types.UserContext{ return types.UserContext{
Username: "", Username: "",
IsLoggedIn: false, IsLoggedIn: false,
@@ -67,7 +87,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
Provider: "", Provider: "",
} }
} }
log.Debug().Msg("Email is whitelisted") log.Debug().Msg("Email is whitelisted")
// Return user context since we are logged in with oauth
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
IsLoggedIn: true, IsLoggedIn: true,
@@ -76,6 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
} }
} }
// Neither basic auth or oauth is set so we return an empty context
return types.UserContext{ return types.UserContext{
Username: "", Username: "",
IsLoggedIn: false, IsLoggedIn: false,

View File

@@ -21,23 +21,33 @@ type OAuth struct {
} }
func (oauth *OAuth) Init() { func (oauth *OAuth) Init() {
// Create a new context and verifier
oauth.Context = context.Background() oauth.Context = context.Background()
oauth.Verifier = oauth2.GenerateVerifier() oauth.Verifier = oauth2.GenerateVerifier()
} }
func (oauth *OAuth) GetAuthURL() string { func (oauth *OAuth) GetAuthURL() string {
// Return the auth url
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
} }
func (oauth *OAuth) ExchangeToken(code string) (string, error) { func (oauth *OAuth) ExchangeToken(code string) (string, error) {
// Exchange the code for a token
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
// Check if there was an error
if err != nil { if err != nil {
return "", err return "", err
} }
// Set the token
oauth.Token = token oauth.Token = token
// Return the access token
return oauth.Token.AccessToken, nil return oauth.Token.AccessToken, nil
} }
func (oauth *OAuth) GetClient() *http.Client { func (oauth *OAuth) GetClient() *http.Client {
// Return the http client with the token set
return oauth.Config.Client(oauth.Context, oauth.Token) return oauth.Config.Client(oauth.Context, oauth.Token)
} }

View File

@@ -8,36 +8,45 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// We are assuming that the generic provider will return a JSON object with an email field
type GenericUserInfoResponse struct { type GenericUserInfoResponse struct {
Email string `json:"email"` Email string `json:"email"`
} }
func GetGenericEmail(client *http.Client, url string) (string, error) { func GetGenericEmail(client *http.Client, url string) (string, error) {
// Using the oauth client get the user info url
res, resErr := client.Get(url) res, resErr := client.Get(url)
// Check if there was an error
if resErr != nil { if resErr != nil {
return "", resErr return "", resErr
} }
log.Debug().Msg("Got response from generic provider") log.Debug().Msg("Got response from generic provider")
// Read the body of the response
body, bodyErr := io.ReadAll(res.Body) body, bodyErr := io.ReadAll(res.Body)
// Check if there was an error
if bodyErr != nil { if bodyErr != nil {
return "", bodyErr return "", bodyErr
} }
log.Debug().Msg("Read body from generic provider") log.Debug().Msg("Read body from generic provider")
// Parse the body into a user struct
var user GenericUserInfoResponse var user GenericUserInfoResponse
// Unmarshal the body into the user struct
jsonErr := json.Unmarshal(body, &user) jsonErr := json.Unmarshal(body, &user)
// Check if there was an error
if jsonErr != nil { if jsonErr != nil {
return "", jsonErr return "", jsonErr
} }
log.Debug().Msg("Parsed user from generic provider") log.Debug().Msg("Parsed user from generic provider")
// Return the email
return user.Email, nil return user.Email, nil
} }

View File

@@ -9,47 +9,58 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// Github has a different response than the generic provider
type GithubUserInfoResponse []struct { type GithubUserInfoResponse []struct {
Email string `json:"email"` Email string `json:"email"`
Primary bool `json:"primary"` Primary bool `json:"primary"`
} }
// The scopes required for the github provider
func GithubScopes() []string { func GithubScopes() []string {
return []string{"user:email"} return []string{"user:email"}
} }
func GetGithubEmail(client *http.Client) (string, error) { func GetGithubEmail(client *http.Client) (string, error) {
// Get the user emails from github using the oauth http client
res, resErr := client.Get("https://api.github.com/user/emails") res, resErr := client.Get("https://api.github.com/user/emails")
// Check if there was an error
if resErr != nil { if resErr != nil {
return "", resErr return "", resErr
} }
log.Debug().Msg("Got response from github") log.Debug().Msg("Got response from github")
// Read the body of the response
body, bodyErr := io.ReadAll(res.Body) body, bodyErr := io.ReadAll(res.Body)
// Check if there was an error
if bodyErr != nil { if bodyErr != nil {
return "", bodyErr return "", bodyErr
} }
log.Debug().Msg("Read body from github") log.Debug().Msg("Read body from github")
// Parse the body into a user struct
var emails GithubUserInfoResponse var emails GithubUserInfoResponse
// Unmarshal the body into the user struct
jsonErr := json.Unmarshal(body, &emails) jsonErr := json.Unmarshal(body, &emails)
// Check if there was an error
if jsonErr != nil { if jsonErr != nil {
return "", jsonErr return "", jsonErr
} }
log.Debug().Msg("Parsed emails from github") log.Debug().Msg("Parsed emails from github")
// Find and return the primary email
for _, email := range emails { for _, email := range emails {
if email.Primary { if email.Primary {
return email.Email, nil return email.Email, nil
} }
} }
// User does not have a primary email?
return "", errors.New("no primary email found") return "", errors.New("no primary email found")
} }

View File

@@ -8,40 +8,50 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// Google works the same as the generic provider
type GoogleUserInfoResponse struct { type GoogleUserInfoResponse struct {
Email string `json:"email"` Email string `json:"email"`
} }
// The scopes required for the google provider
func GoogleScopes() []string { func GoogleScopes() []string {
return []string{"https://www.googleapis.com/auth/userinfo.email"} return []string{"https://www.googleapis.com/auth/userinfo.email"}
} }
func GetGoogleEmail(client *http.Client) (string, error) { func GetGoogleEmail(client *http.Client) (string, error) {
// Get the user info from google using the oauth http client
res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me") res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me")
// Check if there was an error
if resErr != nil { if resErr != nil {
return "", resErr return "", resErr
} }
log.Debug().Msg("Got response from google") log.Debug().Msg("Got response from google")
// Read the body of the response
body, bodyErr := io.ReadAll(res.Body) body, bodyErr := io.ReadAll(res.Body)
// Check if there was an error
if bodyErr != nil { if bodyErr != nil {
return "", bodyErr return "", bodyErr
} }
log.Debug().Msg("Read body from google") log.Debug().Msg("Read body from google")
// Parse the body into a user struct
var user GoogleUserInfoResponse var user GoogleUserInfoResponse
// Unmarshal the body into the user struct
jsonErr := json.Unmarshal(body, &user) jsonErr := json.Unmarshal(body, &user)
// Check if there was an error
if jsonErr != nil { if jsonErr != nil {
return "", jsonErr return "", jsonErr
} }
log.Debug().Msg("Parsed user from google") log.Debug().Msg("Parsed user from google")
// Return the email
return user.Email, nil return user.Email, nil
} }

View File

@@ -25,8 +25,11 @@ type Providers struct {
} }
func (providers *Providers) Init() { func (providers *Providers) Init() {
// If we have a client id and secret for github, initialize the oauth provider
if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" {
log.Info().Msg("Initializing Github OAuth") log.Info().Msg("Initializing Github OAuth")
// Create a new oauth provider with the github config
providers.Github = oauth.NewOAuth(oauth2.Config{ providers.Github = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GithubClientId, ClientID: providers.Config.GithubClientId,
ClientSecret: providers.Config.GithubClientSecret, ClientSecret: providers.Config.GithubClientSecret,
@@ -34,10 +37,16 @@ func (providers *Providers) Init() {
Scopes: GithubScopes(), Scopes: GithubScopes(),
Endpoint: endpoints.GitHub, Endpoint: endpoints.GitHub,
}) })
// Initialize the oauth provider
providers.Github.Init() providers.Github.Init()
} }
// If we have a client id and secret for google, initialize the oauth provider
if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" {
log.Info().Msg("Initializing Google OAuth") log.Info().Msg("Initializing Google OAuth")
// Create a new oauth provider with the google config
providers.Google = oauth.NewOAuth(oauth2.Config{ providers.Google = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GoogleClientId, ClientID: providers.Config.GoogleClientId,
ClientSecret: providers.Config.GoogleClientSecret, ClientSecret: providers.Config.GoogleClientSecret,
@@ -45,10 +54,15 @@ func (providers *Providers) Init() {
Scopes: GoogleScopes(), Scopes: GoogleScopes(),
Endpoint: endpoints.Google, Endpoint: endpoints.Google,
}) })
// Initialize the oauth provider
providers.Google.Init() providers.Google.Init()
} }
if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" { if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" {
log.Info().Msg("Initializing Tailscale OAuth") log.Info().Msg("Initializing Tailscale OAuth")
// Create a new oauth provider with the tailscale config
providers.Tailscale = oauth.NewOAuth(oauth2.Config{ providers.Tailscale = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.TailscaleClientId, ClientID: providers.Config.TailscaleClientId,
ClientSecret: providers.Config.TailscaleClientSecret, ClientSecret: providers.Config.TailscaleClientSecret,
@@ -56,10 +70,16 @@ func (providers *Providers) Init() {
Scopes: TailscaleScopes(), Scopes: TailscaleScopes(),
Endpoint: TailscaleEndpoint, Endpoint: TailscaleEndpoint,
}) })
// Initialize the oauth provider
providers.Tailscale.Init() providers.Tailscale.Init()
} }
// If we have a client id and secret for generic oauth, initialize the oauth provider
if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" {
log.Info().Msg("Initializing Generic OAuth") log.Info().Msg("Initializing Generic OAuth")
// Create a new oauth provider with the generic config
providers.Generic = oauth.NewOAuth(oauth2.Config{ providers.Generic = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GenericClientId, ClientID: providers.Config.GenericClientId,
ClientSecret: providers.Config.GenericClientSecret, ClientSecret: providers.Config.GenericClientSecret,
@@ -70,11 +90,14 @@ func (providers *Providers) Init() {
TokenURL: providers.Config.GenericTokenURL, TokenURL: providers.Config.GenericTokenURL,
}, },
}) })
// Initialize the oauth provider
providers.Generic.Init() providers.Generic.Init()
} }
} }
func (providers *Providers) GetProvider(provider string) *oauth.OAuth { func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
// Return the provider based on the provider string
switch provider { switch provider {
case "github": case "github":
return providers.Github return providers.Github
@@ -90,58 +113,103 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
} }
func (providers *Providers) GetUser(provider string) (string, error) { func (providers *Providers) GetUser(provider string) (string, error) {
// Get the email from the provider
switch provider { switch provider {
case "github": case "github":
// If the github provider is not configured, return an error
if providers.Github == nil { if providers.Github == nil {
log.Debug().Msg("Github provider not configured") log.Debug().Msg("Github provider not configured")
return "", nil return "", nil
} }
// Get the client from the github provider
client := providers.Github.GetClient() client := providers.Github.GetClient()
log.Debug().Msg("Got client from github") log.Debug().Msg("Got client from github")
// Get the email from the github provider
email, emailErr := GetGithubEmail(client) email, emailErr := GetGithubEmail(client)
// Check if there was an error
if emailErr != nil { if emailErr != nil {
return "", emailErr return "", emailErr
} }
log.Debug().Msg("Got email from github") log.Debug().Msg("Got email from github")
// Return the email
return email, nil return email, nil
case "google": case "google":
// If the google provider is not configured, return an error
if providers.Google == nil { if providers.Google == nil {
log.Debug().Msg("Google provider not configured") log.Debug().Msg("Google provider not configured")
return "", nil return "", nil
} }
// Get the client from the google provider
client := providers.Google.GetClient() client := providers.Google.GetClient()
log.Debug().Msg("Got client from google") log.Debug().Msg("Got client from google")
// Get the email from the google provider
email, emailErr := GetGoogleEmail(client) email, emailErr := GetGoogleEmail(client)
// Check if there was an error
if emailErr != nil { if emailErr != nil {
return "", emailErr return "", emailErr
} }
log.Debug().Msg("Got email from google") log.Debug().Msg("Got email from google")
// Return the email
return email, nil return email, nil
case "tailscale": case "tailscale":
// If the tailscale provider is not configured, return an error
if providers.Tailscale == nil { if providers.Tailscale == nil {
log.Debug().Msg("Tailscale provider not configured") log.Debug().Msg("Tailscale provider not configured")
return "", nil return "", nil
} }
// Get the client from the tailscale provider
client := providers.Tailscale.GetClient() client := providers.Tailscale.GetClient()
log.Debug().Msg("Got client from tailscale") log.Debug().Msg("Got client from tailscale")
// Get the email from the tailscale provider
email, emailErr := GetTailscaleEmail(client) email, emailErr := GetTailscaleEmail(client)
// Check if there was an error
if emailErr != nil { if emailErr != nil {
return "", emailErr return "", emailErr
} }
log.Debug().Msg("Got email from tailscale") log.Debug().Msg("Got email from tailscale")
// Return the email
return email, nil return email, nil
case "generic": case "generic":
// If the generic provider is not configured, return an error
if providers.Generic == nil { if providers.Generic == nil {
log.Debug().Msg("Generic provider not configured") log.Debug().Msg("Generic provider not configured")
return "", nil return "", nil
} }
// Get the client from the generic provider
client := providers.Generic.GetClient() client := providers.Generic.GetClient()
log.Debug().Msg("Got client from generic") log.Debug().Msg("Got client from generic")
// Get the email from the generic provider
email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL) email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL)
// Check if there was an error
if emailErr != nil { if emailErr != nil {
return "", emailErr return "", emailErr
} }
log.Debug().Msg("Got email from generic") log.Debug().Msg("Got email from generic")
// Return the email
return email, nil return email, nil
default: default:
return "", nil return "", nil
@@ -149,6 +217,7 @@ func (providers *Providers) GetUser(provider string) (string, error) {
} }
func (provider *Providers) GetConfiguredProviders() []string { func (provider *Providers) GetConfiguredProviders() []string {
// Create a list of the configured providers
providers := []string{} providers := []string{}
if provider.Github != nil { if provider.Github != nil {
providers = append(providers, "github") providers = append(providers, "github")

View File

@@ -9,48 +9,60 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
// The tailscale email is the loginName
type TailscaleUser struct { type TailscaleUser struct {
LoginName string `json:"loginName"` LoginName string `json:"loginName"`
} }
// The response from the tailscale user info endpoint
type TailscaleUserInfoResponse struct { type TailscaleUserInfoResponse struct {
Users []TailscaleUser `json:"users"` Users []TailscaleUser `json:"users"`
} }
// The scopes required for the tailscale provider
func TailscaleScopes() []string { func TailscaleScopes() []string {
return []string{"users:read"} return []string{"users:read"}
} }
// The tailscale endpoint
var TailscaleEndpoint = oauth2.Endpoint{ var TailscaleEndpoint = oauth2.Endpoint{
TokenURL: "https://api.tailscale.com/api/v2/oauth/token", TokenURL: "https://api.tailscale.com/api/v2/oauth/token",
} }
func GetTailscaleEmail(client *http.Client) (string, error) { func GetTailscaleEmail(client *http.Client) (string, error) {
// Get the user info from tailscale using the oauth http client
res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users") res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users")
// Check if there was an error
if resErr != nil { if resErr != nil {
return "", resErr return "", resErr
} }
log.Debug().Msg("Got response from tailscale") log.Debug().Msg("Got response from tailscale")
// Read the body of the response
body, bodyErr := io.ReadAll(res.Body) body, bodyErr := io.ReadAll(res.Body)
// Check if there was an error
if bodyErr != nil { if bodyErr != nil {
return "", bodyErr return "", bodyErr
} }
log.Debug().Msg("Read body from tailscale") log.Debug().Msg("Read body from tailscale")
// Parse the body into a user struct
var users TailscaleUserInfoResponse var users TailscaleUserInfoResponse
// Unmarshal the body into the user struct
jsonErr := json.Unmarshal(body, &users) jsonErr := json.Unmarshal(body, &users)
// Check if there was an error
if jsonErr != nil { if jsonErr != nil {
return "", jsonErr return "", jsonErr
} }
log.Debug().Msg("Parsed users from tailscale") log.Debug().Msg("Parsed users from tailscale")
// Return the email of the first user
return users.Users[0].LoginName, nil return users.Users[0].LoginName, nil
} }

View File

@@ -2,22 +2,27 @@ package types
import "tinyauth/internal/oauth" import "tinyauth/internal/oauth"
// LoginQuery is the query parameters for the login endpoint
type LoginQuery struct { type LoginQuery struct {
RedirectURI string `url:"redirect_uri"` RedirectURI string `url:"redirect_uri"`
} }
// LoginRequest is the request body for the login endpoint
type LoginRequest struct { type LoginRequest struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
} }
// User is the struct for a user
type User struct { type User struct {
Username string Username string
Password string Password string
} }
// Users is a list of users
type Users []User type Users []User
// Config is the configuration for the tinyauth server
type Config struct { type Config struct {
Port int `mapstructure:"port" validate:"required"` Port int `mapstructure:"port" validate:"required"`
Address string `validate:"required,ip4_addr" mapstructure:"address"` Address string `validate:"required,ip4_addr" mapstructure:"address"`
@@ -49,6 +54,7 @@ type Config struct {
LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"` LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"`
} }
// UserContext is the context for the user
type UserContext struct { type UserContext struct {
Username string Username string
IsLoggedIn bool IsLoggedIn bool
@@ -56,6 +62,7 @@ type UserContext struct {
Provider string Provider string
} }
// APIConfig is the configuration for the API
type APIConfig struct { type APIConfig struct {
Port int Port int
Address string Address string
@@ -66,6 +73,7 @@ type APIConfig struct {
DisableContinue bool DisableContinue bool
} }
// OAuthConfig is the configuration for the providers
type OAuthConfig struct { type OAuthConfig struct {
GithubClientId string GithubClientId string
GithubClientSecret string GithubClientSecret string
@@ -82,35 +90,42 @@ type OAuthConfig struct {
AppURL string AppURL string
} }
// OAuthRequest is the request for the OAuth endpoint
type OAuthRequest struct { type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"` Provider string `uri:"provider" binding:"required"`
} }
// OAuthProviders is the struct for the OAuth providers
type OAuthProviders struct { type OAuthProviders struct {
Github *oauth.OAuth Github *oauth.OAuth
Google *oauth.OAuth Google *oauth.OAuth
Microsoft *oauth.OAuth Microsoft *oauth.OAuth
} }
// UnauthorizedQuery is the query parameters for the unauthorized endpoint
type UnauthorizedQuery struct { type UnauthorizedQuery struct {
Username string `url:"username"` Username string `url:"username"`
Resource string `url:"resource"` Resource string `url:"resource"`
} }
// SessionCookie is the cookie for the session (exculding the expiry)
type SessionCookie struct { type SessionCookie struct {
Username string Username string
Provider string Provider string
} }
// TinyauthLabels is the labels for the tinyauth container
type TinyauthLabels struct { type TinyauthLabels struct {
OAuthWhitelist []string OAuthWhitelist []string
Users []string Users []string
} }
// TailscaleQuery is the query parameters for the tailscale endpoint
type TailscaleQuery struct { type TailscaleQuery struct {
Code int `url:"code"` Code int `url:"code"`
} }
// Proxy is the uri parameters for the proxy endpoint
type Proxy struct { type Proxy struct {
Proxy string `uri:"proxy" binding:"required"` Proxy string `uri:"proxy" binding:"required"`
} }

View File

@@ -12,20 +12,32 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
// Parses a list of comma separated users in a struct
func ParseUsers(users string) (types.Users, error) { func ParseUsers(users string) (types.Users, error) {
log.Debug().Msg("Parsing users") log.Debug().Msg("Parsing users")
// Create a new users struct
var usersParsed types.Users var usersParsed types.Users
// Split the users by comma
userList := strings.Split(users, ",") userList := strings.Split(users, ",")
// Check if there are any users
if len(userList) == 0 { if len(userList) == 0 {
return types.Users{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
// Loop through the users and split them by colon
for _, user := range userList { for _, user := range userList {
// Split the user by colon
userSplit := strings.Split(user, ":") userSplit := strings.Split(user, ":")
// Check if the user is in the correct format
if len(userSplit) != 2 { if len(userSplit) != 2 {
return types.Users{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
// Append the user to the users struct
usersParsed = append(usersParsed, types.User{ usersParsed = append(usersParsed, types.User{
Username: userSplit[0], Username: userSplit[0],
Password: userSplit[1], Password: userSplit[1],
@@ -34,43 +46,61 @@ func ParseUsers(users string) (types.Users, error) {
log.Debug().Msg("Parsed users") log.Debug().Msg("Parsed users")
// Return the users struct
return usersParsed, nil return usersParsed, nil
} }
// Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> domain.com)
func GetRootURL(urlSrc string) (string, error) { func GetRootURL(urlSrc string) (string, error) {
// Make sure the url is valid
urlParsed, parseErr := url.Parse(urlSrc) urlParsed, parseErr := url.Parse(urlSrc)
// Check if there was an error
if parseErr != nil { if parseErr != nil {
return "", parseErr return "", parseErr
} }
// Split the hostname by period
urlSplitted := strings.Split(urlParsed.Hostname(), ".") urlSplitted := strings.Split(urlParsed.Hostname(), ".")
// Get the last part of the url
urlFinal := strings.Join(urlSplitted[1:], ".") urlFinal := strings.Join(urlSplitted[1:], ".")
// Return the root domain
return urlFinal, nil return urlFinal, nil
} }
// Reads a file and returns the contents
func ReadFile(file string) (string, error) { func ReadFile(file string) (string, error) {
// Check if the file exists
_, statErr := os.Stat(file) _, statErr := os.Stat(file)
// Check if there was an error
if statErr != nil { if statErr != nil {
return "", statErr return "", statErr
} }
// Read the file
data, readErr := os.ReadFile(file) data, readErr := os.ReadFile(file)
// Check if there was an error
if readErr != nil { if readErr != nil {
return "", readErr return "", readErr
} }
// Return the file contents
return string(data), nil return string(data), nil
} }
// Parses a file into a comma separated list of users
func ParseFileToLine(content string) string { func ParseFileToLine(content string) string {
// Split the content by newline
lines := strings.Split(content, "\n") lines := strings.Split(content, "\n")
// Create a list of users
users := make([]string, 0) users := make([]string, 0)
// Loop through the lines, trimming the whitespace and appending to the users list
for _, line := range lines { for _, line := range lines {
if strings.TrimSpace(line) == "" { if strings.TrimSpace(line) == "" {
continue continue
@@ -79,63 +109,92 @@ func ParseFileToLine(content string) string {
users = append(users, strings.TrimSpace(line)) users = append(users, strings.TrimSpace(line))
} }
// Return the users as a comma separated string
return strings.Join(users, ",") return strings.Join(users, ",")
} }
// Get the secret from the config or file
func GetSecret(conf string, file string) string { func GetSecret(conf string, file string) string {
// If neither the config or file is set, return an empty string
if conf == "" && file == "" { if conf == "" && file == "" {
return "" return ""
} }
// If the config is set, return the config (environment variable)
if conf != "" { if conf != "" {
return conf return conf
} }
// If the file is set, read the file
contents, err := ReadFile(file) contents, err := ReadFile(file)
// Check if there was an error
if err != nil { if err != nil {
return "" return ""
} }
// Return the contents of the file
return contents return contents
} }
// Get the users from the config or file
func GetUsers(conf string, file string) (types.Users, error) { func GetUsers(conf string, file string) (types.Users, error) {
// Create a string to store the users
var users string var users string
// If neither the config or file is set, return an empty users struct
if conf == "" && file == "" { if conf == "" && file == "" {
return types.Users{}, nil return types.Users{}, nil
} }
// If the config (environment) is set, append the users to the users string
if conf != "" { if conf != "" {
log.Debug().Msg("Using users from config") log.Debug().Msg("Using users from config")
users += conf users += conf
} }
// If the file is set, read the file and append the users to the users string
if file != "" { if file != "" {
// Read the file
fileContents, fileErr := ReadFile(file) fileContents, fileErr := ReadFile(file)
// If there isn't an error we can append the users to the users string
if fileErr == nil { if fileErr == nil {
log.Debug().Msg("Using users from file") log.Debug().Msg("Using users from file")
// Append the users to the users string
if users != "" { if users != "" {
users += "," users += ","
} }
// Parse the file contents into a comma separated list of users
users += ParseFileToLine(fileContents) users += ParseFileToLine(fileContents)
} }
} }
// Return the parsed users
return ParseUsers(users) return ParseUsers(users)
} }
// Check if any of the OAuth providers are configured based on the client id and secret
func OAuthConfigured(config types.Config) bool { func OAuthConfigured(config types.Config) bool {
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "") return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "")
} }
// Parse the docker labels to the tinyauth labels struct
func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
// Create a new tinyauth labels struct
var tinyauthLabels types.TinyauthLabels var tinyauthLabels types.TinyauthLabels
// Loop through the labels
for label, value := range labels { for label, value := range labels {
// Check if the label is in the tinyauth labels
if slices.Contains(constants.TinyauthLabels, label) { if slices.Contains(constants.TinyauthLabels, label) {
log.Debug().Str("label", label).Msg("Found label") log.Debug().Str("label", label).Msg("Found label")
// Add the label value to the tinyauth labels struct
switch label { switch label {
case "tinyauth.oauth.whitelist": case "tinyauth.oauth.whitelist":
tinyauthLabels.OAuthWhitelist = strings.Split(value, ",") tinyauthLabels.OAuthWhitelist = strings.Split(value, ",")
@@ -144,5 +203,7 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
} }
} }
} }
// Return the tinyauth labels
return tinyauthLabels return tinyauthLabels
} }