Compare commits

..

7 Commits

Author SHA1 Message Date
Stavros
def9e5aaaa chore: fix typo 2026-04-07 19:00:09 +03:00
Stavros
6ffb52a5cd tests: add test for invalid challenge method 2026-04-07 18:48:38 +03:00
Stavros
668348655f chore: remove simple logger from testing 2026-04-07 18:30:05 +03:00
Stavros
482b3c6b57 chore: remove debug line 2026-04-07 18:28:28 +03:00
Stavros
e451b3d62f fix: review comments 2026-04-07 18:27:45 +03:00
Stavros
5bada13919 tests: add test cases for pkce 2026-04-07 17:49:42 +03:00
Stavros
9a6676b054 feat: add pkce support to oidc server 2026-04-06 23:09:14 +03:00
7 changed files with 113 additions and 232 deletions

View File

@@ -1,40 +1,64 @@
import { z } from "zod"; export type OIDCValues = {
scope: string;
export const oidcParamsSchema = z.object({ response_type: string;
scope: z.string().nonempty(), client_id: string;
response_type: z.string().nonempty(), redirect_uri: string;
client_id: z.string().nonempty(), state: string;
redirect_uri: z.string().nonempty(), nonce: string;
state: z.string().optional(), code_challenge: string;
nonce: z.string().optional(), code_challenge_method: string;
code_challenge: z.string().optional(),
code_challenge_method: z.string().optional(),
});
export const useOIDCParams = (
params: URLSearchParams,
): {
values: z.infer<typeof oidcParamsSchema>;
issues: string[];
isOidc: boolean;
compiled: string;
} => {
const obj = Object.fromEntries(params.entries());
const parsed = oidcParamsSchema.safeParse(obj);
if (parsed.success) {
return {
values: parsed.data,
issues: [],
isOidc: true,
compiled: new URLSearchParams(parsed.data).toString(),
}; };
interface IuseOIDCParams {
values: OIDCValues;
compiled: string;
isOidc: boolean;
missingParams: string[];
}
const optionalParams: string[] = [
"state",
"nonce",
"code_challenge",
"code_challenge_method",
];
export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
let compiled: string = "";
let isOidc = false;
const missingParams: string[] = [];
const values: OIDCValues = {
scope: params.get("scope") ?? "",
response_type: params.get("response_type") ?? "",
client_id: params.get("client_id") ?? "",
redirect_uri: params.get("redirect_uri") ?? "",
state: params.get("state") ?? "",
nonce: params.get("nonce") ?? "",
code_challenge: params.get("code_challenge") ?? "",
code_challenge_method: params.get("code_challenge_method") ?? "",
};
for (const key of Object.keys(values)) {
if (!values[key as keyof OIDCValues]) {
if (!optionalParams.includes(key)) {
missingParams.push(key);
}
}
}
if (missingParams.length === 0) {
isOidc = true;
}
if (isOidc) {
compiled = new URLSearchParams(values).toString();
} }
return { return {
issues: parsed.error.issues.map((issue) => issue.path.toString()), values,
values: {} as z.infer<typeof oidcParamsSchema>, compiled,
isOidc: false, isOidc,
compiled: "", missingParams,
};
}; };
}

View File

