Compare commits

..

1 Commits

Author SHA1 Message Date
Stavros 9a6676b054 feat: add pkce support to oidc server 2026-04-06 23:09:14 +03:00
24 changed files with 229 additions and 618 deletions
+57 -33
View File
@@ -1,40 +1,64 @@
import { z } from "zod"; export type OIDCValues = {
scope: string;
response_type: string;
client_id: string;
redirect_uri: string;
state: string;
nonce: string;
code_challenge: string;
code_challenge_method: string;
};
export const oidcParamsSchema = z.object({ interface IuseOIDCParams {
scope: z.string().min(1), values: OIDCValues;
response_type: z.string().min(1),
client_id: z.string().min(1),
redirect_uri: z.string().min(1),
state: z.string().optional(),
nonce: z.string().optional(),
code_challenge: z.string().optional(),
code_challenge_method: z.string().optional(),
});
export const useOIDCParams = (
params: URLSearchParams,
): {
values: z.infer<typeof oidcParamsSchema>;
issues: string[];
isOidc: boolean;
compiled: string; compiled: string;
} => { isOidc: boolean;
const obj = Object.fromEntries(params.entries()); missingParams: string[];
const parsed = oidcParamsSchema.safeParse(obj); }
if (parsed.success) { const optionalParams: string[] = [
return { "state",
values: parsed.data, "nonce",
issues: [], "code_challenge",
isOidc: true, "code_challenge_method",
compiled: new URLSearchParams(parsed.data).toString(), ];
};
export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
let compiled: string = "";
let isOidc = false;
const missingParams: string[] = [];
const values: OIDCValues = {
scope: params.get("scope") ?? "",
response_type: params.get("response_type") ?? "",
client_id: params.get("client_id") ?? "",
redirect_uri: params.get("redirect_uri") ?? "",
state: params.get("state") ?? "",
nonce: params.get("nonce") ?? "",
code_challenge: params.get("code_challenge") ?? "",
code_challenge_method: params.get("code_challenge_method") ?? "",
};
for (const key of Object.keys(values)) {
if (!values[key as keyof OIDCValues]) {
if (!optionalParams.includes(key)) {
missingParams.push(key);
}
}
}
if (missingParams.length === 0) {
isOidc = true;
}
if (isOidc) {
compiled = new URLSearchParams(values).toString();
} }
return { return {
issues: parsed.error.issues.map((issue) => issue.path.toString()), values,
values: {} as z.infer<typeof oidcParamsSchema>, compiled,
isOidc: false, isOidc,
compiled: "", missingParams,
}; };
}; }
+20 -14
View File
@@ -72,27 +72,36 @@ export const AuthorizePage = () => {
const scopeMap = createScopeMap(t); const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const oidcParams = useOIDCParams(searchParams); const {
values: props,
missingParams,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
const getClientInfo = useQuery({ const getClientInfo = useQuery({
queryKey: ["client", oidcParams.values.client_id], queryKey: ["client", props.client_id],
queryFn: async () => { queryFn: async () => {
const res = await fetch( const res = await fetch(`/api/oidc/clients/${props.client_id}`);
`/api/oidc/clients/${encodeURIComponent(oidcParams.values.client_id)}`,
);
const data = await getOidcClientInfoSchema.parseAsync(await res.json()); const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data; return data;
}, },
enabled: oidcParams.isOidc, enabled: isOidc,
}); });
const authorizeMutation = useMutation({ const authorizeMutation = useMutation({
mutationFn: () => { mutationFn: () => {
return axios.post("/api/oidc/authorize", { return axios.post("/api/oidc/authorize", {
...oidcParams.values, scope: props.scope,
response_type: props.response_type,
client_id: props.client_id,
redirect_uri: props.redirect_uri,
state: props.state,
nonce: props.nonce,
}); });
}, },
mutationKey: ["authorize", oidcParams.values.client_id], mutationKey: ["authorize", props.client_id],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), { toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"), description: t("authorizeSuccessSubtitle"),
@@ -106,17 +115,17 @@ export const AuthorizePage = () => {
}, },
}); });
if (oidcParams.issues.length > 0) { if (missingParams.length > 0) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: oidcParams.issues.join(", ") }))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: missingParams.join(", ") }))}`}
replace replace
/> />
); );
} }
if (!isLoggedIn) { if (!isLoggedIn) {
return <Navigate to={`/login?${oidcParams.compiled}`} replace />; return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
} }
if (getClientInfo.isLoading) { if (getClientInfo.isLoading) {
@@ -143,9 +152,6 @@ export const AuthorizePage = () => {
); );
} }
const scopes =
oidcParams.values.scope.split(" ").filter((s) => s.trim() !== "") || [];
return ( return (
<Card> <Card>
<CardHeader className="mb-2"> <CardHeader className="mb-2">
+20 -29
View File
@@ -51,12 +51,15 @@ export const LoginPage = () => {
const formId = useId(); const formId = useId();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const {
const oidcParams = useOIDCParams(searchParams); values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState( const [isOauthAutoRedirect, setIsOauthAutoRedirect] = useState(
providers.find((provider) => provider.id === oauthAutoRedirect) !== providers.find((provider) => provider.id === oauthAutoRedirect) !==
undefined && redirectUri !== undefined, undefined && props.redirect_uri,
); );
const oauthProviders = providers.filter( const oauthProviders = providers.filter(
@@ -73,18 +76,10 @@ export const LoginPage = () => {
isPending: oauthIsPending, isPending: oauthIsPending,
variables: oauthVariables, variables: oauthVariables,
} = useMutation({ } = useMutation({
mutationFn: (provider: string) => { mutationFn: (provider: string) =>
const getParams = function (): string { axios.get(
if (oidcParams.isOidc) { `/api/oauth/url/${provider}${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
return `?${oidcParams.compiled}`; ),
}
if (redirectUri) {
return `?redirect_uri=${encodeURIComponent(redirectUri)}`;
}
return "";
};
return axios.get(`/api/oauth/url/${provider}${getParams()}`);
},
mutationKey: ["oauth"], mutationKey: ["oauth"],
onSuccess: (data) => { onSuccess: (data) => {
toast.info(t("loginOauthSuccessTitle"), { toast.info(t("loginOauthSuccessTitle"), {
@@ -114,12 +109,8 @@ export const LoginPage = () => {
mutationKey: ["login"], mutationKey: ["login"],
onSuccess: (data) => { onSuccess: (data) => {
if (data.data.totpPending) { if (data.data.totpPending) {
if (oidcParams.isOidc) {
window.location.replace(`/totp?${oidcParams.compiled}`);
return;
}
window.location.replace( window.location.replace(
`/totp${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/totp${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
); );
return; return;
} }
@@ -129,12 +120,12 @@ export const LoginPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { if (isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`); window.location.replace(`/authorize?${compiledOIDCParams}`);
return; return;
} }
window.location.replace( window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
); );
}, 500); }, 500);
}, },
@@ -153,7 +144,7 @@ export const LoginPage = () => {
!isLoggedIn && !isLoggedIn &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
redirectUri !== undefined props.redirect_uri
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauthAutoRedirect); oauthMutate(oauthAutoRedirect);
@@ -164,7 +155,7 @@ export const LoginPage = () => {
hasAutoRedirectedRef, hasAutoRedirectedRef,
oauthAutoRedirect, oauthAutoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
redirectUri, props.redirect_uri,
]); ]);
useEffect(() => { useEffect(() => {
@@ -179,14 +170,14 @@ export const LoginPage = () => {
}; };
}, [redirectTimer, redirectButtonTimer]); }, [redirectTimer, redirectButtonTimer]);
if (isLoggedIn && oidcParams.isOidc) { if (isLoggedIn && isOidc) {
return <Navigate to={`/authorize?${oidcParams.compiled}`} replace />; return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
} }
if (isLoggedIn && redirectUri !== undefined) { if (isLoggedIn && props.redirect_uri !== "") {
return ( return (
<Navigate <Navigate
to={`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`} to={`/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`}
replace replace
/> />
); );
+8 -5
View File
@@ -27,8 +27,11 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null); const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri") || undefined; const {
const oidcParams = useOIDCParams(searchParams); values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const totpMutation = useMutation({ const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -39,13 +42,13 @@ export const TotpPage = () => {
}); });
redirectTimer.current = window.setTimeout(() => { redirectTimer.current = window.setTimeout(() => {
if (oidcParams.isOidc) { if (isOidc) {
window.location.replace(`/authorize?${oidcParams.compiled}`); window.location.replace(`/authorize?${compiledOIDCParams}`);
return; return;
} }
window.location.replace( window.location.replace(
`/continue${redirectUri ? `?redirect_uri=${encodeURIComponent(redirectUri)}` : ""}`, `/continue${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
); );
}, 500); }, 500);
}, },
@@ -1 +1,2 @@
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge"; ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge_method";
@@ -1 +1,2 @@
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT ""; ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge_method" TEXT DEFAULT "";
@@ -10,12 +10,10 @@ import (
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestContextController(t *testing.T) { func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{ controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{ Providers: []controller.Provider{
{ {
@@ -8,12 +8,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHealthController(t *testing.T) { func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct { tests := []struct {
description string description string
path string path string
+40 -67
View File
@@ -62,29 +62,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
var reqParams service.OAuthURLParams sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
err = c.BindQuery(&reqParams)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
if !controller.isOidcRequest(reqParams) {
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
if !isRedirectSafe {
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
reqParams.RedirectURI = ""
}
}
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session") tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
@@ -107,6 +85,20 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
redirectURI := c.Query("redirect_uri")
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
if !isRedirectSafe {
tlog.App.Warn().Str("redirect_uri", redirectURI).Msg("Unsafe redirect URI detected, ignoring")
redirectURI = ""
}
if redirectURI != "" && isRedirectSafe {
tlog.App.Debug().Msg("Setting redirect URI cookie")
c.SetCookie(controller.config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -137,24 +129,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
} }
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
defer controller.auth.EndOAuthSession(sessionIdCookie) defer controller.auth.EndOAuthSession(sessionIdCookie)
state := c.Query("state") state := c.Query("state")
if state != oauthPendingSession.State { csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
if err != nil || state != csrfCookie {
tlog.App.Warn().Err(err).Msg("CSRF token mismatch or cookie missing")
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
code := c.Query("code") code := c.Query("code")
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code) _, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
@@ -210,7 +198,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
svc, err := controller.auth.GetOAuthService(sessionIdCookie) service, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
@@ -218,8 +206,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
if svc.ID() != req.Provider { if service.ID() != req.Provider {
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", service.ID(), req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
@@ -228,9 +216,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
Username: username, Username: username,
Name: name, Name: name,
Email: user.Email, Email: user.Email,
Provider: svc.ID(), Provider: service.ID(),
OAuthGroups: utils.CoalesceToString(user.Groups), OAuthGroups: utils.CoalesceToString(user.Groups),
OAuthName: svc.Name(), OAuthName: service.Name(),
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }
@@ -246,39 +234,24 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
if controller.isOidcRequest(oauthPendingSession.CallbackParams) { redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
queries, err := query.Values(oauthPendingSession.CallbackParams) if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {
if err != nil { tlog.App.Debug().Msg("No redirect URI cookie found, redirecting to app root")
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
return return
} }
if oauthPendingSession.CallbackParams.RedirectURI != "" { queries, err := query.Values(config.RedirectQuery{
queries, err := query.Values(config.RedirectQuery{ RedirectURI: redirectURI,
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, })
})
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) c.SetCookie(controller.config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
} c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
return params.Scope != "" &&
params.ResponseType != "" &&
params.ClientID != "" &&
params.RedirectURI != ""
} }
+12 -40
View File
@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -71,7 +70,6 @@ func (controller *OIDCController) SetupRoutes() {
oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token) oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.GET("/userinfo", controller.Userinfo)
oidcGroup.POST("/userinfo", controller.Userinfo)
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
@@ -311,7 +309,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return return
} }
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, entry.CodeChallengeMethod, req.CodeVerifier)
if !ok { if !ok {
tlog.App.Warn().Msg("PKCE validation failed") tlog.App.Warn().Msg("PKCE validation failed")
@@ -377,48 +375,22 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return return
} }
var token string
authorization := c.GetHeader("Authorization") authorization := c.GetHeader("Authorization")
if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
if strings.ToLower(tokenType) != "bearer" { tokenType, token, ok := strings.Cut(authorization, " ")
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
token = bearerToken if !ok {
} else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
token = c.PostForm("access_token")
if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
} else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"error": "invalid_request", "error": "invalid_grant",
})
return
}
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_grant",
}) })
return return
} }
-347
View File
@@ -1,8 +1,6 @@
package controller_test package controller_test
import ( import (
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@@ -17,13 +15,11 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestOIDCController(t *testing.T) { func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{ oidcServiceCfg := service.OIDCServiceConfig{
@@ -435,349 +431,6 @@ func TestOIDCController(t *testing.T) {
assert.False(t, ok, "Did not expect email claim in userinfo response") assert.False(t, ok, "Did not expect email claim in userinfo response")
}, },
}, },
{
description: "Ensure userinfo forbids access with no authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with malformed authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with invalid token type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Basic some-token")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with empty bearer token",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer ")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
{
description: "Ensure userinfo POST rejects missing access token in body",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo POST rejects wrong content type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo accepts access token via POST body",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
assert.True(t, found, "Token test not found")
tokenRecorder := httptest.NewRecorder()
tokenTest(t, router, tokenRecorder)
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
body := url.Values{}
body.Set("access_token", accessToken)
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
},
},
{
description: "Ensure plain PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: "some-challenge",
// Not setting a code challenge method should default to "plain"
CodeChallengeMethod: "",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure S256 PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure request with invalid PKCE fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge-1",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)
},
},
{
description: "Ensure request with invalid challenge method fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "foo",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
error := queryParams.Get("error")
assert.NotEmpty(t, error)
},
},
} }
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
+2 -1
View File
@@ -17,7 +17,6 @@ import (
) )
func TestProxyController(t *testing.T) { func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
@@ -391,6 +390,8 @@ func TestProxyController(t *testing.T) {
}, },
} }
tlog.NewSimpleLogger().Init()
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
@@ -8,13 +8,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestResourcesController(t *testing.T) { func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
resourcesControllerCfg := controller.ResourcesControllerConfig{ resourcesControllerCfg := controller.ResourcesControllerConfig{
+2 -1
View File
@@ -22,7 +22,6 @@ import (
) )
func TestUserController(t *testing.T) { func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{ authServiceCfg := service.AuthServiceConfig{
@@ -275,6 +274,8 @@ func TestUserController(t *testing.T) {
}, },
} }
tlog.NewSimpleLogger().Init()
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig) oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
@@ -13,13 +13,11 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWellKnownController(t *testing.T) { func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir() tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{ oidcServiceCfg := service.OIDCServiceConfig{
@@ -24,7 +24,6 @@ var (
"GET /api/oidc/clients", "GET /api/oidc/clients",
"POST /api/oidc/token", "POST /api/oidc/token",
"GET /api/oidc/userinfo", "GET /api/oidc/userinfo",
"POST /api/oidc/userinfo",
"GET /resources", "GET /resources",
"POST /api/user/login", "POST /api/user/login",
"GET /.well-known/openid-configuration", "GET /.well-known/openid-configuration",
+9 -8
View File
@@ -5,14 +5,15 @@
package repository package repository
type OidcCode struct { type OidcCode struct {
Sub string Sub string
CodeHash string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
Nonce string Nonce string
CodeChallenge string CodeChallenge string
CodeChallengeMethod string
} }
type OidcToken struct { type OidcToken struct {
+25 -16
View File
@@ -18,22 +18,24 @@ INSERT INTO "oidc_codes" (
"client_id", "client_id",
"expires_at", "expires_at",
"nonce", "nonce",
"code_challenge" "code_challenge",
"code_challenge_method"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
` `
type CreateOidcCodeParams struct { type CreateOidcCodeParams struct {
Sub string Sub string
CodeHash string CodeHash string
Scope string Scope string
RedirectURI string RedirectURI string
ClientID string ClientID string
ExpiresAt int64 ExpiresAt int64
Nonce string Nonce string
CodeChallenge string CodeChallenge string
CodeChallengeMethod string
} }
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
@@ -46,6 +48,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
arg.ExpiresAt, arg.ExpiresAt,
arg.Nonce, arg.Nonce,
arg.CodeChallenge, arg.CodeChallenge,
arg.CodeChallengeMethod,
) )
var i OidcCode var i OidcCode
err := row.Scan( err := row.Scan(
@@ -57,6 +60,7 @@ func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams)
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
@@ -160,7 +164,7 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "expires_at" < ? WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
` `
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
@@ -181,6 +185,7 @@ func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) (
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
@@ -291,7 +296,7 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
const getOidcCode = `-- name: GetOidcCode :one const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
` `
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -306,6 +311,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
@@ -313,7 +319,7 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes" DELETE FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method
` `
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -328,12 +334,13 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes"
WHERE "sub" = ? WHERE "sub" = ?
` `
@@ -349,12 +356,13 @@ func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcC
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge, code_challenge_method FROM "oidc_codes"
WHERE "code_hash" = ? WHERE "code_hash" = ?
` `
@@ -370,6 +378,7 @@ func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcC
&i.ExpiresAt, &i.ExpiresAt,
&i.Nonce, &i.Nonce,
&i.CodeChallenge, &i.CodeChallenge,
&i.CodeChallengeMethod,
) )
return i, err return i, err
} }
+15 -30
View File
@@ -28,26 +28,12 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256 const MaxLoginAttemptRecords = 256
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed
type OAuthURLParams struct {
Scope string `form:"scope" url:"scope"`
ResponseType string `form:"response_type" url:"response_type"`
ClientID string `form:"client_id" url:"client_id"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
State string `form:"state" url:"state"`
Nonce string `form:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
}
type OAuthPendingSession struct { type OAuthPendingSession struct {
State string State string
Verifier string Verifier string
Token *oauth2.Token Token *oauth2.Token
Service *OAuthServiceImpl Service *OAuthServiceImpl
ExpiresAt time.Time ExpiresAt time.Time
CallbackParams OAuthURLParams
} }
type LdapGroupsCache struct { type LdapGroupsCache struct {
@@ -612,7 +598,7 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
return false return false
} }
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit() auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
@@ -631,11 +617,10 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
verifier := service.NewRandom() verifier := service.NewRandom()
session := OAuthPendingSession{ session := OAuthPendingSession{
State: state, State: state,
Verifier: verifier, Verifier: verifier,
Service: &service, Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params,
} }
auth.oauthMutex.Lock() auth.oauthMutex.Lock()
@@ -646,7 +631,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
} }
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.getOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return "", err return "", err
@@ -656,7 +641,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
} }
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.getOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -676,7 +661,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
} }
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.getOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return config.Claims{}, err return config.Claims{}, err
@@ -696,7 +681,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro
} }
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.getOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -730,7 +715,7 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
} }
} }
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit() auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock() auth.oauthMutex.RLock()
+10 -4
View File
@@ -297,7 +297,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
// PKCE code challenge method if set // PKCE code challenge method if set
if req.CodeChallenge != "" && req.CodeChallengeMethod != "" { if req.CodeChallenge != "" && req.CodeChallengeMethod != "" {
if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" { if req.CodeChallengeMethod != "S256" || req.CodeChallenge == "plain" {
return errors.New("invalid_request") return errors.New("invalid_request")
} }
} }
@@ -329,8 +329,10 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
if req.CodeChallenge != "" { if req.CodeChallenge != "" {
if req.CodeChallengeMethod == "S256" { if req.CodeChallengeMethod == "S256" {
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
entry.CodeChallengeMethod = "S256"
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
entry.CodeChallengeMethod = "plain"
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
} }
} }
@@ -749,15 +751,19 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
return jwk.Public().MarshalJSON() return jwk.Public().MarshalJSON()
} }
func (service *OIDCService) ValidatePKCE(codeChallenge string, codeVerifier string) bool { func (service *OIDCService) ValidatePKCE(codeChallenge string, codeChallengeMethod string, codeVerifier string) bool {
if codeChallenge == "" { if codeChallenge == "" {
return true return true
} }
return codeChallenge == service.hashAndEncodePKCE(codeVerifier) if codeChallengeMethod == "plain" {
// Code challenge is hashed and encoded in the database for security reasons
return codeChallenge == service.hashAndEncodePKCE(codeVerifier)
}
return codeChallenge == codeVerifier
} }
func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string { func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte(codeVerifier)) hasher.Write([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) return base64.URLEncoding.EncodeToString(hasher.Sum(nil))
} }
-11
View File
@@ -55,17 +55,6 @@ func NewSimpleLogger() *Logger {
}) })
} }
func NewTestLogger() *Logger {
return NewLogger(config.LogConfig{
Level: "trace",
Streams: config.LogStreams{
HTTP: config.LogStreamConfig{Enabled: true},
App: config.LogStreamConfig{Enabled: true},
Audit: config.LogStreamConfig{Enabled: true},
},
})
}
func (l *Logger) Init() { func (l *Logger) Init() {
Audit = l.Audit Audit = l.Audit
HTTP = l.HTTP HTTP = l.HTTP
+3 -2
View File
@@ -7,9 +7,10 @@ INSERT INTO "oidc_codes" (
"client_id", "client_id",
"expires_at", "expires_at",
"nonce", "nonce",
"code_challenge" "code_challenge",
"code_challenge_method"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
+2 -1
View File
@@ -6,7 +6,8 @@ CREATE TABLE IF NOT EXISTS "oidc_codes" (
"client_id" TEXT NOT NULL, "client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL, "expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "", "nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT "" "code_challenge" TEXT DEFAULT "",
"code_challenge_method" TEXT DEFAULT ""
); );
CREATE TABLE IF NOT EXISTS "oidc_tokens" ( CREATE TABLE IF NOT EXISTS "oidc_tokens" (
+2
View File
@@ -28,3 +28,5 @@ sql:
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge" - column: "oidc_codes.code_challenge"
go_type: "string" go_type: "string"
- column: "oidc_codes.code_challenge_method"
go_type: "string"