feat: persist sessions and auto redirect to app

This commit is contained in:
Stavros
2025-01-24 15:29:46 +02:00
parent 80d25551e0
commit 433e71bd50
11 changed files with 287 additions and 88 deletions

View File

@@ -36,6 +36,7 @@ type API struct {
Hooks *hooks.Hooks Hooks *hooks.Hooks
Auth *auth.Auth Auth *auth.Auth
Providers *providers.Providers Providers *providers.Providers
Domain string
} }
func (api *API) Init() { func (api *API) Init() {
@@ -70,8 +71,10 @@ func (api *API) Init() {
isSecure = false isSecure = false
} }
api.Domain = fmt.Sprintf(".%s", domain)
store.Options(sessions.Options{ store.Options(sessions.Options{
Domain: fmt.Sprintf(".%s", domain), Domain: api.Domain,
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: isSecure, Secure: isSecure,
@@ -163,8 +166,7 @@ func (api *API) SetupRoutes() {
} }
session := sessions.Default(c) session := sessions.Default(c)
session.Set("tinyauth_sid", user.Email) session.Set("tinyauth_sid", fmt.Sprintf("email:%s", login.Email))
session.Set("tinyauth_oauth_provider", "")
session.Save() session.Save()
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -176,9 +178,10 @@ func (api *API) SetupRoutes() {
api.Router.POST("/api/logout", func(c *gin.Context) { api.Router.POST("/api/logout", func(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
session.Delete("tinyauth_sid") session.Delete("tinyauth_sid")
session.Delete("tinyauth_oauth_provider")
session.Save() session.Save()
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logged out", "message": "Logged out",
@@ -198,23 +201,25 @@ func (api *API) SetupRoutes() {
if !userContext.IsLoggedIn { if !userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Unauthenticated", "message": "Unauthenticated",
"email": "", "email": "",
"isLoggedIn": false, "isLoggedIn": false,
"oauth": false, "oauth": false,
"provider": "", "provider": "",
"configuredProviders": api.Providers.GetConfiguredProviders(),
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
"email": userContext.Email, "email": userContext.Email,
"isLoggedIn": userContext.IsLoggedIn, "isLoggedIn": userContext.IsLoggedIn,
"oauth": userContext.OAuth, "oauth": userContext.OAuth,
"provider": userContext.Provider, "provider": userContext.Provider,
"configuredProviders": api.Providers.GetConfiguredProviders(),
}) })
}) })
@@ -226,9 +231,9 @@ func (api *API) SetupRoutes() {
}) })
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
var provider types.OAuthBind var request types.OAuthRequest
bindErr := c.BindUri(&provider) bindErr := c.BindUri(&request)
if bindErr != nil { if bindErr != nil {
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -238,16 +243,24 @@ func (api *API) SetupRoutes() {
return return
} }
authURL := api.Providers.GetAuthURL(provider.Provider) provider := api.Providers.GetProvider(request.Provider)
if authURL == "" { if provider == nil {
c.JSON(400, gin.H{ c.JSON(404, gin.H{
"status": 400, "status": 404,
"message": "Bad Request", "message": "Not Found",
}) })
return return
} }
authURL := provider.GetAuthURL()
redirectURI := c.Query("redirect_uri")
if redirectURI != "" {
c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Ok", "message": "Ok",
@@ -256,9 +269,9 @@ func (api *API) SetupRoutes() {
}) })
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
var provider types.OAuthBind var providerName types.OAuthRequest
bindErr := c.BindUri(&provider) bindErr := c.BindUri(&providerName)
if bindErr != nil { if bindErr != nil {
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -278,9 +291,19 @@ func (api *API) SetupRoutes() {
return return
} }
email, emailErr := api.Providers.Login(code, provider.Provider) provider := api.Providers.GetProvider(providerName.Provider)
if emailErr != nil { if provider == nil {
c.JSON(404, gin.H{
"status": 404,
"message": "Not Found",
})
return
}
token, tokenErr := provider.ExchangeToken(code)
if tokenErr != nil {
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -289,14 +312,33 @@ func (api *API) SetupRoutes() {
} }
session := sessions.Default(c) session := sessions.Default(c)
session.Set("tinyauth_sid", email) session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token))
session.Set("tinyauth_oauth_provider", provider.Provider)
session.Save() session.Save()
c.JSON(200, gin.H{ redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
"status": 200,
"message": "Logged in", if redirectURIErr != nil {
c.JSON(200, gin.H{
"status": 200,
"message": "Logged in",
})
}
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
queries, queryErr := query.Values(types.LoginQuery{
RedirectURI: redirectURI,
}) })
if queryErr != nil {
c.JSON(501, gin.H{
"status": 501,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode()))
}) })
} }

