From 7a3a46348921633f22a0461389625af97fa32807 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 8 Feb 2025 12:33:58 +0200 Subject: [PATCH] chore: add comments to code --- cmd/root.go | 14 +- cmd/user/create/create.go | 9 ++ cmd/user/verify/verify.go | 9 ++ internal/api/api.go | 249 +++++++++++++++++++++----------- internal/assets/assets.go | 6 +- internal/auth/auth.go | 57 +++++++- internal/constants/constants.go | 1 + internal/docker/docker.go | 11 ++ internal/hooks/hooks.go | 24 +++ internal/oauth/oauth.go | 10 ++ internal/providers/generic.go | 9 ++ internal/providers/github.go | 11 ++ internal/providers/google.go | 10 ++ internal/providers/providers.go | 69 +++++++++ internal/providers/tailscale.go | 12 ++ internal/types/types.go | 15 ++ internal/utils/utils.go | 61 ++++++++ 17 files changed, 485 insertions(+), 92 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 65eaba4..1d9ec3a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -125,20 +125,24 @@ var rootCmd = &cobra.Command{ func Execute() { err := rootCmd.Execute() - if err != nil { - log.Fatal().Err(err).Msg("Failed to execute command") - } + HandleError(err, "Failed to execute root command") } func HandleError(err error, msg string) { + // If error log it and exit if err != nil { log.Fatal().Err(err).Msg(msg) } } func init() { + // Add user command rootCmd.AddCommand(cmd.UserCmd()) + + // Read environment variables viper.AutomaticEnv() + + // Flags rootCmd.Flags().Int("port", 3000, "Port to run the server on.") rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") @@ -167,6 +171,8 @@ func init() { rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") rootCmd.Flags().Int("log-level", 1, "Log level.") + + // Bind flags to environment viper.BindEnv("port", "PORT") viper.BindEnv("address", "ADDRESS") viper.BindEnv("secret", "SECRET") @@ -195,5 +201,7 @@ func init() { viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") viper.BindEnv("session-expiry", "SESSION_EXPIRY") viper.BindEnv("log-level", "LOG_LEVEL") + + // Bind flags to viper viper.BindPFlags(rootCmd.Flags()) } diff --git a/cmd/user/create/create.go b/cmd/user/create/create.go index 7dcbfcb..fed337c 100644 --- a/cmd/user/create/create.go +++ b/cmd/user/create/create.go @@ -22,9 +22,12 @@ var CreateCmd = &cobra.Command{ Short: "Create a user", Long: `Create a user either interactively or by passing flags.`, Run: func(cmd *cobra.Command, args []string) { + // Setup logger log.Logger = log.Level(zerolog.InfoLevel) + // Check if interactive if interactive { + // Create huh form form := huh.NewForm( huh.NewGroup( huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error { @@ -43,6 +46,7 @@ var CreateCmd = &cobra.Command{ ), ) + // Use simple theme var baseTheme *huh.Theme = huh.ThemeBase() formErr := form.WithTheme(baseTheme).Run() @@ -52,12 +56,14 @@ var CreateCmd = &cobra.Command{ } } + // Do we have username and password? if username == "" || password == "" { log.Error().Msg("Username and password cannot be empty") } log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user") + // Hash password passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if passwordErr != nil { @@ -66,15 +72,18 @@ var CreateCmd = &cobra.Command{ passwordString := string(passwordByte) + // Escape $ for docker if docker { passwordString = strings.ReplaceAll(passwordString, "$", "$$") } + // Log user created log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created") }, } func init() { + // Flags CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively") CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker") CreateCmd.Flags().StringVar(&username, "username", "", "Username") diff --git a/cmd/user/verify/verify.go b/cmd/user/verify/verify.go index ace1609..167e8dc 100644 --- a/cmd/user/verify/verify.go +++ b/cmd/user/verify/verify.go @@ -22,9 +22,12 @@ var VerifyCmd = &cobra.Command{ Short: "Verify a user is set up correctly", Long: `Verify a user is set up correctly meaning that it has a correct username and password.`, Run: func(cmd *cobra.Command, args []string) { + // Setup logger log.Logger = log.Level(zerolog.InfoLevel) + // Check if interactive if interactive { + // Create huh form form := huh.NewForm( huh.NewGroup( huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error { @@ -49,6 +52,7 @@ var VerifyCmd = &cobra.Command{ ), ) + // Use simple theme var baseTheme *huh.Theme = huh.ThemeBase() formErr := form.WithTheme(baseTheme).Run() @@ -58,22 +62,26 @@ var VerifyCmd = &cobra.Command{ } } + // Do we have username, password and user? if username == "" || password == "" || user == "" { log.Fatal().Msg("Username, password and user cannot be empty") } log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user") + // Split username and password userSplit := strings.Split(user, ":") if userSplit[1] == "" { log.Fatal().Msg("User is not formatted correctly") } + // Replace $$ with $ if formatted for docker if docker { userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$") } + // Compare username and password verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password)) if verifyErr != nil || username != userSplit[0] { @@ -85,6 +93,7 @@ var VerifyCmd = &cobra.Command{ } func init() { + // Flags VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively") VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?") VerifyCmd.Flags().StringVar(&username, "username", "", "Username") diff --git a/internal/api/api.go b/internal/api/api.go index 08a9814..c2758ca 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -41,11 +41,15 @@ type API struct { } func (api *API) Init() { + // Disable gin logs gin.SetMode(gin.ReleaseMode) + // Create router and use zerolog for logs log.Debug().Msg("Setting up router") router := gin.New() router.Use(zerolog()) + + // Read UI assets log.Debug().Msg("Setting up assets") dist, distErr := fs.Sub(assets.Assets, "dist") @@ -53,11 +57,15 @@ func (api *API) Init() { log.Fatal().Err(distErr).Msg("Failed to get UI assets") } + // Create file server log.Debug().Msg("Setting up file server") fileServer := http.FileServer(http.FS(dist)) + + // Setup cookie store log.Debug().Msg("Setting up cookie store") store := cookie.NewStore([]byte(api.Config.Secret)) + // Get domain to use for session cookies log.Debug().Msg("Getting domain") domain, domainErr := utils.GetRootURL(api.Config.AppURL) @@ -70,6 +78,7 @@ func (api *API) Init() { api.Domain = fmt.Sprintf(".%s", domain) + // Use session middleware store.Options(sessions.Options{ Domain: api.Domain, Path: "/", @@ -80,175 +89,169 @@ func (api *API) Init() { router.Use(sessions.Sessions("tinyauth", store)) + // UI middleware router.Use(func(c *gin.Context) { + // If not an API request, serve the UI if !strings.HasPrefix(c.Request.URL.Path, "/api") { _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) + + // If the file doesn't exist, serve the index.html if os.IsNotExist(err) { c.Request.URL.Path = "/" } + + // Serve the file fileServer.ServeHTTP(c.Writer, c.Request) + + // Stop further processing c.Abort() } }) + // Set router api.Router = router } func (api *API) SetupRoutes() { api.Router.GET("/api/auth/:proxy", func(c *gin.Context) { + // Create struct for proxy var proxy types.Proxy + // Bind URI bindErr := c.BindUri(&proxy) + // Handle error if api.handleError(c, "Failed to bind URI", bindErr) { return } log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") + // Get user context userContext := api.Hooks.UseUserContext(c) + // Get headers uri := c.Request.Header.Get("X-Forwarded-Uri") proto := c.Request.Header.Get("X-Forwarded-Proto") host := c.Request.Header.Get("X-Forwarded-Host") + // Check if user is logged in if userContext.IsLoggedIn { log.Debug().Msg("Authenticated") + // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host) + // Check if there was an error if appAllowedErr != nil { - switch proxy.Proxy { - case "nginx": - log.Error().Err(appAllowedErr).Msg("Failed to check if resource is allowed") + // Return 501 if nginx is the proxy or if the request is using an Authorization header + if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { + log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed") c.JSON(501, gin.H{ "status": 501, "message": "Internal Server Error", }) return - default: - if c.GetHeader("Authorization") != "" { - log.Error().Err(appAllowedErr).Msg("Failed to check if resource is allowed") - c.JSON(501, gin.H{ - "status": 501, - "message": "Internal Server Error", - }) - return - } - if api.handleError(c, "Failed to check if resource is allowed", appAllowedErr) { - return - } + } + + // Return the internal server error page + if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) { + return } } log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") + // The user is not allowed to access the app if !appAllowed { log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") + // Build query queries, queryErr := query.Values(types.UnauthorizedQuery{ Username: userContext.Username, Resource: strings.Split(host, ".")[0], }) + // Check if there was an error if queryErr != nil { - switch proxy.Proxy { - case "nginx": + // Return 501 if nginx is the proxy or if the request is using an Authorization header + if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { log.Error().Err(queryErr).Msg("Failed to build query") c.JSON(501, gin.H{ "status": 501, "message": "Internal Server Error", }) return - default: - if c.GetHeader("Authorization") != "" { - log.Error().Err(appAllowedErr).Msg("Failed to build query") - c.JSON(501, gin.H{ - "status": 501, - "message": "Internal Server Error", - }) - return - } - if api.handleError(c, "Failed to build query", queryErr) { - return - } + } + + // Return the internal server error page + if api.handleError(c, "Failed to build query", queryErr) { + return } } - switch proxy.Proxy { - case "nginx": + // Return 401 if nginx is the proxy or if the request is using an Authorization header + if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", }) return - default: - if c.GetHeader("Authorization") != "" { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode())) - return } + + // We are using caddy/traefik so redirect + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode())) + + // Stop further processing + return } + // The user is allowed to access the app c.JSON(200, gin.H{ "status": 200, "message": "Authenticated", }) + + // Stop further processing return } - switch proxy.Proxy { - case "nginx": + // The user is not logged in + log.Debug().Msg("Unauthorized") + + // Return 401 if nginx is the proxy or if the request is using an Authorization header + if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", }) return - default: - if c.GetHeader("Authorization") != "" { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - queries, queryErr := query.Values(types.LoginQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), - }) - - log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") - - if queryErr != nil { - switch proxy.Proxy { - case "nginx": - log.Error().Err(queryErr).Msg("Failed to build query") - c.JSON(501, gin.H{ - "status": 501, - "message": "Internal Server Error", - }) - return - default: - if api.handleError(c, "Failed to build query", queryErr) { - return - } - } - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) } + + // Build query + queries, queryErr := query.Values(types.LoginQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), + }) + + log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") + + // Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) + if api.handleError(c, "Failed to build query", queryErr) { + return + } + + // Redirect to login + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) }) api.Router.POST("/api/login", func(c *gin.Context) { + // Create login struct var login types.LoginRequest + // Bind JSON err := c.BindJSON(&login) + // Handle error if err != nil { log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ @@ -260,8 +263,10 @@ func (api *API) SetupRoutes() { log.Debug().Msg("Got login request") + // Get user based on username user := api.Auth.GetUser(login.Username) + // User does not exist if user == nil { log.Debug().Str("username", login.Username).Msg("User not found") c.JSON(401, gin.H{ @@ -271,6 +276,9 @@ func (api *API) SetupRoutes() { return } + log.Debug().Msg("Got user") + + // Check if password is correct if !api.Auth.CheckPassword(*user, login.Password) { log.Debug().Str("username", login.Username).Msg("Password incorrect") c.JSON(401, gin.H{ @@ -282,11 +290,13 @@ func (api *API) SetupRoutes() { log.Debug().Msg("Password correct, logging in") + // Create session cookie with username as provider api.Auth.CreateSessionCookie(c, &types.SessionCookie{ Username: login.Username, Provider: "username", }) + // Return logged in c.JSON(200, gin.H{ "status": 200, "message": "Logged in", @@ -294,12 +304,17 @@ func (api *API) SetupRoutes() { }) api.Router.POST("/api/logout", func(c *gin.Context) { + log.Debug().Msg("Logging out") + + // Delete session cookie api.Auth.DeleteSessionCookie(c) log.Debug().Msg("Cleaning up redirect cookie") + // Clean up redirect cookie if it exists c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) + // Return logged out c.JSON(200, gin.H{ "status": 200, "message": "Logged out", @@ -308,19 +323,24 @@ func (api *API) SetupRoutes() { api.Router.GET("/api/status", func(c *gin.Context) { log.Debug().Msg("Checking status") + + // Get user context userContext := api.Hooks.UseUserContext(c) + // Get configured providers configuredProviders := api.Providers.GetConfiguredProviders() + // We have username/password configured so add it to our providers if api.Auth.UserAuthConfigured() { configuredProviders = append(configuredProviders, "username") } + // We are not logged in so return unauthorized if !userContext.IsLoggedIn { - log.Debug().Msg("Unauthenticated") + log.Debug().Msg("Unauthorized") c.JSON(200, gin.H{ "status": 200, - "message": "Unauthenticated", + "message": "Unauthorized", "username": "", "isLoggedIn": false, "oauth": false, @@ -333,6 +353,7 @@ func (api *API) SetupRoutes() { log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated") + // We are logged in so return our user context c.JSON(200, gin.H{ "status": 200, "message": "Authenticated", @@ -345,18 +366,14 @@ func (api *API) SetupRoutes() { }) }) - api.Router.GET("/api/healthcheck", func(c *gin.Context) { - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - }) - }) - api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { + // Create struct for OAuth request var request types.OAuthRequest + // Bind URI bindErr := c.BindUri(&request) + // Handle error if bindErr != nil { log.Error().Err(bindErr).Msg("Failed to bind URI") c.JSON(400, gin.H{ @@ -368,8 +385,10 @@ func (api *API) SetupRoutes() { log.Debug().Msg("Got OAuth request") + // Check if provider exists provider := api.Providers.GetProvider(request.Provider) + // Provider does not exist if provider == nil { c.JSON(404, gin.H{ "status": 404, @@ -380,24 +399,38 @@ func (api *API) SetupRoutes() { log.Debug().Str("provider", request.Provider).Msg("Got provider") + // Get auth URL authURL := provider.GetAuthURL() log.Debug().Msg("Got auth URL") + // Get redirect URI redirectURI := c.Query("redirect_uri") + // Set redirect cookie if redirect URI is provided if redirectURI != "" { log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) } + // Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it if request.Provider == "tailscale" { + // Build tailscale query tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{ - Code: (1000 + rand.IntN(9000)), // doesn't need to be secure, just there to avoid caching + Code: (1000 + rand.IntN(9000)), }) - if api.handleError(c, "Failed to build query", tailscaleQueryErr) { + + // Handle error + if tailscaleQueryErr != nil { + log.Error().Err(tailscaleQueryErr).Msg("Failed to build query") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) return } + + // Return tailscale URL (immidiately redirects to the callback) c.JSON(200, gin.H{ "status": 200, "message": "Ok", @@ -406,6 +439,7 @@ func (api *API) SetupRoutes() { return } + // Return auth URL c.JSON(200, gin.H{ "status": 200, "message": "Ok", @@ -414,18 +448,23 @@ func (api *API) SetupRoutes() { }) api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { + // Create struct for OAuth request var providerName types.OAuthRequest + // Bind URI bindErr := c.BindUri(&providerName) + // Handle error if api.handleError(c, "Failed to bind URI", bindErr) { return } log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") + // Get code code := c.Query("code") + // Code empty so redirect to error if code == "" { log.Error().Msg("No code provided") c.Redirect(http.StatusPermanentRedirect, "/error") @@ -434,51 +473,67 @@ func (api *API) SetupRoutes() { log.Debug().Msg("Got code") + // Get provider provider := api.Providers.GetProvider(providerName.Provider) log.Debug().Str("provider", providerName.Provider).Msg("Got provider") + // Provider does not exist if provider == nil { c.Redirect(http.StatusPermanentRedirect, "/not-found") return } + // Exchange token (authenticates user) _, tokenErr := provider.ExchangeToken(code) log.Debug().Msg("Got token") + // Handle error if api.handleError(c, "Failed to exchange token", tokenErr) { return } + // Get email email, emailErr := api.Providers.GetUser(providerName.Provider) log.Debug().Str("email", email).Msg("Got email") + // Handle error if api.handleError(c, "Failed to get user", emailErr) { return } + // Email is not whitelisted if !api.Auth.EmailWhitelisted(email) { log.Warn().Str("email", email).Msg("Email not whitelisted") + + // Build query unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ Username: email, }) + + // Handle error if api.handleError(c, "Failed to build query", unauthorizedQueryErr) { return } + + // Redirect to unauthorized c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) } log.Debug().Msg("Email whitelisted") + // Create session cookie api.Auth.CreateSessionCookie(c, &types.SessionCookie{ Username: email, Provider: providerName.Provider, }) + // Get redirect URI redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") + // If it is empty it means that no redirect_uri was provided to the login screen so we just log in if redirectURIErr != nil { c.JSON(200, gin.H{ "status": 200, @@ -488,28 +543,44 @@ func (api *API) SetupRoutes() { log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") + // Clean up redirect cookie since we already have the value c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) + // Build query redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ RedirectURI: redirectURI, }) log.Debug().Msg("Got redirect query") + // Handle error if api.handleError(c, "Failed to build query", redirectQueryErr) { return } + // Redirect to continue with the redirect URI c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) }) + + // Simple healthcheck + api.Router.GET("/api/healthcheck", func(c *gin.Context) { + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + }) + }) } func (api *API) Run() { log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") + + // Run server api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) } +// handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths) func (api *API) handleError(c *gin.Context, msg string, err error) bool { + // If error is not nil log it and redirect to error page also return true so we can stop further processing if err != nil { log.Error().Err(err).Msg(msg) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL)) @@ -518,19 +589,25 @@ func (api *API) handleError(c *gin.Context, msg string, err error) bool { return false } +// zerolog is a middleware for gin that logs requests using zerolog func zerolog() gin.HandlerFunc { return func(c *gin.Context) { + // Get initial time tStart := time.Now() + // Process request c.Next() + // Get status code, address, method and path code := c.Writer.Status() address := c.Request.RemoteAddr method := c.Request.Method path := c.Request.URL.Path + // Get latency latency := time.Since(tStart).String() + // Log request switch { case code >= 200 && code < 300: log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 10b9e09..fc88051 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -4,8 +4,12 @@ import ( "embed" ) +// UI assets +// //go:embed dist var Assets embed.FS +// Version file +// //go:embed version -var Version string \ No newline at end of file +var Version string diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 3448346..273c9ba 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -31,6 +31,7 @@ type Auth struct { } func (auth *Auth) GetUser(username string) *types.User { + // Loop through users and return the user if the username matches for _, user := range auth.Users { if user.Username == username { return &user @@ -40,64 +41,93 @@ func (auth *Auth) GetUser(username string) *types.User { } func (auth *Auth) CheckPassword(user types.User, password string) bool { - hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) - return hashedPasswordErr == nil + // Compare the hashed password with the password provided + return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } func (auth *Auth) EmailWhitelisted(emailSrc string) bool { + // If the whitelist is empty, allow all emails if len(auth.OAuthWhitelist) == 0 { return true } + + // Loop through the whitelist and return true if the email matches for _, email := range auth.OAuthWhitelist { if email == emailSrc { return true } } + + // If no emails match, return false return false } func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) { log.Debug().Msg("Creating session cookie") + + // Get session sessions := sessions.Default(c) + log.Debug().Msg("Setting session cookie") + + // Set data sessions.Set("username", data.Username) sessions.Set("provider", data.Provider) sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix()) + + // Save session sessions.Save() } func (auth *Auth) DeleteSessionCookie(c *gin.Context) { log.Debug().Msg("Deleting session cookie") + + // Get session sessions := sessions.Default(c) + + // Clear session sessions.Clear() + + // Save session sessions.Save() } func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { log.Debug().Msg("Getting session cookie") + + // Get session sessions := sessions.Default(c) + // Get data cookieUsername := sessions.Get("username") cookieProvider := sessions.Get("provider") cookieExpiry := sessions.Get("expiry") + // Convert interfaces to correct types username, usernameOk := cookieUsername.(string) provider, providerOk := cookieProvider.(string) expiry, expiryOk := cookieExpiry.(int64) + // Check if the cookie is invalid if !usernameOk || !providerOk || !expiryOk { log.Warn().Msg("Session cookie invalid") return types.SessionCookie{} } + // Check if the cookie has expired if time.Now().Unix() > expiry { log.Warn().Msg("Session cookie expired") + + // If it has, delete it auth.DeleteSessionCookie(c) + + // Return empty cookie return types.SessionCookie{} } log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie") + // Return the cookie return types.SessionCookie{ Username: username, Provider: provider, @@ -105,42 +135,56 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { } func (auth *Auth) UserAuthConfigured() bool { + // If there are users, return true return len(auth.Users) > 0 } func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) { + // Check if we have access to the Docker API isConnected := auth.Docker.DockerConnected() + // If we don't have access, it is assumed that the user has access if !isConnected { log.Debug().Msg("Docker not connected, allowing access") return true, nil } + // Get the app ID from the host appId := strings.Split(host, ".")[0] + + // Get the containers containers, containersErr := auth.Docker.GetContainers() + // If there is an error, return false if containersErr != nil { return false, containersErr } log.Debug().Msg("Got containers") + // Loop through the containers for _, container := range containers { + // Inspect the container inspect, inspectErr := auth.Docker.InspectContainer(container.ID) + // If there is an error, return false if inspectErr != nil { return false, inspectErr } + // Get the container name (for some reason it is /name) containerName := strings.Split(inspect.Name, "/")[1] + // There is a container with the same name as the app ID if containerName == appId { log.Debug().Str("container", containerName).Msg("Found container") + // Get only the tinyauth labels in a struct labels := utils.GetTinyauthLabels(inspect.Config.Labels) log.Debug().Msg("Got labels") + // If the container has an oauth whitelist, check if the user is in it if context.OAuth && len(labels.OAuthWhitelist) != 0 { log.Debug().Msg("Checking OAuth whitelist") if slices.Contains(labels.OAuthWhitelist, context.Username) { @@ -149,6 +193,7 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, return false, nil } + // If the container has users, check if the user is in it if len(labels.Users) != 0 { log.Debug().Msg("Checking users") if slices.Contains(labels.Users, context.Username) { @@ -162,32 +207,40 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, log.Debug().Msg("No matching container found, allowing access") + // If no matching container is found, allow access return true, nil } func (auth *Auth) GetBasicAuth(c *gin.Context) types.User { + // Get the Authorization header header := c.GetHeader("Authorization") + // If the header is empty, return an empty user if header == "" { return types.User{} } + // Split the header headerSplit := strings.Split(header, " ") if len(headerSplit) != 2 { return types.User{} } + // Check if the header is Basic if headerSplit[0] != "Basic" { return types.User{} } + // Split the credentials credentials := strings.Split(headerSplit[1], ":") + // If the credentials are not in the correct format, return an empty user if len(credentials) != 2 { return types.User{} } + // Return the user return types.User{ Username: credentials[0], Password: credentials[1], diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 996b879..ce12089 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -1,5 +1,6 @@ package constants +// TinyauthLabels is a list of labels that can be used in a tinyauth protected container var TinyauthLabels = []string{ "tinyauth.oauth.whitelist", "tinyauth.users", diff --git a/internal/docker/docker.go b/internal/docker/docker.go index c25bcf6..3accea6 100644 --- a/internal/docker/docker.go +++ b/internal/docker/docker.go @@ -18,39 +18,50 @@ type Docker struct { } func (docker *Docker) Init() error { + // Create a new docker client apiClient, err := client.NewClientWithOpts(client.FromEnv) + // Check if there was an error if err != nil { return err } + // Set the context and api client docker.Context = context.Background() docker.Client = apiClient + // Done return nil } func (docker *Docker) GetContainers() ([]types.Container, error) { + // Get the list of containers containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) + // Check if there was an error if err != nil { return nil, err } + // Return the containers return containers, nil } func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) { + // Inspect the container inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) + // Check if there was an error if err != nil { return types.ContainerJSON{}, err } + // Return the inspect return inspect, nil } func (docker *Docker) DockerConnected() bool { + // Ping the docker client if there is an error it is not connected _, err := docker.Client.Ping(docker.Context) return err == nil } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 88869bf..cdbf852 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -22,13 +22,19 @@ type Hooks struct { } func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { + // Get session cookie and basic auth cookie := hooks.Auth.GetSessionCookie(c) basic := hooks.Auth.GetBasicAuth(c) + // Check if basic auth is set if basic.Username != "" { log.Debug().Msg("Got basic auth") + + // Check if user exists and password is correct user := hooks.Auth.GetUser(basic.Username) + if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { + // Return user context since we are logged in with basic auth return types.UserContext{ Username: basic.Username, IsLoggedIn: true, @@ -39,10 +45,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { } + // Check if session cookie is username/password auth if cookie.Provider == "username" { log.Debug().Msg("Provider is username") + + // Check if user exists if hooks.Auth.GetUser(cookie.Username) != nil { log.Debug().Msg("User exists") + + // It exists so we are logged in return types.UserContext{ Username: cookie.Username, IsLoggedIn: true, @@ -53,13 +64,22 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { } log.Debug().Msg("Provider is not username") + + // The provider is not username so we need to check if it is an oauth provider provider := hooks.Providers.GetProvider(cookie.Provider) + // If we have a provider with this name if provider != nil { log.Debug().Msg("Provider exists") + + // Check if the oauth email is whitelisted if !hooks.Auth.EmailWhitelisted(cookie.Username) { log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted") + + // It isn't so we delete the cookie and return an empty context hooks.Auth.DeleteSessionCookie(c) + + // Return empty context return types.UserContext{ Username: "", IsLoggedIn: false, @@ -67,7 +87,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { Provider: "", } } + log.Debug().Msg("Email is whitelisted") + + // Return user context since we are logged in with oauth return types.UserContext{ Username: cookie.Username, IsLoggedIn: true, @@ -76,6 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { } } + // Neither basic auth or oauth is set so we return an empty context return types.UserContext{ Username: "", IsLoggedIn: false, diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 9ec2690..86ca010 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -21,23 +21,33 @@ type OAuth struct { } func (oauth *OAuth) Init() { + // Create a new context and verifier oauth.Context = context.Background() oauth.Verifier = oauth2.GenerateVerifier() } func (oauth *OAuth) GetAuthURL() string { + // Return the auth url return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) } func (oauth *OAuth) ExchangeToken(code string) (string, error) { + // Exchange the code for a token token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) + + // Check if there was an error if err != nil { return "", err } + + // Set the token oauth.Token = token + + // Return the access token return oauth.Token.AccessToken, nil } func (oauth *OAuth) GetClient() *http.Client { + // Return the http client with the token set return oauth.Config.Client(oauth.Context, oauth.Token) } diff --git a/internal/providers/generic.go b/internal/providers/generic.go index 2dbcf4e..80b32ef 100644 --- a/internal/providers/generic.go +++ b/internal/providers/generic.go @@ -8,36 +8,45 @@ import ( "github.com/rs/zerolog/log" ) +// We are assuming that the generic provider will return a JSON object with an email field type GenericUserInfoResponse struct { Email string `json:"email"` } func GetGenericEmail(client *http.Client, url string) (string, error) { + // Using the oauth client get the user info url res, resErr := client.Get(url) + // Check if there was an error if resErr != nil { return "", resErr } log.Debug().Msg("Got response from generic provider") + // Read the body of the response body, bodyErr := io.ReadAll(res.Body) + // Check if there was an error if bodyErr != nil { return "", bodyErr } log.Debug().Msg("Read body from generic provider") + // Parse the body into a user struct var user GenericUserInfoResponse + // Unmarshal the body into the user struct jsonErr := json.Unmarshal(body, &user) + // Check if there was an error if jsonErr != nil { return "", jsonErr } log.Debug().Msg("Parsed user from generic provider") + // Return the email return user.Email, nil } diff --git a/internal/providers/github.go b/internal/providers/github.go index 515652d..d48d5df 100644 --- a/internal/providers/github.go +++ b/internal/providers/github.go @@ -9,47 +9,58 @@ import ( "github.com/rs/zerolog/log" ) +// Github has a different response than the generic provider type GithubUserInfoResponse []struct { Email string `json:"email"` Primary bool `json:"primary"` } +// The scopes required for the github provider func GithubScopes() []string { return []string{"user:email"} } func GetGithubEmail(client *http.Client) (string, error) { + // Get the user emails from github using the oauth http client res, resErr := client.Get("https://api.github.com/user/emails") + // Check if there was an error if resErr != nil { return "", resErr } log.Debug().Msg("Got response from github") + // Read the body of the response body, bodyErr := io.ReadAll(res.Body) + // Check if there was an error if bodyErr != nil { return "", bodyErr } log.Debug().Msg("Read body from github") + // Parse the body into a user struct var emails GithubUserInfoResponse + // Unmarshal the body into the user struct jsonErr := json.Unmarshal(body, &emails) + // Check if there was an error if jsonErr != nil { return "", jsonErr } log.Debug().Msg("Parsed emails from github") + // Find and return the primary email for _, email := range emails { if email.Primary { return email.Email, nil } } + // User does not have a primary email? return "", errors.New("no primary email found") } diff --git a/internal/providers/google.go b/internal/providers/google.go index d3554fd..4b31891 100644 --- a/internal/providers/google.go +++ b/internal/providers/google.go @@ -8,40 +8,50 @@ import ( "github.com/rs/zerolog/log" ) +// Google works the same as the generic provider type GoogleUserInfoResponse struct { Email string `json:"email"` } +// The scopes required for the google provider func GoogleScopes() []string { return []string{"https://www.googleapis.com/auth/userinfo.email"} } func GetGoogleEmail(client *http.Client) (string, error) { + // Get the user info from google using the oauth http client res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me") + // Check if there was an error if resErr != nil { return "", resErr } log.Debug().Msg("Got response from google") + // Read the body of the response body, bodyErr := io.ReadAll(res.Body) + // Check if there was an error if bodyErr != nil { return "", bodyErr } log.Debug().Msg("Read body from google") + // Parse the body into a user struct var user GoogleUserInfoResponse + // Unmarshal the body into the user struct jsonErr := json.Unmarshal(body, &user) + // Check if there was an error if jsonErr != nil { return "", jsonErr } log.Debug().Msg("Parsed user from google") + // Return the email return user.Email, nil } diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 58ef44b..e826684 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -25,8 +25,11 @@ type Providers struct { } func (providers *Providers) Init() { + // If we have a client id and secret for github, initialize the oauth provider if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { log.Info().Msg("Initializing Github OAuth") + + // Create a new oauth provider with the github config providers.Github = oauth.NewOAuth(oauth2.Config{ ClientID: providers.Config.GithubClientId, ClientSecret: providers.Config.GithubClientSecret, @@ -34,10 +37,16 @@ func (providers *Providers) Init() { Scopes: GithubScopes(), Endpoint: endpoints.GitHub, }) + + // Initialize the oauth provider providers.Github.Init() } + + // If we have a client id and secret for google, initialize the oauth provider if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { log.Info().Msg("Initializing Google OAuth") + + // Create a new oauth provider with the google config providers.Google = oauth.NewOAuth(oauth2.Config{ ClientID: providers.Config.GoogleClientId, ClientSecret: providers.Config.GoogleClientSecret, @@ -45,10 +54,15 @@ func (providers *Providers) Init() { Scopes: GoogleScopes(), Endpoint: endpoints.Google, }) + + // Initialize the oauth provider providers.Google.Init() } + if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" { log.Info().Msg("Initializing Tailscale OAuth") + + // Create a new oauth provider with the tailscale config providers.Tailscale = oauth.NewOAuth(oauth2.Config{ ClientID: providers.Config.TailscaleClientId, ClientSecret: providers.Config.TailscaleClientSecret, @@ -56,10 +70,16 @@ func (providers *Providers) Init() { Scopes: TailscaleScopes(), Endpoint: TailscaleEndpoint, }) + + // Initialize the oauth provider providers.Tailscale.Init() } + + // If we have a client id and secret for generic oauth, initialize the oauth provider if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { log.Info().Msg("Initializing Generic OAuth") + + // Create a new oauth provider with the generic config providers.Generic = oauth.NewOAuth(oauth2.Config{ ClientID: providers.Config.GenericClientId, ClientSecret: providers.Config.GenericClientSecret, @@ -70,11 +90,14 @@ func (providers *Providers) Init() { TokenURL: providers.Config.GenericTokenURL, }, }) + + // Initialize the oauth provider providers.Generic.Init() } } func (providers *Providers) GetProvider(provider string) *oauth.OAuth { + // Return the provider based on the provider string switch provider { case "github": return providers.Github @@ -90,58 +113,103 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { } func (providers *Providers) GetUser(provider string) (string, error) { + // Get the email from the provider switch provider { case "github": + // If the github provider is not configured, return an error if providers.Github == nil { log.Debug().Msg("Github provider not configured") return "", nil } + + // Get the client from the github provider client := providers.Github.GetClient() + log.Debug().Msg("Got client from github") + + // Get the email from the github provider email, emailErr := GetGithubEmail(client) + + // Check if there was an error if emailErr != nil { return "", emailErr } + log.Debug().Msg("Got email from github") + + // Return the email return email, nil case "google": + // If the google provider is not configured, return an error if providers.Google == nil { log.Debug().Msg("Google provider not configured") return "", nil } + + // Get the client from the google provider client := providers.Google.GetClient() + log.Debug().Msg("Got client from google") + + // Get the email from the google provider email, emailErr := GetGoogleEmail(client) + + // Check if there was an error if emailErr != nil { return "", emailErr } + log.Debug().Msg("Got email from google") + + // Return the email return email, nil case "tailscale": + // If the tailscale provider is not configured, return an error if providers.Tailscale == nil { log.Debug().Msg("Tailscale provider not configured") return "", nil } + + // Get the client from the tailscale provider client := providers.Tailscale.GetClient() + log.Debug().Msg("Got client from tailscale") + + // Get the email from the tailscale provider email, emailErr := GetTailscaleEmail(client) + + // Check if there was an error if emailErr != nil { return "", emailErr } + log.Debug().Msg("Got email from tailscale") + + // Return the email return email, nil case "generic": + // If the generic provider is not configured, return an error if providers.Generic == nil { log.Debug().Msg("Generic provider not configured") return "", nil } + + // Get the client from the generic provider client := providers.Generic.GetClient() + log.Debug().Msg("Got client from generic") + + // Get the email from the generic provider email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL) + + // Check if there was an error if emailErr != nil { return "", emailErr } + log.Debug().Msg("Got email from generic") + + // Return the email return email, nil default: return "", nil @@ -149,6 +217,7 @@ func (providers *Providers) GetUser(provider string) (string, error) { } func (provider *Providers) GetConfiguredProviders() []string { + // Create a list of the configured providers providers := []string{} if provider.Github != nil { providers = append(providers, "github") diff --git a/internal/providers/tailscale.go b/internal/providers/tailscale.go index 99a0346..ec8f08c 100644 --- a/internal/providers/tailscale.go +++ b/internal/providers/tailscale.go @@ -9,48 +9,60 @@ import ( "golang.org/x/oauth2" ) +// The tailscale email is the loginName type TailscaleUser struct { LoginName string `json:"loginName"` } +// The response from the tailscale user info endpoint type TailscaleUserInfoResponse struct { Users []TailscaleUser `json:"users"` } +// The scopes required for the tailscale provider func TailscaleScopes() []string { return []string{"users:read"} } +// The tailscale endpoint var TailscaleEndpoint = oauth2.Endpoint{ TokenURL: "https://api.tailscale.com/api/v2/oauth/token", } func GetTailscaleEmail(client *http.Client) (string, error) { + // Get the user info from tailscale using the oauth http client res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users") + // Check if there was an error if resErr != nil { return "", resErr } log.Debug().Msg("Got response from tailscale") + // Read the body of the response body, bodyErr := io.ReadAll(res.Body) + // Check if there was an error if bodyErr != nil { return "", bodyErr } log.Debug().Msg("Read body from tailscale") + // Parse the body into a user struct var users TailscaleUserInfoResponse + // Unmarshal the body into the user struct jsonErr := json.Unmarshal(body, &users) + // Check if there was an error if jsonErr != nil { return "", jsonErr } log.Debug().Msg("Parsed users from tailscale") + // Return the email of the first user return users.Users[0].LoginName, nil } diff --git a/internal/types/types.go b/internal/types/types.go index 81bf568..591ca63 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -2,22 +2,27 @@ package types import "tinyauth/internal/oauth" +// LoginQuery is the query parameters for the login endpoint type LoginQuery struct { RedirectURI string `url:"redirect_uri"` } +// LoginRequest is the request body for the login endpoint type LoginRequest struct { Username string `json:"username"` Password string `json:"password"` } +// User is the struct for a user type User struct { Username string Password string } +// Users is a list of users type Users []User +// Config is the configuration for the tinyauth server type Config struct { Port int `mapstructure:"port" validate:"required"` Address string `validate:"required,ip4_addr" mapstructure:"address"` @@ -49,6 +54,7 @@ type Config struct { LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"` } +// UserContext is the context for the user type UserContext struct { Username string IsLoggedIn bool @@ -56,6 +62,7 @@ type UserContext struct { Provider string } +// APIConfig is the configuration for the API type APIConfig struct { Port int Address string @@ -66,6 +73,7 @@ type APIConfig struct { DisableContinue bool } +// OAuthConfig is the configuration for the providers type OAuthConfig struct { GithubClientId string GithubClientSecret string @@ -82,35 +90,42 @@ type OAuthConfig struct { AppURL string } +// OAuthRequest is the request for the OAuth endpoint type OAuthRequest struct { Provider string `uri:"provider" binding:"required"` } +// OAuthProviders is the struct for the OAuth providers type OAuthProviders struct { Github *oauth.OAuth Google *oauth.OAuth Microsoft *oauth.OAuth } +// UnauthorizedQuery is the query parameters for the unauthorized endpoint type UnauthorizedQuery struct { Username string `url:"username"` Resource string `url:"resource"` } +// SessionCookie is the cookie for the session (exculding the expiry) type SessionCookie struct { Username string Provider string } +// TinyauthLabels is the labels for the tinyauth container type TinyauthLabels struct { OAuthWhitelist []string Users []string } +// TailscaleQuery is the query parameters for the tailscale endpoint type TailscaleQuery struct { Code int `url:"code"` } +// Proxy is the uri parameters for the proxy endpoint type Proxy struct { Proxy string `uri:"proxy" binding:"required"` } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 17f052f..29344d5 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -12,20 +12,32 @@ import ( "github.com/rs/zerolog/log" ) +// Parses a list of comma separated users in a struct func ParseUsers(users string) (types.Users, error) { log.Debug().Msg("Parsing users") + + // Create a new users struct var usersParsed types.Users + + // Split the users by comma userList := strings.Split(users, ",") + // Check if there are any users if len(userList) == 0 { return types.Users{}, errors.New("invalid user format") } + // Loop through the users and split them by colon for _, user := range userList { + // Split the user by colon userSplit := strings.Split(user, ":") + + // Check if the user is in the correct format if len(userSplit) != 2 { return types.Users{}, errors.New("invalid user format") } + + // Append the user to the users struct usersParsed = append(usersParsed, types.User{ Username: userSplit[0], Password: userSplit[1], @@ -34,43 +46,61 @@ func ParseUsers(users string) (types.Users, error) { log.Debug().Msg("Parsed users") + // Return the users struct return usersParsed, nil } +// Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> domain.com) func GetRootURL(urlSrc string) (string, error) { + // Make sure the url is valid urlParsed, parseErr := url.Parse(urlSrc) + // Check if there was an error if parseErr != nil { return "", parseErr } + // Split the hostname by period urlSplitted := strings.Split(urlParsed.Hostname(), ".") + // Get the last part of the url urlFinal := strings.Join(urlSplitted[1:], ".") + // Return the root domain return urlFinal, nil } +// Reads a file and returns the contents func ReadFile(file string) (string, error) { + // Check if the file exists _, statErr := os.Stat(file) + // Check if there was an error if statErr != nil { return "", statErr } + // Read the file data, readErr := os.ReadFile(file) + // Check if there was an error if readErr != nil { return "", readErr } + // Return the file contents return string(data), nil } +// Parses a file into a comma separated list of users func ParseFileToLine(content string) string { + // Split the content by newline lines := strings.Split(content, "\n") + + // Create a list of users users := make([]string, 0) + // Loop through the lines, trimming the whitespace and appending to the users list for _, line := range lines { if strings.TrimSpace(line) == "" { continue @@ -79,63 +109,92 @@ func ParseFileToLine(content string) string { users = append(users, strings.TrimSpace(line)) } + // Return the users as a comma separated string return strings.Join(users, ",") } +// Get the secret from the config or file func GetSecret(conf string, file string) string { + // If neither the config or file is set, return an empty string if conf == "" && file == "" { return "" } + // If the config is set, return the config (environment variable) if conf != "" { return conf } + // If the file is set, read the file contents, err := ReadFile(file) + // Check if there was an error if err != nil { return "" } + // Return the contents of the file return contents } +// Get the users from the config or file func GetUsers(conf string, file string) (types.Users, error) { + // Create a string to store the users var users string + // If neither the config or file is set, return an empty users struct if conf == "" && file == "" { return types.Users{}, nil } + // If the config (environment) is set, append the users to the users string if conf != "" { log.Debug().Msg("Using users from config") users += conf } + // If the file is set, read the file and append the users to the users string if file != "" { + // Read the file fileContents, fileErr := ReadFile(file) + // If there isn't an error we can append the users to the users string if fileErr == nil { log.Debug().Msg("Using users from file") + + // Append the users to the users string if users != "" { users += "," } + + // Parse the file contents into a comma separated list of users users += ParseFileToLine(fileContents) } } + // Return the parsed users return ParseUsers(users) } +// Check if any of the OAuth providers are configured based on the client id and secret func OAuthConfigured(config types.Config) bool { return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "") } +// Parse the docker labels to the tinyauth labels struct func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { + // Create a new tinyauth labels struct var tinyauthLabels types.TinyauthLabels + + // Loop through the labels for label, value := range labels { + + // Check if the label is in the tinyauth labels if slices.Contains(constants.TinyauthLabels, label) { + log.Debug().Str("label", label).Msg("Found label") + + // Add the label value to the tinyauth labels struct switch label { case "tinyauth.oauth.whitelist": tinyauthLabels.OAuthWhitelist = strings.Split(value, ",") @@ -144,5 +203,7 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { } } } + + // Return the tinyauth labels return tinyauthLabels }