Compare commits

...

26 Commits

Author SHA1 Message Date
Stavros
627fd05d71 i18n: authorize page error messages 2026-02-01 18:58:03 +02:00
Stavros
fb705eaf07 fix: final review comments 2026-02-01 18:47:18 +02:00
Stavros
673f556fb3 fix: more rabbit nitpicks 2026-02-01 00:16:58 +02:00
Stavros
01e491c3be i18n: fix typo 2026-01-29 16:26:02 +02:00
Stavros
63fcc654f0 feat: jwk endpoint 2026-01-29 15:44:26 +02:00
Stavros
a8f57e584e feat: openid discovery endpoint 2026-01-26 19:50:15 +02:00
Stavros
328064946b refactor: rework oidc error messages 2026-01-26 19:03:20 +02:00
Stavros
fe391fc571 fix: more review comments 2026-01-26 16:20:49 +02:00
Stavros
e498ee4be0 tests: add basic testing 2026-01-25 20:45:56 +02:00
Stavros
9cbcd62c6e fix: fix typo in error screen 2026-01-25 20:04:20 +02:00
Stavros
fae1345a06 feat: frontend i18n 2026-01-25 19:54:39 +02:00
Stavros
8dd731b21e feat: cleanup expired oidc sessions 2026-01-25 19:45:17 +02:00
Stavros
46f25aaa38 feat: refresh token grant type support 2026-01-25 19:15:57 +02:00
Stavros
8af233b78d fix: oidc review comments 2026-01-25 18:32:14 +02:00
Stavros
cf1a613229 fix: review comments 2026-01-24 16:16:26 +02:00
Stavros
71bc3966bc feat: adapt frontend to oidc flow 2026-01-24 15:52:22 +02:00
Stavros
c817e353f6 refactor: implement oidc following tinyauth patterns 2026-01-24 14:31:03 +02:00
Stavros
97e90ea560 feat: implement basic oidc functionality 2026-01-22 22:30:23 +02:00
Stavros
6ae7c1cbda wip: authorize page 2026-01-21 20:12:32 +02:00
Stavros
7dc3525a8d chore: add oidc base config 2026-01-21 18:54:00 +02:00
Stavros
402dfa727b chore: update traefik and add use infisical as an options for secrets in
dev
2026-01-21 12:50:03 +02:00
Stavros
d67c3ab8a4 fix: ensure safe redirect check only accepts actual domains 2026-01-17 20:36:42 +02:00
Stavros
87e2b52a04 fix: set gin mode correctly 2026-01-17 20:26:48 +02:00
dependabot[bot]
f36b62561a chore(deps): bump modernc.org/sqlite in the minor-patch group (#588)
Bumps the minor-patch group with 1 update: [modernc.org/sqlite](https://gitlab.com/cznic/sqlite).


Updates `modernc.org/sqlite` from 1.44.0 to 1.44.1
- [Commits](https://gitlab.com/cznic/sqlite/compare/v1.44.0...v1.44.1)

---
updated-dependencies:
- dependency-name: modernc.org/sqlite
  dependency-version: 1.44.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-17 20:04:08 +02:00
dependabot[bot]
d2a146ead0 chore(deps-dev): bump @types/node in /frontend in the minor-patch group (#589)
Bumps the minor-patch group in /frontend with 1 update: [@types/node](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/node).


Updates `@types/node` from 25.0.8 to 25.0.9
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/node)

---
updated-dependencies:
- dependency-name: "@types/node"
  dependency-version: 25.0.9
  dependency-type: direct:development
  update-type: version-update:semver-patch
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-17 20:03:51 +02:00
Stavros
4926e53409 feat: ldap group acls (#590)
* wip

* refactor: remove useless session struct abstraction

* feat: retrieve and store groups from ldap provider

* chore: fix merge issue

* refactor: rework ldap group fetching logic

* feat: store ldap group results in cache

* fix: review nitpicks

* fix: review feedback
2026-01-17 20:03:29 +02:00
58 changed files with 2903 additions and 180 deletions

View File

@@ -89,7 +89,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -144,7 +144,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0

View File

@@ -67,7 +67,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -119,7 +119,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0

3
.gitignore vendored
View File

@@ -36,3 +36,6 @@
# debug files
__debug_*
# infisical
/.infisical.json

13
.zed/debug.json Normal file
View File

@@ -0,0 +1,13 @@
[
{
"label": "Attach to remote Delve",
"adapter": "Delve",
"mode": "remote",
"remotePath": "/tinyauth",
"request": "attach",
"tcp_connection": {
"host": "127.0.0.1",
"port": 4000,
},
},
]

View File

@@ -39,7 +39,10 @@ COPY ./cmd ./cmd
COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN CGO_ENABLED=0 go build -ldflags "-s -w -X tinyauth/internal/config.Version=${VERSION} -X tinyauth/internal/config.CommitHash=${COMMIT_HASH} -X tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/steveiliop56/tinyauth/internal/config.Version=${VERSION} \
-X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM alpine:3.23 AS runner
@@ -54,11 +57,9 @@ EXPOSE 3000
VOLUME ["/data"]
ENV DATABASEPATH=/data/tinyauth.db
ENV TINYAUTH_DATABASEPATH=/data/tinyauth.db
ENV RESOURCESDIR=/data/resources
ENV GIN_MODE=release
ENV TINYAUTH_RESOURCESDIR=/data/resources
ENV PATH=$PATH:/tinyauth

View File

@@ -41,7 +41,10 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "-s -w -X tinyauth/internal/config.Version=${VERSION} -X tinyauth/internal/config.CommitHash=${COMMIT_HASH} -X tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/steveiliop56/tinyauth/internal/config.Version=${VERSION} \
-X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM gcr.io/distroless/static-debian12:latest AS runner
@@ -61,8 +64,6 @@ ENV TINYAUTH_DATABASEPATH=/data/tinyauth.db
ENV TINYAUTH_RESOURCESDIR=/data/resources
ENV GIN_MODE=release
ENV PATH=$PATH:/tinyauth
HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 CMD ["tinyauth", "healthcheck"]

View File

@@ -18,6 +18,10 @@ deps:
bun install --cwd frontend
go mod download
# Clean data
clean-data:
rm -rf data/
# Clean web UI build
clean-webui:
rm -rf internal/assets/dist
@@ -31,9 +35,9 @@ webui: clean-webui
# Build the binary
binary: webui
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
-X tinyauth/internal/config.Version=${TAG_NAME} \
-X tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
-X github.com/steveiliop56/tinyauth/internal/config.Version=${TAG_NAME} \
-X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
-o ${BIN_NAME} ./cmd/tinyauth
# Build for amd64
@@ -57,8 +61,17 @@ test:
# Development
develop:
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
# Development - Infisical
develop-infisical:
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
# Production
prod:
docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# SQL
.PHONY: sql
sql:
sqlc generate

View File

@@ -21,9 +21,9 @@ func NewTinyauthCmdConfiguration() *config.Config {
Address: "0.0.0.0",
},
Auth: config.AuthConfig{
SessionExpiry: 3600,
SessionMaxLifetime: 0,
LoginTimeout: 300,
SessionExpiry: 86400, // 1 day
SessionMaxLifetime: 0, // disabled
LoginTimeout: 300, // 5 minutes
LoginMaxRetries: 3,
},
UI: config.UIConfig{
@@ -32,8 +32,9 @@ func NewTinyauthCmdConfiguration() *config.Config {
BackgroundImage: "/background.jpg",
},
Ldap: config.LdapConfig{
Insecure: false,
SearchFilter: "(uid=%s)",
Insecure: false,
SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes
},
Log: config.LogConfig{
Level: "info",
@@ -53,6 +54,10 @@ func NewTinyauthCmdConfiguration() *config.Config {
},
},
},
OIDC: config.OIDCConfig{
PrivateKeyPath: "./tinyauth_oidc_key",
PublicKeyPath: "./tinyauth_oidc_key.pub",
},
Experimental: config.ExperimentalConfig{
ConfigFile: "",
},

View File

@@ -1,7 +1,7 @@
services:
traefik:
container_name: traefik
image: traefik:v3.3
image: traefik:v3.6
command: --api.insecure=true --providers.docker
ports:
- 80:80
@@ -50,3 +50,4 @@ services:
labels:
traefik.enable: true
traefik.http.middlewares.tinyauth.forwardauth.address: http://tinyauth-backend:3000/api/auth/traefik
traefik.http.middlewares.tinyauth.forwardauth.authResponseHeaders: remote-user, remote-sub, remote-name, remote-email, remote-groups

View File

@@ -1,7 +1,7 @@
services:
traefik:
container_name: traefik
image: traefik:v3.3
image: traefik:v3.6
command: --api.insecure=true --providers.docker
ports:
- 80:80

View File

@@ -36,7 +36,7 @@
"devDependencies": {
"@eslint/js": "^9.39.2",
"@tanstack/eslint-plugin-query": "^5.91.2",
"@types/node": "^25.0.8",
"@types/node": "^25.0.9",
"@types/react": "^19.2.8",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.2",
@@ -365,7 +365,7 @@
"@types/ms": ["@types/ms@2.1.0", "", {}, "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA=="],
"@types/node": ["@types/node@25.0.8", "", { "dependencies": { "undici-types": "~7.16.0" } }, "sha512-powIePYMmC3ibL0UJ2i2s0WIbq6cg6UyVFQxSCpaPxxzAaziRfimGivjdF943sSGV6RADVbk0Nvlm5P/FB44Zg=="],
"@types/node": ["@types/node@25.0.9", "", { "dependencies": { "undici-types": "~7.16.0" } }, "sha512-/rpCXHlCWeqClNBwUhDcusJxXYDjZTyE8v5oTO7WbL8eij2nKhUeU89/6xgjU7N4/Vh3He0BtyhJdQbDyhiXAw=="],
"@types/react": ["@types/react@19.2.8", "", { "dependencies": { "csstype": "^3.2.2" } }, "sha512-3MbSL37jEchWZz2p2mjntRZtPt837ij10ApxKfgmXCTuHWagYg7iA5bqPw6C8BMPfwidlvfPI/fxOc42HLhcyg=="],

View File

@@ -42,7 +42,7 @@
"devDependencies": {
"@eslint/js": "^9.39.2",
"@tanstack/eslint-plugin-query": "^5.91.2",
"@types/node": "^25.0.8",
"@types/node": "^25.0.9",
"@types/react": "^19.2.8",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.2",

View File

@@ -159,6 +159,10 @@ code {
@apply relative rounded bg-muted px-[0.2rem] py-[0.1rem] font-mono text-sm font-semibold break-all;
}
pre {
@apply bg-accent border border-border rounded-md p-2;
}
.lead {
@apply text-xl text-muted-foreground;
}

View File

@@ -0,0 +1,53 @@
export type OIDCValues = {
scope: string;
response_type: string;
client_id: string;
redirect_uri: string;
state: string;
};
interface IuseOIDCParams {
values: OIDCValues;
compiled: string;
isOidc: boolean;
missingParams: string[];
}
const optionalParams: string[] = ["state"];
export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
let compiled: string = "";
let isOidc = false;
const missingParams: string[] = [];
const values: OIDCValues = {
scope: params.get("scope") ?? "",
response_type: params.get("response_type") ?? "",
client_id: params.get("client_id") ?? "",
redirect_uri: params.get("redirect_uri") ?? "",
state: params.get("state") ?? "",
};
for (const key of Object.keys(values)) {
if (!values[key as keyof OIDCValues]) {
if (!optionalParams.includes(key)) {
missingParams.push(key);
}
}
}
if (missingParams.length === 0) {
isOidc = true;
}
if (isOidc) {
compiled = new URLSearchParams(values).toString();
}
return {
values,
compiled,
isOidc,
missingParams,
};
}

View File

@@ -51,12 +51,31 @@
"forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred",
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
"errorSubtitleInfo": "The following error occurred while processing your request:",
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required",
"invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
"ignoreTitle": "Ignore",
"goToCorrectDomainTitle": "Go to correct domain"
}
"goToCorrectDomainTitle": "Go to correct domain",
"authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
"openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access your group information."
}

View File

@@ -51,12 +51,31 @@
"forgotPasswordTitle": "Forgot your password?",
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
"errorTitle": "An error occurred",
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
"errorSubtitleInfo": "The following error occurred while processing your request:",
"errorSubtitle": "An error occurred while trying to perform this action. Please check your browser console or the app logs for more information.",
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
"fieldRequired": "This field is required",
"invalidInput": "Invalid input",
"domainWarningTitle": "Invalid Domain",
"domainWarningSubtitle": "This instance is configured to be accessed from <code>{{appUrl}}</code>, but <code>{{currentUrl}}</code> is being used. If you proceed, you may encounter issues with authentication.",
"ignoreTitle": "Ignore",
"goToCorrectDomainTitle": "Go to correct domain"
}
"goToCorrectDomainTitle": "Go to correct domain",
"authorizeTitle": "Authorize",
"authorizeCardTitle": "Continue to {{app}}?",
"authorizeSubtitle": "Would you like to continue to this app? Please carefully review the permissions requested by the app.",
"authorizeSubtitleOAuth": "Would you like to continue to this app?",
"authorizeLoadingTitle": "Loading...",
"authorizeLoadingSubtitle": "Please wait while we load the client information.",
"authorizeSuccessTitle": "Authorized",
"authorizeSuccessSubtitle": "You will be redirected to the app in a few seconds.",
"authorizeErrorClientInfo": "An error occurred while loading the client information. Please try again later.",
"authorizeErrorMissingParams": "The following parameters are missing: {{missingParams}}",
"openidScopeName": "OpenID Connect",
"openidScopeDescription": "Allows the app to access your OpenID Connect information.",
"emailScopeName": "Email",
"emailScopeDescription": "Allows the app to access your email address.",
"profileScopeName": "Profile",
"profileScopeDescription": "Allows the app to access your profile information.",
"groupsScopeName": "Groups",
"groupsScopeDescription": "Allows the app to access your group information."
}

View File

@@ -17,6 +17,7 @@ import { AppContextProvider } from "./context/app-context.tsx";
import { UserContextProvider } from "./context/user-context.tsx";
import { Toaster } from "@/components/ui/sonner";
import { ThemeProvider } from "./components/providers/theme-provider.tsx";
import { AuthorizePage } from "./pages/authorize-page.tsx";
const queryClient = new QueryClient();
@@ -31,6 +32,7 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} />
<Route path="/authorize" element={<AuthorizePage />} />
<Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} />

