Compare commits

...

9 Commits

Author SHA1 Message Date
Stavros
97e90ea560 feat: implement basic oidc functionality 2026-01-22 22:30:23 +02:00
Stavros
6ae7c1cbda wip: authorize page 2026-01-21 20:12:32 +02:00
Stavros
7dc3525a8d chore: add oidc base config 2026-01-21 18:54:00 +02:00
Stavros
402dfa727b chore: update traefik and add use infisical as an options for secrets in
dev
2026-01-21 12:50:03 +02:00
Stavros
d67c3ab8a4 fix: ensure safe redirect check only accepts actual domains 2026-01-17 20:36:42 +02:00
Stavros
87e2b52a04 fix: set gin mode correctly 2026-01-17 20:26:48 +02:00
dependabot[bot]
f36b62561a chore(deps): bump modernc.org/sqlite in the minor-patch group (#588)
Bumps the minor-patch group with 1 update: [modernc.org/sqlite](https://gitlab.com/cznic/sqlite).


Updates `modernc.org/sqlite` from 1.44.0 to 1.44.1
- [Commits](https://gitlab.com/cznic/sqlite/compare/v1.44.0...v1.44.1)

---
updated-dependencies:
- dependency-name: modernc.org/sqlite
  dependency-version: 1.44.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-17 20:04:08 +02:00
dependabot[bot]
d2a146ead0 chore(deps-dev): bump @types/node in /frontend in the minor-patch group (#589)
Bumps the minor-patch group in /frontend with 1 update: [@types/node](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/node).


Updates `@types/node` from 25.0.8 to 25.0.9
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/node)

---
updated-dependencies:
- dependency-name: "@types/node"
  dependency-version: 25.0.9
  dependency-type: direct:development
  update-type: version-update:semver-patch
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-17 20:03:51 +02:00
Stavros
4926e53409 feat: ldap group acls (#590)
* wip

* refactor: remove useless session struct abstraction

* feat: retrieve and store groups from ldap provider

* chore: fix merge issue

* refactor: rework ldap group fetching logic

* feat: store ldap group results in cache

* fix: review nitpicks

* fix: review feedback
2026-01-17 20:03:29 +02:00
47 changed files with 1450 additions and 134 deletions

View File

@@ -89,7 +89,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -144,7 +144,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0

View File

@@ -67,7 +67,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -119,7 +119,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-s -w -X tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/steveiliop56/tinyauth/internal/config.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0

3
.gitignore vendored
View File

@@ -36,3 +36,6 @@
# debug files
__debug_*
# infisical
/.infisical.json

13
.zed/debug.json Normal file
View File

@@ -0,0 +1,13 @@
[
{
"label": "Attach to remote Delve",
"adapter": "Delve",
"mode": "remote",
"remotePath": "/tinyauth",
"request": "attach",
"tcp_connection": {
"host": "127.0.0.1",
"port": 4000,
},
},
]

View File

@@ -39,7 +39,10 @@ COPY ./cmd ./cmd
COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN CGO_ENABLED=0 go build -ldflags "-s -w -X tinyauth/internal/config.Version=${VERSION} -X tinyauth/internal/config.CommitHash=${COMMIT_HASH} -X tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/steveiliop56/tinyauth/internal/config.Version=${VERSION} \
-X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM alpine:3.23 AS runner
@@ -54,11 +57,9 @@ EXPOSE 3000
VOLUME ["/data"]
ENV DATABASEPATH=/data/tinyauth.db
ENV TINYAUTH_DATABASEPATH=/data/tinyauth.db
ENV RESOURCESDIR=/data/resources
ENV GIN_MODE=release
ENV TINYAUTH_RESOURCESDIR=/data/resources
ENV PATH=$PATH:/tinyauth

View File

@@ -41,7 +41,10 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "-s -w -X tinyauth/internal/config.Version=${VERSION} -X tinyauth/internal/config.CommitHash=${COMMIT_HASH} -X tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/steveiliop56/tinyauth/internal/config.Version=${VERSION} \
-X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
# Runner
FROM gcr.io/distroless/static-debian12:latest AS runner
@@ -61,8 +64,6 @@ ENV TINYAUTH_DATABASEPATH=/data/tinyauth.db
ENV TINYAUTH_RESOURCESDIR=/data/resources
ENV GIN_MODE=release
ENV PATH=$PATH:/tinyauth
HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 CMD ["tinyauth", "healthcheck"]

View File

@@ -31,9 +31,9 @@ webui: clean-webui
# Build the binary
binary: webui
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
-X tinyauth/internal/config.Version=${TAG_NAME} \
-X tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
-X github.com/steveiliop56/tinyauth/internal/config.Version=${TAG_NAME} \
-X github.com/steveiliop56/tinyauth/internal/config.CommitHash=${COMMIT_HASH} \
-X github.com/steveiliop56/tinyauth/internal/config.BuildTimestamp=${BUILD_TIMESTAMP}" \
-o ${BIN_NAME} ./cmd/tinyauth
# Build for amd64
@@ -59,6 +59,15 @@ test:
develop:
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
# Development - Infisical
develop-infisical:
infisical run --env=dev -- docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans
# Production
prod:
docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# SQL
.PHONY: sql
sql:
sqlc generate

View File

@@ -21,9 +21,9 @@ func NewTinyauthCmdConfiguration() *config.Config {
Address: "0.0.0.0",
},
Auth: config.AuthConfig{
SessionExpiry: 3600,
SessionMaxLifetime: 0,
LoginTimeout: 300,
SessionExpiry: 86400, // 1 day
SessionMaxLifetime: 0, // disabled
LoginTimeout: 300, // 5 minutes
LoginMaxRetries: 3,
},
UI: config.UIConfig{
@@ -32,8 +32,9 @@ func NewTinyauthCmdConfiguration() *config.Config {
BackgroundImage: "/background.jpg",
},
Ldap: config.LdapConfig{
Insecure: false,
SearchFilter: "(uid=%s)",
Insecure: false,
SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes
},
Log: config.LogConfig{
Level: "info",

View File

@@ -1,7 +1,7 @@
services:
traefik:
container_name: traefik
image: traefik:v3.3
image: traefik:v3.6
command: --api.insecure=true --providers.docker
ports:
- 80:80
@@ -50,3 +50,4 @@ services:
labels:
traefik.enable: true
traefik.http.middlewares.tinyauth.forwardauth.address: http://tinyauth-backend:3000/api/auth/traefik
traefik.http.middlewares.tinyauth.forwardauth.authResponseHeaders: remote-user, remote-sub, remote-name, remote-email, remote-groups

View File

@@ -1,7 +1,7 @@
services:
traefik:
container_name: traefik
image: traefik:v3.3
image: traefik:v3.6
command: --api.insecure=true --providers.docker
ports:
- 80:80

View File

@@ -36,7 +36,7 @@
"devDependencies": {
"@eslint/js": "^9.39.2",
"@tanstack/eslint-plugin-query": "^5.91.2",
"@types/node": "^25.0.8",
"@types/node": "^25.0.9",
"@types/react": "^19.2.8",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.2",
@@ -365,7 +365,7 @@
"@types/ms": ["@types/ms@2.1.0", "", {}, "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA=="],
"@types/node": ["@types/node@25.0.8", "", { "dependencies": { "undici-types": "~7.16.0" } }, "sha512-powIePYMmC3ibL0UJ2i2s0WIbq6cg6UyVFQxSCpaPxxzAaziRfimGivjdF943sSGV6RADVbk0Nvlm5P/FB44Zg=="],
"@types/node": ["@types/node@25.0.9", "", { "dependencies": { "undici-types": "~7.16.0" } }, "sha512-/rpCXHlCWeqClNBwUhDcusJxXYDjZTyE8v5oTO7WbL8eij2nKhUeU89/6xgjU7N4/Vh3He0BtyhJdQbDyhiXAw=="],
"@types/react": ["@types/react@19.2.8", "", { "dependencies": { "csstype": "^3.2.2" } }, "sha512-3MbSL37jEchWZz2p2mjntRZtPt837ij10ApxKfgmXCTuHWagYg7iA5bqPw6C8BMPfwidlvfPI/fxOc42HLhcyg=="],

View File

@@ -42,7 +42,7 @@
"devDependencies": {
"@eslint/js": "^9.39.2",
"@tanstack/eslint-plugin-query": "^5.91.2",
"@types/node": "^25.0.8",
"@types/node": "^25.0.9",
"@types/react": "^19.2.8",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.2",

View File

@@ -17,6 +17,7 @@ import { AppContextProvider } from "./context/app-context.tsx";
import { UserContextProvider } from "./context/user-context.tsx";
import { Toaster } from "@/components/ui/sonner";
import { ThemeProvider } from "./components/providers/theme-provider.tsx";
import { AuthorizePage } from "./pages/authorize-page.tsx";
const queryClient = new QueryClient();
@@ -31,6 +32,7 @@ createRoot(document.getElementById("root")!).render(
<Route element={<Layout />} errorElement={<ErrorPage />}>
<Route path="/" element={<App />} />
<Route path="/login" element={<LoginPage />} />
<Route path="/authorize" element={<AuthorizePage />} />
<Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} />
<Route path="/totp" element={<TotpPage />} />

View File

@@ -0,0 +1,139 @@
import { useUserContext } from "@/context/user-context";
import { useMutation, useQuery } from "@tanstack/react-query";
import { Navigate, useNavigate } from "react-router";
import { useLocation } from "react-router";
import {
Card,
CardHeader,
CardTitle,
CardDescription,
CardFooter,
} from "@/components/ui/card";
import { getOidcClientInfoScehma } from "@/schemas/oidc-schemas";
import { Button } from "@/components/ui/button";
import axios from "axios";
import { toast } from "sonner";
type AuthorizePageProps = {
scope: string;
responseType: string;
clientId: string;
redirectUri: string;
state: string;
};
const optionalAuthorizeProps = ["state"];
export const AuthorizePage = () => {
const { isLoggedIn } = useUserContext();
const { search } = useLocation();
const navigate = useNavigate();
const searchParams = new URLSearchParams(search);
// If there is a better way to do this, please do let me know
const props: AuthorizePageProps = {
scope: searchParams.get("scope") || "",
responseType: searchParams.get("response_type") || "",
clientId: searchParams.get("client_id") || "",
redirectUri: searchParams.get("redirect_uri") || "",
state: searchParams.get("state") || "",
};
const getClientInfo = useQuery({
queryKey: ["client", props.clientId],
queryFn: async () => {
const res = await fetch(`/api/oidc/clients/${props.clientId}`);
const data = await getOidcClientInfoScehma.parseAsync(await res.json());
return data;
},
});
const authorizeMutation = useMutation({
mutationFn: () => {
return axios.post("/api/oidc/authorize", {
scope: props.scope,
response_type: props.responseType,
client_id: props.clientId,
redirect_uri: props.redirectUri,
state: props.state,
});
},
mutationKey: ["authorize", props.clientId],
onSuccess: (data) => {
toast.info("Authorized", {
description: "You will be soon redirected to your application",
});
window.location.replace(
`${data.data.redirect_uri}?code=${encodeURIComponent(data.data.code)}&state=${encodeURIComponent(data.data.state)}`,
);
},
onError: (error) => {
window.location.replace(
`/error?error=${encodeURIComponent(error.message)}`,
);
},
});
if (!isLoggedIn) {
// TODO: Pass the params to the login page, so user can login -> authorize
return <Navigate to="/login" replace />;
}
Object.keys(props).forEach((key) => {
if (
!props[key as keyof AuthorizePageProps] &&
!optionalAuthorizeProps.includes(key)
) {
// TODO: Add reason for error
return <Navigate to="/error" replace />;
}
});
if (getClientInfo.isLoading) {
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">Loading...</CardTitle>
<CardDescription>
Please wait while we load the client information.
</CardDescription>
</CardHeader>
</Card>
);
}
if (getClientInfo.isError) {
// TODO: Add reason for error
return <Navigate to="/error" replace />;
}
return (
<Card className="min-w-xs sm:min-w-sm">
<CardHeader>
<CardTitle className="text-3xl">
Continue to {getClientInfo.data?.name || "Unknown"}?
</CardTitle>
<CardDescription>
Would you like to continue to this app? Please keep in mind that this
app will have access to your email and other information.
</CardDescription>
</CardHeader>
<CardFooter className="flex flex-col items-stretch gap-2">
<Button
onClick={() => authorizeMutation.mutate()}
loading={authorizeMutation.isPending}
>
Authorize
</Button>
<Button
onClick={() => navigate("/")}
disabled={authorizeMutation.isPending}
variant="outline"
>
Cancel
</Button>
</CardFooter>
</Card>
);
};

View File

@@ -50,10 +50,12 @@ export const LoginPage = () => {
const redirectUri = searchParams.get("redirect_uri");
const oauthProviders = providers.filter(
(provider) => provider.id !== "username",
(provider) => provider.id !== "local" && provider.id !== "ldap",
);
const userAuthConfigured =
providers.find((provider) => provider.id === "username") !== undefined;
providers.find(
(provider) => provider.id === "local" || provider.id === "ldap",
) !== undefined;
const oauthMutation = useMutation({
mutationFn: (provider: string) =>

View File

@@ -0,0 +1,5 @@
import { z } from "zod";
export const getOidcClientInfoScehma = z.object({
name: z.string(),
});

4
go.mod
View File

@@ -24,7 +24,7 @@ require (
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546
golang.org/x/oauth2 v0.34.0
gotest.tools/v3 v3.5.2
modernc.org/sqlite v1.44.0
modernc.org/sqlite v1.44.1
)
require (
@@ -119,7 +119,7 @@ require (
golang.org/x/text v0.33.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.67.4 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
rsc.io/qr v0.2.0 // indirect

8
go.sum
View File

@@ -383,8 +383,8 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
@@ -393,8 +393,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc=
modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE=
modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -0,0 +1,3 @@
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";

View File

@@ -0,0 +1,25 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL
);

View File

@@ -30,6 +30,7 @@ type BootstrapApp struct {
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig
}
services Services
}
@@ -84,6 +85,12 @@ func (app *BootstrapApp) Setup() error {
app.context.oauthProviders[id] = provider
}
// Setup OIDC clients
for id, client := range app.config.OIDC.Clients {
client.ID = id
app.context.oidcClients = append(app.context.oidcClients, client)
}
// Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
@@ -144,10 +151,18 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})
if services.authService.UserAuthConfigured() {
if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
})
}
if services.authService.LdapAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP",
ID: "ldap",
OAuth: false,
})
}
@@ -161,7 +176,7 @@ func (app *BootstrapApp) Setup() error {
app.context.configuredProviders = configuredProviders
// Setup router
router, err := app.setupRouter()
router, err := app.setupRouter(queries)
if err != nil {
return fmt.Errorf("failed to setup routes: %w", err)

View File

@@ -2,14 +2,23 @@ package bootstrap
import (
"fmt"
"slices"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/middleware"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/gin-gonic/gin"
)
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
var DEV_MODES = []string{"main", "test", "development"}
func (app *BootstrapApp) setupRouter(queries *repository.Queries) (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, config.Version) {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New()
engine.Use(gin.Recovery())
@@ -78,6 +87,13 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
oauthController.SetupRoutes()
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{
Clients: app.context.oidcClients,
AppURL: app.config.AppURL,
}, apiRouter, queries)
oidcController.SetupRoutes()
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: app.config.AppURL,
}, apiRouter, app.services.accessControlService, app.services.authService)

View File

@@ -67,6 +67,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, dockerService, services.ldapService, queries)
err = authService.Init()

View File

@@ -25,6 +25,7 @@ type Config struct {
Auth AuthConfig `description:"Authentication configuration." yaml:"auth"`
Apps map[string]App `description:"Application ACLs configuration." yaml:"apps"`
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"`
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
@@ -60,6 +61,10 @@ type OAuthConfig struct {
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
}
type OIDCConfig struct {
Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"`
}
type UIConfig struct {
Title string `description:"The title of the UI." yaml:"title"`
ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"`
@@ -67,14 +72,15 @@ type UIConfig struct {
}
type LdapConfig struct {
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
}
type LogConfig struct {
@@ -113,16 +119,25 @@ type Claims struct {
}
type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID."`
ClientSecret string `description:"OAuth client secret."`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret."`
Scopes []string `description:"OAuth scopes."`
RedirectURL string `description:"OAuth redirect URL."`
AuthURL string `description:"OAuth authorization URL."`
TokenURL string `description:"OAuth token URL."`
UserinfoURL string `description:"OAuth userinfo URL."`
Insecure bool `description:"Allow insecure OAuth connections."`
Name string `description:"Provider name in UI."`
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
TokenURL string `description:"OAuth token URL." yaml:"tokenUrl"`
UserinfoURL string `description:"OAuth userinfo URL." yaml:"userinfoUrl"`
Insecure bool `description:"Allow insecure OAuth connections." yaml:"insecure"`
Name string `description:"Provider name in UI." yaml:"name"`
}
type OIDCClientConfig struct {
ID string `description:"OIDC client ID." yaml:"-"`
ClientID string `description:"OIDC client ID." yaml:"clientId"`
ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"`
TrustedRedirectURLs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"`
Name string `description:"Client name in UI." yaml:"name"`
}
var OverrideProviders = map[string]string{
@@ -138,28 +153,22 @@ type User struct {
TotpSecret string
}
type LdapUser struct {
DN string
Groups []string
}
type UserSearch struct {
Username string
Type string // local, ldap or unknown
}
type SessionCookie struct {
UUID string
Username string
Name string
Email string
Provider string
TotpPending bool
OAuthGroups string
OAuthName string
OAuthSub string
}
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
@@ -167,6 +176,7 @@ type UserContext struct {
TotpEnabled bool
OAuthName string
OAuthSub string
LdapGroups string
}
// API responses and queries
@@ -195,6 +205,7 @@ type App struct {
IP AppIP `description:"IP access configuration." yaml:"ip"`
Response AppResponse `description:"Response customization." yaml:"response"`
Path AppPath `description:"Path access configuration." yaml:"path"`
LDAP AppLDAP `description:"LDAP access configuration." yaml:"ldap"`
}
type AppConfig struct {
@@ -211,6 +222,10 @@ type AppOAuth struct {
Groups string `description:"Comma-separated list of required OAuth groups." yaml:"groups"`
}
type AppLDAP struct {
Groups string `description:"Comma-separated list of required LDAP groups." yaml:"groups"`
}
type AppIP struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"`

View File

@@ -21,7 +21,6 @@ type UserContextResponse struct {
OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
OAuthSub string `json:"oauthSub"`
}
type AppContextResponse struct {
@@ -90,7 +89,6 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
OAuth: context.OAuth,
TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
OAuthSub: context.OAuthSub,
}
if err != nil {

View File

@@ -16,8 +16,8 @@ import (
var controllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
},
{
@@ -40,8 +40,9 @@ var userContext = config.UserContext{
Name: "testuser",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -188,10 +189,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = user.PreferredUsername
} else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
username = strings.Replace(user.Email, "@", "_", -1)
username = strings.Replace(user.Email, "@", "_", 1)
}
sessionCookie := config.SessionCookie{
sessionCookie := repository.Session{
Username: username,
Name: name,
Email: user.Email,

View File

@@ -0,0 +1,501 @@
package controller
import (
"fmt"
"slices"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
var (
SupportedResponseTypes = []string{"code"}
SupportedScopes = []string{"openid", "profile", "email", "groups"}
SupportedGrantTypes = []string{"authorization_code"}
)
type OIDCControllerConfig struct {
Clients []config.OIDCClientConfig
AppURL string
}
type OIDCController struct {
config OIDCControllerConfig
router *gin.RouterGroup
queries *repository.Queries
}
type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state" binding:"required"`
}
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
RedirectURI string `form:"redirect_uri" binding:"required"`
}
type CallbackError struct {
Error string `url:"error"`
ErrorDescription string `url:"error_description"`
State string `url:"state"`
}
func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, queries *repository.Queries) *OIDCController {
return &OIDCController{
config: config,
router: router,
queries: queries,
}
}
func (controller *OIDCController) SetupRoutes() {
oidcGroup := controller.router.Group("/oidc")
oidcGroup.GET("/clients/:id", controller.GetClientInfo)
oidcGroup.POST("/authorize", controller.Authorize)
oidcGroup.POST("/token", controller.Token)
oidcGroup.GET("/userinfo", controller.Userinfo)
}
type ClientRequest struct {
ClientID string `uri:"id" binding:"required"`
}
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
var req ClientRequest
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
var client *config.OIDCClientConfig
// Inefficient yeah, but it will be good until we have thousands of clients
for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == req.ClientID {
client = &clientCfg
break
}
}
if client == nil {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"client": &client.ClientID,
"name": &client.Name,
})
}
func (controller *OIDCController) Authorize(c *gin.Context) {
// Check if we are logged in
userContext, err := utils.GetContext(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
// OIDC stuff
var req AuthorizeRequest
err = c.BindJSON(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
// TODO: All these errors should redirect to the error page with an explanation
// Validate client ID
var client *config.OIDCClientConfig
for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == req.ClientID {
client = &clientCfg
break
}
}
if client == nil {
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
c.JSON(404, gin.H{
"status": 404,
"message": "Client not found",
})
return
}
// Validate redirect URI
if !slices.Contains(client.TrustedRedirectURLs, req.RedirectURI) {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI not trusted")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
// Validate scopes
reqScopes := strings.Split(req.Scope, " ")
keptScopes := make([]string, 0)
if len(reqScopes) == 0 || strings.TrimSpace(req.Scope) == "" {
queries, err := query.Values(CallbackError{
Error: "invalid_request",
ErrorDescription: "Missing scope parameter",
State: req.State,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to build query")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.Redirect(302, fmt.Sprintf("%s/callback?%s", req.RedirectURI, queries.Encode()))
return
}
for _, scope := range reqScopes {
if slices.Contains(SupportedScopes, scope) {
keptScopes = append(keptScopes, scope)
continue
}
tlog.App.Warn().Str("scope", scope).Msg("Scope not supported, ignoring")
}
// Generate a code and a sub
code, err := utils.GetRandomString(32)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate random string")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
sub, err := utils.GetRandomInt(10)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate random integer")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
// Insert the code into the database
_, err = controller.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
Code: code,
Sub: strconv.Itoa(int(sub)),
Scope: strings.Join(keptScopes, ","),
RedirectURI: req.RedirectURI,
ClientID: client.ClientID,
ExpiresAt: expiresAt,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert code into database")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
// We also need a snapshot of the user that authorized this
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: strconv.Itoa(int(sub)),
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
UpdatedAt: time.Now().Unix(),
}
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
}
_, err = controller.queries.CreateOidcUserInfo(c, userInfoParams)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
// Return code and done
c.JSON(200, gin.H{
"status": 200,
"message": "Authorized",
"code": code,
"state": req.State,
"redirect_uri": req.RedirectURI,
})
}
func (controller *OIDCController) Token(c *gin.Context) {
// Get basic auth
clientId, clientSecret, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Error().Msg("Missing token verifier")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Ensure client exists
var client *config.OIDCClientConfig
for _, clientCfg := range controller.config.Clients {
if clientCfg.ClientID == clientId {
client = &clientCfg
break
}
}
if client == nil {
tlog.App.Warn().Str("client_id", clientId).Msg("Client not found")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
if client.ClientSecret != clientSecret {
tlog.App.Warn().Str("client_id", clientId).Msg("Invalid client secret")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
// Get token
var req TokenRequest
err := c.Bind(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind token request")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Validate grant type
if !slices.Contains(SupportedGrantTypes, req.GrantType) {
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
c.JSON(400, gin.H{
"error": "unsupported_grant_type",
})
return
}
// Find pending code entry
entry, err := controller.queries.GetOidcCode(c, req.Code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to find code in database")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Ensure redirect URIs match
if entry.RedirectURI != req.RedirectURI {
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Generate access token
genToken, err := utils.GetRandomString(29)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to generate access token")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Add tinyauth prefix
token := fmt.Sprintf("ta-%s", genToken)
// TODO: either add a refresh token or customize token expiry
expiresAt := time.Now().Add(time.Duration(3600) * time.Second).Unix()
// Create token entry
_, err = controller.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: entry.Sub,
AccessToken: token,
Scope: entry.Scope,
ClientID: client.ClientID,
ExpiresAt: expiresAt,
})
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create token in database")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Delete code entry
err = controller.queries.DeleteOidcCode(c, entry.Code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
// Respond with token
c.JSON(200, gin.H{
"access_token": token,
"token_type": "bearer",
"expires_in": 3600,
})
}
func (controller *OIDCController) Userinfo(c *gin.Context) {
// Get bearer
authorizationHeader := c.GetHeader("Authorization")
tokenType, token, ok := strings.Cut(authorizationHeader, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
// Get token entry
entry, err := controller.queries.GetOidcToken(c, token)
if err != nil {
tlog.App.Err(err).Msg("Failed to get token entry")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
// Get scopes
scopes := strings.Split(entry.Scope, ",")
// Check if token is expired
if time.Now().Unix() > entry.ExpiresAt {
tlog.App.Warn().Msg("OIDC userinfo accessed with expired token")
err = controller.queries.DeleteOidcToken(c, entry.AccessToken)
if err != nil {
tlog.App.Err(err).Msg("Failed to delete expired token")
}
err = controller.queries.DeleteOidcUserInfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to delete oidc user info")
}
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
// Get user info
user, err := controller.queries.GetOidcUserInfo(c, entry.Sub)
if err != nil {
tlog.App.Err(err).Msg("Failed to get user entry")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
// If we don't have the openid scope, return an error
if !slices.Contains(scopes, "openid") {
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
// Let's build the response
res := map[string]any{
"sub": user.Sub,
"updated_at": user.UpdatedAt,
}
// If we have the profile scope, add the profile stuff
if slices.Contains(scopes, "profile") {
res["name"] = user.Name
res["preferred_username"] = user.PreferredUsername
}
// If we have the email scope, add the email stuff
if slices.Contains(scopes, "email") {
res["email"] = user.Email
}
// If we have the groups scope, add the groups stuff
if slices.Contains(scopes, "groups") {
res["groups"] = user.Groups
}
c.JSON(200, res)
}

View File

@@ -173,7 +173,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Provider == "basic" && userContext.TotpEnabled {
if userContext.IsBasicAuth && userContext.TotpEnabled {
tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access")
userContext.IsLoggedIn = false
}
@@ -212,11 +212,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
groupOK := controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.OAuth || userContext.Provider == "ldap" {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User groups do not match resource requirements")
if req.Proxy == "nginx" || !isBrowser {
c.JSON(403, gin.H{
@@ -251,7 +257,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
controller.setHeaders(c, acls)

View File

@@ -143,11 +143,11 @@ func TestProxyHandler(t *testing.T) {
// Test logged in user
c := gin.CreateTestContextOnly(recorder, router)
err := authService.CreateSessionCookie(c, &config.SessionCookie{
err := authService.CreateSessionCookie(c, &repository.Session{
Username: "testuser",
Name: "testuser",
Email: "testuser@example.com",
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
})
@@ -164,7 +164,7 @@ func TestProxyHandler(t *testing.T) {
Email: "testuser@example.com",
IsLoggedIn: true,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
@@ -192,8 +192,9 @@ func TestProxyHandler(t *testing.T) {
Name: "testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
IsBasicAuth: true,
OAuth: false,
Provider: "basic",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: true,

View File

@@ -5,7 +5,7 @@ import (
"strings"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
@@ -112,11 +112,11 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if user.TotpSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
err := controller.auth.CreateSessionCookie(c, &config.SessionCookie{
err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
TotpPending: true,
})
@@ -138,11 +138,15 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
sessionCookie := config.SessionCookie{
sessionCookie := repository.Session{
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
if userSearch.Type == "ldap" {
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
@@ -248,11 +252,11 @@ func (controller *UserController) totpHandler(c *gin.Context) {
controller.auth.RecordLoginAttempt(context.Username, true)
sessionCookie := config.SessionCookie{
sessionCookie := repository.Session{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")

View File

@@ -204,7 +204,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -267,7 +267,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -290,7 +290,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,

View File

@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"slices"
"strings"
"time"
@@ -13,6 +14,8 @@ import (
"github.com/gin-gonic/gin"
)
var OIDCIgnorePaths = []string{"/api/oidc/token", "/api/oidc/userinfo"}
type ContextMiddlewareConfig struct {
CookieDomain string
}
@@ -37,6 +40,13 @@ func (m *ContextMiddleware) Init() error {
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
// There is no point in trying to get credentials if it's an OIDC endpoint
path := c.Request.URL.Path
if slices.Contains(OIDCIgnorePaths, path) {
c.Next()
return
}
cookie, err := m.auth.GetSessionCookie(c)
if err != nil {
@@ -49,7 +59,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
@@ -58,22 +68,44 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
switch cookie.Provider {
case "username":
case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" || userSearch.Type == "error" {
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
}
if userSearch.Type != cookie.Provider {
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
m.auth.DeleteSessionCookie(c)
c.Next()
return
}
var ldapGroups []string
if cookie.Provider == "ldap" {
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
ldapGroups = ldapUser.Groups
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: cookie.Provider,
IsLoggedIn: true,
LdapGroups: strings.Join(ldapGroups, ","),
})
c.Next()
return
@@ -155,20 +187,32 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
Provider: "basic",
Provider: "local",
IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "",
IsBasicAuth: true,
})
c.Next()
return
case "ldap":
tlog.App.Debug().Msg("Basic auth user is LDAP")
ldapUser, err := m.auth.GetLdapUser(basic.Username)
if err != nil {
tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "basic",
IsLoggedIn: true,
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
IsBasicAuth: true,
})
c.Next()
return

View File

@@ -4,6 +4,32 @@
package repository
type OidcCode struct {
Sub string
Code string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
}
type OidcToken struct {
Sub string
AccessToken string
Scope string
ClientID string
ExpiresAt int64
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
}
type Session struct {
UUID string
Username string

View File

@@ -0,0 +1,224 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: oidc_queries.sql
package repository
import (
"context"
)
const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code",
"scope",
"redirect_uri",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING sub, code, scope, redirect_uri, client_id, expires_at
`
type CreateOidcCodeParams struct {
Sub string
Code string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.Code,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.Code,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token",
"scope",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?
)
RETURNING sub, access_token, scope, client_id, expires_at
`
type CreateOidcTokenParams struct {
Sub string
AccessToken string
Scope string
ClientID string
ExpiresAt int64
}
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub,
arg.AccessToken,
arg.Scope,
arg.ClientID,
arg.ExpiresAt,
)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessToken,
&i.Scope,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING sub, name, preferred_username, email, "groups", updated_at
`
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
)
return i, err
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code" = ?
`
func (q *Queries) DeleteOidcCode(ctx context.Context, code string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, code)
return err
}
const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token" = ?
`
func (q *Queries) DeleteOidcToken(ctx context.Context, accessToken string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessToken)
return err
}
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
return err
}
const getOidcCode = `-- name: GetOidcCode :one
SELECT sub, code, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
WHERE "code" = ?
`
func (q *Queries) GetOidcCode(ctx context.Context, code string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, code)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.Code,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token, scope, client_id, expires_at FROM "oidc_tokens"
WHERE "access_token" = ?
`
func (q *Queries) GetOidcToken(ctx context.Context, accessToken string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessToken)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessToken,
&i.Scope,
&i.ClientID,
&i.ExpiresAt,
)
return i, err
}
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
)
return i, err
}

View File

@@ -1,7 +1,7 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.30.0
// source: queries.sql
// source: session_queries.sql
package repository
@@ -10,7 +10,7 @@ import (
)
const createSession = `-- name: CreateSession :one
INSERT INTO sessions (
INSERT INTO "sessions" (
"uuid",
"username",
"email",

View File

@@ -19,6 +19,11 @@ import (
"golang.org/x/crypto/bcrypt"
)
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
@@ -36,24 +41,28 @@ type AuthServiceConfig struct {
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
LDAPGroupsCacheTTL int
}
type AuthService struct {
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
loginMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
return &AuthService{
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldap: ldap,
queries: queries,
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
ldap: ldap,
queries: queries,
}
}
@@ -70,12 +79,12 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
}
if auth.ldap != nil {
userDN, err := auth.ldap.Search(username)
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "error",
Type: "unknown",
}
}
@@ -131,6 +140,41 @@ func (auth *AuthService) GetLocalUser(username string) config.User {
return config.User{}
}
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
if auth.ldap == nil {
return config.LdapUser{}, errors.New("LDAP service not initialized")
}
auth.ldapGroupsMutex.RLock()
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) {
return config.LdapUser{
DN: userDN,
Groups: entry.Groups,
}, nil
}
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
}
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return config.LdapUser{
DN: userDN,
Groups: groups,
}, nil
}
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
@@ -190,7 +234,7 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
uuid, err := uuid.NewRandom()
if err != nil {
@@ -300,20 +344,20 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil {
return config.SessionCookie{}, err
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return config.SessionCookie{}, fmt.Errorf("session not found")
return repository.Session{}, fmt.Errorf("session not found")
}
return config.SessionCookie{}, err
return repository.Session{}, err
}
currentTime := time.Now().Unix()
@@ -324,7 +368,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
}
return config.SessionCookie{}, fmt.Errorf("session expired due to max lifetime exceeded")
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
}
}
@@ -333,10 +377,10 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
}
return config.SessionCookie{}, fmt.Errorf("session expired")
return repository.Session{}, fmt.Errorf("session expired")
}
return config.SessionCookie{
return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
@@ -349,8 +393,12 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
}, nil
}
func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.config.Users) > 0 || auth.ldap != nil
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
@@ -393,6 +441,22 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
// Check for block list
if path.Block != "" {

View File

@@ -116,7 +116,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
return ldap.conn, nil
}
func (ldap *LdapService) Search(username string) (string, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
@@ -145,6 +145,48 @@ func (ldap *LdapService) Search(username string) (string, error) {
return userDN, nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return []string{}, err
}
groupDNs := []string{}
for _, entry := range searchResult.Entries {
groupDNs = append(groupDNs, entry.DN)
}
groups := []string{}
// I guess it should work for most ldap providers
for _, dn := range groupDNs {
rdnParts, err := ldapgo.ParseDN(dn)
if err != nil {
return []string{}, err
}
if len(rdnParts.RDNs) == 0 || len(rdnParts.RDNs[0].Attributes) == 0 {
return []string{}, fmt.Errorf("invalid DN format: %s", dn)
}
groups = append(groups, rdnParts.RDNs[0].Attributes[0].Value)
}
return groups, nil
}
func (ldap *LdapService) BindService(rebind bool) error {
// Locks must not be used for initial binding attempt
if rebind {

View File

@@ -2,6 +2,7 @@ package utils
import (
"errors"
"fmt"
"net"
"net/url"
"strings"
@@ -95,7 +96,7 @@ func IsRedirectSafe(redirectURL string, domain string) bool {
hostname := parsed.Hostname()
if strings.HasSuffix(hostname, domain) {
if strings.HasSuffix(hostname, fmt.Sprintf(".%s", domain)) {
return true
}

View File

@@ -205,4 +205,9 @@ func TestIsRedirectSafe(t *testing.T) {
redirectURL = "http://example.org/page"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
// Case with malicious domain
redirectURL = "https://malicious-example.com/yoyo"
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
}

View File

@@ -1,8 +1,11 @@
package utils
import (
"crypto/rand"
"encoding/base64"
"errors"
"math"
"math/big"
"net"
"regexp"
"strings"
@@ -105,3 +108,28 @@ func GenerateUUID(str string) string {
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
return uuid.String()
}
// These could definitely be improved A LOT but at least they are cryptographically secure
func GetRandomString(length int) (string, error) {
if length < 1 {
return "", errors.New("length must be greater than 0")
}
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
state := base64.RawURLEncoding.EncodeToString(b)
return state[:length], nil
}
func GetRandomInt(length int) (int64, error) {
if length < 1 {
return 0, errors.New("length must be greater than 0")
}
a, err := rand.Int(rand.Reader, big.NewInt(int64(math.Pow(10, float64(length)))))
if err != nil {
return 0, err
}
return a.Int64(), nil
}

View File

@@ -2,6 +2,7 @@ package utils_test
import (
"os"
"strconv"
"testing"
"github.com/steveiliop56/tinyauth/internal/utils"
@@ -147,3 +148,25 @@ func TestGenerateUUID(t *testing.T) {
id3 := utils.GenerateUUID("differentstring")
assert.Assert(t, id1 != id3)
}
func TestGetRandomString(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomString(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(state))
// Test with zero length
state, err = utils.GetRandomString(0)
assert.Error(t, err, "length must be greater than 0")
}
func TestGetRandomInt(t *testing.T) {
// Test with normal length
state, err := utils.GetRandomInt(16)
assert.NilError(t, err)
assert.Equal(t, 16, len(strconv.Itoa(int(state))))
// Test with zero length
state, err = utils.GetRandomInt(0)
assert.Error(t, err, "length must be greater than 0")
}

61
sql/oidc_queries.sql Normal file
View File

@@ -0,0 +1,61 @@
-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code",
"scope",
"redirect_uri",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code" = ?;
-- name: GetOidcCode :one
SELECT * FROM "oidc_codes"
WHERE "code" = ?;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token",
"scope",
"client_id",
"expires_at"
) VALUES (
?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token" = ?;
-- name: GetOidcToken :one
SELECT * FROM "oidc_tokens"
WHERE "access_token" = ?;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at"
) VALUES (
?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = ?;

25
sql/oidc_schemas.sql Normal file
View File

@@ -0,0 +1,25 @@
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL
);

View File

@@ -1,5 +1,5 @@
-- name: CreateSession :one
INSERT INTO sessions (
INSERT INTO "sessions" (
"uuid",
"username",
"email",

View File

@@ -1,8 +1,8 @@
version: "2"
sql:
- engine: "sqlite"
queries: "sql/queries.sql"
schema: "sql/schema.sql"
queries: "sql/*_queries.sql"
schema: "sql/*_schemas.sql"
gen:
go:
package: "repository"
@@ -12,6 +12,7 @@ sql:
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
redirect_uri: "RedirectURI"
overrides:
- column: "sessions.oauth_groups"
go_type: "string"
@@ -19,3 +20,5 @@ sql:
go_type: "string"
- column: "sessions.oauth_sub"
go_type: "string"
- column: "sessions.ldap_groups"
go_type: "string"