fix: more review comments

This commit is contained in:
Stavros
2026-01-26 16:20:49 +02:00
parent e498ee4be0
commit fe391fc571
9 changed files with 270 additions and 57 deletions

View File

@@ -62,9 +62,18 @@
"goToCorrectDomainTitle": "Go to correct domain",
"authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please keep in mind that this app will have access to your email and other information.",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds."
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access the groups in which you are a member."
}

View File

@@ -62,9 +62,18 @@
"goToCorrectDomainTitle": "Go to correct domain",
"authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please keep in mind that this app will have access to your email and other information.",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds."
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access the groups in which you are a member."
}

View File

@@ -8,6 +8,7 @@ import {
CardTitle,
CardDescription,
CardFooter,
CardContent,
} from "@/components/ui/card";
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button";
@@ -15,12 +16,55 @@ import axios from "axios";
import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next";
import { TFunction } from "i18next";
import { Mail, Shield, User, Users } from "lucide-react";
type Scope = {
id: string;
name: string;
description: string;
icon: React.ReactNode;
};
const scopeMapIconProps = {
className: "stroke-card stroke-2.5",
};
const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
return [
{
id: "openid",
name: t("openidScopeName"),
description: t("openidScopeDescription"),
icon: <Shield {...scopeMapIconProps} />,
},
{
id: "email",
name: t("emailScopeName"),
description: t("emailScopeDescription"),
icon: <Mail {...scopeMapIconProps} />,
},
{
id: "profile",
name: t("profileScopeName"),
description: t("profileScopeDescription"),
icon: <User {...scopeMapIconProps} />,
},
{
id: "groups",
name: t("groupsScopeName"),
description: t("groupsScopeDescription"),
icon: <Users {...scopeMapIconProps} />,
},
];
};
export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext();
const { search } = useLocation();
const { t } = useTranslation();
const navigate = useNavigate();
const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search);
const {
@@ -29,6 +73,7 @@ export const AuthorizePage = () => {
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const scopes = props.scope.split(" ");
const getClientInfo = useQuery({
queryKey: ["client", props.client_id],
@@ -100,15 +145,40 @@ export const AuthorizePage = () => {
}
return (
<Card className="min-w-xs sm:min-w-sm">
<Card className="min-w-xs sm:min-w-sm mx-4">
<CardHeader>
<CardTitle className="text-3xl">
{t("authorizeCardTitle", {
app: getClientInfo.data?.name || "Unknown",
})}
</CardTitle>
<CardDescription>{t("authorizeSubtitle")}</CardDescription>
<CardDescription>
{scopes.includes("openid")
? t("authorizeSubtitle")
: t("authorizeSubtitleOAuth")}
</CardDescription>
</CardHeader>
{scopes.includes("openid") && (
<CardContent className="flex flex-col gap-4">
{scopes.map((id) => {
const scope = scopeMap.find((s) => s.id === id);
if (!scope) return null;
return (
<div key={scope.id} className="flex flex-row items-center gap-3">
<div className="p-2 flex flex-col items-center justify-center bg-card-foreground rounded-md">
{scope.icon}
</div>
<div className="flex flex-col gap-0.5">
<div className="text-md">{scope.name}</div>
<div className="text-sm text-muted-foreground">
{scope.description}
</div>
</div>
</div>
);
})}
</CardContent>
)}
<CardFooter className="flex flex-col items-stretch gap-2">
<Button
onClick={() => authorizeMutation.mutate()}

View File

@@ -13,7 +13,7 @@ import (
"gotest.tools/v3/assert"
)
var controllerCfg = controller.ContextControllerConfig{
var contextControllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Local",
@@ -35,7 +35,7 @@ var controllerCfg = controller.ContextControllerConfig{
DisableUIWarnings: false,
}
var userContext = config.UserContext{
var contextCtrlTestContext = config.UserContext{
Username: "testuser",
Name: "testuser",
Email: "test@example.com",
@@ -65,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt
group := router.Group("/api")
ctrl := controller.NewContextController(controllerCfg, group)
ctrl := controller.NewContextController(contextControllerCfg, group)
ctrl.SetupRoutes()
return router, recorder
@@ -75,14 +75,14 @@ func TestAppContextHandler(t *testing.T) {
expectedRes := controller.AppContextResponse{
Status: 200,
Message: "Success",
Providers: controllerCfg.Providers,
Title: controllerCfg.Title,
AppURL: controllerCfg.AppURL,
CookieDomain: controllerCfg.CookieDomain,
ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage,
BackgroundImage: controllerCfg.BackgroundImage,
OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect,
DisableUIWarnings: controllerCfg.DisableUIWarnings,
Providers: contextControllerCfg.Providers,
Title: contextControllerCfg.Title,
AppURL: contextControllerCfg.AppURL,
CookieDomain: contextControllerCfg.CookieDomain,
ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
BackgroundImage: contextControllerCfg.BackgroundImage,
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
DisableUIWarnings: contextControllerCfg.DisableUIWarnings,
}
router, recorder := setupContextController(nil)
@@ -103,20 +103,20 @@ func TestUserContextHandler(t *testing.T) {
expectedRes := controller.UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: userContext.IsLoggedIn,
Username: userContext.Username,
Name: userContext.Name,
Email: userContext.Email,
Provider: userContext.Provider,
OAuth: userContext.OAuth,
TotpPending: userContext.TotpPending,
OAuthName: userContext.OAuthName,
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
Username: contextCtrlTestContext.Username,
Name: contextCtrlTestContext.Name,
Email: contextCtrlTestContext.Email,
Provider: contextCtrlTestContext.Provider,
OAuth: contextCtrlTestContext.OAuth,
TotpPending: contextCtrlTestContext.TotpPending,
OAuthName: contextCtrlTestContext.OAuthName,
}
// Test with context
router, recorder := setupContextController(&[]gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &userContext)
c.Set("context", &contextCtrlTestContext)
c.Next()
},
})

View File

@@ -114,7 +114,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
_, ok := controller.oidc.GetClient(req.ClientID)
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
@@ -133,8 +133,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too.
sub := utils.GenerateUUID(userContext.Username)
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
code := rand.Text()
// Before storing the code, delete old session
@@ -272,16 +272,6 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
err = controller.oidc.DeleteCodeEntry(c, entry.CodeHash)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
tokenResponse = tokenRes
case "refresh_token":
client, ok := controller.oidc.GetClient(req.ClientID)

View File

@@ -20,7 +20,7 @@ import (
"gotest.tools/v3/assert"
)
var serviceConfig = service.OIDCServiceConfig{
var oidcServiceConfig = service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
"client1": {
ClientID: "some-client-id",
@@ -38,7 +38,7 @@ var serviceConfig = service.OIDCServiceConfig{
SessionExpiry: 3600,
}
var oidcTestContext = config.UserContext{
var oidcCtrlTestContext = config.UserContext{
Username: "test",
Name: "Test",
Email: "test@example.com",
@@ -69,7 +69,7 @@ func TestOIDCController(t *testing.T) {
queries := repository.New(db)
// Create a new OIDC Servicee
oidcService := service.NewOIDCService(serviceConfig, queries)
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
err = oidcService.Init()
assert.NilError(t, err)
@@ -78,7 +78,7 @@ func TestOIDCController(t *testing.T) {
router := gin.Default()
router.Use(func(c *gin.Context) {
c.Set("context", &oidcTestContext)
c.Set("context", &oidcCtrlTestContext)
c.Next()
})
@@ -137,6 +137,8 @@ func TestOIDCController(t *testing.T) {
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
@@ -154,12 +156,31 @@ func TestOIDCController(t *testing.T) {
_, ok = resJson["id_token"].(string)
assert.Assert(t, ok)
_, ok = resJson["refresh_token"].(string)
refreshToken, ok := resJson["refresh_token"].(string)
assert.Assert(t, ok)
expires_in, ok := resJson["expires_in"].(float64)
assert.Assert(t, ok)
assert.Equal(t, expires_in, float64(serviceConfig.SessionExpiry))
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
// Ensure code is expired
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
// Test userinfo
recorder = httptest.NewRecorder()
@@ -182,18 +203,81 @@ func TestOIDCController(t *testing.T) {
name, ok := resJson["name"].(string)
assert.Assert(t, ok)
assert.Equal(t, name, oidcTestContext.Name)
assert.Equal(t, name, oidcCtrlTestContext.Name)
email, ok := resJson["email"].(string)
assert.Assert(t, ok)
assert.Equal(t, email, oidcTestContext.Email)
assert.Equal(t, email, oidcCtrlTestContext.Email)
preferred_username, ok := resJson["preferred_username"].(string)
assert.Assert(t, ok)
assert.Equal(t, preferred_username, oidcTestContext.Username)
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
// Not sure why this is failing, will look into it later
// groups, ok := resJson["groups"].([]string)
// assert.Assert(t, ok)
// assert.Equal(t, strings.Split(oidcTestContext.LdapGroups, ","), groups)
igroups, ok := resJson["groups"].([]any)
assert.Assert(t, ok)
groups := make([]string, len(igroups))
for i, group := range igroups {
groups[i], ok = group.(string)
assert.Assert(t, ok)
}
assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
// Test refresh token
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
})
if err != nil {
t.Fatal(err)
}
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
if err != nil {
t.Fatal(err)
}
newToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
assert.Assert(t, newToken != accessToken)
// Ensure old token is invalid
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
// Test new token
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
}

View File

@@ -274,8 +274,9 @@ func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
}
const getOidcCode = `-- name: GetOidcCode :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
@@ -293,8 +294,9 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
@@ -311,6 +313,44 @@ func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, e
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "access_token_hash" = ?

View File

@@ -383,6 +383,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
Sub: sub,
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken),
ClientID: client.ClientID,
Scope: scope,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
@@ -425,7 +426,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken,
@@ -586,7 +587,7 @@ func (service *OIDCService) Cleanup() {
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.queries.DeleteSession(ctx, expiredCode.Sub)
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session")
}

View File

@@ -11,14 +11,24 @@ INSERT INTO "oidc_codes" (
)
RETURNING *;
-- name: GetOidcCode :one
-- name: GetOidcCodeUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: GetOidcCodeBySub :one
-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "sub" = ?;
-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING *;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?;