mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-19 18:00:22 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e553383a91 |
@@ -13,7 +13,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Delete old release
|
- name: Delete old release
|
||||||
run: gh release delete --cleanup-tag --yes nightly || echo release not found
|
run: gh release delete --cleanup-tag --yes nightly || echo release not found
|
||||||
@@ -37,7 +37,7 @@ jobs:
|
|||||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -100,7 +100,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ jobs:
|
|||||||
- image-build
|
- image-build
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -261,7 +261,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ jobs:
|
|||||||
- image-build-arm
|
- image-build-arm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
with:
|
with:
|
||||||
ref: nightly
|
ref: nightly
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ jobs:
|
|||||||
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
BUILD_TIMESTAMP: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Generate metadata
|
- name: Generate metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
@@ -75,7 +75,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
uses: pnpm/action-setup@0ebf47130e4866e96fce0953f49152a61190b271 # v6.0.9
|
||||||
@@ -117,7 +117,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -173,7 +173,7 @@ jobs:
|
|||||||
- image-build
|
- image-build
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -229,7 +229,7 @@ jobs:
|
|||||||
- generate-metadata
|
- generate-metadata
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
@@ -285,7 +285,7 @@ jobs:
|
|||||||
- image-build-arm
|
- image-build-arm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Docker meta
|
- name: Docker meta
|
||||||
id: meta
|
id: meta
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||||
|
|
||||||
- name: Generate Sponsors
|
- name: Generate Sponsors
|
||||||
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
|
uses: JamesIves/github-sponsors-readme-action@2fd9142e765f755780202122261dc85e78459405 # v1
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ type ScreenParams = {
|
|||||||
oidc_ticket?: string;
|
oidc_ticket?: string;
|
||||||
oidc_scope?: string;
|
oidc_scope?: string;
|
||||||
oidc_name?: string;
|
oidc_name?: string;
|
||||||
oidc_prompt?: "none" | "login";
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const zodScreenParams = z.object({
|
const zodScreenParams = z.object({
|
||||||
@@ -15,7 +14,6 @@ const zodScreenParams = z.object({
|
|||||||
oidc_ticket: z.string().optional(),
|
oidc_ticket: z.string().optional(),
|
||||||
oidc_scope: z.string().optional(),
|
oidc_scope: z.string().optional(),
|
||||||
oidc_name: z.string().optional(),
|
oidc_name: z.string().optional(),
|
||||||
oidc_prompt: z.enum(["none", "login"]).optional(),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
export function useScreenParams(params: URLSearchParams): ScreenParams {
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import {
|
|||||||
recompileScreenParams,
|
recompileScreenParams,
|
||||||
useScreenParams,
|
useScreenParams,
|
||||||
} from "@/lib/hooks/screen-params";
|
} from "@/lib/hooks/screen-params";
|
||||||
import { useEffect } from "react";
|
|
||||||
|
|
||||||
type Scope = {
|
type Scope = {
|
||||||
id: string;
|
id: string;
|
||||||
@@ -91,15 +90,7 @@ export const AuthorizePage = () => {
|
|||||||
const isOidc = screenParams.login_for === "oidc";
|
const isOidc = screenParams.login_for === "oidc";
|
||||||
const compiledParams = recompileScreenParams(screenParams);
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
|
|
||||||
// TODO: maybe a better way to do this
|
const authorizeMutation = useMutation({
|
||||||
const shouldAutoAuthorize =
|
|
||||||
auth.authenticated &&
|
|
||||||
isOidc &&
|
|
||||||
screenParams.oidc_ticket !== undefined &&
|
|
||||||
screenParams.oidc_scope !== undefined &&
|
|
||||||
screenParams.oidc_prompt === "none";
|
|
||||||
|
|
||||||
const { mutate: authorizeMutate, isPending: authorizePending } = useMutation({
|
|
||||||
mutationFn: () => {
|
mutationFn: () => {
|
||||||
return axios.post("/api/oidc/authorize-complete", {
|
return axios.post("/api/oidc/authorize-complete", {
|
||||||
ticket: screenParams.oidc_ticket,
|
ticket: screenParams.oidc_ticket,
|
||||||
@@ -119,12 +110,6 @@ export const AuthorizePage = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (shouldAutoAuthorize) {
|
|
||||||
authorizeMutate();
|
|
||||||
}
|
|
||||||
}, [shouldAutoAuthorize, authorizeMutate]);
|
|
||||||
|
|
||||||
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
|
||||||
return (
|
return (
|
||||||
<Navigate
|
<Navigate
|
||||||
@@ -134,7 +119,7 @@ export const AuthorizePage = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!auth.authenticated || screenParams.oidc_prompt === "login") {
|
if (!auth.authenticated) {
|
||||||
return <Navigate to={`/login${compiledParams}`} replace />;
|
return <Navigate to={`/login${compiledParams}`} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,15 +168,14 @@ export const AuthorizePage = () => {
|
|||||||
)}
|
)}
|
||||||
<CardFooter className="flex flex-col items-stretch gap-3">
|
<CardFooter className="flex flex-col items-stretch gap-3">
|
||||||
<Button
|
<Button
|
||||||
onClick={() => authorizeMutate()}
|
onClick={() => authorizeMutation.mutate()}
|
||||||
loading={authorizePending}
|
loading={authorizeMutation.isPending}
|
||||||
disabled={shouldAutoAuthorize}
|
|
||||||
>
|
>
|
||||||
{t("authorizeTitle")}
|
{t("authorizeTitle")}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
onClick={() => navigate(`/logout${compiledParams}`)}
|
onClick={() => navigate(`/logout${compiledParams}`)}
|
||||||
disabled={authorizePending || shouldAutoAuthorize}
|
disabled={authorizeMutation.isPending}
|
||||||
variant="outline"
|
variant="outline"
|
||||||
>
|
>
|
||||||
{t("cancelTitle")}
|
{t("cancelTitle")}
|
||||||
|
|||||||
@@ -63,10 +63,7 @@ export const LoginPage = () => {
|
|||||||
|
|
||||||
const searchParams = new URLSearchParams(search);
|
const searchParams = new URLSearchParams(search);
|
||||||
const screenParams = useScreenParams(searchParams);
|
const screenParams = useScreenParams(searchParams);
|
||||||
const compiledParams = recompileScreenParams({
|
const compiledParams = recompileScreenParams(screenParams);
|
||||||
...screenParams,
|
|
||||||
oidc_prompt: undefined,
|
|
||||||
});
|
|
||||||
const loginForUrl = useLoginFor({
|
const loginForUrl = useLoginFor({
|
||||||
login_for: screenParams.login_for,
|
login_for: screenParams.login_for,
|
||||||
compiledParams,
|
compiledParams,
|
||||||
@@ -199,7 +196,7 @@ export const LoginPage = () => {
|
|||||||
};
|
};
|
||||||
}, [redirectTimer, redirectButtonTimer]);
|
}, [redirectTimer, redirectButtonTimer]);
|
||||||
|
|
||||||
if (auth.authenticated && screenParams.oidc_prompt !== "login") {
|
if (auth.authenticated) {
|
||||||
return <Navigate to={loginForUrl} replace />;
|
return <Navigate to={loginForUrl} replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gin-gonic/gin/binding"
|
"github.com/gin-gonic/gin/binding"
|
||||||
@@ -71,11 +69,10 @@ type ClientCredentials struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeScreenParams struct {
|
type AuthorizeScreenParams struct {
|
||||||
LoginFor FrontendLoginFor `url:"login_for"`
|
LoginFor FrontendLoginFor `url:"login_for"`
|
||||||
OIDCTicket string `url:"oidc_ticket"`
|
OIDCTicket string `url:"oidc_ticket"`
|
||||||
OIDCScope string `url:"oidc_scope"`
|
OIDCScope string `url:"oidc_scope"`
|
||||||
OIDCName string `url:"oidc_name"`
|
OIDCName string `url:"oidc_name"`
|
||||||
OIDCPrompt service.OIDCPrompt `url:"oidc_prompt,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCompleteRequest struct {
|
type AuthorizeCompleteRequest struct {
|
||||||
@@ -170,87 +167,20 @@ func (controller *OIDCController) authorize(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts := controller.oidc.GetPrompt(req.Prompt)
|
|
||||||
|
|
||||||
if slices.Contains(prompts, service.OIDCPromptNone) && len(prompts) > 1 {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("invalid prompt"),
|
|
||||||
reason: "Invalid prompt",
|
|
||||||
reasonPublic: "The prompt parameters are invalid",
|
|
||||||
callback: req.RedirectURI,
|
|
||||||
callbackError: "invalid_request",
|
|
||||||
state: req.State,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (err != nil || !userContext.Authenticated) && slices.Contains(prompts, service.OIDCPromptNone) {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: errors.New("user not logged in"),
|
|
||||||
reason: "User not logged in",
|
|
||||||
reasonPublic: "The user is not logged in",
|
|
||||||
callback: req.RedirectURI,
|
|
||||||
callbackError: "login_required",
|
|
||||||
state: req.State,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
|
||||||
|
|
||||||
values := AuthorizeScreenParams{
|
queries, err := query.Values(AuthorizeScreenParams{
|
||||||
LoginFor: FrontendLoginForOIDC,
|
LoginFor: FrontendLoginForOIDC,
|
||||||
OIDCTicket: ticket,
|
OIDCTicket: ticket,
|
||||||
OIDCScope: req.Scope,
|
OIDCScope: req.Scope,
|
||||||
OIDCName: client.Name,
|
OIDCName: client.Name,
|
||||||
}
|
})
|
||||||
|
|
||||||
if slices.Contains(prompts, service.OIDCPromptLogin) {
|
|
||||||
values.OIDCPrompt = service.OIDCPromptLogin
|
|
||||||
} else if slices.Contains(prompts, service.OIDCPromptNone) {
|
|
||||||
values.OIDCPrompt = service.OIDCPromptNone
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.MaxAge != "" && userContext != nil {
|
|
||||||
maxAge, err := strconv.Atoi(req.MaxAge)
|
|
||||||
if err != nil {
|
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
|
||||||
err: err,
|
|
||||||
reason: "Invalid max_age",
|
|
||||||
reasonPublic: "The max_age parameter is invalid",
|
|
||||||
callback: req.RedirectURI,
|
|
||||||
callbackError: "invalid_request",
|
|
||||||
state: req.State,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if userContext.Authenticated {
|
|
||||||
authTime := time.Unix(userContext.AuthTime, 0)
|
|
||||||
if authTime.Add(time.Duration(maxAge) * time.Second).Before(time.Now()) {
|
|
||||||
values.OIDCPrompt = service.OIDCPromptLogin
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
queries, err := query.Values(values)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: err,
|
err: err,
|
||||||
reason: "Failed to compile authorize queries",
|
reason: "Failed to compile authorize queries",
|
||||||
reasonPublic: "An internal error occured while processing your request",
|
reasonPublic: "An internal error occured while processing your request",
|
||||||
callback: req.RedirectURI,
|
|
||||||
callbackError: "server_error",
|
|
||||||
state: req.State,
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -278,12 +208,16 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
|
|||||||
userContext, err := new(model.UserContext).NewFromGin(c)
|
userContext, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, model.ErrUserContextNotFound) {
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
controller.log.App.Warn().Err(err).Msg("Failed to get user context")
|
err: err,
|
||||||
}
|
reason: "Failed to get user context",
|
||||||
|
reasonPublic: "User is not logged in or the session is invalid",
|
||||||
|
json: true,
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil || !userContext.Authenticated {
|
if !userContext.Authenticated {
|
||||||
controller.authorizeError(c, authorizeErrorParams{
|
controller.authorizeError(c, authorizeErrorParams{
|
||||||
err: errors.New("err user not logged in"),
|
err: errors.New("err user not logged in"),
|
||||||
reason: "User not logged in",
|
reason: "User not logged in",
|
||||||
@@ -491,7 +425,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry, entry.AuthTime)
|
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
|
|||||||
@@ -209,26 +209,6 @@ func TestOIDCController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// --- authorize-complete ---
|
// --- authorize-complete ---
|
||||||
{
|
|
||||||
description: "Should fail if oidc is disabled",
|
|
||||||
oidcDisabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
body, err := json.Marshal(AuthorizeCompleteRequest{Ticket: "some-ticket"})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/oidc/authorize-complete", strings.NewReader(string(body)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
|
|
||||||
var res map[string]any
|
|
||||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &res))
|
|
||||||
redirectURI, ok := res["redirect_uri"].(string)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Contains(t, redirectURI, oidcService.GetIssuer()+"/error")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Authorize complete returns a JSON error when the user context is missing",
|
description: "Authorize complete returns a JSON error when the user context is missing",
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +207,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,7 +251,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -300,7 +300,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
|
||||||
@@ -336,7 +336,7 @@ func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyCon
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusFound, redirectURL)
|
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
|
||||||
|
|||||||
@@ -2,9 +2,6 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -66,17 +63,6 @@ func TestProxyController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
|
||||||
description: "Should get bad request on invalid proxy",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/invalid", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Bad request")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Default forward auth should be detected and used for traefik",
|
description: "Default forward auth should be detected and used for traefik",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -88,7 +74,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -103,7 +89,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -119,7 +105,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/hello"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -137,7 +123,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -154,7 +140,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
location := recorder.Header().Get("x-tinyauth-location")
|
location := recorder.Header().Get("x-tinyauth-location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -172,7 +158,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
req.Header.Set("user-agent", browserUserAgent)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
assert.Equal(t, 307, recorder.Code)
|
||||||
location := recorder.Header().Get("Location")
|
location := recorder.Header().Get("Location")
|
||||||
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
assert.Contains(t, location, url.QueryEscape("https://test.example.com/"))
|
||||||
assert.Contains(t, location, "login_for=app")
|
assert.Contains(t, location, "login_for=app")
|
||||||
@@ -189,7 +175,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -204,7 +190,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -219,7 +205,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/hello")
|
req.Header.Set("x-forwarded-uri", "/hello")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
|
assert.Equal(t, 401, recorder.Code)
|
||||||
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
assert.Contains(t, recorder.Body.String(), `"status":401`)
|
||||||
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
assert.Contains(t, recorder.Body.String(), `"message":"Unauthorized"`)
|
||||||
},
|
},
|
||||||
@@ -236,7 +222,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -252,7 +238,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://test.example.com/")
|
req.Header.Set("x-original-url", "https://test.example.com/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -269,7 +255,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
||||||
@@ -284,7 +270,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/allowed")
|
req.Header.Set("x-forwarded-uri", "/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -294,7 +280,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
req := httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||||
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
req.Header.Set("x-original-url", "https://path-allow.example.com/allowed")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -305,7 +291,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Host = "path-allow.example.com"
|
req.Host = "path-allow.example.com"
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -318,7 +304,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -329,7 +315,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
req.Header.Set("x-original-url", "https://ip-bypass.example.com/")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -341,7 +327,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -355,7 +341,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
assert.Equal(t, 200, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -369,301 +355,12 @@ func TestProxyController(t *testing.T) {
|
|||||||
req.Header.Set("x-forwarded-proto", "https")
|
req.Header.Set("x-forwarded-proto", "https")
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
req.Header.Set("x-forwarded-uri", "/")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
assert.Equal(t, 403, recorder.Code)
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Test IP block rule, with non browser user agent",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip=10.10.10.10")
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ip-block")
|
|
||||||
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Test IP block rule, with browser user agent",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ip-block.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("x-forwarded-for", "10.10.10.10")
|
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.Contains(t, location, url.QueryEscape("10.10.10.10"))
|
|
||||||
assert.Contains(t, location, url.QueryEscape("ip-block"))
|
|
||||||
assert.Contains(t, location, runtime.AppURL)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "OAuth allowed group",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderOAuth,
|
|
||||||
OAuth: &model.OAuthContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "OAuth not in required groups and non browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderOAuth,
|
|
||||||
OAuth: &model.OAuthContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "oauth-group")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "OAuth not in required groups and browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderOAuth,
|
|
||||||
OAuth: &model.OAuthContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "oauth-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.Contains(t, location, "groupErr=true")
|
|
||||||
assert.Contains(t, location, "oauth-group")
|
|
||||||
assert.Contains(t, location, runtime.AppURL)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "LDAP allowed group",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderLDAP,
|
|
||||||
LDAP: &model.LDAPContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, "testuser", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "Testuser", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "testuser@example.com", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "group1", recorder.Header().Get("remote-groups"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "LDAP not in required groups and non browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderLDAP,
|
|
||||||
LDAP: &model.LDAPContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusForbidden, recorder.Code)
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-user"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-name"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-email"))
|
|
||||||
assert.Equal(t, "", recorder.Header().Get("remote-groups"))
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), runtime.AppURL)
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "groupErr=true")
|
|
||||||
assert.Contains(t, recorder.Header().Get("x-tinyauth-location"), "ldap-group")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "LDAP not in required groups and browser",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: true,
|
|
||||||
Provider: model.ProviderLDAP,
|
|
||||||
LDAP: &model.LDAPContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "testuser",
|
|
||||||
Name: "Testuser",
|
|
||||||
Email: "testuser@example.com",
|
|
||||||
},
|
|
||||||
Groups: []string{"group3"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "ldap-group.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("user-agent", browserUserAgent)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.Contains(t, location, "groupErr=true")
|
|
||||||
assert.Contains(t, location, "ldap-group")
|
|
||||||
assert.Contains(t, location, runtime.AppURL)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should add basic auth if it's in ACLs",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "basic-auth.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("authorization", "foo") // should be overridden by basic auth
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
authorizationHeader := recorder.Header().Get("Authorization")
|
|
||||||
assert.NotEmpty(t, authorizationHeader)
|
|
||||||
assert.Equal(t, fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte("test:password"))), authorizationHeader)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Authorization header should be preserved when not basic auth acls",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "test.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
req.Header.Set("authorization", "Bearer mytoken")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
authorizationHeader := recorder.Header().Get("Authorization")
|
|
||||||
assert.NotEmpty(t, authorizationHeader)
|
|
||||||
assert.Equal(t, "Bearer mytoken", authorizationHeader)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should add response headers if present",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
|
||||||
req.Header.Set("x-forwarded-host", "response-headers.example.com")
|
|
||||||
req.Header.Set("x-forwarded-proto", "https")
|
|
||||||
req.Header.Set("x-forwarded-uri", "/")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
|
||||||
assert.Equal(t, "bar", recorder.Header().Get("x-foo"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,12 +18,8 @@ func TestResourcesController(t *testing.T) {
|
|||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create a "backup" of the original configuration to restore after each test
|
|
||||||
originalCfg := cfg.Resources
|
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
customCfg *model.ResourcesConfig
|
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,32 +52,6 @@ func TestResourcesController(t *testing.T) {
|
|||||||
assert.Equal(t, 404, recorder.Code)
|
assert.Equal(t, 404, recorder.Code)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Ensure resources controller returns 404 when resources path is empty",
|
|
||||||
customCfg: &model.ResourcesConfig{
|
|
||||||
Path: "",
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 404, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure resources controller returns 403 when resources are disabled",
|
|
||||||
customCfg: &model.ResourcesConfig{
|
|
||||||
Path: cfg.Resources.Path,
|
|
||||||
Enabled: false,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/resources/testfile.txt", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 403, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
testFilePath := cfg.Resources.Path + "/testfile.txt"
|
||||||
@@ -99,14 +68,6 @@ func TestResourcesController(t *testing.T) {
|
|||||||
group := router.Group("/")
|
group := router.Group("/")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
// if custom configuration is provided, override the default config
|
|
||||||
if test.customCfg != nil {
|
|
||||||
cfg.Resources = *test.customCfg
|
|
||||||
} else {
|
|
||||||
// Reset to default configuration for each test
|
|
||||||
cfg.Resources = originalCfg
|
|
||||||
}
|
|
||||||
|
|
||||||
NewResourcesController(ResourcesControllerInput{
|
NewResourcesController(ResourcesControllerInput{
|
||||||
RouterGroup: group,
|
RouterGroup: group,
|
||||||
Config: &cfg,
|
Config: &cfg,
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
@@ -57,7 +56,6 @@ func TestUserController(t *testing.T) {
|
|||||||
TOTPPending: true,
|
TOTPPending: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
@@ -72,7 +70,6 @@ func TestUserController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Next()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
@@ -84,40 +81,6 @@ func TestUserController(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
|
||||||
description: "Login should fail gracefully on invalid json",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(`{"username": "testuser", "password":`))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Should fail on missing user",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
loginReq := LoginRequest{
|
|
||||||
Username: "nonexistentuser",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
loginReqBody, err := json.Marshal(loginReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
assert.Len(t, recorder.Result().Cookies(), 0)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Should be able to login with valid credentials",
|
description: "Should be able to login with valid credentials",
|
||||||
middlewares: []gin.HandlerFunc{},
|
middlewares: []gin.HandlerFunc{},
|
||||||
@@ -279,87 +242,6 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Logout should be treated as valid without a session cookie",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/logout", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "TOTP should gracefully reject invalid json",
|
|
||||||
middlewares: []gin.HandlerFunc{},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(`{"code":`))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Bad Request")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "TOTP should fail on non-totp context",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
simpleCtx,
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
totpReq := TotpRequest{
|
|
||||||
Code: "123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "TOTP should fail when user in context doesn't exist",
|
|
||||||
middlewares: []gin.HandlerFunc{
|
|
||||||
func(ctx *gin.Context) {
|
|
||||||
ctx.Set("context", &model.UserContext{
|
|
||||||
Authenticated: false,
|
|
||||||
Provider: model.ProviderLocal,
|
|
||||||
Local: &model.LocalContext{
|
|
||||||
BaseContext: model.BaseContext{
|
|
||||||
Username: "idontexist",
|
|
||||||
Name: "Totpuser",
|
|
||||||
Email: "totpuser@example.com",
|
|
||||||
},
|
|
||||||
TOTPPending: true,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
ctx.Next()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
totpReq := TotpRequest{
|
|
||||||
Code: "123456",
|
|
||||||
}
|
|
||||||
|
|
||||||
totpReqBody, err := json.Marshal(totpReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
recorder = httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "Unauthorized")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
description: "Should be able to login with totp",
|
description: "Should be able to login with totp",
|
||||||
middlewares: []gin.HandlerFunc{
|
middlewares: []gin.HandlerFunc{
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -26,14 +25,12 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
oidcEnabled bool
|
|
||||||
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
description: "Ensure well-known endpoint returns correct OIDC configuration",
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -42,7 +39,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
res := OpenIDConnectConfiguration{}
|
res := OpenIDConnectConfiguration{}
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
err := json.Unmarshal(recorder.Body.Bytes(), &res)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
expected := OpenIDConnectConfiguration{
|
expected := OpenIDConnectConfiguration{
|
||||||
Issuer: runtime.AppURL,
|
Issuer: runtime.AppURL,
|
||||||
@@ -58,8 +55,8 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
|
||||||
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
|
||||||
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
|
||||||
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
|
||||||
RequestParameterSupported: true,
|
RequestParameterSupported: true,
|
||||||
|
RequestObjectSigningAlgValuesSupported: []string{"none"},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, res)
|
assert.Equal(t, expected, res)
|
||||||
@@ -67,7 +64,6 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Ensure well-known endpoint returns correct JWKS",
|
description: "Ensure well-known endpoint returns correct JWKS",
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
@@ -76,204 +72,19 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
decodedBody := make(map[string]any)
|
decodedBody := make(map[string]any)
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
||||||
require.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
keys, ok := decodedBody["keys"].([]any)
|
keys, ok := decodedBody["keys"].([]any)
|
||||||
require.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Len(t, keys, 1)
|
assert.Len(t, keys, 1)
|
||||||
|
|
||||||
keyData, ok := keys[0].(map[string]any)
|
keyData, ok := keys[0].(map[string]any)
|
||||||
require.True(t, ok)
|
assert.True(t, ok)
|
||||||
assert.Equal(t, "RSA", keyData["kty"])
|
assert.Equal(t, "RSA", keyData["kty"])
|
||||||
assert.Equal(t, "sig", keyData["use"])
|
assert.Equal(t, "sig", keyData["use"])
|
||||||
assert.Equal(t, "RS256", keyData["alg"])
|
assert.Equal(t, "RS256", keyData["alg"])
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
description: "Ensure openid configuration returns 500 on nil oidc service",
|
|
||||||
oidcEnabled: false,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 500, recorder.Code)
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure jwks endpoint returns 500 on nil oidc service",
|
|
||||||
oidcEnabled: false,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/jwks.json", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 500, recorder.Code)
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "OIDC service not configured", decodedBody["message"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger returns 400 on invalid resource",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/webfinger?resource=invalid-resource", nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 400, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "invalid resource", decodedBody["message"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger resource validator allows acct",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger resource validator allows https",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "https://example.com/testuser"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Ensure webfinger resource validator allows http",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "http://example.com/testuser"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return no links when oidc is nil",
|
|
||||||
oidcEnabled: false,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 0)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return links when oidc is configured and no rel is provided",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s", url.QueryEscape(resource)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 1)
|
|
||||||
|
|
||||||
linkData, ok := links[0].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, "http://openid.net/specs/connect/1.0/issuer", linkData["rel"])
|
|
||||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return links when oidc is configured and rel is provided",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := fmt.Sprintf("acct:%s@%s", "testuser", runtime.AppURL)
|
|
||||||
rel := "http://openid.net/specs/connect/1.0/issuer"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 1)
|
|
||||||
|
|
||||||
linkData, ok := links[0].(map[string]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, rel, linkData["rel"])
|
|
||||||
assert.Equal(t, runtime.AppURL, linkData["href"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "Webfinger should return no links when oidc is configured and rel is provided but does not match",
|
|
||||||
oidcEnabled: true,
|
|
||||||
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
|
|
||||||
resource := "acct:testuser@example.com"
|
|
||||||
rel := "http://example.com/does-not-exist"
|
|
||||||
req := httptest.NewRequest("GET", fmt.Sprintf("/.well-known/webfinger?resource=%s&rel=%s", url.QueryEscape(resource), url.QueryEscape(rel)), nil)
|
|
||||||
router.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, 200, recorder.Code)
|
|
||||||
assert.Equal(t, "application/jrd+json", recorder.Header().Get("content-type"))
|
|
||||||
assert.Equal(t, "*", recorder.Header().Get("access-control-allow-origin"))
|
|
||||||
|
|
||||||
decodedBody := make(map[string]any)
|
|
||||||
err := json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
links, ok := decodedBody["links"].([]any)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Len(t, links, 0)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
@@ -297,15 +108,10 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
wellKnownControllerInput := WellKnownControllerInput{
|
NewWellKnownController(WellKnownControllerInput{
|
||||||
|
OIDCService: oidcService,
|
||||||
RouterGroup: &router.RouterGroup,
|
RouterGroup: &router.RouterGroup,
|
||||||
}
|
})
|
||||||
|
|
||||||
if test.oidcEnabled {
|
|
||||||
wellKnownControllerInput.OIDCService = oidcService
|
|
||||||
}
|
|
||||||
|
|
||||||
NewWellKnownController(wellKnownControllerInput)
|
|
||||||
|
|
||||||
test.run(t, router, recorder)
|
test.run(t, router, recorder)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ const (
|
|||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Authenticated bool
|
Authenticated bool
|
||||||
Provider ProviderType
|
Provider ProviderType
|
||||||
AuthTime int64
|
|
||||||
Local *LocalContext
|
Local *LocalContext
|
||||||
OAuth *OAuthContext
|
OAuth *OAuthContext
|
||||||
LDAP *LDAPContext
|
LDAP *LDAPContext
|
||||||
@@ -111,7 +110,6 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
*c = UserContext{
|
*c = UserContext{
|
||||||
Authenticated: !session.TotpPending,
|
Authenticated: !session.TotpPending,
|
||||||
AuthTime: session.CreatedAt,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch session.Provider {
|
switch session.Provider {
|
||||||
|
|||||||
@@ -44,15 +44,6 @@ var (
|
|||||||
ErrInvalidClient = errors.New("invalid_client")
|
ErrInvalidClient = errors.New("invalid_client")
|
||||||
)
|
)
|
||||||
|
|
||||||
type OIDCPrompt string
|
|
||||||
|
|
||||||
const (
|
|
||||||
OIDCPromptLogin OIDCPrompt = "login"
|
|
||||||
OIDCPromptNone OIDCPrompt = "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
|
||||||
|
|
||||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||||
@@ -63,7 +54,6 @@ type ClaimSet struct {
|
|||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Iat int64 `json:"iat"`
|
Iat int64 `json:"iat"`
|
||||||
Exp int64 `json:"exp"`
|
Exp int64 `json:"exp"`
|
||||||
AuthTime int64 `json:"auth_time,omitempty"`
|
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
@@ -127,8 +117,6 @@ type AuthorizeRequest struct {
|
|||||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||||
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
|
||||||
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCodeEntry struct {
|
type AuthorizeCodeEntry struct {
|
||||||
@@ -139,7 +127,6 @@ type AuthorizeCodeEntry struct {
|
|||||||
Nonce string
|
Nonce string
|
||||||
CodeChallenge string
|
CodeChallenge string
|
||||||
Userinfo UserinfoResponse
|
Userinfo UserinfoResponse
|
||||||
AuthTime int64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsedCodeEntry struct {
|
type UsedCodeEntry struct {
|
||||||
@@ -436,7 +423,6 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
|||||||
ClientID: req.ClientID,
|
ClientID: req.ClientID,
|
||||||
Nonce: req.Nonce,
|
Nonce: req.Nonce,
|
||||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||||
AuthTime: userContext.AuthTime,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.CodeChallenge != "" {
|
if req.CodeChallenge != "" {
|
||||||
@@ -526,7 +512,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
|||||||
return &entry, true
|
return &entry, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, auth_time int64) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -563,7 +549,6 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
Sub: user.Sub,
|
Sub: user.Sub,
|
||||||
Iat: createdAt,
|
Iat: createdAt,
|
||||||
Exp: expiresAt,
|
Exp: expiresAt,
|
||||||
AuthTime: auth_time,
|
|
||||||
Name: userInfo.Name,
|
Name: userInfo.Name,
|
||||||
Email: userInfo.Email,
|
Email: userInfo.Email,
|
||||||
EmailVerified: userInfo.EmailVerified,
|
EmailVerified: userInfo.EmailVerified,
|
||||||
@@ -593,8 +578,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
||||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, authTime)
|
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -675,7 +660,7 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
|||||||
|
|
||||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, userInfo, entry.Scope, entry.Nonce, 0) // auth_time is not available during refresh, so we set it to 0
|
}, userInfo, entry.Scope, entry.Nonce)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -944,24 +929,5 @@ func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRe
|
|||||||
Nonce: get("nonce"),
|
Nonce: get("nonce"),
|
||||||
CodeChallenge: get("code_challenge"),
|
CodeChallenge: get("code_challenge"),
|
||||||
CodeChallengeMethod: get("code_challenge_method"),
|
CodeChallengeMethod: get("code_challenge_method"),
|
||||||
Prompt: get("prompt"),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
|
||||||
if prompt == "" {
|
|
||||||
return []OIDCPrompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedPromps := make([]OIDCPrompt, 0)
|
|
||||||
prompts := strings.SplitSeq(prompt, " ")
|
|
||||||
|
|
||||||
for p := range prompts {
|
|
||||||
if !slices.Contains(SupportedPrompts, p) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsedPromps
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -76,50 +76,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
Bypass: []string{"10.10.10.10"},
|
Bypass: []string{"10.10.10.10"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ip_block": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "ip-block.example.com",
|
|
||||||
},
|
|
||||||
IP: model.AppIP{
|
|
||||||
Block: []string{"10.10.10.10"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"oauth_group": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "oauth-group.example.com",
|
|
||||||
},
|
|
||||||
OAuth: model.AppOAuth{
|
|
||||||
Whitelist: "testuser@example.com",
|
|
||||||
Groups: "group1,group2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"ldap_group": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "ldap-group.example.com",
|
|
||||||
},
|
|
||||||
LDAP: model.AppLDAP{
|
|
||||||
Groups: "group1,group2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"basic_auth": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "basic-auth.example.com",
|
|
||||||
},
|
|
||||||
Response: model.AppResponse{
|
|
||||||
BasicAuth: model.AppBasicAuth{
|
|
||||||
Username: "test",
|
|
||||||
Password: "password",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"response_headers": {
|
|
||||||
Config: model.AppConfig{
|
|
||||||
Domain: "response-headers.example.com",
|
|
||||||
},
|
|
||||||
Response: model.AppResponse{
|
|
||||||
Headers: []string{"x-foo=bar"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user