View File

@@ -1,12 +1,14 @@
package hooks package hooks
import ( import (
"strings"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/oauth2"
) )
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
@@ -24,7 +26,6 @@ type Hooks struct {
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
session := sessions.Default(c) session := sessions.Default(c)
sessionCookie := session.Get("tinyauth_sid") sessionCookie := session.Get("tinyauth_sid")
oauthProviderCookie := session.Get("tinyauth_oauth_provider")
if sessionCookie == nil { if sessionCookie == nil {
return types.UserContext{ return types.UserContext{
@@ -35,19 +36,33 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
}, nil }, nil
} }
email, emailOk := sessionCookie.(string) data, dataOk := sessionCookie.(string)
provider, providerOk := oauthProviderCookie.(string)
if provider == "" || !providerOk { if !dataOk {
if !emailOk { return types.UserContext{
return types.UserContext{ Email: "",
Email: "", IsLoggedIn: false,
IsLoggedIn: false, OAuth: false,
OAuth: false, Provider: "",
Provider: "", }, nil
}, nil }
}
user := hooks.Auth.GetUser(email) split := strings.Split(data, ":")
if len(split) != 2 {
return types.UserContext{
Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
sessionType := split[0]
sessionValue := split[1]
if sessionType == "email" {
user := hooks.Auth.GetUser(sessionValue)
if user == nil { if user == nil {
return types.UserContext{ return types.UserContext{
Email: "", Email: "",
@@ -57,16 +72,31 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
}, nil }, nil
} }
return types.UserContext{ return types.UserContext{
Email: email, Email: sessionValue,
IsLoggedIn: true, IsLoggedIn: true,
OAuth: false, OAuth: false,
Provider: "", Provider: "",
}, nil }, nil
} }
oauthEmail, oauthEmailErr := hooks.Providers.GetUser(provider) provider := hooks.Providers.GetProvider(sessionType)
if oauthEmailErr != nil { if provider == nil {
return types.UserContext{
Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
provider.Token = &oauth2.Token{
AccessToken: sessionValue,
}
email, emailErr := hooks.Providers.GetUser(sessionType)
if emailErr != nil {
return types.UserContext{ return types.UserContext{
Email: "", Email: "",
IsLoggedIn: false, IsLoggedIn: false,
@@ -76,9 +106,9 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
} }
return types.UserContext{ return types.UserContext{
Email: oauthEmail, Email: email,
IsLoggedIn: true, IsLoggedIn: true,
OAuth: true, OAuth: true,
Provider: provider, Provider: sessionType,
}, nil }, nil
} }

View File

@@ -30,14 +30,14 @@ func (oauth *OAuth) GetAuthURL() string {
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
} }
func (oauth *OAuth) ExchangeToken(code string) error { func (oauth *OAuth) ExchangeToken(code string) (string, error) {
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to exchange code") log.Error().Err(err).Msg("Failed to exchange code")
return err return "", err
} }
oauth.Token = token oauth.Token = token
return nil return oauth.Token.AccessToken, nil
} }
func (oauth *OAuth) GetClient() *http.Client { func (oauth *OAuth) GetClient() *http.Client {

View File

@@ -35,24 +35,12 @@ func (providers *Providers) Init() {
} }
} }
func (providers *Providers) Login(code string, provider string) (string, error) { func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
switch provider { switch provider {
case "github": case "github":
if providers.Github == nil { return providers.Github
return "", nil
}
exchangeErr := providers.Github.ExchangeToken(code)
if exchangeErr != nil {
return "", exchangeErr
}
client := providers.Github.GetClient()
email, emailErr := GetGithubEmail(client)
if emailErr != nil {
return "", emailErr
}
return email, nil
default: default:
return "", nil return nil
} }
} }
@@ -73,14 +61,10 @@ func (providers *Providers) GetUser(provider string) (string, error) {
} }
} }
func (providers *Providers) GetAuthURL(provider string) string { func (provider *Providers) GetConfiguredProviders() []string {
switch provider { providers := []string{}
case "github": if provider.Github != nil {
if providers.Github == nil { providers = append(providers, "github")
return ""
}
return providers.Github.GetAuthURL()
default:
return ""
} }
return providers
} }

View File

