From 433e71bd50899bf4f17282a67e73dda28a35452f Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 24 Jan 2025 15:29:46 +0200 Subject: [PATCH] feat: persist sessions and auto redirect to app --- internal/api/api.go | 106 +++++++++++++++++------- internal/hooks/hooks.go | 66 +++++++++++---- internal/oauth/oauth.go | 6 +- internal/providers/providers.go | 32 ++----- internal/types/types.go | 7 +- internal/utils/utils.go | 2 +- site/src/icons/github.tsx | 18 ++++ site/src/icons/google.tsx | 30 +++++++ site/src/icons/microsoft.tsx | 18 ++++ site/src/pages/login-page.tsx | 89 +++++++++++++++++++- site/src/schemas/user-context-schema.ts | 1 + 11 files changed, 287 insertions(+), 88 deletions(-) create mode 100644 site/src/icons/github.tsx create mode 100644 site/src/icons/google.tsx create mode 100644 site/src/icons/microsoft.tsx diff --git a/internal/api/api.go b/internal/api/api.go index eb8d46e..86f77c1 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -36,6 +36,7 @@ type API struct { Hooks *hooks.Hooks Auth *auth.Auth Providers *providers.Providers + Domain string } func (api *API) Init() { @@ -70,8 +71,10 @@ func (api *API) Init() { isSecure = false } + api.Domain = fmt.Sprintf(".%s", domain) + store.Options(sessions.Options{ - Domain: fmt.Sprintf(".%s", domain), + Domain: api.Domain, Path: "/", HttpOnly: true, Secure: isSecure, @@ -163,8 +166,7 @@ func (api *API) SetupRoutes() { } session := sessions.Default(c) - session.Set("tinyauth_sid", user.Email) - session.Set("tinyauth_oauth_provider", "") + session.Set("tinyauth_sid", fmt.Sprintf("email:%s", login.Email)) session.Save() c.JSON(200, gin.H{ @@ -176,9 +178,10 @@ func (api *API) SetupRoutes() { api.Router.POST("/api/logout", func(c *gin.Context) { session := sessions.Default(c) session.Delete("tinyauth_sid") - session.Delete("tinyauth_oauth_provider") session.Save() + c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) + c.JSON(200, gin.H{ "status": 200, "message": "Logged out", @@ -198,23 +201,25 @@ func (api *API) SetupRoutes() { if !userContext.IsLoggedIn { c.JSON(200, gin.H{ - "status": 200, - "message": "Unauthenticated", - "email": "", - "isLoggedIn": false, - "oauth": false, - "provider": "", + "status": 200, + "message": "Unauthenticated", + "email": "", + "isLoggedIn": false, + "oauth": false, + "provider": "", + "configuredProviders": api.Providers.GetConfiguredProviders(), }) return } c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - "email": userContext.Email, - "isLoggedIn": userContext.IsLoggedIn, - "oauth": userContext.OAuth, - "provider": userContext.Provider, + "status": 200, + "message": "Authenticated", + "email": userContext.Email, + "isLoggedIn": userContext.IsLoggedIn, + "oauth": userContext.OAuth, + "provider": userContext.Provider, + "configuredProviders": api.Providers.GetConfiguredProviders(), }) }) @@ -226,9 +231,9 @@ func (api *API) SetupRoutes() { }) api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { - var provider types.OAuthBind + var request types.OAuthRequest - bindErr := c.BindUri(&provider) + bindErr := c.BindUri(&request) if bindErr != nil { c.JSON(400, gin.H{ @@ -238,16 +243,24 @@ func (api *API) SetupRoutes() { return } - authURL := api.Providers.GetAuthURL(provider.Provider) + provider := api.Providers.GetProvider(request.Provider) - if authURL == "" { - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", }) return } + authURL := provider.GetAuthURL() + + redirectURI := c.Query("redirect_uri") + + if redirectURI != "" { + c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) + } + c.JSON(200, gin.H{ "status": 200, "message": "Ok", @@ -256,9 +269,9 @@ func (api *API) SetupRoutes() { }) api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { - var provider types.OAuthBind + var providerName types.OAuthRequest - bindErr := c.BindUri(&provider) + bindErr := c.BindUri(&providerName) if bindErr != nil { c.JSON(400, gin.H{ @@ -278,9 +291,19 @@ func (api *API) SetupRoutes() { return } - email, emailErr := api.Providers.Login(code, provider.Provider) + provider := api.Providers.GetProvider(providerName.Provider) - if emailErr != nil { + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + token, tokenErr := provider.ExchangeToken(code) + + if tokenErr != nil { c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -289,14 +312,33 @@ func (api *API) SetupRoutes() { } session := sessions.Default(c) - session.Set("tinyauth_sid", email) - session.Set("tinyauth_oauth_provider", provider.Provider) + session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token)) session.Save() - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", + redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") + + if redirectURIErr != nil { + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + }) + } + + c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) + + queries, queryErr := query.Values(types.LoginQuery{ + RedirectURI: redirectURI, }) + + if queryErr != nil { + c.JSON(501, gin.H{ + "status": 501, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode())) }) } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 0390cab..bfb1ac3 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -1,12 +1,14 @@ package hooks import ( + "strings" "tinyauth/internal/auth" "tinyauth/internal/providers" "tinyauth/internal/types" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "golang.org/x/oauth2" ) func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { @@ -24,7 +26,6 @@ type Hooks struct { func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { session := sessions.Default(c) sessionCookie := session.Get("tinyauth_sid") - oauthProviderCookie := session.Get("tinyauth_oauth_provider") if sessionCookie == nil { return types.UserContext{ @@ -35,19 +36,33 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { }, nil } - email, emailOk := sessionCookie.(string) - provider, providerOk := oauthProviderCookie.(string) + data, dataOk := sessionCookie.(string) - if provider == "" || !providerOk { - if !emailOk { - return types.UserContext{ - Email: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - }, nil - } - user := hooks.Auth.GetUser(email) + if !dataOk { + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil + } + + split := strings.Split(data, ":") + + if len(split) != 2 { + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil + } + + sessionType := split[0] + sessionValue := split[1] + + if sessionType == "email" { + user := hooks.Auth.GetUser(sessionValue) if user == nil { return types.UserContext{ Email: "", @@ -57,16 +72,31 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { }, nil } return types.UserContext{ - Email: email, + Email: sessionValue, IsLoggedIn: true, OAuth: false, Provider: "", }, nil } - oauthEmail, oauthEmailErr := hooks.Providers.GetUser(provider) + provider := hooks.Providers.GetProvider(sessionType) - if oauthEmailErr != nil { + if provider == nil { + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil + } + + provider.Token = &oauth2.Token{ + AccessToken: sessionValue, + } + + email, emailErr := hooks.Providers.GetUser(sessionType) + + if emailErr != nil { return types.UserContext{ Email: "", IsLoggedIn: false, @@ -76,9 +106,9 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { } return types.UserContext{ - Email: oauthEmail, + Email: email, IsLoggedIn: true, OAuth: true, - Provider: provider, + Provider: sessionType, }, nil } diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 523cac2..7d13d81 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -30,14 +30,14 @@ func (oauth *OAuth) GetAuthURL() string { return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) } -func (oauth *OAuth) ExchangeToken(code string) error { +func (oauth *OAuth) ExchangeToken(code string) (string, error) { token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) if err != nil { log.Error().Err(err).Msg("Failed to exchange code") - return err + return "", err } oauth.Token = token - return nil + return oauth.Token.AccessToken, nil } func (oauth *OAuth) GetClient() *http.Client { diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 0170335..0e52072 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -35,24 +35,12 @@ func (providers *Providers) Init() { } } -func (providers *Providers) Login(code string, provider string) (string, error) { +func (providers *Providers) GetProvider(provider string) *oauth.OAuth { switch provider { case "github": - if providers.Github == nil { - return "", nil - } - exchangeErr := providers.Github.ExchangeToken(code) - if exchangeErr != nil { - return "", exchangeErr - } - client := providers.Github.GetClient() - email, emailErr := GetGithubEmail(client) - if emailErr != nil { - return "", emailErr - } - return email, nil + return providers.Github default: - return "", nil + return nil } } @@ -73,14 +61,10 @@ func (providers *Providers) GetUser(provider string) (string, error) { } } -func (providers *Providers) GetAuthURL(provider string) string { - switch provider { - case "github": - if providers.Github == nil { - return "" - } - return providers.Github.GetAuthURL() - default: - return "" +func (provider *Providers) GetConfiguredProviders() []string { + providers := []string{} + if provider.Github != nil { + providers = append(providers, "github") } + return providers } diff --git a/internal/types/types.go b/internal/types/types.go index 8c8d72d..945a19e 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -58,7 +58,7 @@ type OAuthConfig struct { MicrosoftClientSecret string } -type OAuthBind struct { +type OAuthRequest struct { Provider string `uri:"provider" binding:"required"` } @@ -67,8 +67,3 @@ type OAuthProviders struct { Google *oauth.OAuth Microsoft *oauth.OAuth } - -type OAuthLogin struct { - Email string - Token string -} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 608211f..239f3dc 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -22,7 +22,7 @@ func ParseUsers(users string) (types.Users, error) { return types.Users{}, errors.New("invalid user format") } usersParsed = append(usersParsed, types.User{ - Email: userSplit[0], + Email: userSplit[0], Password: userSplit[1], }) } diff --git a/site/src/icons/github.tsx b/site/src/icons/github.tsx new file mode 100644 index 0000000..1485f35 --- /dev/null +++ b/site/src/icons/github.tsx @@ -0,0 +1,18 @@ +import type { SVGProps } from "react"; + +export function GithubIcon(props: SVGProps) { + return ( + + + + ); +} diff --git a/site/src/icons/google.tsx b/site/src/icons/google.tsx new file mode 100644 index 0000000..1148569 --- /dev/null +++ b/site/src/icons/google.tsx @@ -0,0 +1,30 @@ +import type { SVGProps } from "react"; + +export function GoogleIcon(props: SVGProps) { + return ( + + + + + + + ); +} diff --git a/site/src/icons/microsoft.tsx b/site/src/icons/microsoft.tsx new file mode 100644 index 0000000..4e072ae --- /dev/null +++ b/site/src/icons/microsoft.tsx @@ -0,0 +1,18 @@ +import type { SVGProps } from "react"; + +export function MicrosoftIcon(props: SVGProps) { + return ( + + + + + + + ); +} diff --git a/site/src/pages/login-page.tsx b/site/src/pages/login-page.tsx index 1a5d5f6..5e9e008 100644 --- a/site/src/pages/login-page.tsx +++ b/site/src/pages/login-page.tsx @@ -1,4 +1,13 @@ -import { Button, Paper, PasswordInput, TextInput, Title } from "@mantine/core"; +import { + Button, + Paper, + PasswordInput, + TextInput, + Title, + Text, + Divider, + Grid, +} from "@mantine/core"; import { useForm, zodResolver } from "@mantine/form"; import { notifications } from "@mantine/notifications"; import { useMutation } from "@tanstack/react-query"; @@ -7,13 +16,16 @@ import { z } from "zod"; import { useUserContext } from "../context/user-context"; import { Navigate } from "react-router"; import { Layout } from "../components/layouts/layout"; +import { GoogleIcon } from "../icons/google"; +import { MicrosoftIcon } from "../icons/microsoft"; +import { GithubIcon } from "../icons/github"; export const LoginPage = () => { const queryString = window.location.search; const params = new URLSearchParams(queryString); const redirectUri = params.get("redirect_uri"); - const { isLoggedIn } = useUserContext(); + const { isLoggedIn, configuredProviders } = useUserContext(); if (isLoggedIn) { return ; @@ -58,14 +70,83 @@ export const LoginPage = () => { }, }); + const loginOAuthMutation = useMutation({ + mutationFn: (provider: string) => { + return axios.get( + `/api/oauth/url/${provider}?redirect_uri=${redirectUri}`, + ); + }, + onError: () => { + notifications.show({ + title: "Internal error", + message: "Failed to get OAuth URL", + color: "red", + }); + }, + onSuccess: (data) => { + window.location.replace(data.data.url); + }, + }); + const handleSubmit = (values: FormValues) => { loginMutation.mutate(values); }; return ( - Welcome back! - + Tinyauth + + + Welcome back, login with + + + {configuredProviders.includes("google") && ( + + + + )} + {configuredProviders.includes("microsoft") && ( + + + + )} + {configuredProviders.includes("github") && ( + + + + )} + +
;