diff --git a/cmd/totp/generate/generate.go b/cmd/totp/generate/generate.go index cc82da8..3df3ea6 100644 --- a/cmd/totp/generate/generate.go +++ b/cmd/totp/generate/generate.go @@ -63,7 +63,7 @@ var GenerateCmd = &cobra.Command{ // Check if user was using docker escape dockerEscape := false - if strings.Contains(user.Username, "$$") { + if strings.Contains(iUser, "$$") { dockerEscape = true } diff --git a/internal/api/api.go b/internal/api/api.go index f17e148..64bcd69 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -19,6 +19,7 @@ import ( "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" + "github.com/pquerna/otp/totp" "github.com/rs/zerolog/log" ) @@ -321,7 +322,29 @@ func (api *API) SetupRoutes() { return } - log.Debug().Msg("Password correct, logging in") + log.Debug().Msg("Password correct, checking totp") + + // Check if user has totp enabled + if user.TotpSecret != "" { + log.Debug().Msg("Totp enabled") + + // Set totp pending cookie + api.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: login.Username, + Provider: "username", + TotpPending: true, + }) + + // Return totp required + c.JSON(200, gin.H{ + "status": 200, + "message": "Waiting for totp", + "totpPending": true, + }) + + // Stop further processing + return + } // Create session cookie with username as provider api.Auth.CreateSessionCookie(c, &types.SessionCookie{ @@ -329,6 +352,80 @@ func (api *API) SetupRoutes() { Provider: "username", }) + // Return logged in + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + "totpPending": false, + }) + }) + + api.Router.POST("/api/totp", func(c *gin.Context) { + // Create totp struct + var totpReq types.Totp + + // Bind JSON + err := c.BindJSON(&totpReq) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Checking totp") + + // Get user context + userContext := api.Hooks.UseUserContext(c) + + // Check if we have a user + if userContext.Username == "" { + log.Debug().Msg("No user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // Get user + user := api.Auth.GetUser(userContext.Username) + + // Check if user exists + if user == nil { + log.Debug().Msg("User not found") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // Check if totp is correct + totpOk := totp.Validate(totpReq.Code, user.TotpSecret) + + // TOTP is incorrect + if !totpOk { + log.Debug().Msg("Totp incorrect") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Totp correct") + + // Create session cookie with username as provider + api.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Provider: "username", + }) + // Return logged in c.JSON(200, gin.H{ "status": 200, @@ -378,6 +475,7 @@ func (api *API) SetupRoutes() { DisableContinue: api.Config.DisableContinue, Title: api.Config.Title, GenericName: api.Config.GenericName, + TotpPending: userContext.TotpPending, } // If we are not logged in we set the status to 401 and add the WWW-Authenticate header else we set it to 200 @@ -392,19 +490,6 @@ func (api *API) SetupRoutes() { status.Message = "Authenticated" } - // // Marshall status to JSON - // statusJson, marshalErr := json.Marshal(status) - - // // Handle error - // if marshalErr != nil { - // log.Error().Err(marshalErr).Msg("Failed to marshal status") - // c.JSON(500, gin.H{ - // "status": 500, - // "message": "Internal Server Error", - // }) - // return - // } - // Return data c.JSON(200, status) }) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 926ee2a..c433648 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -74,6 +74,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) sessions.Set("username", data.Username) sessions.Set("provider", data.Provider) sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix()) + sessions.Set("totpPending", data.TotpPending) // Save session sessions.Save() @@ -102,14 +103,16 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { cookieUsername := sessions.Get("username") cookieProvider := sessions.Get("provider") cookieExpiry := sessions.Get("expiry") + cookieTotpPending := sessions.Get("totpPending") // Convert interfaces to correct types username, usernameOk := cookieUsername.(string) provider, providerOk := cookieProvider.(string) expiry, expiryOk := cookieExpiry.(int64) + totpPending, totpPendingOk := cookieTotpPending.(bool) // Check if the cookie is invalid - if !usernameOk || !providerOk || !expiryOk { + if !usernameOk || !providerOk || !expiryOk || !totpPendingOk { log.Warn().Msg("Session cookie invalid") return types.SessionCookie{} } @@ -125,12 +128,13 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { return types.SessionCookie{} } - log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie") + log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") // Return the cookie return types.SessionCookie{ - Username: username, - Provider: provider, + Username: username, + Provider: provider, + TotpPending: totpPending, } } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index b8f4e48..6921372 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -36,15 +36,29 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { // Return user context since we are logged in with basic auth return types.UserContext{ - Username: basic.Username, - IsLoggedIn: true, - OAuth: false, - Provider: "basic", + Username: basic.Username, + IsLoggedIn: true, + OAuth: false, + Provider: "basic", + TotpPending: false, } } } + // Check if session cookie has totp pending + if cookie.TotpPending { + log.Debug().Msg("Totp pending") + // Return empty context since we are pending totp + return types.UserContext{ + Username: cookie.Username, + IsLoggedIn: false, + OAuth: false, + Provider: cookie.Provider, + TotpPending: true, + } + } + // Check if session cookie is username/password auth if cookie.Provider == "username" { log.Debug().Msg("Provider is username") @@ -55,10 +69,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { // It exists so we are logged in return types.UserContext{ - Username: cookie.Username, - IsLoggedIn: true, - OAuth: false, - Provider: "username", + Username: cookie.Username, + IsLoggedIn: true, + OAuth: false, + Provider: "username", + TotpPending: false, } } } @@ -81,10 +96,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { // Return empty context return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", + Username: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + TotpPending: false, } } @@ -92,18 +108,20 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { // Return user context since we are logged in with oauth return types.UserContext{ - Username: cookie.Username, - IsLoggedIn: true, - OAuth: true, - Provider: cookie.Provider, + Username: cookie.Username, + IsLoggedIn: true, + OAuth: true, + Provider: cookie.Provider, + TotpPending: false, } } // Neither basic auth or oauth is set so we return an empty context return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", + Username: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + TotpPending: false, } } diff --git a/internal/types/types.go b/internal/types/types.go index 1e287f6..c6a067c 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -59,10 +59,11 @@ type Config struct { // UserContext is the context for the user type UserContext struct { - Username string - IsLoggedIn bool - OAuth bool - Provider string + Username string + IsLoggedIn bool + OAuth bool + Provider string + TotpPending bool } // APIConfig is the configuration for the API @@ -115,8 +116,9 @@ type UnauthorizedQuery struct { // SessionCookie is the cookie for the session (exculding the expiry) type SessionCookie struct { - Username string - Provider string + Username string + Provider string + TotpPending bool } // TinyauthLabels is the labels for the tinyauth container @@ -148,4 +150,10 @@ type Status struct { DisableContinue bool `json:"disableContinue"` Title string `json:"title"` GenericName string `json:"genericName"` + TotpPending bool `json:"totpPending"` +} + +// Totp request +type Totp struct { + Code string `json:"code"` } diff --git a/site/src/components/auth/totp-form.tsx b/site/src/components/auth/totp-form.tsx new file mode 100644 index 0000000..d860c6e --- /dev/null +++ b/site/src/components/auth/totp-form.tsx @@ -0,0 +1,40 @@ +import { Button, PinInput } from "@mantine/core"; +import { useForm, zodResolver } from "@mantine/form"; +import { z } from "zod"; + +const schema = z.object({ + code: z.string(), +}); + +type FormValues = z.infer; + +interface TotpFormProps { + onSubmit: (values: FormValues) => void; + isLoading: boolean; +} + +export const TotpForm = (props: TotpFormProps) => { + const { onSubmit, isLoading } = props; + + const form = useForm({ + mode: "uncontrolled", + initialValues: { + code: "", + }, + validate: zodResolver(schema), + }); + + return ( +
+ + + + ); +}; diff --git a/site/src/context/user-context.tsx b/site/src/context/user-context.tsx index eabeca6..cc59627 100644 --- a/site/src/context/user-context.tsx +++ b/site/src/context/user-context.tsx @@ -1,7 +1,7 @@ import { useQuery } from "@tanstack/react-query"; import React, { createContext, useContext } from "react"; -import { UserContextSchemaType } from "../schemas/user-context-schema"; import axios from "axios"; +import { UserContextSchemaType } from "../schemas/user-context-schema"; const UserContext = createContext(null); @@ -15,7 +15,7 @@ export const UserContextProvider = ({ isLoading, error, } = useQuery({ - queryKey: ["isLoggedIn"], + queryKey: ["userContext"], queryFn: async () => { const res = await axios.get("/api/status"); return res.data; diff --git a/site/src/main.tsx b/site/src/main.tsx index a30a5fd..dd88fd8 100644 --- a/site/src/main.tsx +++ b/site/src/main.tsx @@ -15,6 +15,7 @@ import { ContinuePage } from "./pages/continue-page.tsx"; import { NotFoundPage } from "./pages/not-found-page.tsx"; import { UnauthorizedPage } from "./pages/unauthorized-page.tsx"; import { InternalServerError } from "./pages/internal-server-error.tsx"; +import { TotpPage } from "./pages/totp-page.tsx"; const queryClient = new QueryClient({ defaultOptions: { @@ -34,6 +35,7 @@ createRoot(document.getElementById("root")!).render( } /> } /> + } /> } /> } /> } /> diff --git a/site/src/pages/login-page.tsx b/site/src/pages/login-page.tsx index 948c735..3b026e5 100644 --- a/site/src/pages/login-page.tsx +++ b/site/src/pages/login-page.tsx @@ -5,10 +5,10 @@ import axios from "axios"; import { useUserContext } from "../context/user-context"; import { Navigate } from "react-router"; import { Layout } from "../components/layouts/layout"; -import { isQueryValid } from "../utils/utils"; import { OAuthButtons } from "../components/auth/oauth-buttons"; import { LoginFormValues } from "../schemas/login-schema"; import { LoginForm } from "../components/auth/login-forn"; +import { isQueryValid } from "../utils/utils"; export const LoginPage = () => { const queryString = window.location.search; @@ -37,18 +37,25 @@ export const LoginPage = () => { color: "red", }); }, - onSuccess: () => { + onSuccess: async (data) => { + if (data.data.totpPending) { + window.location.replace(`/totp?redirect_uri=${redirectUri}`); + return; + } + notifications.show({ title: "Logged in", message: "Welcome back!", color: "green", }); + setTimeout(() => { if (!isQueryValid(redirectUri)) { window.location.replace("/"); - } else { - window.location.replace(`/continue?redirect_uri=${redirectUri}`); + return; } + + window.location.replace(`/continue?redirect_uri=${redirectUri}`); }, 500); }, }); diff --git a/site/src/pages/totp-page.tsx b/site/src/pages/totp-page.tsx new file mode 100644 index 0000000..33cc5ae --- /dev/null +++ b/site/src/pages/totp-page.tsx @@ -0,0 +1,62 @@ +import { Navigate } from "react-router"; +import { useUserContext } from "../context/user-context"; +import { Title, Paper, Text } from "@mantine/core"; +import { Layout } from "../components/layouts/layout"; +import { TotpForm } from "../components/auth/totp-form"; +import { useMutation } from "@tanstack/react-query"; +import axios from "axios"; +import { notifications } from "@mantine/notifications"; + +export const TotpPage = () => { + const queryString = window.location.search; + const params = new URLSearchParams(queryString); + const redirectUri = params.get("redirect_uri") ?? ""; + + const { totpPending, isLoggedIn, title } = useUserContext(); + + if (isLoggedIn) { + return ; + } + + if (!totpPending) { + return ; + } + + const totpMutation = useMutation({ + mutationFn: async (totp: { code: string }) => { + await axios.post("/api/totp", totp); + }, + onError: () => { + notifications.show({ + title: "Failed to verify code", + message: "Please try again", + color: "red", + }); + }, + onSuccess: () => { + notifications.show({ + title: "Verified", + message: "Redirecting to your app", + color: "green", + }); + setTimeout(() => { + window.location.replace(`/continue?redirect_uri=${redirectUri}`); + }, 500); + }, + }); + + return ( + + {title} + + + Enter your TOTP code + + totpMutation.mutate(values)} + /> + + + ); +}; diff --git a/site/src/schemas/user-context-schema.ts b/site/src/schemas/user-context-schema.ts index 58db405..83eb2e2 100644 --- a/site/src/schemas/user-context-schema.ts +++ b/site/src/schemas/user-context-schema.ts @@ -9,6 +9,7 @@ export const userContextSchema = z.object({ disableContinue: z.boolean(), title: z.string(), genericName: z.string(), + totpPending: z.boolean(), }); export type UserContextSchemaType = z.infer;