feat: persist sessions and auto redirect to app

This commit is contained in:
Stavros
2025-01-24 15:29:46 +02:00
parent 80d25551e0
commit 433e71bd50
11 changed files with 287 additions and 88 deletions

View File

@@ -36,6 +36,7 @@ type API struct {
Hooks *hooks.Hooks
Auth *auth.Auth
Providers *providers.Providers
Domain string
}
func (api *API) Init() {
@@ -70,8 +71,10 @@ func (api *API) Init() {
isSecure = false
}
api.Domain = fmt.Sprintf(".%s", domain)
store.Options(sessions.Options{
Domain: fmt.Sprintf(".%s", domain),
Domain: api.Domain,
Path: "/",
HttpOnly: true,
Secure: isSecure,
@@ -163,8 +166,7 @@ func (api *API) SetupRoutes() {
}
session := sessions.Default(c)
session.Set("tinyauth_sid", user.Email)
session.Set("tinyauth_oauth_provider", "")
session.Set("tinyauth_sid", fmt.Sprintf("email:%s", login.Email))
session.Save()
c.JSON(200, gin.H{
@@ -176,9 +178,10 @@ func (api *API) SetupRoutes() {
api.Router.POST("/api/logout", func(c *gin.Context) {
session := sessions.Default(c)
session.Delete("tinyauth_sid")
session.Delete("tinyauth_oauth_provider")
session.Save()
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
c.JSON(200, gin.H{
"status": 200,
"message": "Logged out",
@@ -204,6 +207,7 @@ func (api *API) SetupRoutes() {
"isLoggedIn": false,
"oauth": false,
"provider": "",
"configuredProviders": api.Providers.GetConfiguredProviders(),
})
return
}
@@ -215,6 +219,7 @@ func (api *API) SetupRoutes() {
"isLoggedIn": userContext.IsLoggedIn,
"oauth": userContext.OAuth,
"provider": userContext.Provider,
"configuredProviders": api.Providers.GetConfiguredProviders(),
})
})
@@ -226,9 +231,9 @@ func (api *API) SetupRoutes() {
})
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
var provider types.OAuthBind
var request types.OAuthRequest
bindErr := c.BindUri(&provider)
bindErr := c.BindUri(&request)
if bindErr != nil {
c.JSON(400, gin.H{
@@ -238,16 +243,24 @@ func (api *API) SetupRoutes() {
return
}
authURL := api.Providers.GetAuthURL(provider.Provider)
provider := api.Providers.GetProvider(request.Provider)
if authURL == "" {
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
if provider == nil {
c.JSON(404, gin.H{
"status": 404,
"message": "Not Found",
})
return
}
authURL := provider.GetAuthURL()
redirectURI := c.Query("redirect_uri")
if redirectURI != "" {
c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
}
c.JSON(200, gin.H{
"status": 200,
"message": "Ok",
@@ -256,9 +269,9 @@ func (api *API) SetupRoutes() {
})
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
var provider types.OAuthBind
var providerName types.OAuthRequest
bindErr := c.BindUri(&provider)
bindErr := c.BindUri(&providerName)
if bindErr != nil {
c.JSON(400, gin.H{
@@ -278,9 +291,19 @@ func (api *API) SetupRoutes() {
return
}
email, emailErr := api.Providers.Login(code, provider.Provider)
provider := api.Providers.GetProvider(providerName.Provider)
if emailErr != nil {
if provider == nil {
c.JSON(404, gin.H{
"status": 404,
"message": "Not Found",
})
return
}
token, tokenErr := provider.ExchangeToken(code)
if tokenErr != nil {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -289,14 +312,33 @@ func (api *API) SetupRoutes() {
}
session := sessions.Default(c)
session.Set("tinyauth_sid", email)
session.Set("tinyauth_oauth_provider", provider.Provider)
session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token))
session.Save()
redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
if redirectURIErr != nil {
c.JSON(200, gin.H{
"status": 200,
"message": "Logged in",
})
}
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
queries, queryErr := query.Values(types.LoginQuery{
RedirectURI: redirectURI,
})
if queryErr != nil {
c.JSON(501, gin.H{
"status": 501,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode()))
})
}

View File

@@ -1,12 +1,14 @@
package hooks
import (
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"golang.org/x/oauth2"
)
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
@@ -24,7 +26,6 @@ type Hooks struct {
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
session := sessions.Default(c)
sessionCookie := session.Get("tinyauth_sid")
oauthProviderCookie := session.Get("tinyauth_oauth_provider")
if sessionCookie == nil {
return types.UserContext{
@@ -35,11 +36,9 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
}, nil
}
email, emailOk := sessionCookie.(string)
provider, providerOk := oauthProviderCookie.(string)
data, dataOk := sessionCookie.(string)
if provider == "" || !providerOk {
if !emailOk {
if !dataOk {
return types.UserContext{
Email: "",
IsLoggedIn: false,
@@ -47,7 +46,23 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
Provider: "",
}, nil
}
user := hooks.Auth.GetUser(email)
split := strings.Split(data, ":")
if len(split) != 2 {
return types.UserContext{
Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
sessionType := split[0]
sessionValue := split[1]
if sessionType == "email" {
user := hooks.Auth.GetUser(sessionValue)
if user == nil {
return types.UserContext{
Email: "",
@@ -57,16 +72,31 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
}, nil
}
return types.UserContext{
Email: email,
Email: sessionValue,
IsLoggedIn: true,
OAuth: false,
Provider: "",
}, nil
}
oauthEmail, oauthEmailErr := hooks.Providers.GetUser(provider)
provider := hooks.Providers.GetProvider(sessionType)
if oauthEmailErr != nil {
if provider == nil {
return types.UserContext{
Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
provider.Token = &oauth2.Token{
AccessToken: sessionValue,
}
email, emailErr := hooks.Providers.GetUser(sessionType)
if emailErr != nil {
return types.UserContext{
Email: "",
IsLoggedIn: false,
@@ -76,9 +106,9 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
}
return types.UserContext{
Email: oauthEmail,
Email: email,
IsLoggedIn: true,
OAuth: true,
Provider: provider,
Provider: sessionType,
}, nil
}

View File

@@ -30,14 +30,14 @@ func (oauth *OAuth) GetAuthURL() string {
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
}
func (oauth *OAuth) ExchangeToken(code string) error {
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
if err != nil {
log.Error().Err(err).Msg("Failed to exchange code")
return err
return "", err
}
oauth.Token = token
return nil
return oauth.Token.AccessToken, nil
}
func (oauth *OAuth) GetClient() *http.Client {

View File

@@ -35,24 +35,12 @@ func (providers *Providers) Init() {
}
}
func (providers *Providers) Login(code string, provider string) (string, error) {
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
switch provider {
case "github":
if providers.Github == nil {
return "", nil
}
exchangeErr := providers.Github.ExchangeToken(code)
if exchangeErr != nil {
return "", exchangeErr
}
client := providers.Github.GetClient()
email, emailErr := GetGithubEmail(client)
if emailErr != nil {
return "", emailErr
}
return email, nil
return providers.Github
default:
return "", nil
return nil
}
}
@@ -73,14 +61,10 @@ func (providers *Providers) GetUser(provider string) (string, error) {
}
}
func (providers *Providers) GetAuthURL(provider string) string {
switch provider {
case "github":
if providers.Github == nil {
return ""
}
return providers.Github.GetAuthURL()
default:
return ""
func (provider *Providers) GetConfiguredProviders() []string {
providers := []string{}
if provider.Github != nil {
providers = append(providers, "github")
}
return providers
}

View File

@@ -58,7 +58,7 @@ type OAuthConfig struct {
MicrosoftClientSecret string
}
type OAuthBind struct {
type OAuthRequest struct {
Provider string `uri:"provider" binding:"required"`
}
@@ -67,8 +67,3 @@ type OAuthProviders struct {
Google *oauth.OAuth
Microsoft *oauth.OAuth
}
type OAuthLogin struct {
Email string
Token string
}

18
site/src/icons/github.tsx Normal file
View File

@@ -0,0 +1,18 @@
import type { SVGProps } from "react";
export function GithubIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width={24}
height={24}
viewBox="0 0 24 24"
{...props}
>
<path
fill="currentColor"
d="M12 2A10 10 0 0 0 2 12c0 4.42 2.87 8.17 6.84 9.5c.5.08.66-.23.66-.5v-1.69c-2.77.6-3.36-1.34-3.36-1.34c-.46-1.16-1.11-1.47-1.11-1.47c-.91-.62.07-.6.07-.6c1 .07 1.53 1.03 1.53 1.03c.87 1.52 2.34 1.07 2.91.83c.09-.65.35-1.09.63-1.34c-2.22-.25-4.55-1.11-4.55-4.92c0-1.11.38-2 1.03-2.71c-.1-.25-.45-1.29.1-2.64c0 0 .84-.27 2.75 1.02c.79-.22 1.65-.33 2.5-.33s1.71.11 2.5.33c1.91-1.29 2.75-1.02 2.75-1.02c.55 1.35.2 2.39.1 2.64c.65.71 1.03 1.6 1.03 2.71c0 3.82-2.34 4.66-4.57 4.91c.36.31.69.92.69 1.85V21c0 .27.16.59.67.5C19.14 20.16 22 16.42 22 12A10 10 0 0 0 12 2"
></path>
</svg>
);
}

30
site/src/icons/google.tsx Normal file
View File

@@ -0,0 +1,30 @@
import type { SVGProps } from "react";
export function GoogleIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width={48}
height={48}
viewBox="0 0 48 48"
{...props}
>
<path
fill="#ffc107"
d="M43.611 20.083H42V20H24v8h11.303c-1.649 4.657-6.08 8-11.303 8c-6.627 0-12-5.373-12-12s5.373-12 12-12c3.059 0 5.842 1.154 7.961 3.039l5.657-5.657C34.046 6.053 29.268 4 24 4C12.955 4 4 12.955 4 24s8.955 20 20 20s20-8.955 20-20c0-1.341-.138-2.65-.389-3.917"
></path>
<path
fill="#ff3d00"
d="m6.306 14.691l6.571 4.819C14.655 15.108 18.961 12 24 12c3.059 0 5.842 1.154 7.961 3.039l5.657-5.657C34.046 6.053 29.268 4 24 4C16.318 4 9.656 8.337 6.306 14.691"
></path>
<path
fill="#4caf50"
d="M24 44c5.166 0 9.86-1.977 13.409-5.192l-6.19-5.238A11.9 11.9 0 0 1 24 36c-5.202 0-9.619-3.317-11.283-7.946l-6.522 5.025C9.505 39.556 16.227 44 24 44"
></path>
<path
fill="#1976d2"
d="M43.611 20.083H42V20H24v8h11.303a12.04 12.04 0 0 1-4.087 5.571l.003-.002l6.19 5.238C36.971 39.205 44 34 44 24c0-1.341-.138-2.65-.389-3.917"
></path>
</svg>
);
}

View File

@@ -0,0 +1,18 @@
import type { SVGProps } from "react";
export function MicrosoftIcon(props: SVGProps<SVGSVGElement>) {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width={256}
height={256}
viewBox="0 0 256 256"
{...props}
>
<path fill="#f1511b" d="M121.666 121.666H0V0h121.666z"></path>
<path fill="#80cc28" d="M256 121.666H134.335V0H256z"></path>
<path fill="#00adef" d="M121.663 256.002H0V134.336h121.663z"></path>
<path fill="#fbbc09" d="M256 256.002H134.335V134.336H256z"></path>
</svg>
);
}

View File

@@ -1,4 +1,13 @@
import { Button, Paper, PasswordInput, TextInput, Title } from "@mantine/core";
import {
Button,
Paper,
PasswordInput,
TextInput,
Title,
Text,
Divider,
Grid,
} from "@mantine/core";
import { useForm, zodResolver } from "@mantine/form";
import { notifications } from "@mantine/notifications";
import { useMutation } from "@tanstack/react-query";
@@ -7,13 +16,16 @@ import { z } from "zod";
import { useUserContext } from "../context/user-context";
import { Navigate } from "react-router";
import { Layout } from "../components/layouts/layout";
import { GoogleIcon } from "../icons/google";
import { MicrosoftIcon } from "../icons/microsoft";
import { GithubIcon } from "../icons/github";
export const LoginPage = () => {
const queryString = window.location.search;
const params = new URLSearchParams(queryString);
const redirectUri = params.get("redirect_uri");
const { isLoggedIn } = useUserContext();
const { isLoggedIn, configuredProviders } = useUserContext();
if (isLoggedIn) {
return <Navigate to="/logout" />;
@@ -58,14 +70,83 @@ export const LoginPage = () => {
},
});
const loginOAuthMutation = useMutation({
mutationFn: (provider: string) => {
return axios.get(
`/api/oauth/url/${provider}?redirect_uri=${redirectUri}`,
);
},
onError: () => {
notifications.show({
title: "Internal error",
message: "Failed to get OAuth URL",
color: "red",
});
},
onSuccess: (data) => {
window.location.replace(data.data.url);
},
});
const handleSubmit = (values: FormValues) => {
loginMutation.mutate(values);
};
return (
<Layout>
<Title ta="center">Welcome back!</Title>
<Paper shadow="md" p={30} mt={30} radius="md" withBorder>
<Title ta="center">Tinyauth</Title>
<Paper shadow="md" p="xl" mt={30} radius="md" withBorder>
<Text size="lg" fw={500} ta="center">
Welcome back, login with
</Text>
<Grid mb="md" mt="md" align="center" justify="center">
{configuredProviders.includes("google") && (
<Grid.Col span="content">
<Button
radius="xl"
leftSection={<GoogleIcon style={{ width: 14, height: 14 }} />}
variant="default"
onClick={() => loginOAuthMutation.mutate("google")}
loading={loginOAuthMutation.isLoading}
>
Google
</Button>
</Grid.Col>
)}
{configuredProviders.includes("microsoft") && (
<Grid.Col span="content">
<Button
radius="xl"
leftSection={
<MicrosoftIcon style={{ width: 14, height: 14 }} />
}
variant="default"
onClick={() => loginOAuthMutation.mutate("microsoft")}
loading={loginOAuthMutation.isLoading}
>
Microsoft
</Button>
</Grid.Col>
)}
{configuredProviders.includes("github") && (
<Grid.Col span="content">
<Button
radius="xl"
leftSection={<GithubIcon style={{ width: 14, height: 14 }} />}
variant="default"
onClick={() => loginOAuthMutation.mutate("github")}
loading={loginOAuthMutation.isLoading}
>
Github
</Button>
</Grid.Col>
)}
</Grid>
<Divider
label="Or continue with email"
labelPosition="center"
my="lg"
/>
<form onSubmit={form.onSubmit(handleSubmit)}>
<TextInput
label="Email"

View File

@@ -5,6 +5,7 @@ export const userContextSchema = z.object({
email: z.string(),
oauth: z.boolean(),
provider: z.string(),
configuredProviders: z.array(z.string()),
});
export type UserContextSchemaType = z.infer<typeof userContextSchema>;