mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 12:45:47 +00:00
feat: persist sessions and auto redirect to app
This commit is contained in:
@@ -36,6 +36,7 @@ type API struct {
|
||||
Hooks *hooks.Hooks
|
||||
Auth *auth.Auth
|
||||
Providers *providers.Providers
|
||||
Domain string
|
||||
}
|
||||
|
||||
func (api *API) Init() {
|
||||
@@ -70,8 +71,10 @@ func (api *API) Init() {
|
||||
isSecure = false
|
||||
}
|
||||
|
||||
api.Domain = fmt.Sprintf(".%s", domain)
|
||||
|
||||
store.Options(sessions.Options{
|
||||
Domain: fmt.Sprintf(".%s", domain),
|
||||
Domain: api.Domain,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
@@ -163,8 +166,7 @@ func (api *API) SetupRoutes() {
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
session.Set("tinyauth_sid", user.Email)
|
||||
session.Set("tinyauth_oauth_provider", "")
|
||||
session.Set("tinyauth_sid", fmt.Sprintf("email:%s", login.Email))
|
||||
session.Save()
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
@@ -176,9 +178,10 @@ func (api *API) SetupRoutes() {
|
||||
api.Router.POST("/api/logout", func(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Delete("tinyauth_sid")
|
||||
session.Delete("tinyauth_oauth_provider")
|
||||
session.Save()
|
||||
|
||||
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Logged out",
|
||||
@@ -204,6 +207,7 @@ func (api *API) SetupRoutes() {
|
||||
"isLoggedIn": false,
|
||||
"oauth": false,
|
||||
"provider": "",
|
||||
"configuredProviders": api.Providers.GetConfiguredProviders(),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -215,6 +219,7 @@ func (api *API) SetupRoutes() {
|
||||
"isLoggedIn": userContext.IsLoggedIn,
|
||||
"oauth": userContext.OAuth,
|
||||
"provider": userContext.Provider,
|
||||
"configuredProviders": api.Providers.GetConfiguredProviders(),
|
||||
})
|
||||
})
|
||||
|
||||
@@ -226,9 +231,9 @@ func (api *API) SetupRoutes() {
|
||||
})
|
||||
|
||||
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
|
||||
var provider types.OAuthBind
|
||||
var request types.OAuthRequest
|
||||
|
||||
bindErr := c.BindUri(&provider)
|
||||
bindErr := c.BindUri(&request)
|
||||
|
||||
if bindErr != nil {
|
||||
c.JSON(400, gin.H{
|
||||
@@ -238,16 +243,24 @@ func (api *API) SetupRoutes() {
|
||||
return
|
||||
}
|
||||
|
||||
authURL := api.Providers.GetAuthURL(provider.Provider)
|
||||
provider := api.Providers.GetProvider(request.Provider)
|
||||
|
||||
if authURL == "" {
|
||||
c.JSON(400, gin.H{
|
||||
"status": 400,
|
||||
"message": "Bad Request",
|
||||
if provider == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"status": 404,
|
||||
"message": "Not Found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
authURL := provider.GetAuthURL()
|
||||
|
||||
redirectURI := c.Query("redirect_uri")
|
||||
|
||||
if redirectURI != "" {
|
||||
c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Ok",
|
||||
@@ -256,9 +269,9 @@ func (api *API) SetupRoutes() {
|
||||
})
|
||||
|
||||
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
|
||||
var provider types.OAuthBind
|
||||
var providerName types.OAuthRequest
|
||||
|
||||
bindErr := c.BindUri(&provider)
|
||||
bindErr := c.BindUri(&providerName)
|
||||
|
||||
if bindErr != nil {
|
||||
c.JSON(400, gin.H{
|
||||
@@ -278,9 +291,19 @@ func (api *API) SetupRoutes() {
|
||||
return
|
||||
}
|
||||
|
||||
email, emailErr := api.Providers.Login(code, provider.Provider)
|
||||
provider := api.Providers.GetProvider(providerName.Provider)
|
||||
|
||||
if emailErr != nil {
|
||||
if provider == nil {
|
||||
c.JSON(404, gin.H{
|
||||
"status": 404,
|
||||
"message": "Not Found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token, tokenErr := provider.ExchangeToken(code)
|
||||
|
||||
if tokenErr != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
@@ -289,14 +312,33 @@ func (api *API) SetupRoutes() {
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
session.Set("tinyauth_sid", email)
|
||||
session.Set("tinyauth_oauth_provider", provider.Provider)
|
||||
session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token))
|
||||
session.Save()
|
||||
|
||||
redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
|
||||
|
||||
if redirectURIErr != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Logged in",
|
||||
})
|
||||
}
|
||||
|
||||
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
||||
|
||||
queries, queryErr := query.Values(types.LoginQuery{
|
||||
RedirectURI: redirectURI,
|
||||
})
|
||||
|
||||
if queryErr != nil {
|
||||
c.JSON(501, gin.H{
|
||||
"status": 501,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode()))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"tinyauth/internal/auth"
|
||||
"tinyauth/internal/providers"
|
||||
"tinyauth/internal/types"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
|
||||
@@ -24,7 +26,6 @@ type Hooks struct {
|
||||
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
|
||||
session := sessions.Default(c)
|
||||
sessionCookie := session.Get("tinyauth_sid")
|
||||
oauthProviderCookie := session.Get("tinyauth_oauth_provider")
|
||||
|
||||
if sessionCookie == nil {
|
||||
return types.UserContext{
|
||||
@@ -35,11 +36,9 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
email, emailOk := sessionCookie.(string)
|
||||
provider, providerOk := oauthProviderCookie.(string)
|
||||
data, dataOk := sessionCookie.(string)
|
||||
|
||||
if provider == "" || !providerOk {
|
||||
if !emailOk {
|
||||
if !dataOk {
|
||||
return types.UserContext{
|
||||
Email: "",
|
||||
IsLoggedIn: false,
|
||||
@@ -47,7 +46,23 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
|
||||
Provider: "",
|
||||
}, nil
|
||||
}
|
||||
user := hooks.Auth.GetUser(email)
|
||||
|
||||
split := strings.Split(data, ":")
|
||||
|
||||
if len(split) != 2 {
|
||||
return types.UserContext{
|
||||
Email: "",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
sessionType := split[0]
|
||||
sessionValue := split[1]
|
||||
|
||||
if sessionType == "email" {
|
||||
user := hooks.Auth.GetUser(sessionValue)
|
||||
if user == nil {
|
||||
return types.UserContext{
|
||||
Email: "",
|
||||
@@ -57,16 +72,31 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
|
||||
}, nil
|
||||
}
|
||||
return types.UserContext{
|
||||
Email: email,
|
||||
Email: sessionValue,
|
||||
IsLoggedIn: true,
|
||||
OAuth: false,
|
||||
Provider: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
oauthEmail, oauthEmailErr := hooks.Providers.GetUser(provider)
|
||||
provider := hooks.Providers.GetProvider(sessionType)
|
||||
|
||||
if oauthEmailErr != nil {
|
||||
if provider == nil {
|
||||
return types.UserContext{
|
||||
Email: "",
|
||||
IsLoggedIn: false,
|
||||
OAuth: false,
|
||||
Provider: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
provider.Token = &oauth2.Token{
|
||||
AccessToken: sessionValue,
|
||||
}
|
||||
|
||||
email, emailErr := hooks.Providers.GetUser(sessionType)
|
||||
|
||||
if emailErr != nil {
|
||||
return types.UserContext{
|
||||
Email: "",
|
||||
IsLoggedIn: false,
|
||||
@@ -76,9 +106,9 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
|
||||
}
|
||||
|
||||
return types.UserContext{
|
||||
Email: oauthEmail,
|
||||
Email: email,
|
||||
IsLoggedIn: true,
|
||||
OAuth: true,
|
||||
Provider: provider,
|
||||
Provider: sessionType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -30,14 +30,14 @@ func (oauth *OAuth) GetAuthURL() string {
|
||||
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
||||
}
|
||||
|
||||
func (oauth *OAuth) ExchangeToken(code string) error {
|
||||
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
||||
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to exchange code")
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
oauth.Token = token
|
||||
return nil
|
||||
return oauth.Token.AccessToken, nil
|
||||
}
|
||||
|
||||
func (oauth *OAuth) GetClient() *http.Client {
|
||||
|
||||
@@ -35,24 +35,12 @@ func (providers *Providers) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
func (providers *Providers) Login(code string, provider string) (string, error) {
|
||||
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
|
||||
switch provider {
|
||||
case "github":
|
||||
if providers.Github == nil {
|
||||
return "", nil
|
||||
}
|
||||
exchangeErr := providers.Github.ExchangeToken(code)
|
||||
if exchangeErr != nil {
|
||||
return "", exchangeErr
|
||||
}
|
||||
client := providers.Github.GetClient()
|
||||
email, emailErr := GetGithubEmail(client)
|
||||
if emailErr != nil {
|
||||
return "", emailErr
|
||||
}
|
||||
return email, nil
|
||||
return providers.Github
|
||||
default:
|
||||
return "", nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,14 +61,10 @@ func (providers *Providers) GetUser(provider string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (providers *Providers) GetAuthURL(provider string) string {
|
||||
switch provider {
|
||||
case "github":
|
||||
if providers.Github == nil {
|
||||
return ""
|
||||
}
|
||||
return providers.Github.GetAuthURL()
|
||||
default:
|
||||
return ""
|
||||
func (provider *Providers) GetConfiguredProviders() []string {
|
||||
providers := []string{}
|
||||
if provider.Github != nil {
|
||||
providers = append(providers, "github")
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ type OAuthConfig struct {
|
||||
MicrosoftClientSecret string
|
||||
}
|
||||
|
||||
type OAuthBind struct {
|
||||
type OAuthRequest struct {
|
||||
Provider string `uri:"provider" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -67,8 +67,3 @@ type OAuthProviders struct {
|
||||
Google *oauth.OAuth
|
||||
Microsoft *oauth.OAuth
|
||||
}
|
||||
|
||||
type OAuthLogin struct {
|
||||
Email string
|
||||
Token string
|
||||
}
|
||||
|
||||
18
site/src/icons/github.tsx
Normal file
18
site/src/icons/github.tsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
export function GithubIcon(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={24}
|
||||
height={24}
|
||||
viewBox="0 0 24 24"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M12 2A10 10 0 0 0 2 12c0 4.42 2.87 8.17 6.84 9.5c.5.08.66-.23.66-.5v-1.69c-2.77.6-3.36-1.34-3.36-1.34c-.46-1.16-1.11-1.47-1.11-1.47c-.91-.62.07-.6.07-.6c1 .07 1.53 1.03 1.53 1.03c.87 1.52 2.34 1.07 2.91.83c.09-.65.35-1.09.63-1.34c-2.22-.25-4.55-1.11-4.55-4.92c0-1.11.38-2 1.03-2.71c-.1-.25-.45-1.29.1-2.64c0 0 .84-.27 2.75 1.02c.79-.22 1.65-.33 2.5-.33s1.71.11 2.5.33c1.91-1.29 2.75-1.02 2.75-1.02c.55 1.35.2 2.39.1 2.64c.65.71 1.03 1.6 1.03 2.71c0 3.82-2.34 4.66-4.57 4.91c.36.31.69.92.69 1.85V21c0 .27.16.59.67.5C19.14 20.16 22 16.42 22 12A10 10 0 0 0 12 2"
|
||||
></path>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
30
site/src/icons/google.tsx
Normal file
30
site/src/icons/google.tsx
Normal file
@@ -0,0 +1,30 @@
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
export function GoogleIcon(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={48}
|
||||
height={48}
|
||||
viewBox="0 0 48 48"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
fill="#ffc107"
|
||||
d="M43.611 20.083H42V20H24v8h11.303c-1.649 4.657-6.08 8-11.303 8c-6.627 0-12-5.373-12-12s5.373-12 12-12c3.059 0 5.842 1.154 7.961 3.039l5.657-5.657C34.046 6.053 29.268 4 24 4C12.955 4 4 12.955 4 24s8.955 20 20 20s20-8.955 20-20c0-1.341-.138-2.65-.389-3.917"
|
||||
></path>
|
||||
<path
|
||||
fill="#ff3d00"
|
||||
d="m6.306 14.691l6.571 4.819C14.655 15.108 18.961 12 24 12c3.059 0 5.842 1.154 7.961 3.039l5.657-5.657C34.046 6.053 29.268 4 24 4C16.318 4 9.656 8.337 6.306 14.691"
|
||||
></path>
|
||||
<path
|
||||
fill="#4caf50"
|
||||
d="M24 44c5.166 0 9.86-1.977 13.409-5.192l-6.19-5.238A11.9 11.9 0 0 1 24 36c-5.202 0-9.619-3.317-11.283-7.946l-6.522 5.025C9.505 39.556 16.227 44 24 44"
|
||||
></path>
|
||||
<path
|
||||
fill="#1976d2"
|
||||
d="M43.611 20.083H42V20H24v8h11.303a12.04 12.04 0 0 1-4.087 5.571l.003-.002l6.19 5.238C36.971 39.205 44 34 44 24c0-1.341-.138-2.65-.389-3.917"
|
||||
></path>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
18
site/src/icons/microsoft.tsx
Normal file
18
site/src/icons/microsoft.tsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
export function MicrosoftIcon(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={256}
|
||||
height={256}
|
||||
viewBox="0 0 256 256"
|
||||
{...props}
|
||||
>
|
||||
<path fill="#f1511b" d="M121.666 121.666H0V0h121.666z"></path>
|
||||
<path fill="#80cc28" d="M256 121.666H134.335V0H256z"></path>
|
||||
<path fill="#00adef" d="M121.663 256.002H0V134.336h121.663z"></path>
|
||||
<path fill="#fbbc09" d="M256 256.002H134.335V134.336H256z"></path>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,13 @@
|
||||
import { Button, Paper, PasswordInput, TextInput, Title } from "@mantine/core";
|
||||
import {
|
||||
Button,
|
||||
Paper,
|
||||
PasswordInput,
|
||||
TextInput,
|
||||
Title,
|
||||
Text,
|
||||
Divider,
|
||||
Grid,
|
||||
} from "@mantine/core";
|
||||
import { useForm, zodResolver } from "@mantine/form";
|
||||
import { notifications } from "@mantine/notifications";
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
@@ -7,13 +16,16 @@ import { z } from "zod";
|
||||
import { useUserContext } from "../context/user-context";
|
||||
import { Navigate } from "react-router";
|
||||
import { Layout } from "../components/layouts/layout";
|
||||
import { GoogleIcon } from "../icons/google";
|
||||
import { MicrosoftIcon } from "../icons/microsoft";
|
||||
import { GithubIcon } from "../icons/github";
|
||||
|
||||
export const LoginPage = () => {
|
||||
const queryString = window.location.search;
|
||||
const params = new URLSearchParams(queryString);
|
||||
const redirectUri = params.get("redirect_uri");
|
||||
|
||||
const { isLoggedIn } = useUserContext();
|
||||
const { isLoggedIn, configuredProviders } = useUserContext();
|
||||
|
||||
if (isLoggedIn) {
|
||||
return <Navigate to="/logout" />;
|
||||
@@ -58,14 +70,83 @@ export const LoginPage = () => {
|
||||
},
|
||||
});
|
||||
|
||||
const loginOAuthMutation = useMutation({
|
||||
mutationFn: (provider: string) => {
|
||||
return axios.get(
|
||||
`/api/oauth/url/${provider}?redirect_uri=${redirectUri}`,
|
||||
);
|
||||
},
|
||||
onError: () => {
|
||||
notifications.show({
|
||||
title: "Internal error",
|
||||
message: "Failed to get OAuth URL",
|
||||
color: "red",
|
||||
});
|
||||
},
|
||||
onSuccess: (data) => {
|
||||
window.location.replace(data.data.url);
|
||||
},
|
||||
});
|
||||
|
||||
const handleSubmit = (values: FormValues) => {
|
||||
loginMutation.mutate(values);
|
||||
};
|
||||
|
||||
return (
|
||||
<Layout>
|
||||
<Title ta="center">Welcome back!</Title>
|
||||
<Paper shadow="md" p={30} mt={30} radius="md" withBorder>
|
||||
<Title ta="center">Tinyauth</Title>
|
||||
<Paper shadow="md" p="xl" mt={30} radius="md" withBorder>
|
||||
<Text size="lg" fw={500} ta="center">
|
||||
Welcome back, login with
|
||||
</Text>
|
||||
<Grid mb="md" mt="md" align="center" justify="center">
|
||||
{configuredProviders.includes("google") && (
|
||||
<Grid.Col span="content">
|
||||
<Button
|
||||
radius="xl"
|
||||
leftSection={<GoogleIcon style={{ width: 14, height: 14 }} />}
|
||||
variant="default"
|
||||
onClick={() => loginOAuthMutation.mutate("google")}
|
||||
loading={loginOAuthMutation.isLoading}
|
||||
>
|
||||
Google
|
||||
</Button>
|
||||
</Grid.Col>
|
||||
)}
|
||||
{configuredProviders.includes("microsoft") && (
|
||||
<Grid.Col span="content">
|
||||
<Button
|
||||
radius="xl"
|
||||
leftSection={
|
||||
<MicrosoftIcon style={{ width: 14, height: 14 }} />
|
||||
}
|
||||
variant="default"
|
||||
onClick={() => loginOAuthMutation.mutate("microsoft")}
|
||||
loading={loginOAuthMutation.isLoading}
|
||||
>
|
||||
Microsoft
|
||||
</Button>
|
||||
</Grid.Col>
|
||||
)}
|
||||
{configuredProviders.includes("github") && (
|
||||
<Grid.Col span="content">
|
||||
<Button
|
||||
radius="xl"
|
||||
leftSection={<GithubIcon style={{ width: 14, height: 14 }} />}
|
||||
variant="default"
|
||||
onClick={() => loginOAuthMutation.mutate("github")}
|
||||
loading={loginOAuthMutation.isLoading}
|
||||
>
|
||||
Github
|
||||
</Button>
|
||||
</Grid.Col>
|
||||
)}
|
||||
</Grid>
|
||||
<Divider
|
||||
label="Or continue with email"
|
||||
labelPosition="center"
|
||||
my="lg"
|
||||
/>
|
||||
<form onSubmit={form.onSubmit(handleSubmit)}>
|
||||
<TextInput
|
||||
label="Email"
|
||||
|
||||
@@ -5,6 +5,7 @@ export const userContextSchema = z.object({
|
||||
email: z.string(),
|
||||
oauth: z.boolean(),
|
||||
provider: z.string(),
|
||||
configuredProviders: z.array(z.string()),
|
||||
});
|
||||
|
||||
export type UserContextSchemaType = z.infer<typeof userContextSchema>;
|
||||
|
||||
Reference in New Issue
Block a user