View File

@@ -0,0 +1,199 @@
import { useUserContext } from "@/context/user-context";
import { useMutation, useQuery } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router";
import {
Card,
CardHeader,
CardTitle,
CardDescription,
CardFooter,
CardContent,
} from "@/components/ui/card";
import { getOidcClientInfoSchema } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button";
import axios from "axios";
import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { useTranslation } from "react-i18next";
import { TFunction } from "i18next";
import { Mail, Shield, User, Users } from "lucide-react";
type Scope = {
id: string;
name: string;
description: string;
icon: React.ReactNode;
};
const scopeMapIconProps = {
className: "stroke-card stroke-2.5",
};
const createScopeMap = (t: TFunction<"translation", undefined>): Scope[] => {
return [
{
id: "openid",
name: t("openidScopeName"),
description: t("openidScopeDescription"),
icon: <Shield {...scopeMapIconProps} />,
},
{
id: "email",
name: t("emailScopeName"),
description: t("emailScopeDescription"),
icon: <Mail {...scopeMapIconProps} />,
},
{
id: "profile",
name: t("profileScopeName"),
description: t("profileScopeDescription"),
icon: <User {...scopeMapIconProps} />,
},
{
id: "groups",
name: t("groupsScopeName"),
description: t("groupsScopeDescription"),
icon: <Users {...scopeMapIconProps} />,
},
];
};
export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext();
const { search } = useLocation();
const { t } = useTranslation();
const navigate = useNavigate();
const scopeMap = createScopeMap(t);
const searchParams = new URLSearchParams(search);
const {
values: props,
missingParams,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const scopes = props.scope ? props.scope.split(" ").filter(Boolean) : [];
const getClientInfo = useQuery({
queryKey: ["client", props.client_id],
queryFn: async () => {
const res = await fetch(`/api/oidc/clients/${props.client_id}`);
const data = await getOidcClientInfoSchema.parseAsync(await res.json());
return data;
},
enabled: isOidc,
});
const authorizeMutation = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize", {
scope: props.scope,
response_type: props.response_type,
client_id: props.client_id,
redirect_uri: props.redirect_uri,
state: props.state,
});
},
mutationKey: ["authorize", props.client_id],
onSuccess: (data) => {
toast.info(t("authorizeSuccessTitle"), {
description: t("authorizeSuccessSubtitle"),
});
window.location.replace(data.data.redirect_uri);
},
onError: (error) => {
window.location.replace(
`/error?error=${encodeURIComponent(error.message)}`,
);
},
});
if (missingParams.length > 0) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorMissingParams", { missingParams: missingParams.join(", ") }))}`}
replace
/>
);
}
if (!isLoggedIn) {
return <Navigate to={`/login?${compiledOIDCParams}`} replace />;
}
if (getClientInfo.isLoading) {
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">
{t("authorizeLoadingTitle")}
</CardTitle>
<CardDescription>{t("authorizeLoadingSubtitle")}</CardDescription>
</CardHeader>
</Card>
);
}
if (getClientInfo.isError) {
return (
<Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorClientInfo"))}`}
replace
/>
);
}
return (
<Card className="min-w-xs sm:min-w-sm mx-4">
<CardHeader>
<CardTitle className="text-3xl">
{t("authorizeCardTitle", {
app: getClientInfo.data?.name || "Unknown",
})}
</CardTitle>
<CardDescription>
{scopes.includes("openid")
? t("authorizeSubtitle")
: t("authorizeSubtitleOAuth")}
</CardDescription>
</CardHeader>
{scopes.includes("openid") && (
<CardContent className="flex flex-col gap-4">
{scopes.map((id) => {
const scope = scopeMap.find((s) => s.id === id);
if (!scope) return null;
return (
<div key={scope.id} className="flex flex-row items-center gap-3">
<div className="p-2 flex flex-col items-center justify-center bg-card-foreground rounded-md">
{scope.icon}
</div>
<div className="flex flex-col gap-0.5">
<div className="text-md">{scope.name}</div>
<div className="text-sm text-muted-foreground">
{scope.description}
</div>
</div>
</div>
);
})}
</CardContent>
)}
<CardFooter className="flex flex-col items-stretch gap-2">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
>
{t("authorizeTitle")}
</Button>
<Button
onClick={() => navigate("/")}
disabled={authorizeMutation.isPending}
variant="outline"
>
{t("cancelTitle")}
</Button>
</CardFooter>
</Card>
);
};

