mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-11-02 15:15:51 +00:00 
			
		
		
		
	Compare commits
	
		
			8 Commits
		
	
	
		
			2328e17ff4
			...
			feat/oauth
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					86a18b4cac | ||
| 
						 | 
					d171c5940b | ||
| 
						 | 
					3dff650e71 | ||
| 
						 | 
					1fec583ead | ||
| 
						 | 
					065b9eaf3d | ||
| 
						 | 
					dca09a3d9d | ||
| 
						 | 
					5e4e2ddbd9 | ||
| 
						 | 
					13032e564d | 
@@ -111,6 +111,11 @@ var rootCmd = &cobra.Command{
 | 
				
			|||||||
			LoginMaxRetries: config.LoginMaxRetries,
 | 
								LoginMaxRetries: config.LoginMaxRetries,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create hooks config
 | 
				
			||||||
 | 
							hooksConfig := types.HooksConfig{
 | 
				
			||||||
 | 
								Domain: domain,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Create docker service
 | 
							// Create docker service
 | 
				
			||||||
		docker := docker.NewDocker()
 | 
							docker := docker.NewDocker()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -128,7 +133,7 @@ var rootCmd = &cobra.Command{
 | 
				
			|||||||
		providers.Init()
 | 
							providers.Init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Create hooks service
 | 
							// Create hooks service
 | 
				
			||||||
		hooks := hooks.NewHooks(auth, providers)
 | 
							hooks := hooks.NewHooks(hooksConfig, auth, providers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Create handlers
 | 
							// Create handlers
 | 
				
			||||||
		handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
 | 
							handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
 | 
				
			||||||
@@ -189,7 +194,7 @@ func init() {
 | 
				
			|||||||
	rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.")
 | 
						rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.")
 | 
				
			||||||
	rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.")
 | 
						rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.")
 | 
				
			||||||
	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
 | 
						rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
 | 
				
			||||||
	rootCmd.Flags().String("generic-name", "Other", "Generic OAuth provider name.")
 | 
						rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.")
 | 
				
			||||||
	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
 | 
						rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
 | 
				
			||||||
	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
 | 
						rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
 | 
				
			||||||
	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.")
 | 
						rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.")
 | 
				
			||||||
 
 | 
				
			|||||||
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					span,
 | 
				
			||||||
 | 
					p {
 | 
				
			||||||
 | 
					  word-break: break-word;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -41,7 +41,8 @@
 | 
				
			|||||||
    "totpTitle": "Enter your TOTP code",
 | 
					    "totpTitle": "Enter your TOTP code",
 | 
				
			||||||
    "unauthorizedTitle": "Unauthorized",
 | 
					    "unauthorizedTitle": "Unauthorized",
 | 
				
			||||||
    "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.",
 | 
					    "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.",
 | 
				
			||||||
    "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.",
 | 
					    "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.",
 | 
				
			||||||
 | 
					    "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.",
 | 
				
			||||||
    "unauthorizedButton": "Try again",
 | 
					    "unauthorizedButton": "Try again",
 | 
				
			||||||
    "untrustedRedirectTitle": "Untrusted redirect",
 | 
					    "untrustedRedirectTitle": "Untrusted redirect",
 | 
				
			||||||
    "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?",
 | 
					    "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?",
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -41,7 +41,8 @@
 | 
				
			|||||||
    "totpTitle": "Enter your TOTP code",
 | 
					    "totpTitle": "Enter your TOTP code",
 | 
				
			||||||
    "unauthorizedTitle": "Unauthorized",
 | 
					    "unauthorizedTitle": "Unauthorized",
 | 
				
			||||||
    "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.",
 | 
					    "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.",
 | 
				
			||||||
    "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.",
 | 
					    "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.",
 | 
				
			||||||
 | 
					    "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.",
 | 
				
			||||||
    "unauthorizedButton": "Try again",
 | 
					    "unauthorizedButton": "Try again",
 | 
				
			||||||
    "untrustedRedirectTitle": "Untrusted redirect",
 | 
					    "untrustedRedirectTitle": "Untrusted redirect",
 | 
				
			||||||
    "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?",
 | 
					    "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?",
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,6 +19,7 @@ import { TotpPage } from "./pages/totp-page.tsx";
 | 
				
			|||||||
import { AppContextProvider } from "./context/app-context.tsx";
 | 
					import { AppContextProvider } from "./context/app-context.tsx";
 | 
				
			||||||
import "./lib/i18n/i18n.ts";
 | 
					import "./lib/i18n/i18n.ts";
 | 
				
			||||||
import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx";
 | 
					import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx";
 | 
				
			||||||
 | 
					import "./index.css";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const queryClient = new QueryClient();
 | 
					const queryClient = new QueryClient();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -38,7 +39,10 @@ createRoot(document.getElementById("root")!).render(
 | 
				
			|||||||
                <Route path="/continue" element={<ContinuePage />} />
 | 
					                <Route path="/continue" element={<ContinuePage />} />
 | 
				
			||||||
                <Route path="/unauthorized" element={<UnauthorizedPage />} />
 | 
					                <Route path="/unauthorized" element={<UnauthorizedPage />} />
 | 
				
			||||||
                <Route path="/error" element={<InternalServerError />} />
 | 
					                <Route path="/error" element={<InternalServerError />} />
 | 
				
			||||||
                <Route path="/forgot-password" element={<ForgotPasswordPage />} />
 | 
					                <Route
 | 
				
			||||||
 | 
					                  path="/forgot-password"
 | 
				
			||||||
 | 
					                  element={<ForgotPasswordPage />}
 | 
				
			||||||
 | 
					                />
 | 
				
			||||||
                <Route path="*" element={<NotFoundPage />} />
 | 
					                <Route path="*" element={<NotFoundPage />} />
 | 
				
			||||||
              </Routes>
 | 
					              </Routes>
 | 
				
			||||||
            </BrowserRouter>
 | 
					            </BrowserRouter>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,7 +10,7 @@ import { useAppContext } from "../context/app-context";
 | 
				
			|||||||
import { Trans, useTranslation } from "react-i18next";
 | 
					import { Trans, useTranslation } from "react-i18next";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export const LogoutPage = () => {
 | 
					export const LogoutPage = () => {
 | 
				
			||||||
  const { isLoggedIn, username, oauth, provider } = useUserContext();
 | 
					  const { isLoggedIn, oauth, provider, email, username } = useUserContext();
 | 
				
			||||||
  const { genericName } = useAppContext();
 | 
					  const { genericName } = useAppContext();
 | 
				
			||||||
  const { t } = useTranslation();
 | 
					  const { t } = useTranslation();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -56,7 +56,7 @@ export const LogoutPage = () => {
 | 
				
			|||||||
              values={{
 | 
					              values={{
 | 
				
			||||||
                provider:
 | 
					                provider:
 | 
				
			||||||
                  provider === "generic" ? genericName : capitalize(provider),
 | 
					                  provider === "generic" ? genericName : capitalize(provider),
 | 
				
			||||||
                username: username,
 | 
					                username: email,
 | 
				
			||||||
              }}
 | 
					              }}
 | 
				
			||||||
            />
 | 
					            />
 | 
				
			||||||
          ) : (
 | 
					          ) : (
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,11 +3,13 @@ import { Layout } from "../components/layouts/layout";
 | 
				
			|||||||
import { Navigate } from "react-router";
 | 
					import { Navigate } from "react-router";
 | 
				
			||||||
import { isQueryValid } from "../utils/utils";
 | 
					import { isQueryValid } from "../utils/utils";
 | 
				
			||||||
import { Trans, useTranslation } from "react-i18next";
 | 
					import { Trans, useTranslation } from "react-i18next";
 | 
				
			||||||
 | 
					import React from "react";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export const UnauthorizedPage = () => {
 | 
					export const UnauthorizedPage = () => {
 | 
				
			||||||
  const queryString = window.location.search;
 | 
					  const queryString = window.location.search;
 | 
				
			||||||
  const params = new URLSearchParams(queryString);
 | 
					  const params = new URLSearchParams(queryString);
 | 
				
			||||||
  const username = params.get("username") ?? "";
 | 
					  const username = params.get("username") ?? "";
 | 
				
			||||||
 | 
					  const groupErr = params.get("groupErr") ?? "";
 | 
				
			||||||
  const resource = params.get("resource") ?? "";
 | 
					  const resource = params.get("resource") ?? "";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const { t } = useTranslation();
 | 
					  const { t } = useTranslation();
 | 
				
			||||||
@@ -16,33 +18,54 @@ export const UnauthorizedPage = () => {
 | 
				
			|||||||
    return <Navigate to="/" />;
 | 
					    return <Navigate to="/" />;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (isQueryValid(resource) && !isQueryValid(groupErr)) {
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
    <Layout>
 | 
					      <UnauthorizedLayout>
 | 
				
			||||||
      <Paper shadow="md" p={30} mt={30} radius="md" withBorder>
 | 
					 | 
				
			||||||
        <Text size="xl" fw={700}>
 | 
					 | 
				
			||||||
          {t("Unauthorized")}
 | 
					 | 
				
			||||||
        </Text>
 | 
					 | 
				
			||||||
        <Text>
 | 
					 | 
				
			||||||
          {isQueryValid(resource) ? (
 | 
					 | 
				
			||||||
            <Text>
 | 
					 | 
				
			||||||
        <Trans
 | 
					        <Trans
 | 
				
			||||||
          i18nKey="unauthorizedResourceSubtitle"
 | 
					          i18nKey="unauthorizedResourceSubtitle"
 | 
				
			||||||
          t={t}
 | 
					          t={t}
 | 
				
			||||||
          components={{ Code: <Code /> }}
 | 
					          components={{ Code: <Code /> }}
 | 
				
			||||||
          values={{ resource, username }}
 | 
					          values={{ resource, username }}
 | 
				
			||||||
        />
 | 
					        />
 | 
				
			||||||
            </Text>
 | 
					      </UnauthorizedLayout>
 | 
				
			||||||
          ) : (
 | 
					    );
 | 
				
			||||||
            <Text>
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (isQueryValid(groupErr) && isQueryValid(resource)) {
 | 
				
			||||||
 | 
					    return (
 | 
				
			||||||
 | 
					      <UnauthorizedLayout>
 | 
				
			||||||
        <Trans
 | 
					        <Trans
 | 
				
			||||||
                i18nKey="unaothorizedLoginSubtitle"
 | 
					          i18nKey="unauthorizedGroupsSubtitle"
 | 
				
			||||||
 | 
					          t={t}
 | 
				
			||||||
 | 
					          components={{ Code: <Code /> }}
 | 
				
			||||||
 | 
					          values={{ username, resource }}
 | 
				
			||||||
 | 
					        />
 | 
				
			||||||
 | 
					      </UnauthorizedLayout>
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <UnauthorizedLayout>
 | 
				
			||||||
 | 
					      <Trans
 | 
				
			||||||
 | 
					        i18nKey="unauthorizedLoginSubtitle"
 | 
				
			||||||
        t={t}
 | 
					        t={t}
 | 
				
			||||||
        components={{ Code: <Code /> }}
 | 
					        components={{ Code: <Code /> }}
 | 
				
			||||||
        values={{ username }}
 | 
					        values={{ username }}
 | 
				
			||||||
      />
 | 
					      />
 | 
				
			||||||
 | 
					    </UnauthorizedLayout>
 | 
				
			||||||
 | 
					  );
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const UnauthorizedLayout = ({ children }: { children: React.ReactNode }) => {
 | 
				
			||||||
 | 
					  const { t } = useTranslation();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (
 | 
				
			||||||
 | 
					    <Layout>
 | 
				
			||||||
 | 
					      <Paper shadow="md" p={30} mt={30} radius="md" withBorder>
 | 
				
			||||||
 | 
					        <Text size="xl" fw={700}>
 | 
				
			||||||
 | 
					          {t("Unauthorized")}
 | 
				
			||||||
        </Text>
 | 
					        </Text>
 | 
				
			||||||
          )}
 | 
					        <Text>{children}</Text>
 | 
				
			||||||
        </Text>
 | 
					 | 
				
			||||||
        <Button
 | 
					        <Button
 | 
				
			||||||
          fullWidth
 | 
					          fullWidth
 | 
				
			||||||
          mt="xl"
 | 
					          mt="xl"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,8 @@ import { z } from "zod";
 | 
				
			|||||||
export const userContextSchema = z.object({
 | 
					export const userContextSchema = z.object({
 | 
				
			||||||
  isLoggedIn: z.boolean(),
 | 
					  isLoggedIn: z.boolean(),
 | 
				
			||||||
  username: z.string(),
 | 
					  username: z.string(),
 | 
				
			||||||
 | 
					  name: z.string(),
 | 
				
			||||||
 | 
					  email: z.string(),
 | 
				
			||||||
  oauth: z.boolean(),
 | 
					  oauth: z.boolean(),
 | 
				
			||||||
  provider: z.string(),
 | 
					  provider: z.string(),
 | 
				
			||||||
  totpPending: z.boolean(),
 | 
					  totpPending: z.boolean(),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{
 | 
				
			|||||||
	LoginMaxRetries: 0,
 | 
						LoginMaxRetries: 0,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Simple hooks config for tests
 | 
				
			||||||
 | 
					var hooksConfig = types.HooksConfig{
 | 
				
			||||||
 | 
						Domain: "localhost",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Cookie
 | 
					// Cookie
 | 
				
			||||||
var cookie string
 | 
					var cookie string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API {
 | 
				
			|||||||
	providers.Init()
 | 
						providers.Init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Create hooks service
 | 
						// Create hooks service
 | 
				
			||||||
	hooks := hooks.NewHooks(auth, providers)
 | 
						hooks := hooks.NewHooks(hooksConfig, auth, providers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Create handlers service
 | 
						// Create handlers service
 | 
				
			||||||
	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
 | 
						handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -160,9 +160,12 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Set data
 | 
						// Set data
 | 
				
			||||||
	session.Values["username"] = data.Username
 | 
						session.Values["username"] = data.Username
 | 
				
			||||||
 | 
						session.Values["name"] = data.Name
 | 
				
			||||||
 | 
						session.Values["email"] = data.Email
 | 
				
			||||||
	session.Values["provider"] = data.Provider
 | 
						session.Values["provider"] = data.Provider
 | 
				
			||||||
	session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix()
 | 
						session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix()
 | 
				
			||||||
	session.Values["totpPending"] = data.TotpPending
 | 
						session.Values["totpPending"] = data.TotpPending
 | 
				
			||||||
 | 
						session.Values["oauthGroups"] = data.OAuthGroups
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Save session
 | 
						// Save session
 | 
				
			||||||
	err = session.Save(c.Request, c.Writer)
 | 
						err = session.Save(c.Request, c.Writer)
 | 
				
			||||||
@@ -211,14 +214,24 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
 | 
				
			|||||||
		return types.SessionCookie{}, err
 | 
							return types.SessionCookie{}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Msg("Got session")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get data from session
 | 
						// Get data from session
 | 
				
			||||||
	username, usernameOk := session.Values["username"].(string)
 | 
						username, usernameOk := session.Values["username"].(string)
 | 
				
			||||||
 | 
						email, emailOk := session.Values["email"].(string)
 | 
				
			||||||
 | 
						name, nameOk := session.Values["name"].(string)
 | 
				
			||||||
	provider, providerOK := session.Values["provider"].(string)
 | 
						provider, providerOK := session.Values["provider"].(string)
 | 
				
			||||||
	expiry, expiryOk := session.Values["expiry"].(int64)
 | 
						expiry, expiryOk := session.Values["expiry"].(int64)
 | 
				
			||||||
	totpPending, totpPendingOk := session.Values["totpPending"].(bool)
 | 
						totpPending, totpPendingOk := session.Values["totpPending"].(bool)
 | 
				
			||||||
 | 
						oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk {
 | 
						if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
 | 
				
			||||||
		log.Warn().Msg("Session cookie is missing data")
 | 
							log.Warn().Msg("Session cookie is invalid")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// If any data is missing, delete the session cookie
 | 
				
			||||||
 | 
							auth.DeleteSessionCookie(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return empty cookie
 | 
				
			||||||
		return types.SessionCookie{}, nil
 | 
							return types.SessionCookie{}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -233,13 +246,16 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
 | 
				
			|||||||
		return types.SessionCookie{}, nil
 | 
							return types.SessionCookie{}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie")
 | 
						log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Return the cookie
 | 
						// Return the cookie
 | 
				
			||||||
	return types.SessionCookie{
 | 
						return types.SessionCookie{
 | 
				
			||||||
		Username:    username,
 | 
							Username:    username,
 | 
				
			||||||
 | 
							Name:        name,
 | 
				
			||||||
 | 
							Email:       email,
 | 
				
			||||||
		Provider:    provider,
 | 
							Provider:    provider,
 | 
				
			||||||
		TotpPending: totpPending,
 | 
							TotpPending: totpPending,
 | 
				
			||||||
 | 
							OAuthGroups: oauthGroups,
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -248,48 +264,52 @@ func (auth *Auth) UserAuthConfigured() bool {
 | 
				
			|||||||
	return len(auth.Config.Users) > 0
 | 
						return len(auth.Config.Users) > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bool, error) {
 | 
					func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool {
 | 
				
			||||||
	// Get headers
 | 
					 | 
				
			||||||
	host := c.Request.Header.Get("X-Forwarded-Host")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get app id
 | 
					 | 
				
			||||||
	appId := strings.Split(host, ".")[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get the container labels
 | 
					 | 
				
			||||||
	labels, err := auth.Docker.GetLabels(appId)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If there is an error, return false
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return false, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Check if oauth is allowed
 | 
						// Check if oauth is allowed
 | 
				
			||||||
	if context.OAuth {
 | 
						if context.OAuth {
 | 
				
			||||||
		log.Debug().Msg("Checking OAuth whitelist")
 | 
							log.Debug().Msg("Checking OAuth whitelist")
 | 
				
			||||||
		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Username), nil
 | 
							return utils.CheckWhitelist(labels.OAuthWhitelist, context.Email)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check users
 | 
						// Check users
 | 
				
			||||||
	log.Debug().Msg("Checking users")
 | 
						log.Debug().Msg("Checking users")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return utils.CheckWhitelist(labels.Users, context.Username), nil
 | 
						return utils.CheckWhitelist(labels.Users, context.Username)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) {
 | 
					func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool {
 | 
				
			||||||
 | 
						// Check if groups are required
 | 
				
			||||||
 | 
						if labels.OAuthGroups == "" {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if we are using the generic oauth provider
 | 
				
			||||||
 | 
						if context.Provider != "generic" {
 | 
				
			||||||
 | 
							log.Debug().Msg("Not using generic provider, skipping group check")
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Split the groups by comma (no need to parse since they are from the API response)
 | 
				
			||||||
 | 
						oauthGroups := strings.Split(context.OAuthGroups, ",")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// For every group check if it is in the required groups
 | 
				
			||||||
 | 
						for _, group := range oauthGroups {
 | 
				
			||||||
 | 
							if utils.CheckWhitelist(labels.OAuthGroups, group) {
 | 
				
			||||||
 | 
								log.Debug().Str("group", group).Msg("Group is in required groups")
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// No groups matched
 | 
				
			||||||
 | 
						log.Debug().Msg("No groups matched")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return false
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (auth *Auth) AuthEnabled(c *gin.Context, labels types.TinyauthLabels) (bool, error) {
 | 
				
			||||||
	// Get headers
 | 
						// Get headers
 | 
				
			||||||
	uri := c.Request.Header.Get("X-Forwarded-Uri")
 | 
						uri := c.Request.Header.Get("X-Forwarded-Uri")
 | 
				
			||||||
	host := c.Request.Header.Get("X-Forwarded-Host")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get app id
 | 
					 | 
				
			||||||
	appId := strings.Split(host, ".")[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get the container labels
 | 
					 | 
				
			||||||
	labels, err := auth.Docker.GetLabels(appId)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// If there is an error, auth enabled
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return true, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if the allowed label is empty
 | 
						// Check if the allowed label is empty
 | 
				
			||||||
	if labels.Allowed == "" {
 | 
						if labels.Allowed == "" {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,4 +6,13 @@ var TinyauthLabels = []string{
 | 
				
			|||||||
	"tinyauth.users",
 | 
						"tinyauth.users",
 | 
				
			||||||
	"tinyauth.allowed",
 | 
						"tinyauth.allowed",
 | 
				
			||||||
	"tinyauth.headers",
 | 
						"tinyauth.headers",
 | 
				
			||||||
 | 
						"tinyauth.oauth.groups",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Claims are the OIDC supported claims (including preferd username for some reason)
 | 
				
			||||||
 | 
					type Claims struct {
 | 
				
			||||||
 | 
						Name              string   `json:"name"`
 | 
				
			||||||
 | 
						Email             string   `json:"email"`
 | 
				
			||||||
 | 
						PreferredUsername string   `json:"preferred_username"`
 | 
				
			||||||
 | 
						Groups            []string `json:"groups"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,8 +6,7 @@ import (
 | 
				
			|||||||
	"tinyauth/internal/types"
 | 
						"tinyauth/internal/types"
 | 
				
			||||||
	"tinyauth/internal/utils"
 | 
						"tinyauth/internal/utils"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	apiTypes "github.com/docker/docker/api/types"
 | 
						container "github.com/docker/docker/api/types/container"
 | 
				
			||||||
	containerTypes "github.com/docker/docker/api/types/container"
 | 
					 | 
				
			||||||
	"github.com/docker/docker/client"
 | 
						"github.com/docker/docker/client"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -38,9 +37,9 @@ func (docker *Docker) Init() error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (docker *Docker) GetContainers() ([]apiTypes.Container, error) {
 | 
					func (docker *Docker) GetContainers() ([]container.Summary, error) {
 | 
				
			||||||
	// Get the list of containers
 | 
						// Get the list of containers
 | 
				
			||||||
	containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{})
 | 
						containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) {
 | 
				
			|||||||
	return containers, nil
 | 
						return containers, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) {
 | 
					func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) {
 | 
				
			||||||
	// Inspect the container
 | 
						// Inspect the container
 | 
				
			||||||
	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
 | 
						inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return apiTypes.ContainerJSON{}, err
 | 
							return container.InspectResponse{}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Return the inspect
 | 
						// Return the inspect
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,6 +10,7 @@ import (
 | 
				
			|||||||
	"tinyauth/internal/hooks"
 | 
						"tinyauth/internal/hooks"
 | 
				
			||||||
	"tinyauth/internal/providers"
 | 
						"tinyauth/internal/providers"
 | 
				
			||||||
	"tinyauth/internal/types"
 | 
						"tinyauth/internal/types"
 | 
				
			||||||
 | 
						"tinyauth/internal/utils"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/google/go-querystring/query"
 | 
						"github.com/google/go-querystring/query"
 | 
				
			||||||
@@ -68,12 +69,15 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
 | 
				
			|||||||
	proto := c.Request.Header.Get("X-Forwarded-Proto")
 | 
						proto := c.Request.Header.Get("X-Forwarded-Proto")
 | 
				
			||||||
	host := c.Request.Header.Get("X-Forwarded-Host")
 | 
						host := c.Request.Header.Get("X-Forwarded-Host")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if auth is enabled
 | 
						// Get the app id
 | 
				
			||||||
	authEnabled, err := h.Auth.AuthEnabled(c)
 | 
						appId := strings.Split(host, ".")[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the container labels
 | 
				
			||||||
 | 
						labels, err := h.Docker.GetLabels(appId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().Err(err).Msg("Failed to check if app is allowed")
 | 
							log.Error().Err(err).Msg("Failed to get container labels")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if proxy.Proxy == "nginx" || !isBrowser {
 | 
							if proxy.Proxy == "nginx" || !isBrowser {
 | 
				
			||||||
			c.JSON(500, gin.H{
 | 
								c.JSON(500, gin.H{
 | 
				
			||||||
@@ -87,11 +91,8 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get the app id
 | 
						// Check if auth is enabled
 | 
				
			||||||
	appId := strings.Split(host, ".")[0]
 | 
						authEnabled, err := h.Auth.AuthEnabled(c, labels)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get the container labels
 | 
					 | 
				
			||||||
	labels, err := h.Docker.GetLabels(appId)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@@ -113,7 +114,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
 | 
				
			|||||||
	if !authEnabled {
 | 
						if !authEnabled {
 | 
				
			||||||
		for key, value := range labels.Headers {
 | 
							for key, value := range labels.Headers {
 | 
				
			||||||
			log.Debug().Str("key", key).Str("value", value).Msg("Setting header")
 | 
								log.Debug().Str("key", key).Str("value", value).Msg("Setting header")
 | 
				
			||||||
			c.Header(key, value)
 | 
								c.Header(key, utils.SanitizeHeader(value))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"status":  200,
 | 
								"status":  200,
 | 
				
			||||||
@@ -130,23 +131,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
 | 
				
			|||||||
		log.Debug().Msg("Authenticated")
 | 
							log.Debug().Msg("Authenticated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx
 | 
							// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx
 | 
				
			||||||
		appAllowed, err := h.Auth.ResourceAllowed(c, userContext)
 | 
							appAllowed := h.Auth.ResourceAllowed(c, userContext, labels)
 | 
				
			||||||
 | 
					 | 
				
			||||||
		// Check if there was an error
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			log.Error().Err(err).Msg("Failed to check if app is allowed")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if proxy.Proxy == "nginx" || !isBrowser {
 | 
					 | 
				
			||||||
				c.JSON(500, gin.H{
 | 
					 | 
				
			||||||
					"status":  500,
 | 
					 | 
				
			||||||
					"message": "Internal Server Error",
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
 | 
							log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -165,11 +150,20 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Build query
 | 
								// Values
 | 
				
			||||||
			queries, err := query.Values(types.UnauthorizedQuery{
 | 
								values := types.UnauthorizedQuery{
 | 
				
			||||||
				Username: userContext.Username,
 | 
					 | 
				
			||||||
				Resource: strings.Split(host, ".")[0],
 | 
									Resource: strings.Split(host, ".")[0],
 | 
				
			||||||
			})
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Use either username or email
 | 
				
			||||||
 | 
								if userContext.OAuth {
 | 
				
			||||||
 | 
									values.Username = userContext.Email
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									values.Username = userContext.Username
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Build query
 | 
				
			||||||
 | 
								queries, err := query.Values(values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
 | 
								// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
@@ -183,13 +177,65 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Set the user header
 | 
							log.Debug().Interface("labels", labels).Msg("Got labels")
 | 
				
			||||||
		c.Header("Remote-User", userContext.Username)
 | 
					
 | 
				
			||||||
 | 
							// Check if user is in required groups
 | 
				
			||||||
 | 
							groupOk := h.Auth.OAuthGroup(c, userContext, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// The user is not allowed to access the app
 | 
				
			||||||
 | 
							if !groupOk {
 | 
				
			||||||
 | 
								log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Set WWW-Authenticate header
 | 
				
			||||||
 | 
								c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if proxy.Proxy == "nginx" || !isBrowser {
 | 
				
			||||||
 | 
									c.JSON(401, gin.H{
 | 
				
			||||||
 | 
										"status":  401,
 | 
				
			||||||
 | 
										"message": "Unauthorized",
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Values
 | 
				
			||||||
 | 
								values := types.UnauthorizedQuery{
 | 
				
			||||||
 | 
									Resource: strings.Split(host, ".")[0],
 | 
				
			||||||
 | 
									GroupErr: true,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Use either username or email
 | 
				
			||||||
 | 
								if userContext.OAuth {
 | 
				
			||||||
 | 
									values.Username = userContext.Email
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									values.Username = userContext.Username
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Build query
 | 
				
			||||||
 | 
								queries, err := query.Values(values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									log.Error().Err(err).Msg("Failed to build queries")
 | 
				
			||||||
 | 
									c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// We are using caddy/traefik so redirect
 | 
				
			||||||
 | 
								c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode()))
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
 | 
				
			||||||
 | 
							c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
 | 
				
			||||||
 | 
							c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
 | 
				
			||||||
 | 
							c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Set the rest of the headers
 | 
							// Set the rest of the headers
 | 
				
			||||||
		for key, value := range labels.Headers {
 | 
							for key, value := range labels.Headers {
 | 
				
			||||||
			log.Debug().Str("key", key).Str("value", value).Msg("Setting header")
 | 
								log.Debug().Str("key", key).Str("value", value).Msg("Setting header")
 | 
				
			||||||
			c.Header(key, value)
 | 
								c.Header(key, utils.SanitizeHeader(value))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// The user is allowed to access the app
 | 
							// The user is allowed to access the app
 | 
				
			||||||
@@ -310,6 +356,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
 | 
				
			|||||||
		// Set totp pending cookie
 | 
							// Set totp pending cookie
 | 
				
			||||||
		h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
							h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
				
			||||||
			Username:    login.Username,
 | 
								Username:    login.Username,
 | 
				
			||||||
 | 
								Name:        utils.Capitalize(login.Username),
 | 
				
			||||||
 | 
								Email:       fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain),
 | 
				
			||||||
			Provider:    "username",
 | 
								Provider:    "username",
 | 
				
			||||||
			TotpPending: true,
 | 
								TotpPending: true,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
@@ -328,6 +376,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
 | 
				
			|||||||
	// Create session cookie with username as provider
 | 
						// Create session cookie with username as provider
 | 
				
			||||||
	h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
						h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
				
			||||||
		Username: login.Username,
 | 
							Username: login.Username,
 | 
				
			||||||
 | 
							Name:     utils.Capitalize(login.Username),
 | 
				
			||||||
 | 
							Email:    fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain),
 | 
				
			||||||
		Provider: "username",
 | 
							Provider: "username",
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -402,6 +452,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) {
 | 
				
			|||||||
	// Create session cookie with username as provider
 | 
						// Create session cookie with username as provider
 | 
				
			||||||
	h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
						h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
				
			||||||
		Username: user.Username,
 | 
							Username: user.Username,
 | 
				
			||||||
 | 
							Name:     utils.Capitalize(user.Username),
 | 
				
			||||||
 | 
							Email:    fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain),
 | 
				
			||||||
		Provider: "username",
 | 
							Provider: "username",
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -465,6 +517,8 @@ func (h *Handlers) UserHandler(c *gin.Context) {
 | 
				
			|||||||
		Status:      200,
 | 
							Status:      200,
 | 
				
			||||||
		IsLoggedIn:  userContext.IsLoggedIn,
 | 
							IsLoggedIn:  userContext.IsLoggedIn,
 | 
				
			||||||
		Username:    userContext.Username,
 | 
							Username:    userContext.Username,
 | 
				
			||||||
 | 
							Name:        userContext.Name,
 | 
				
			||||||
 | 
							Email:       userContext.Email,
 | 
				
			||||||
		Provider:    userContext.Provider,
 | 
							Provider:    userContext.Provider,
 | 
				
			||||||
		Oauth:       userContext.OAuth,
 | 
							Oauth:       userContext.OAuth,
 | 
				
			||||||
		TotpPending: userContext.TotpPending,
 | 
							TotpPending: userContext.TotpPending,
 | 
				
			||||||
@@ -613,25 +667,32 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get email
 | 
						// Get user
 | 
				
			||||||
	email, err := h.Providers.GetUser(providerName.Provider)
 | 
						user, err := h.Providers.GetUser(providerName.Provider)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Debug().Str("email", email).Msg("Got email")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Handle error
 | 
						// Handle error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().Err(err).Msg("Failed to get email")
 | 
							log.Error().Msg("Failed to get user")
 | 
				
			||||||
 | 
							c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Msg("Got user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that email is not empty
 | 
				
			||||||
 | 
						if user.Email == "" {
 | 
				
			||||||
 | 
							log.Error().Msg("Email is empty")
 | 
				
			||||||
		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
 | 
							c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Email is not whitelisted
 | 
						// Email is not whitelisted
 | 
				
			||||||
	if !h.Auth.EmailWhitelisted(email) {
 | 
						if !h.Auth.EmailWhitelisted(user.Email) {
 | 
				
			||||||
		log.Warn().Str("email", email).Msg("Email not whitelisted")
 | 
							log.Warn().Str("email", user.Email).Msg("Email not whitelisted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Build query
 | 
							// Build query
 | 
				
			||||||
		queries, err := query.Values(types.UnauthorizedQuery{
 | 
							queries, err := query.Values(types.UnauthorizedQuery{
 | 
				
			||||||
			Username: email,
 | 
								Username: user.Email,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Handle error
 | 
							// Handle error
 | 
				
			||||||
@@ -647,10 +708,31 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Email whitelisted")
 | 
						log.Debug().Msg("Email whitelisted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get username
 | 
				
			||||||
 | 
						var username string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if user.PreferredUsername != "" {
 | 
				
			||||||
 | 
							username = user.PreferredUsername
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get name
 | 
				
			||||||
 | 
						var name string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if user.Name != "" {
 | 
				
			||||||
 | 
							name = user.Name
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Create session cookie (also cleans up redirect cookie)
 | 
						// Create session cookie (also cleans up redirect cookie)
 | 
				
			||||||
	h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
						h.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
				
			||||||
		Username: email,
 | 
							Username:    username,
 | 
				
			||||||
 | 
							Name:        name,
 | 
				
			||||||
 | 
							Email:       user.Email,
 | 
				
			||||||
		Provider:    providerName.Provider,
 | 
							Provider:    providerName.Provider,
 | 
				
			||||||
 | 
							OAuthGroups: strings.Join(user.Groups, ","),
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if we have a redirect URI
 | 
						// Check if we have a redirect URI
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,22 +1,27 @@
 | 
				
			|||||||
package hooks
 | 
					package hooks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"tinyauth/internal/auth"
 | 
						"tinyauth/internal/auth"
 | 
				
			||||||
	"tinyauth/internal/providers"
 | 
						"tinyauth/internal/providers"
 | 
				
			||||||
	"tinyauth/internal/types"
 | 
						"tinyauth/internal/types"
 | 
				
			||||||
 | 
						"tinyauth/internal/utils"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
 | 
					func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks {
 | 
				
			||||||
	return &Hooks{
 | 
						return &Hooks{
 | 
				
			||||||
 | 
							Config:    config,
 | 
				
			||||||
		Auth:      auth,
 | 
							Auth:      auth,
 | 
				
			||||||
		Providers: providers,
 | 
							Providers: providers,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Hooks struct {
 | 
					type Hooks struct {
 | 
				
			||||||
 | 
						Config    types.HooksConfig
 | 
				
			||||||
	Auth      *auth.Auth
 | 
						Auth      *auth.Auth
 | 
				
			||||||
	Providers *providers.Providers
 | 
						Providers *providers.Providers
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -37,10 +42,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
			// Return user context since we are logged in with basic auth
 | 
								// Return user context since we are logged in with basic auth
 | 
				
			||||||
			return types.UserContext{
 | 
								return types.UserContext{
 | 
				
			||||||
				Username:   basic.Username,
 | 
									Username:   basic.Username,
 | 
				
			||||||
 | 
									Name:       utils.Capitalize(basic.Username),
 | 
				
			||||||
 | 
									Email:      fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain),
 | 
				
			||||||
				IsLoggedIn: true,
 | 
									IsLoggedIn: true,
 | 
				
			||||||
				OAuth:       false,
 | 
					 | 
				
			||||||
				Provider:   "basic",
 | 
									Provider:   "basic",
 | 
				
			||||||
				TotpPending: false,
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -50,13 +55,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().Err(err).Msg("Failed to get session cookie")
 | 
							log.Error().Err(err).Msg("Failed to get session cookie")
 | 
				
			||||||
		// Return empty context
 | 
							// Return empty context
 | 
				
			||||||
		return types.UserContext{
 | 
							return types.UserContext{}
 | 
				
			||||||
			Username:    "",
 | 
					 | 
				
			||||||
			IsLoggedIn:  false,
 | 
					 | 
				
			||||||
			OAuth:       false,
 | 
					 | 
				
			||||||
			Provider:    "",
 | 
					 | 
				
			||||||
			TotpPending: false,
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if session cookie has totp pending
 | 
						// Check if session cookie has totp pending
 | 
				
			||||||
@@ -65,8 +64,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
		// Return empty context since we are pending totp
 | 
							// Return empty context since we are pending totp
 | 
				
			||||||
		return types.UserContext{
 | 
							return types.UserContext{
 | 
				
			||||||
			Username:    cookie.Username,
 | 
								Username:    cookie.Username,
 | 
				
			||||||
			IsLoggedIn:  false,
 | 
								Name:        cookie.Name,
 | 
				
			||||||
			OAuth:       false,
 | 
								Email:       cookie.Email,
 | 
				
			||||||
			Provider:    cookie.Provider,
 | 
								Provider:    cookie.Provider,
 | 
				
			||||||
			TotpPending: true,
 | 
								TotpPending: true,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -83,10 +82,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
			// It exists so we are logged in
 | 
								// It exists so we are logged in
 | 
				
			||||||
			return types.UserContext{
 | 
								return types.UserContext{
 | 
				
			||||||
				Username:   cookie.Username,
 | 
									Username:   cookie.Username,
 | 
				
			||||||
 | 
									Name:       cookie.Name,
 | 
				
			||||||
 | 
									Email:      cookie.Email,
 | 
				
			||||||
				IsLoggedIn: true,
 | 
									IsLoggedIn: true,
 | 
				
			||||||
				OAuth:       false,
 | 
					 | 
				
			||||||
				Provider:   "username",
 | 
									Provider:   "username",
 | 
				
			||||||
				TotpPending: false,
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -108,13 +107,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
			hooks.Auth.DeleteSessionCookie(c)
 | 
								hooks.Auth.DeleteSessionCookie(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// Return empty context
 | 
								// Return empty context
 | 
				
			||||||
			return types.UserContext{
 | 
								return types.UserContext{}
 | 
				
			||||||
				Username:    "",
 | 
					 | 
				
			||||||
				IsLoggedIn:  false,
 | 
					 | 
				
			||||||
				OAuth:       false,
 | 
					 | 
				
			||||||
				Provider:    "",
 | 
					 | 
				
			||||||
				TotpPending: false,
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Email is whitelisted")
 | 
							log.Debug().Msg("Email is whitelisted")
 | 
				
			||||||
@@ -122,19 +115,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
		// Return user context since we are logged in with oauth
 | 
							// Return user context since we are logged in with oauth
 | 
				
			||||||
		return types.UserContext{
 | 
							return types.UserContext{
 | 
				
			||||||
			Username:    cookie.Username,
 | 
								Username:    cookie.Username,
 | 
				
			||||||
 | 
								Name:        cookie.Name,
 | 
				
			||||||
 | 
								Email:       cookie.Email,
 | 
				
			||||||
			IsLoggedIn:  true,
 | 
								IsLoggedIn:  true,
 | 
				
			||||||
			OAuth:       true,
 | 
								OAuth:       true,
 | 
				
			||||||
			Provider:    cookie.Provider,
 | 
								Provider:    cookie.Provider,
 | 
				
			||||||
			TotpPending: false,
 | 
								OAuthGroups: cookie.OAuthGroups,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Neither basic auth or oauth is set so we return an empty context
 | 
						// Neither basic auth or oauth is set so we return an empty context
 | 
				
			||||||
	return types.UserContext{
 | 
						return types.UserContext{}
 | 
				
			||||||
		Username:    "",
 | 
					 | 
				
			||||||
		IsLoggedIn:  false,
 | 
					 | 
				
			||||||
		OAuth:       false,
 | 
					 | 
				
			||||||
		Provider:    "",
 | 
					 | 
				
			||||||
		TotpPending: false,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,24 +4,25 @@ import (
 | 
				
			|||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"tinyauth/internal/constants"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// We are assuming that the generic provider will return a JSON object with an email field
 | 
					func GetGenericUser(client *http.Client, url string) (constants.Claims, error) {
 | 
				
			||||||
type GenericUserInfoResponse struct {
 | 
						// Create user struct
 | 
				
			||||||
	Email string `json:"email"`
 | 
						var user constants.Claims
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetGenericEmail(client *http.Client, url string) (string, error) {
 | 
					 | 
				
			||||||
	// Using the oauth client get the user info url
 | 
						// Using the oauth client get the user info url
 | 
				
			||||||
	res, err := client.Get(url)
 | 
						res, err := client.Get(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from generic provider")
 | 
						log.Debug().Msg("Got response from generic provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Read the body of the response
 | 
						// Read the body of the response
 | 
				
			||||||
@@ -29,24 +30,21 @@ func GetGenericEmail(client *http.Client, url string) (string, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from generic provider")
 | 
						log.Debug().Msg("Read body from generic provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Parse the body into a user struct
 | 
					 | 
				
			||||||
	var user GenericUserInfoResponse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Unmarshal the body into the user struct
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	err = json.Unmarshal(body, &user)
 | 
						err = json.Unmarshal(body, &user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed user from generic provider")
 | 
						log.Debug().Msg("Parsed user from generic provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Return the email
 | 
						// Return the user
 | 
				
			||||||
	return user.Email, nil
 | 
						return user, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,51 +5,96 @@ import (
 | 
				
			|||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"tinyauth/internal/constants"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Github has a different response than the generic provider
 | 
					// Response for the github email endpoint
 | 
				
			||||||
type GithubUserInfoResponse []struct {
 | 
					type GithubEmailResponse []struct {
 | 
				
			||||||
	Email   string `json:"email"`
 | 
						Email   string `json:"email"`
 | 
				
			||||||
	Primary bool   `json:"primary"`
 | 
						Primary bool   `json:"primary"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// The scopes required for the github provider
 | 
					// Response for the github user endpoint
 | 
				
			||||||
func GithubScopes() []string {
 | 
					type GithubUserInfoResponse struct {
 | 
				
			||||||
	return []string{"user:email"}
 | 
						Login string `json:"login"`
 | 
				
			||||||
 | 
						Name  string `json:"name"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetGithubEmail(client *http.Client) (string, error) {
 | 
					// The scopes required for the github provider
 | 
				
			||||||
	// Get the user emails from github using the oauth http client
 | 
					func GithubScopes() []string {
 | 
				
			||||||
	res, err := client.Get("https://api.github.com/user/emails")
 | 
						return []string{"user:email", "read:user"}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func GetGithubUser(client *http.Client) (constants.Claims, error) {
 | 
				
			||||||
 | 
						// Create user struct
 | 
				
			||||||
 | 
						var user constants.Claims
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the user info from github using the oauth http client
 | 
				
			||||||
 | 
						res, err := client.Get("https://api.github.com/user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from github")
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Msg("Got user response from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Read the body of the response
 | 
						// Read the body of the response
 | 
				
			||||||
	body, err := io.ReadAll(res.Body)
 | 
						body, err := io.ReadAll(res.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from github")
 | 
						log.Debug().Msg("Read user body from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Parse the body into a user struct
 | 
						// Parse the body into a user struct
 | 
				
			||||||
	var emails GithubUserInfoResponse
 | 
						var userInfo GithubUserInfoResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
 | 
						err = json.Unmarshal(body, &userInfo)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return user, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the user emails from github using the oauth http client
 | 
				
			||||||
 | 
						res, err = client.Get("https://api.github.com/user/emails")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return user, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Msg("Got email response from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the body of the response
 | 
				
			||||||
 | 
						body, err = io.ReadAll(res.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return user, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Msg("Read email body from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse the body into a user struct
 | 
				
			||||||
 | 
						var emails GithubEmailResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Unmarshal the body into the user struct
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	err = json.Unmarshal(body, &emails)
 | 
						err = json.Unmarshal(body, &emails)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed emails from github")
 | 
						log.Debug().Msg("Parsed emails from github")
 | 
				
			||||||
@@ -57,10 +102,26 @@ func GetGithubEmail(client *http.Client) (string, error) {
 | 
				
			|||||||
	// Find and return the primary email
 | 
						// Find and return the primary email
 | 
				
			||||||
	for _, email := range emails {
 | 
						for _, email := range emails {
 | 
				
			||||||
		if email.Primary {
 | 
							if email.Primary {
 | 
				
			||||||
			return email.Email, nil
 | 
								// Set the email then exit
 | 
				
			||||||
 | 
								user.Email = email.Email
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// User does not have a primary email?
 | 
						// If no primary email was found, use the first available email
 | 
				
			||||||
	return "", errors.New("no primary email found")
 | 
						if len(emails) == 0 {
 | 
				
			||||||
 | 
							return user, errors.New("no emails found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Set the email if it is not set picking the first one
 | 
				
			||||||
 | 
						if user.Email == "" {
 | 
				
			||||||
 | 
							user.Email = emails[0].Email
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Set the username and name
 | 
				
			||||||
 | 
						user.PreferredUsername = userInfo.Login
 | 
				
			||||||
 | 
						user.Name = userInfo.Name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return
 | 
				
			||||||
 | 
						return user, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,29 +4,37 @@ import (
 | 
				
			|||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"tinyauth/internal/constants"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Google works the same as the generic provider
 | 
					// Response for the google user endpoint
 | 
				
			||||||
type GoogleUserInfoResponse struct {
 | 
					type GoogleUserInfoResponse struct {
 | 
				
			||||||
	Email string `json:"email"`
 | 
						Email string `json:"email"`
 | 
				
			||||||
 | 
						Name  string `json:"name"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// The scopes required for the google provider
 | 
					// The scopes required for the google provider
 | 
				
			||||||
func GoogleScopes() []string {
 | 
					func GoogleScopes() []string {
 | 
				
			||||||
	return []string{"https://www.googleapis.com/auth/userinfo.email"}
 | 
						return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetGoogleEmail(client *http.Client) (string, error) {
 | 
					func GetGoogleUser(client *http.Client) (constants.Claims, error) {
 | 
				
			||||||
 | 
						// Create user struct
 | 
				
			||||||
 | 
						var user constants.Claims
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get the user info from google using the oauth http client
 | 
						// Get the user info from google using the oauth http client
 | 
				
			||||||
	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
 | 
						res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer res.Body.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from google")
 | 
						log.Debug().Msg("Got response from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Read the body of the response
 | 
						// Read the body of the response
 | 
				
			||||||
@@ -34,24 +42,29 @@ func GetGoogleEmail(client *http.Client) (string, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from google")
 | 
						log.Debug().Msg("Read body from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Parse the body into a user struct
 | 
						// Create a new user info struct
 | 
				
			||||||
	var user GoogleUserInfoResponse
 | 
						var userInfo GoogleUserInfoResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Unmarshal the body into the user struct
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	err = json.Unmarshal(body, &user)
 | 
						err = json.Unmarshal(body, &userInfo)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if there was an error
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return user, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed user from google")
 | 
						log.Debug().Msg("Parsed user from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Return the email
 | 
						// Map the user info to the user struct
 | 
				
			||||||
	return user.Email, nil
 | 
						user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
 | 
				
			||||||
 | 
						user.Name = userInfo.Name
 | 
				
			||||||
 | 
						user.Email = userInfo.Email
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the user
 | 
				
			||||||
 | 
						return user, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package providers
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"tinyauth/internal/constants"
 | 
				
			||||||
	"tinyauth/internal/oauth"
 | 
						"tinyauth/internal/oauth"
 | 
				
			||||||
	"tinyauth/internal/types"
 | 
						"tinyauth/internal/types"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -93,14 +94,17 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (providers *Providers) GetUser(provider string) (string, error) {
 | 
					func (providers *Providers) GetUser(provider string) (constants.Claims, error) {
 | 
				
			||||||
	// Get the email from the provider
 | 
						// Create user struct
 | 
				
			||||||
 | 
						var user constants.Claims
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the user from the provider
 | 
				
			||||||
	switch provider {
 | 
						switch provider {
 | 
				
			||||||
	case "github":
 | 
						case "github":
 | 
				
			||||||
		// If the github provider is not configured, return an error
 | 
							// If the github provider is not configured, return an error
 | 
				
			||||||
		if providers.Github == nil {
 | 
							if providers.Github == nil {
 | 
				
			||||||
			log.Debug().Msg("Github provider not configured")
 | 
								log.Debug().Msg("Github provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return user, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the client from the github provider
 | 
							// Get the client from the github provider
 | 
				
			||||||
@@ -108,23 +112,23 @@ func (providers *Providers) GetUser(provider string) (string, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from github")
 | 
							log.Debug().Msg("Got client from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the email from the github provider
 | 
							// Get the user from the github provider
 | 
				
			||||||
		email, err := GetGithubEmail(client)
 | 
							user, err := GetGithubUser(client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Check if there was an error
 | 
							// Check if there was an error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err
 | 
								return user, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from github")
 | 
							log.Debug().Msg("Got user from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Return the email
 | 
							// Return the user
 | 
				
			||||||
		return email, nil
 | 
							return user, nil
 | 
				
			||||||
	case "google":
 | 
						case "google":
 | 
				
			||||||
		// If the google provider is not configured, return an error
 | 
							// If the google provider is not configured, return an error
 | 
				
			||||||
		if providers.Google == nil {
 | 
							if providers.Google == nil {
 | 
				
			||||||
			log.Debug().Msg("Google provider not configured")
 | 
								log.Debug().Msg("Google provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return user, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the client from the google provider
 | 
							// Get the client from the google provider
 | 
				
			||||||
@@ -132,23 +136,23 @@ func (providers *Providers) GetUser(provider string) (string, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from google")
 | 
							log.Debug().Msg("Got client from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the email from the google provider
 | 
							// Get the user from the google provider
 | 
				
			||||||
		email, err := GetGoogleEmail(client)
 | 
							user, err := GetGoogleUser(client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Check if there was an error
 | 
							// Check if there was an error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err
 | 
								return user, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from google")
 | 
							log.Debug().Msg("Got user from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Return the email
 | 
							// Return the user
 | 
				
			||||||
		return email, nil
 | 
							return user, nil
 | 
				
			||||||
	case "generic":
 | 
						case "generic":
 | 
				
			||||||
		// If the generic provider is not configured, return an error
 | 
							// If the generic provider is not configured, return an error
 | 
				
			||||||
		if providers.Generic == nil {
 | 
							if providers.Generic == nil {
 | 
				
			||||||
			log.Debug().Msg("Generic provider not configured")
 | 
								log.Debug().Msg("Generic provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return user, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the client from the generic provider
 | 
							// Get the client from the generic provider
 | 
				
			||||||
@@ -156,20 +160,20 @@ func (providers *Providers) GetUser(provider string) (string, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from generic")
 | 
							log.Debug().Msg("Got client from generic")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the email from the generic provider
 | 
							// Get the user from the generic provider
 | 
				
			||||||
		email, err := GetGenericEmail(client, providers.Config.GenericUserURL)
 | 
							user, err := GetGenericUser(client, providers.Config.GenericUserURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Check if there was an error
 | 
							// Check if there was an error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err
 | 
								return user, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from generic")
 | 
							log.Debug().Msg("Got user from generic")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Return the email
 | 
							// Return the email
 | 
				
			||||||
		return email, nil
 | 
							return user, nil
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return "", nil
 | 
							return user, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,6 +20,7 @@ type OAuthRequest struct {
 | 
				
			|||||||
type UnauthorizedQuery struct {
 | 
					type UnauthorizedQuery struct {
 | 
				
			||||||
	Username string `url:"username"`
 | 
						Username string `url:"username"`
 | 
				
			||||||
	Resource string `url:"resource"`
 | 
						Resource string `url:"resource"`
 | 
				
			||||||
 | 
						GroupErr bool   `url:"groupErr"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Proxy is the uri parameters for the proxy endpoint
 | 
					// Proxy is the uri parameters for the proxy endpoint
 | 
				
			||||||
@@ -33,6 +34,8 @@ type UserContextResponse struct {
 | 
				
			|||||||
	Message     string `json:"message"`
 | 
						Message     string `json:"message"`
 | 
				
			||||||
	IsLoggedIn  bool   `json:"isLoggedIn"`
 | 
						IsLoggedIn  bool   `json:"isLoggedIn"`
 | 
				
			||||||
	Username    string `json:"username"`
 | 
						Username    string `json:"username"`
 | 
				
			||||||
 | 
						Name        string `json:"name"`
 | 
				
			||||||
 | 
						Email       string `json:"email"`
 | 
				
			||||||
	Provider    string `json:"provider"`
 | 
						Provider    string `json:"provider"`
 | 
				
			||||||
	Oauth       bool   `json:"oauth"`
 | 
						Oauth       bool   `json:"oauth"`
 | 
				
			||||||
	TotpPending bool   `json:"totpPending"`
 | 
						TotpPending bool   `json:"totpPending"`
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -78,3 +78,8 @@ type AuthConfig struct {
 | 
				
			|||||||
	LoginTimeout    int
 | 
						LoginTimeout    int
 | 
				
			||||||
	LoginMaxRetries int
 | 
						LoginMaxRetries int
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HooksConfig is the configuration for the hooks service
 | 
				
			||||||
 | 
					type HooksConfig struct {
 | 
				
			||||||
 | 
						Domain string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,8 +25,11 @@ type OAuthProviders struct {
 | 
				
			|||||||
// SessionCookie is the cookie for the session (exculding the expiry)
 | 
					// SessionCookie is the cookie for the session (exculding the expiry)
 | 
				
			||||||
type SessionCookie struct {
 | 
					type SessionCookie struct {
 | 
				
			||||||
	Username    string
 | 
						Username    string
 | 
				
			||||||
 | 
						Name        string
 | 
				
			||||||
 | 
						Email       string
 | 
				
			||||||
	Provider    string
 | 
						Provider    string
 | 
				
			||||||
	TotpPending bool
 | 
						TotpPending bool
 | 
				
			||||||
 | 
						OAuthGroups string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TinyauthLabels is the labels for the tinyauth container
 | 
					// TinyauthLabels is the labels for the tinyauth container
 | 
				
			||||||
@@ -35,15 +38,19 @@ type TinyauthLabels struct {
 | 
				
			|||||||
	Users          string
 | 
						Users          string
 | 
				
			||||||
	Allowed        string
 | 
						Allowed        string
 | 
				
			||||||
	Headers        map[string]string
 | 
						Headers        map[string]string
 | 
				
			||||||
 | 
						OAuthGroups    string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UserContext is the context for the user
 | 
					// UserContext is the context for the user
 | 
				
			||||||
type UserContext struct {
 | 
					type UserContext struct {
 | 
				
			||||||
	Username    string
 | 
						Username    string
 | 
				
			||||||
 | 
						Name        string
 | 
				
			||||||
 | 
						Email       string
 | 
				
			||||||
	IsLoggedIn  bool
 | 
						IsLoggedIn  bool
 | 
				
			||||||
	OAuth       bool
 | 
						OAuth       bool
 | 
				
			||||||
	Provider    string
 | 
						Provider    string
 | 
				
			||||||
	TotpPending bool
 | 
						TotpPending bool
 | 
				
			||||||
 | 
						OAuthGroups string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// LoginAttempt tracks information about login attempts for rate limiting
 | 
					// LoginAttempt tracks information about login attempts for rate limiting
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -204,6 +204,8 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
					tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1]
 | 
										tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1]
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
								case "tinyauth.oauth.groups":
 | 
				
			||||||
 | 
									tinyauthLabels.OAuthGroups = value
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -323,3 +325,22 @@ func CheckWhitelist(whitelist string, str string) bool {
 | 
				
			|||||||
	// Return false if no match was found
 | 
						// Return false if no match was found
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Capitalize just the first letter of a string
 | 
				
			||||||
 | 
					func Capitalize(str string) string {
 | 
				
			||||||
 | 
						if len(str) == 0 {
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Sanitize header removes all control characters from a string
 | 
				
			||||||
 | 
					func SanitizeHeader(header string) string {
 | 
				
			||||||
 | 
						return strings.Map(func(r rune) rune {
 | 
				
			||||||
 | 
							// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab)
 | 
				
			||||||
 | 
							if r == ' ' || r == '\t' || (r >= 32 && r <= 126) {
 | 
				
			||||||
 | 
								return r
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return -1
 | 
				
			||||||
 | 
						}, header)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -467,3 +467,65 @@ func TestCheckWhitelist(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("Expected %v, got %v", expected, result)
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test capitalize
 | 
				
			||||||
 | 
					func TestCapitalize(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log("Testing capitalize with a valid string")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create variables
 | 
				
			||||||
 | 
						str := "test"
 | 
				
			||||||
 | 
						expected := "Test"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Test the capitalize function
 | 
				
			||||||
 | 
						result := utils.Capitalize(str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if the result is equal to the expected
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("Testing capitalize with an empty string")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create variables
 | 
				
			||||||
 | 
						str = ""
 | 
				
			||||||
 | 
						expected = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Test the capitalize function
 | 
				
			||||||
 | 
						result = utils.Capitalize(str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if the result is equal to the expected
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test the header sanitizer
 | 
				
			||||||
 | 
					func TestSanitizeHeader(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log("Testing sanitize header with a valid string")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create variables
 | 
				
			||||||
 | 
						str := "X-Header=value"
 | 
				
			||||||
 | 
						expected := "X-Header=value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Test the sanitize header function
 | 
				
			||||||
 | 
						result := utils.SanitizeHeader(str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if the result is equal to the expected
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("Testing sanitize header with an invalid string")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create variables
 | 
				
			||||||
 | 
						str = "X-Header=val\nue"
 | 
				
			||||||
 | 
						expected = "X-Header=value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Test the sanitize header function
 | 
				
			||||||
 | 
						result = utils.SanitizeHeader(str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if the result is equal to the expected
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user