@@ -72,27 +72,36 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t); const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const oidcParams = useOIDCParams(searchParams); const {
values: props,
missingParams,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
const getClientInfo = useQuery({ const getClientInfo = useQuery({
queryKey: ["client", oidcParams.values.client_id], queryKey: ["client", props.client_id],
queryFn: async () => { queryFn: async () => {
const res = await fetch( const res = await fetch(`/api/oidc/clients/${props.client_id}`);
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
);
const data = await getOidcClientInfoSchema.parseAsync(await res.json()); const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data; return data;
}, },
enabled: oidcParams.isOidc, enabled: isOidc,
}); });
const authorizeMutation = useMutation({ const authorizeMutation = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize", { return axios.post("/api/oidc/authorize", {
...oidcParams.values, scope: props.scope,
response_type: props.response_type,
client_id: props.client_id,
redirect_uri: props.redirect_uri,
state: props.state,
nonce: props.nonce,
}); });
}, },
mutationKey: ["authorize", oidcParams.values.client_id], mutationKey: ["authorize", props.client_id],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), { toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"), description: t("authorizeSuccessSubtitle"),
@@ -106,17 +115,17 @@ export const AuthorizePage = () => {
}, },
}); });
if (oidcParams.issues.length > 0) { if (missingParams.length > 0) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: missingParams.join(", ") }))}`}
replace replace
/> />
); );
} }
if (!isLoggedIn) { if (!isLoggedIn) {
return <Navigate to={`/login?${oidcParams.compiled}`} replace />; return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
} }
if (getClientInfo.isLoading) { if (getClientInfo.isLoading) {
@@ -143,9 +152,6 @@ export const AuthorizePage = () => {
); );
} }
const scopes =
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
return ( return (
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">

View File

@@ -51,12 +51,15 @@ export const LoginPage = () => {
const formId = useId(); const formId = useId();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const {
const oidcParams = useOIDCParams(searchParams); values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauthAutoRedirect) !== providers.find((provider) => provider.id === oauthAutoRedirect) !==
undefined && redirectUri !== undefined, undefined && props.redirect_uri,
); );
const oauthProviders = providers.filter( const oauthProviders = providers.filter(
@@ -75,7 +78,7 @@ export const LoginPage = () => {
} = useMutation({ } = useMutation({
mutationFn: (provider: string) => mutationFn: (provider: string) =>
axios.get( axios.get(
`/api/oauth/url/${provider}${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/api/oauth/url/${provider}${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
), ),
mutationKey: ["oauth"], mutationKey: ["oauth"],
onSuccess: (data) => { onSuccess: (data) => {
@@ -106,12 +109,8 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
if (oidcParams.isOidc) {
window.location.replace(`/totp?${oidcParams.compiled}`);
return;
}
window.location.replace( window.location.replace(
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/totp${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
); );
return; return;
} }
@@ -121,12 +120,12 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { if (isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`); window.location.replace(`/authorize?${compiledOIDCParams}`);
return; return;
} }
window.location.replace( window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
); );
}, 500); }, 500);
}, },
@@ -145,7 +144,7 @@ export const LoginPage = () => {
!isLoggedIn && !isLoggedIn &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
redirectUri !== undefined props.redirect_uri
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauthAutoRedirect); oauthMutate(oauthAutoRedirect);
@@ -156,7 +155,7 @@ export const LoginPage = () => {
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauthAutoRedirect, oauthAutoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
redirectUri, props.redirect_uri,
]); ]);
useEffect(() => { useEffect(() => {
@@ -171,14 +170,14 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (isLoggedIn && oidcParams.isOidc) { if (isLoggedIn && isOidc) {
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />; return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
} }
if (isLoggedIn && redirectUri !== undefined) { if (isLoggedIn && props.redirect_uri !== "") {
return ( return (
<Navigate <Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`} to={`/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`}
replace replace
/> />
); );

View File

@@ -27,8 +27,11 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const {
const oidcParams = useOIDCParams(searchParams); values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const totpMutation = useMutation({ const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -39,13 +42,13 @@ export const TotpPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { if (isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`); window.location.replace(`/authorize?${compiledOIDCParams}`);
return; return;
} }
window.location.replace( window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
); );
}, 500); }, 500);
}, },

View File

@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -71,7 +70,6 @@ func (controller *OIDCController) SetupRoutes() {
oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo)
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
@@ -377,15 +375,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
var token string
authorization := c.GetHeader("Authorization") authorization := c.GetHeader("Authorization")
if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ") tokenType, token, ok := strings.Cut(authorization, " ")
if !ok { if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_grant",
}) })
return return
} }
@@ -393,32 +390,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
if strings.ToLower(tokenType) != "bearer" { if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_grant",
})
return
}
token = bearerToken
} else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
token = c.PostForm("access_token")
if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
} else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
}) })
return return
} }

View File

@@ -435,128 +435,6 @@ func TestOIDCController(t *testing.T) {
assert.False(t, ok, "Did not expect email claim in userinfo response") assert.False(t, ok, "Did not expect email claim in userinfo response")
}, },
}, },
{
description: "Ensure userinfo forbids access with no authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with malformed authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with invalid token type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Basic some-token")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with empty bearer token",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer ")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
{
description: "Ensure userinfo POST rejects missing access token in body",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo POST rejects wrong content type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo accepts access token via POST body",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
assert.True(t, found, "Token test not found")
tokenRecorder := httptest.NewRecorder()
tokenTest(t, router, tokenRecorder)
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
body := url.Values{}
body.Set("access_token", accessToken)
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
},
},
{ {
description: "Ensure plain PKCE succeeds", description: "Ensure plain PKCE succeeds",
middlewares: []gin.HandlerFunc{ middlewares: []gin.HandlerFunc{

View File

@@ -24,7 +24,6 @@ var (
"GET /api/oidc/clients", "GET /api/oidc/clients",
"POST /api/oidc/token", "POST /api/oidc/token",
"GET /api/oidc/userinfo", "GET /api/oidc/userinfo",
"POST /api/oidc/userinfo",
"GET /resources", "GET /resources",
"POST /api/user/login", "POST /api/user/login",
"GET /.well-known/openid-configuration", "GET /.well-known/openid-configuration",