@@ -58,7 +58,7 @@ type OAuthConfig struct {
MicrosoftClientSecret string MicrosoftClientSecret string
} }
type OAuthBind struct { type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"` Provider string `uri:"provider" binding:"required"`
} }
@@ -67,8 +67,3 @@ type OAuthProviders struct {
Google *oauth.OAuth Google *oauth.OAuth
Microsoft *oauth.OAuth Microsoft *oauth.OAuth
} }
type OAuthLogin struct {
Email string
Token string
}

View File

@@ -22,7 +22,7 @@ func ParseUsers(users string) (types.Users, error) {
return types.Users{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
usersParsed = append(usersParsed, types.User{ usersParsed = append(usersParsed, types.User{
Email: userSplit[0], Email: userSplit[0],
Password: userSplit[1], Password: userSplit[1],
}) })
} }

18
site/src/icons/github.tsx Normal file
View File

@@ -0,0 +1,18 @@
import type { SVGProps } from "react";
export function GithubIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width={24}
height={24}
viewBox="0 0 24 24"
{...props}
>
<path
fill="currentColor"
d="M12 2A10 10 0 0 0 2 12c0 4.42 2.87 8.17 6.84 9.5c.5.08.66-.23.66-.5v-1.69c-2.77.6-3.36-1.34-3.36-1.34c-.46-1.16-1.11-1.47-1.11-1.47c-.91-.62.07-.6.07-.6c1 .07 1.53 1.03 1.53 1.03c.87 1.52 2.34 1.07 2.91.83c.09-.65.35-1.09.63-1.34c-2.22-.25-4.55-1.11-4.55-4.92c0-1.11.38-2 1.03-2.71c-.1-.25-.45-1.29.1-2.64c0 0 .84-.27 2.75 1.02c.79-.22 1.65-.33 2.5-.33s1.71.11 2.5.33c1.91-1.29 2.75-1.02 2.75-1.02c.55 1.35.2 2.39.1 2.64c.65.71 1.03 1.6 1.03 2.71c0 3.82-2.34 4.66-4.57 4.91c.36.31.69.92.69 1.85V21c0 .27.16.59.67.5C19.14 20.16 22 16.42 22 12A10 10 0 0 0 12 2"
></path>
</svg>
);
}

30
site/src/icons/google.tsx Normal file
View File

@@ -0,0 +1,30 @@
import type { SVGProps } from "react";
export function GoogleIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width={48}
height={48}
viewBox="0 0 48 48"
{...props}
>
<path
fill="#ffc107"
d="M43.611 20.083H42V20H24v8h11.303c-1.649 4.657-6.08 8-11.303 8c-6.627 0-12-5.373-12-12s5.373-12 12-12c3.059 0 5.842 1.154 7.961 3.039l5.657-5.657C34.046 6.053 29.268 4 24 4C12.955 4 4 12.955 4 24s8.955 20 20 20s20-8.955 20-20c0-1.341-.138-2.65-.389-3.917"
></path>
<path
fill="#ff3d00"
d="m6.306 14.691l6.571 4.819C14.655 15.108 18.961 12 24 12c3.059 0 5.842 1.154 7.961 3.039l5.657-5.657C34.046 6.053 29.268 4 24 4C16.318 4 9.656 8.337 6.306 14.691"
></path>
<path
fill="#4caf50"
d="M24 44c5.166 0 9.86-1.977 13.409-5.192l-6.19-5.238A11.9 11.9 0 0 1 24 36c-5.202 0-9.619-3.317-11.283-7.946l-6.522 5.025C9.505 39.556 16.227 44 24 44"
></path>
<path
fill="#1976d2"
d="M43.611 20.083H42V20H24v8h11.303a12.04 12.04 0 0 1-4.087 5.571l.003-.002l6.19 5.238C36.971 39.205 44 34 44 24c0-1.341-.138-2.65-.389-3.917"
></path>
</svg>
);
}

View File

@@ -0,0 +1,18 @@
import type { SVGProps } from "react";
export function MicrosoftIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width={256}
height={256}
viewBox="0 0 256 256"
{...props}
>
<path fill="#f1511b" d="M121.666 121.666H0V0h121.666z"></path>
<path fill="#80cc28" d="M256 121.666H134.335V0H256z"></path>
<path fill="#00adef" d="M121.663 256.002H0V134.336h121.663z"></path>
<path fill="#fbbc09" d="M256 256.002H134.335V134.336H256z"></path>
</svg>
);
}