View File

@@ -80,7 +80,7 @@ export const ContinuePage = () => {
clearTimeout(auto);
clearTimeout(reveal);
};
}, []);
});
if (!isLoggedIn) {
return (

View File

@@ -5,15 +5,30 @@ import {
CardTitle,
} from "@/components/ui/card";
import { useTranslation } from "react-i18next";
import { useLocation } from "react-router";
export const ErrorPage = () => {
const { t } = useTranslation();
const { search } = useLocation();
const searchParams = new URLSearchParams(search);
const error = searchParams.get("error") ?? "";
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">{t("errorTitle")}</CardTitle>
<CardDescription>{t("errorSubtitle")}</CardDescription>
<CardDescription className="flex flex-col gap-1.5">
{error ? (
<>
<p>{t("errorSubtitleInfo")}</p>
<pre>{error}</pre>
</>
) : (
<>
<p>{t("errorSubtitle")}</p>
</>
)}
</CardDescription>
</CardHeader>
</Card>
);

View File

@@ -18,6 +18,7 @@ import { OAuthButton } from "@/components/ui/oauth-button";
import { SeperatorWithChildren } from "@/components/ui/separator";
import { useAppContext } from "@/context/app-context";
import { useUserContext } from "@/context/user-context";
import { useOIDCParams } from "@/lib/hooks/oidc";
import { LoginSchema } from "@/schemas/login-schema";
import { useMutation } from "@tanstack/react-query";
import axios, { AxiosError } from "axios";
@@ -47,18 +48,24 @@ export const LoginPage = () => {
const redirectButtonTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri");
const {
values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const oauthProviders = providers.filter(
(provider) => provider.id !== "username",
(provider) => provider.id !== "local" && provider.id !== "ldap",
);
const userAuthConfigured =
providers.find((provider) => provider.id === "username") !== undefined;
providers.find(
(provider) => provider.id === "local" || provider.id === "ldap",
) !== undefined;
const oauthMutation = useMutation({
mutationFn: (provider: string) =>
axios.get(
`/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
`/api/oauth/url/${provider}?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
),
mutationKey: ["oauth"],
onSuccess: (data) => {
@@ -84,7 +91,7 @@ export const LoginPage = () => {
onSuccess: (data) => {
if (data.data.totpPending) {
window.location.replace(
`/totp?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
`/totp?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
return;
}
@@ -94,8 +101,12 @@ export const LoginPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
if (isOidc) {
window.location.replace(`/authorize?${compiledOIDCParams}`);
return;
}
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
}, 500);
},
@@ -113,7 +124,7 @@ export const LoginPage = () => {
if (
providers.find((provider) => provider.id === oauthAutoRedirect) &&
!isLoggedIn &&
redirectUri
props.redirect_uri !== ""
) {
// Not sure of a better way to do this
// eslint-disable-next-line react-hooks/set-state-in-effect
@@ -123,7 +134,13 @@ export const LoginPage = () => {
setShowRedirectButton(true);
}, 5000);
}
}, []);
}, [
providers,
isLoggedIn,
props.redirect_uri,
oauthAutoRedirect,
oauthMutation,
]);
useEffect(
() => () => {
@@ -134,10 +151,14 @@ export const LoginPage = () => {
[],
);
if (isLoggedIn && redirectUri) {
if (isLoggedIn && isOidc) {
return <Navigate to={`/authorize?${compiledOIDCParams}`} replace />;
}
if (isLoggedIn && props.redirect_uri !== "") {
return (
<Navigate
to={`/continue?redirect_uri=${encodeURIComponent(redirectUri)}`}
to={`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`}
replace
/>
);

View File

@@ -55,7 +55,7 @@ export const LogoutPage = () => {
<CardHeader>
<CardTitle className="text-3xl">{t("logoutTitle")}</CardTitle>
<CardDescription>
{provider !== "username" ? (
{provider !== "local" && provider !== "ldap" ? (
<Trans
i18nKey="logoutOauthSubtitle"
t={t}

View File

@@ -16,6 +16,7 @@ import { useEffect, useId, useRef } from "react";
import { useTranslation } from "react-i18next";
import { Navigate, useLocation } from "react-router";
import { toast } from "sonner";
import { useOIDCParams } from "@/lib/hooks/oidc";
export const TotpPage = () => {
const { totpPending } = useUserContext();
@@ -26,7 +27,11 @@ export const TotpPage = () => {
const redirectTimer = useRef<number | null>(null);
const searchParams = new URLSearchParams(search);
const redirectUri = searchParams.get("redirect_uri");
const {
values: props,
isOidc,
compiled: compiledOIDCParams,
} = useOIDCParams(searchParams);
const totpMutation = useMutation({
mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values),
@@ -37,9 +42,14 @@ export const TotpPage = () => {
});
redirectTimer.current = window.setTimeout(() => {
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(redirectUri ?? "")}`,
);
if (isOidc) {
window.location.replace(`/authorize?${compiledOIDCParams}`);
return;
} else {
window.location.replace(
`/continue?redirect_uri=${encodeURIComponent(props.redirect_uri)}`,
);
}
}, 500);
},
onError: () => {

View File

@@ -0,0 +1,5 @@
import { z } from "zod";
export const getOidcClientInfoSchema = z.object({
name: z.string(),
});

View File

@@ -24,6 +24,11 @@ export default defineConfig({
changeOrigin: true,
rewrite: (path) => path.replace(/^\/resources/, ""),
},
"/.well-known": {
target: "http://tinyauth-backend:3000/.well-known",
changeOrigin: true,
rewrite: (path) => path.replace(/^\/\.well-known/, ""),
},
},
allowedHosts: true,
},

5
go.mod
View File

@@ -24,7 +24,7 @@ require (
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
golang.org/x/oauth2 v0.34.0
gotest.tools/v3 v3.5.2
modernc.org/sqlite v1.44.0
modernc.org/sqlite v1.44.1
)
require (
@@ -61,6 +61,7 @@ require (
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -119,7 +120,7 @@ require (
golang.org/x/text v0.33.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.67.4 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect

10
go.sum
View File

@@ -103,6 +103,8 @@ github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -383,8 +385,8 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -393,8 +395,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc=
modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE=
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -0,0 +1,3 @@
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";

View File

@@ -0,0 +1,27 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL
);

View File

@@ -30,6 +30,7 @@ type BootstrapApp struct {
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig
}
services Services
}
@@ -84,6 +85,12 @@ func (app *BootstrapApp) Setup() error {
app.context.oauthProviders[id] = provider
}
// Setup OIDC clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.context.oidcClients = append(app.context.oidcClients, client)
}
// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
@@ -144,10 +151,18 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})
if services.authService.UserAuthConfigured() {
if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
})
}
if services.authService.LdapAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP",
ID: "ldap",
OAuth: false,
})
}
@@ -232,7 +247,7 @@ func (app *BootstrapApp) heartbeat() {
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
for ; true; <-ticker.C {
for range ticker.C {
tlog.App.Debug().Msg("Sending heartbeat")
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
@@ -264,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
defer ticker.Stop()
ctx := context.Background()
for ; true; <-ticker.C {
for range ticker.C {
tlog.App.Debug().Msg("Cleaning up old database sessions")
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
if err != nil {

View File

@@ -2,14 +2,22 @@ package bootstrap
import (
"fmt"
"slices"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/middleware"
"github.com/gin-gonic/gin"
)
var DEV_MODES = []string{"main", "test", "development"}
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, config.Version) {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New()
engine.Use(gin.Recovery())
@@ -78,6 +86,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)
@@ -101,5 +113,9 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
healthController.SetupRoutes()
wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine)
wellknownController.SetupRoutes()
return engine, nil
}

View File

@@ -12,6 +12,7 @@ type Services struct {
dockerService *service.DockerService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
}
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
@@ -67,6 +68,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, dockerService, services.ldapService, queries)
err = authService.Init()
@@ -87,5 +89,21 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.oauthBrokerService = oauthBrokerService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, queries)
err = oidcService.Init()
if err != nil {
return Services{}, err
}
services.oidcService = oidcService
return services, nil
}

View File

@@ -25,6 +25,7 @@ type Config struct {
Auth AuthConfig `description:"Authentication configuration." yaml:"auth"`
Apps map[string]App `description:"Application ACLs configuration." yaml:"apps"`
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"`
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
@@ -60,6 +61,12 @@ type OAuthConfig struct {
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
}
type OIDCConfig struct {
PrivateKeyPath string `description:"Path to the private key file." yaml:"privateKeyPath"`
PublicKeyPath string `description:"Path to the public key file." yaml:"publicKeyPath"`
Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"`
}
type UIConfig struct {
Title string `description:"The title of the UI." yaml:"title"`
ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"`
@@ -67,14 +74,15 @@ type UIConfig struct {
}
type LdapConfig struct {
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
}
type LogConfig struct {
@@ -113,16 +121,25 @@ type Claims struct {
}
type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID."`
ClientSecret string `description:"OAuth client secret."`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret."`
Scopes []string `description:"OAuth scopes."`
RedirectURL string `description:"OAuth redirect URL."`
AuthURL string `description:"OAuth authorization URL."`
TokenURL string `description:"OAuth token URL."`
UserinfoURL string `description:"OAuth userinfo URL."`
Insecure bool `description:"Allow insecure OAuth connections."`
Name string `description:"Provider name in UI."`
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
TokenURL string `description:"OAuth token URL." yaml:"tokenUrl"`
UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl"`
Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure"`
Name string `description:"Provider name in UI." yaml:"name"`
}
type OIDCClientConfig struct {
ID string `description:"OIDC client ID." yaml:"-"`
ClientID string `description:"OIDC client ID." yaml:"clientId"`
ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"`
TrustedRedirectURIs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"`
Name string `description:"Client name in UI." yaml:"name"`
}
var OverrideProviders = map[string]string{
@@ -138,28 +155,22 @@ type User struct {
TotpSecret string
}
type LdapUser struct {
DN string
Groups []string
}
type UserSearch struct {
Username string
Type string // local, ldap or unknown
}
type SessionCookie struct {
UUID string
Username string
Name string
Email string
Provider string
TotpPending bool
OAuthGroups string
OAuthName string
OAuthSub string
}
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
@@ -167,6 +178,7 @@ type UserContext struct {
TotpEnabled bool
OAuthName string
OAuthSub string
LdapGroups string
}
// API responses and queries
@@ -195,6 +207,7 @@ type App struct {
IP AppIP `description:"IP access configuration." yaml:"ip"`
Response AppResponse `description:"Response customization." yaml:"response"`
Path AppPath `description:"Path access configuration." yaml:"path"`
LDAP AppLDAP `description:"LDAP access configuration." yaml:"ldap"`
}
type AppConfig struct {
@@ -211,6 +224,10 @@ type AppOAuth struct {
Groups string `description:"Comma-separated list of required OAuth groups." yaml:"groups"`
}
type AppLDAP struct {
Groups string `description:"Comma-separated list of required LDAP groups." yaml:"groups"`
}
type AppIP struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"`

View File

@@ -21,7 +21,6 @@ type UserContextResponse struct {
OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
OAuthSub string `json:"oauthSub"`
}
type AppContextResponse struct {
@@ -90,7 +89,6 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
OAuth: context.OAuth,
TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
OAuthSub: context.OAuthSub,
}
if err != nil {

View File

@@ -13,11 +13,11 @@ import (
"gotest.tools/v3/assert"
)
var controllerCfg = controller.ContextControllerConfig{
var contextControllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
},
{
@@ -35,13 +35,14 @@ var controllerCfg = controller.ContextControllerConfig{
DisableUIWarnings: false,
}
var userContext = config.UserContext{
var contextCtrlTestContext = config.UserContext{
Username: "testuser",
Name: "testuser",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
@@ -64,7 +65,7 @@ func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httpt
group := router.Group("/api")
ctrl := controller.NewContextController(controllerCfg, group)
ctrl := controller.NewContextController(contextControllerCfg, group)
ctrl.SetupRoutes()
return router, recorder
@@ -74,14 +75,14 @@ func TestAppContextHandler(t *testing.T) {
expectedRes := controller.AppContextResponse{
Status: 200,
Message: "Success",
Providers: controllerCfg.Providers,
Title: controllerCfg.Title,
AppURL: controllerCfg.AppURL,
CookieDomain: controllerCfg.CookieDomain,
ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage,
BackgroundImage: controllerCfg.BackgroundImage,
OAuthAutoRedirect: controllerCfg.OAuthAutoRedirect,
DisableUIWarnings: controllerCfg.DisableUIWarnings,
Providers: contextControllerCfg.Providers,
Title: contextControllerCfg.Title,
AppURL: contextControllerCfg.AppURL,
CookieDomain: contextControllerCfg.CookieDomain,
ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
BackgroundImage: contextControllerCfg.BackgroundImage,
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
DisableUIWarnings: contextControllerCfg.DisableUIWarnings,
}
router, recorder := setupContextController(nil)
@@ -102,20 +103,20 @@ func TestUserContextHandler(t *testing.T) {
expectedRes := controller.UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: userContext.IsLoggedIn,
Username: userContext.Username,
Name: userContext.Name,
Email: userContext.Email,
Provider: userContext.Provider,
OAuth: userContext.OAuth,
TotpPending: userContext.TotpPending,
OAuthName: userContext.OAuthName,
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
Username: contextCtrlTestContext.Username,
Name: contextCtrlTestContext.Name,
Email: contextCtrlTestContext.Email,
Provider: contextCtrlTestContext.Provider,
OAuth: contextCtrlTestContext.OAuth,
TotpPending: contextCtrlTestContext.TotpPending,
OAuthName: contextCtrlTestContext.OAuthName,
}
// Test with context
router, recorder := setupContextController(&[]gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &userContext)
c.Set("context", &contextCtrlTestContext)
c.Next()
},
})

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -188,10 +189,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = user.PreferredUsername
} else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
username = strings.Replace(user.Email, "@", "_", -1)
username = strings.Replace(user.Email, "@", "_", 1)
}
sessionCookie := config.SessionCookie{
sessionCookie := repository.Session{
Username: username,
Name: name,
Email: user.Email,

View File

@@ -0,0 +1,414 @@
package controller
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
type OIDCControllerConfig struct{}
type OIDCController struct {
config OIDCControllerConfig
router *gin.RouterGroup
oidc *service.OIDCService
}
type AuthorizeCallback struct {
Code string `url:"code"`
State string `url:"state"`
}
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required" url:"grant_type"`
Code string `form:"code" url:"code"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
RefreshToken string `form:"refresh_token" url:"refresh_token"`
}
type CallbackError struct {
Error string `url:"error"`
ErrorDescription string `url:"error_description"`
State string `url:"state"`
}
type ErrorScreen struct {
Error string `url:"error"`
}
type ClientRequest struct {
ClientID string `uri:"id" binding:"required"`
}
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
return &OIDCController{
config: config,
oidc: oidcService,
router: router,
}
}
func (controller *OIDCController) SetupRoutes() {
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
}
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": client.ClientID,
"name": client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
userContext, err := utils.GetContext(c)
if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
var req service.AuthorizeRequest
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "")
return
}
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
return
}
err = controller.oidc.ValidateAuthorizeParams(req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return
}
controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "")
return
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
code := rand.Text()
// Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub)
if err != nil {
controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State)
return
}
err = controller.oidc.StoreCode(c, sub, code, req)
if err != nil {
controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State)
return
}
// We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return
}
}
queries, err := query.Values(AuthorizeCallback{
Code: code,
State: req.State,
})
if err != nil {
controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
})
}
func (controller *OIDCController) Token(c *gin.Context) {
var req TokenRequest
err := c.Bind(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
err = controller.oidc.ValidateGrantType(req.GrantType)
if err != nil {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{
"error": err.Error(),
})
return
}
rclientId, rclientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing authorization header")
c.Header("www-authenticate", "basic")
c.JSON(401, gin.H{
"error": "invalid_client",
})
return
}
client, ok := controller.oidc.GetClient(rclientId)
if !ok {
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
if client.ClientSecret != rclientSecret {
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
var tokenResponse service.TokenResponse
switch req.GrantType {
case "authorization_code":
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
if err != nil {
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrCodeExpired) {
tlog.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
tokenResponse = tokenRes
case "refresh_token":
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, rclientId)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
tlog.App.Error().Err(err).Msg("Refresh token expired")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrInvalidClient) {
tlog.App.Error().Err(err).Msg("Invalid client")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Error().Err(err).Msg("Failed to refresh access token")
c.JSON(400, gin.H{
"error": "server_error",
})
return
}
tokenResponse = tokenRes
}
c.JSON(200, tokenResponse)
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
authorization := c.GetHeader("Authorization")
tokenType, token, ok := strings.Cut(authorization, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
if err != nil {
if err == service.ErrTokenNotFound {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}
tlog.App.Err(err).Msg("Failed to get token entry")
c.JSON(401, gin.H{
"error": "server_error",
})
return
}
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_scope",
})
return
}
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry")
c.JSON(401, gin.H{
"error": "server_error",
})
return
}
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
}
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
tlog.App.Error().Err(err).Msg(reason)
if callback != "" {
errorQueries := CallbackError{
Error: callbackError,
}
if reasonUser != "" {
errorQueries.ErrorDescription = reasonUser
}
if state != "" {
errorQueries.State = state
}
queries, err := query.Values(errorQueries)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", callback, queries.Encode()),
})
return
}
errorQueries := ErrorScreen{
Error: reasonUser,
}
queries, err := query.Values(errorQueries)
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
})
}

View File

@@ -0,0 +1,281 @@
package controller_test
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/bootstrap"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"gotest.tools/v3/assert"
)
var oidcServiceConfig = service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
"client1": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
ClientSecretFile: "",
TrustedRedirectURIs: []string{
"https://example.com/oauth/callback",
},
Name: "Client 1",
},
},
PrivateKeyPath: "/tmp/tinyauth_oidc_key",
PublicKeyPath: "/tmp/tinyauth_oidc_key.pub",
Issuer: "https://example.com",
SessionExpiry: 3600,
}
var oidcCtrlTestContext = config.UserContext{
Username: "test",
Name: "Test",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "ldap", // ldap in order to test the groups
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
OAuthName: "",
OAuthSub: "",
LdapGroups: "test1,test2",
}
// Test is not amazing, but it will confirm the OIDC server works
func TestOIDCController(t *testing.T) {
tlog.NewSimpleLogger().Init()
// Create an app instance
app := bootstrap.NewBootstrapApp(config.Config{})
// Get db
db, err := app.SetupDatabase("/tmp/tinyauth.db")
assert.NilError(t, err)
// Create queries
queries := repository.New(db)
// Create a new OIDC Servicee
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
err = oidcService.Init()
assert.NilError(t, err)
// Create test router
gin.SetMode(gin.TestMode)
router := gin.Default()
router.Use(func(c *gin.Context) {
c.Set("context", &oidcCtrlTestContext)
c.Next()
})
group := router.Group("/api")
// Register oidc controller
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group)
oidcController.SetupRoutes()
// Get redirect URL test
recorder := httptest.NewRecorder()
marshalled, err := json.Marshal(service.AuthorizeRequest{
Scope: "openid profile email groups",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://example.com/oauth/callback",
State: "some-state",
})
assert.NilError(t, err)
req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled)))
assert.NilError(t, err)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson := map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
redirect_uri, ok := resJson["redirect_uri"].(string)
assert.Assert(t, ok)
u, err := url.Parse(redirect_uri)
assert.NilError(t, err)
m, err := url.ParseQuery(u.RawQuery)
assert.NilError(t, err)
assert.Equal(t, m["state"][0], "some-state")
code := m["code"][0]
// Exchange code for token
recorder = httptest.NewRecorder()
params, err := query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
accessToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
_, ok = resJson["id_token"].(string)
assert.Assert(t, ok)
refreshToken, ok := resJson["refresh_token"].(string)
assert.Assert(t, ok)
expires_in, ok := resJson["expires_in"].(float64)
assert.Assert(t, ok)
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
// Ensure code is expired
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
// Test userinfo
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
_, ok = resJson["sub"].(string)
assert.Assert(t, ok)
name, ok := resJson["name"].(string)
assert.Assert(t, ok)
assert.Equal(t, name, oidcCtrlTestContext.Name)
email, ok := resJson["email"].(string)
assert.Assert(t, ok)
assert.Equal(t, email, oidcCtrlTestContext.Email)
preferred_username, ok := resJson["preferred_username"].(string)
assert.Assert(t, ok)
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
// Not sure why this is failing, will look into it later
igroups, ok := resJson["groups"].([]any)
assert.Assert(t, ok)
groups := make([]string, len(igroups))
for i, group := range igroups {
groups[i], ok = group.(string)
assert.Assert(t, ok)
}
assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
// Test refresh token
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
newToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
assert.Assert(t, newToken != accessToken)
// Ensure old token is invalid
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
// Test new token
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
}

View File

@@ -173,7 +173,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Provider == "basic" && userContext.TotpEnabled {
if userContext.IsBasicAuth && userContext.TotpEnabled {
tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access")
userContext.IsLoggedIn = false
}
@@ -212,11 +212,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
groupOK := controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.OAuth || userContext.Provider == "ldap" {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User groups do not match resource requirements")
if req.Proxy == "nginx" || !isBrowser {
c.JSON(403, gin.H{
@@ -251,7 +257,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
controller.setHeaders(c, acls)

View File

@@ -143,11 +143,11 @@ func TestProxyHandler(t *testing.T) {
// Test logged in user
c := gin.CreateTestContextOnly(recorder, router)
err := authService.CreateSessionCookie(c, &config.SessionCookie{
err := authService.CreateSessionCookie(c, &repository.Session{
Username: "testuser",
Name: "testuser",
Email: "testuser@example.com",
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
})
@@ -164,7 +164,7 @@ func TestProxyHandler(t *testing.T) {
Email: "testuser@example.com",
IsLoggedIn: true,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
@@ -192,8 +192,9 @@ func TestProxyHandler(t *testing.T) {
Name: "testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
IsBasicAuth: true,
OAuth: false,
Provider: "basic",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: true,

View File

@@ -5,7 +5,7 @@ import (
"strings"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -112,11 +112,11 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if user.TotpSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
err := controller.auth.CreateSessionCookie(c, &config.SessionCookie{
err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
TotpPending: true,
})
@@ -138,11 +138,15 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
sessionCookie := config.SessionCookie{
sessionCookie := repository.Session{
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
if userSearch.Type == "ldap" {
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
@@ -248,11 +252,11 @@ func (controller *UserController) totpHandler(c *gin.Context) {
controller.auth.RecordLoginAttempt(context.Username, true)
sessionCookie := config.SessionCookie{
sessionCookie := repository.Session{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")

View File

@@ -204,7 +204,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -267,7 +267,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -290,7 +290,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,

View File

@@ -0,0 +1,85 @@
package controller
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/service"
)
type OpenIDConnectConfiguration struct {
Issuer string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
JwksUri string `json:"jwks_uri"`
ScopesSupported []string `json:"scopes_supported"`
ResponseTypesSupported []string `json:"response_types_supported"`
GrantTypesSupported []string `json:"grant_types_supported"`
SubjectTypesSupported []string `json:"subject_types_supported"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
ClaimsSupported []string `json:"claims_supported"`
ServiceDocumentation string `json:"service_documentation"`
}
type WellKnownControllerConfig struct{}
type WellKnownController struct {
config WellKnownControllerConfig
engine *gin.Engine
oidc *service.OIDCService
}
func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController {
return &WellKnownController{
config: config,
oidc: oidc,
engine: engine,
}
}
func (controller *WellKnownController) SetupRoutes() {
controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration)
controller.engine.GET("/.well-known/jwks.json", controller.JWKS)
}
func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) {
issuer := controller.oidc.GetIssuer()
c.JSON(200, OpenIDConnectConfiguration{
Issuer: issuer,
AuthorizationEndpoint: fmt.Sprintf("%s/authorize", issuer),
TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", issuer),
UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", issuer),
JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", issuer),
ScopesSupported: service.SupportedScopes,
ResponseTypesSupported: service.SupportedResponseTypes,
GrantTypesSupported: service.SupportedGrantTypes,
SubjectTypesSupported: []string{"pairwise"},
IDTokenSigningAlgValuesSupported: []string{"RS256"},
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "groups"},
ServiceDocumentation: "https://tinyauth.app/docs/reference/openid",
})
}
func (controller *WellKnownController) JWKS(c *gin.Context) {
jwks, err := controller.oidc.GetJWK()
if err != nil {
c.JSON(500, gin.H{
"status": "500",
"message": "failed to get JWK",
})
return
}
c.Header("content-type", "application/json")
c.Writer.WriteString(`{"keys":[`)
c.Writer.Write(jwks)
c.Writer.WriteString(`]}`)
c.Status(http.StatusOK)
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"slices"
"strings"
"time"
@@ -13,6 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"}
type ContextMiddlewareConfig struct {
CookieDomain string
}
@@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error {
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
// There is no point in trying to get credentials if it's an OIDC endpoint
path := c.Request.URL.Path
if slices.Contains(OIDCIgnorePaths, strings.TrimSuffix(path, "/")) {
c.Next()
return
}
cookie, err := m.auth.GetSessionCookie(c)
if err != nil {
@@ -49,7 +59,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
@@ -58,22 +68,44 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
switch cookie.Provider {
case "username":
case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" || userSearch.Type == "error" {
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
}
if userSearch.Type != cookie.Provider {
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
m.auth.DeleteSessionCookie(c)
c.Next()
return
}
var ldapGroups []string
if cookie.Provider == "ldap" {
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
ldapGroups = ldapUser.Groups
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: cookie.Provider,
IsLoggedIn: true,
LdapGroups: strings.Join(ldapGroups, ","),
})
c.Next()
return
@@ -155,20 +187,32 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
Provider: "basic",
Provider: "local",
IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "",
IsBasicAuth: true,
})
c.Next()
return
case "ldap":
tlog.App.Debug().Msg("Basic auth user is LDAP")
ldapUser, err := m.auth.GetLdapUser(basic.Username)
if err != nil {
tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "basic",
IsLoggedIn: true,
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
IsBasicAuth: true,
})
c.Next()
return

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/steveiliop56/tinyauth/internal/assets"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
)
@@ -39,11 +40,10 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
path := strings.TrimPrefix(c.Request.URL.Path, "/")
tlog.App.Debug().Str("path", path).Msg("path")
switch strings.SplitN(path, "/", 2)[0] {
case "api":
c.Next()
return
case "resources":
case "api", "resources", ".well-known":
c.Next()
return
default:

View File

@@ -4,6 +4,34 @@
package repository
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
}
type Session struct {
UUID string
Username string

View File

@@ -0,0 +1,470 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: oidc_queries.sql
package repository
import (
"context"
)
const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.CodeHash,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token_hash",
"refresh_token_hash",
"scope",
"client_id",
"token_expires_at",
"refresh_token_expires_at"
) VALUES (
?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
`
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope,
arg.ClientID,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING sub, name, preferred_username, email, "groups", updated_at
`
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
)
return i, err
}
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcCode
for rows.Next() {
var i OidcCode
if err := rows.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
`
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcToken
for rows.Next() {
var i OidcToken
if err := rows.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err
}
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err
}
const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?
`
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
return err
}
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
return err
}
const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "access_token_hash" = ?
`
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?
`
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
)
return i, err
}
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
`
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
RefreshTokenHash_2 string
}
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.RefreshTokenHash_2,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
)
return i, err
}

View File

@@ -1,7 +1,7 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: queries.sql
// source: session_queries.sql
package repository
@@ -10,7 +10,7 @@ import (
)
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (
INSERT INTO "sessions" (
"uuid",
"username",
"email",

View File

@@ -19,6 +19,11 @@ import (
"golang.org/x/crypto/bcrypt"
)
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
@@ -36,24 +41,28 @@ type AuthServiceConfig struct {
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
LDAPGroupsCacheTTL int
}
type AuthService struct {
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
loginMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
return &AuthService{
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldap: ldap,
queries: queries,
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
ldap: ldap,
queries: queries,
}
}
@@ -70,12 +79,12 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
}
if auth.ldap != nil {
userDN, err := auth.ldap.Search(username)
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "error",
Type: "unknown",
}
}
@@ -131,6 +140,41 @@ func (auth *AuthService) GetLocalUser(username string) config.User {
return config.User{}
}
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
if auth.ldap == nil {
return config.LdapUser{}, errors.New("LDAP service not initialized")
}
auth.ldapGroupsMutex.RLock()
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) {
return config.LdapUser{
DN: userDN,
Groups: entry.Groups,
}, nil
}
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
}
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return config.LdapUser{
DN: userDN,
Groups: groups,
}, nil
}
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
@@ -190,7 +234,7 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
uuid, err := uuid.NewRandom()
if err != nil {
@@ -300,20 +344,20 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil {
return config.SessionCookie{}, err
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return config.SessionCookie{}, fmt.Errorf("session not found")
return repository.Session{}, fmt.Errorf("session not found")
}
return config.SessionCookie{}, err
return repository.Session{}, err
}
currentTime := time.Now().Unix()
@@ -324,7 +368,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
}
return config.SessionCookie{}, fmt.Errorf("session expired due to max lifetime exceeded")
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
}
}
@@ -333,10 +377,10 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
}
return config.SessionCookie{}, fmt.Errorf("session expired")
return repository.Session{}, fmt.Errorf("session expired")
}
return config.SessionCookie{
return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
@@ -349,8 +393,12 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
}, nil
}
func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.config.Users) > 0 || auth.ldap != nil
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
@@ -393,6 +441,22 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
// Check for block list
if path.Block != "" {

View File

@@ -116,7 +116,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
return ldap.conn, nil
}
func (ldap *LdapService) Search(username string) (string, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
@@ -145,6 +145,48 @@ func (ldap *LdapService) Search(username string) (string, error) {
return userDN, nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return []string{}, err
}
groupDNs := []string{}
for _, entry := range searchResult.Entries {
groupDNs = append(groupDNs, entry.DN)
}
groups := []string{}
// I guess it should work for most ldap providers
for _, dn := range groupDNs {
rdnParts, err := ldapgo.ParseDN(dn)
if err != nil {
return []string{}, err
}
if len(rdnParts.RDNs) == 0 || len(rdnParts.RDNs[0].Attributes) == 0 {
return []string{}, fmt.Errorf("invalid DN format: %s", dn)
}
groups = append(groups, rdnParts.RDNs[0].Attributes[0].Value)
}
return groups, nil
}
func (ldap *LdapService) BindService(rebind bool) error {
// Locks must not be used for initial binding attempt
if rebind {

View File

@@ -0,0 +1,641 @@
package service
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"database/sql"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/url"
"os"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices"
)
var (
SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedResponseTypes = []string{"code"}
SupportedGrantTypes = []string{"authorization_code", "refresh_token"}
)
var (
ErrCodeExpired = errors.New("code_expired")
ErrCodeNotFound = errors.New("code_not_found")
ErrTokenNotFound = errors.New("token_not_found")
ErrTokenExpired = errors.New("token_expired")
ErrInvalidClient = errors.New("invalid_client")
)
type ClaimSet struct {
Iss string `json:"iss"`
Aud string `json:"aud"`
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
}
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"`
UpdatedAt int64 `json:"updated_at"`
}
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
Scope string `json:"scope"`
}
type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state" binding:"required"`
}
type OIDCServiceConfig struct {
Clients map[string]config.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
SessionExpiry int
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
}
}
// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
func (service *OIDCService) Init() error {
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
if err != nil {
return err
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
return errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
if block == nil {
return errors.New("failed to decode private key")
}
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
if err != nil {
return err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
return errors.New("failed to decode public key")
}
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
}
service.publicKey = publicKey
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]config.OIDCClientConfig)
for id, client := range service.config.Clients {
client.ID = id
service.clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
}
return nil
}
func (service *OIDCService) GetIssuer() string {
return service.issuer
}
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
client, ok := service.clients[id]
return client, ok
}
func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error {
// Validate client ID
client, ok := service.GetClient(req.ClientID)
if !ok {
return errors.New("access_denied")
}
// Scopes
scopes := strings.Split(req.Scope, " ")
if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" {
return errors.New("invalid_scope")
}
for _, scope := range scopes {
if strings.TrimSpace(scope) == "" {
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
}
}
// Response type
if !slices.Contains(SupportedResponseTypes, req.ResponseType) {
return errors.New("unsupported_response_type")
}
// Redirect URI
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
return errors.New("invalid_request_uri")
}
return nil
}
func (service *OIDCService) filterScopes(scopes []string) []string {
return utils.Filter(scopes, func(scope string) bool {
return slices.Contains(SupportedScopes, scope)
})
}
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
// Fixed 10 minutes
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
// Insert the code into the database
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Sub: sub,
CodeHash: service.Hash(code),
// Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI,
ClientID: req.ClientID,
ExpiresAt: expiresAt,
})
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
UpdatedAt: time.Now().Unix(),
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
}
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
func (service *OIDCService) ValidateGrantType(grantType string) error {
if !slices.Contains(SupportedGrantTypes, grantType) {
return errors.New("unsupported_grant_type")
}
return nil
}
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
}
if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil {
return repository.OidcCode{}, err
}
err = service.DeleteUserinfo(c, oidcCode.Sub)
if err != nil {
return repository.OidcCode{}, err
}
return repository.OidcCode{}, ErrCodeExpired
}
return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: service.privateKey,
}, &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]any{
"typ": "jwt",
"jku": fmt.Sprintf("%s/.well-known/jwks.json", service.issuer),
},
})
if err != nil {
return "", err
}
claims := ClaimSet{
Iss: service.issuer,
Aud: client.ClientID,
Sub: sub,
Iat: createdAt,
Exp: expiresAt,
}
payload, err := json.Marshal(claims)
if err != nil {
return "", err
}
object, err := signer.Sign(payload)
if err != nil {
return "", err
}
token, err := object.CompactSerialize()
if err != nil {
return "", err
}
return token, nil
}
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) {
idToken, err := service.generateIDToken(client, sub)
if err != nil {
return TokenResponse{}, err
}
accessToken := rand.Text()
refreshToken := rand.Text()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(scope, ",", " "),
}
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: sub,
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken),
ClientID: client.ClientID,
Scope: scope,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
})
if err != nil {
return TokenResponse{}, err
}
return tokenResponse, nil
}
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if err == sql.ErrNoRows {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
}
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return TokenResponse{}, ErrTokenExpired
}
// Ensure the client ID in the request matches the client ID in the token
if entry.ClientID != reqClientId {
return TokenResponse{}, ErrInvalidClient
}
idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID,
}, entry.Sub)
if err != nil {
return TokenResponse{}, err
}
accessToken := rand.Text()
newRefreshToken := rand.Text()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
})
if err != nil {
return TokenResponse{}, err
}
return tokenResponse, nil
}
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcCode(c, codeHash)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, tokenHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if err == sql.ErrNoRows {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
}
if entry.TokenExpiresAt < time.Now().Unix() {
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, tokenHash)
if err != nil {
return repository.OidcToken{}, err
}
err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
}
}
return repository.OidcToken{}, ErrTokenExpired
}
return entry, nil
}
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
return service.queries.GetOidcUserInfo(c, sub)
}
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
userInfo := UserinfoResponse{
Sub: user.Sub,
UpdatedAt: user.UpdatedAt,
}
if slices.Contains(scopes, "profile") {
userInfo.Name = user.Name
userInfo.PreferredUsername = user.PreferredUsername
}
if slices.Contains(scopes, "email") {
userInfo.Email = user.Email
}
if slices.Contains(scopes, "groups") {
if user.Groups != "" {
userInfo.Groups = strings.Split(user.Groups, ",")
} else {
userInfo.Groups = []string{}
}
}
return userInfo
}
func (service *OIDCService) Hash(token string) string {
hasher := sha256.New()
hasher.Write([]byte(token))
return fmt.Sprintf("%x", hasher.Sum(nil))
}
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
return nil
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() {
// We need a context for the routine
ctx := context.Background()
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
for range ticker.C {
currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if err == sql.ErrNoRows {
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session")
}
}
}
}
}
func (service *OIDCService) GetJWK() ([]byte, error) {
jwk := jose.JSONWebKey{
Key: service.privateKey,
Algorithm: string(jose.RS256),
Use: "sig",
}
return jwk.Public().MarshalJSON()
}

View File

@@ -2,6 +2,7 @@ package utils
import (
"errors"
"fmt"
"net"
"net/url"
"strings"
@@ -95,7 +96,7 @@ func IsRedirectSafe(redirectURL string, domain string) bool {
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, domain) {
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}

View File

@@ -205,4 +205,9 @@ func TestIsRedirectSafe(t *testing.T) {
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
}

113
sql/oidc_queries.sql Normal file
View File

@@ -0,0 +1,113 @@
-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: GetOidcCodeUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "sub" = ?;
-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING *;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token_hash",
"refresh_token_hash",
"scope",
"client_id",
"token_expires_at",
"refresh_token_expires_at"
) VALUES (
?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING *;
-- name: GetOidcToken :one
SELECT * FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?;
-- name: GetOidcTokenBySub :one
SELECT * FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING *;
-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING *;

27
sql/oidc_schemas.sql Normal file
View File

@@ -0,0 +1,27 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL
);

View File

@@ -1,5 +1,5 @@
-- name: CreateSession :one
INSERT INTO sessions (
INSERT INTO "sessions" (
"uuid",
"username",
"email",

View File

@@ -1,8 +1,8 @@
version: "2"
sql:
- engine: "sqlite"
queries: "sql/queries.sql"
schema: "sql/schema.sql"
queries: "sql/*_queries.sql"
schema: "sql/*_schemas.sql"
gen:
go:
package: "repository"
@@ -12,6 +12,7 @@ sql:
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
redirect_uri: "RedirectURI"
overrides:
- column: "sessions.oauth_groups"
go_type: "string"
@@ -19,3 +20,5 @@ sql:
go_type: "string"
- column: "sessions.oauth_sub"
go_type: "string"
- column: "sessions.ldap_groups"
go_type: "string"