View File

@@ -1,4 +1,13 @@
import { Button, Paper, PasswordInput, TextInput, Title } from "@mantine/core"; import {
Button,
Paper,
PasswordInput,
TextInput,
Title,
Text,
Divider,
Grid,
} from "@mantine/core";
import { useForm, zodResolver } from "@mantine/form"; import { useForm, zodResolver } from "@mantine/form";
import { notifications } from "@mantine/notifications"; import { notifications } from "@mantine/notifications";
import { useMutation } from "@tanstack/react-query"; import { useMutation } from "@tanstack/react-query";
@@ -7,13 +16,16 @@ import { z } from "zod";
import { useUserContext } from "../context/user-context"; import { useUserContext } from "../context/user-context";
import { Navigate } from "react-router"; import { Navigate } from "react-router";
import { Layout } from "../components/layouts/layout"; import { Layout } from "../components/layouts/layout";
import { GoogleIcon } from "../icons/google";
import { MicrosoftIcon } from "../icons/microsoft";
import { GithubIcon } from "../icons/github";
export const LoginPage = () => { export const LoginPage = () => {
const queryString = window.location.search; const queryString = window.location.search;
const params = new URLSearchParams(queryString); const params = new URLSearchParams(queryString);
const redirectUri = params.get("redirect_uri"); const redirectUri = params.get("redirect_uri");
const { isLoggedIn } = useUserContext(); const { isLoggedIn, configuredProviders } = useUserContext();
if (isLoggedIn) { if (isLoggedIn) {
return <Navigate to="/logout" />; return <Navigate to="/logout" />;
@@ -58,14 +70,83 @@ export const LoginPage = () => {
}, },
}); });
const loginOAuthMutation = useMutation({
mutationFn: (provider: string) => {
return axios.get(
`/api/oauth/url/${provider}?redirect_uri=${redirectUri}`,
);
},
onError: () => {
notifications.show({
title: "Internal error",
message: "Failed to get OAuth URL",
color: "red",
});
},
onSuccess: (data) => {
window.location.replace(data.data.url);
},
});
const handleSubmit = (values: FormValues) => { const handleSubmit = (values: FormValues) => {
loginMutation.mutate(values); loginMutation.mutate(values);
}; };
return ( return (
<Layout> <Layout>
<Title ta="center">Welcome back!</Title> <Title ta="center">Tinyauth</Title>
<Paper shadow="md" p={30} mt={30} radius="md" withBorder> <Paper shadow="md" p="xl" mt={30} radius="md" withBorder>
<Text size="lg" fw={500} ta="center">
Welcome back, login with
</Text>
<Grid mb="md" mt="md" align="center" justify="center">
{configuredProviders.includes("google") && (
<Grid.Col span="content">
<Button
radius="xl"
leftSection={<GoogleIcon style={{ width: 14, height: 14 }} />}
variant="default"
onClick={() => loginOAuthMutation.mutate("google")}
loading={loginOAuthMutation.isLoading}
>
Google
</Button>
</Grid.Col>
)}
{configuredProviders.includes("microsoft") && (
<Grid.Col span="content">
<Button
radius="xl"
leftSection={
<MicrosoftIcon style={{ width: 14, height: 14 }} />
}
variant="default"
onClick={() => loginOAuthMutation.mutate("microsoft")}
loading={loginOAuthMutation.isLoading}
>
Microsoft
</Button>
</Grid.Col>
)}
{configuredProviders.includes("github") && (
<Grid.Col span="content">
<Button
radius="xl"
leftSection={<GithubIcon style={{ width: 14, height: 14 }} />}
variant="default"
onClick={() => loginOAuthMutation.mutate("github")}
loading={loginOAuthMutation.isLoading}
>
Github
</Button>
</Grid.Col>
)}
</Grid>
<Divider
label="Or continue with email"
labelPosition="center"
my="lg"
/>
<form onSubmit={form.onSubmit(handleSubmit)}> <form onSubmit={form.onSubmit(handleSubmit)}>
<TextInput <TextInput
label="Email" label="Email"

View File

@@ -5,6 +5,7 @@ export const userContextSchema = z.object({
email: z.string(), email: z.string(),
oauth: z.boolean(), oauth: z.boolean(),
provider: z.string(), provider: z.string(),
configuredProviders: z.array(z.string()),
}); });
export type UserContextSchemaType = z.infer<typeof userContextSchema>; export type UserContextSchemaType = z.infer<typeof userContextSchema>;