Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot] ac7a6ca804 chore(deps): bump the minor-patch group across 1 directory with 6 updates
Bumps the minor-patch group with 5 updates in the / directory:

| Package | From | To |
| --- | --- | --- |
| [golang.org/x/crypto](https://github.com/golang/crypto) | `0.50.0` | `0.51.0` |
| [k8s.io/apimachinery](https://github.com/kubernetes/apimachinery) | `0.36.0` | `0.36.1` |
| [k8s.io/client-go](https://github.com/kubernetes/client-go) | `0.36.0` | `0.36.1` |
| [modernc.org/sqlite](https://gitlab.com/cznic/sqlite) | `1.50.0` | `1.50.1` |
| [tailscale.com](https://github.com/tailscale/tailscale) | `1.96.5` | `1.98.2` |



Updates `golang.org/x/crypto` from 0.50.0 to 0.51.0
- [Commits](https://github.com/golang/crypto/compare/v0.50.0...v0.51.0)

Updates `golang.org/x/tools` from 0.43.0 to 0.44.0
- [Release notes](https://github.com/golang/tools/releases)
- [Commits](https://github.com/golang/tools/compare/v0.43.0...v0.44.0)

Updates `k8s.io/apimachinery` from 0.36.0 to 0.36.1
- [Commits](https://github.com/kubernetes/apimachinery/compare/v0.36.0...v0.36.1)

Updates `k8s.io/client-go` from 0.36.0 to 0.36.1
- [Changelog](https://github.com/kubernetes/client-go/blob/master/CHANGELOG.md)
- [Commits](https://github.com/kubernetes/client-go/compare/v0.36.0...v0.36.1)

Updates `modernc.org/sqlite` from 1.50.0 to 1.50.1
- [Changelog](https://gitlab.com/cznic/sqlite/blob/master/CHANGELOG.md)
- [Commits](https://gitlab.com/cznic/sqlite/compare/v1.50.0...v1.50.1)

Updates `tailscale.com` from 1.96.5 to 1.98.2
- [Release notes](https://github.com/tailscale/tailscale/releases)
- [Commits](https://github.com/tailscale/tailscale/compare/v1.96.5...v1.98.2)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.51.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-patch
- dependency-name: golang.org/x/tools
  dependency-version: 0.44.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-patch
- dependency-name: k8s.io/apimachinery
  dependency-version: 0.36.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: minor-patch
- dependency-name: k8s.io/client-go
  dependency-version: 0.36.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: minor-patch
- dependency-name: modernc.org/sqlite
  dependency-version: 1.50.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: minor-patch
- dependency-name: tailscale.com
  dependency-version: 1.98.2
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-21 08:35:28 +00:00
76 changed files with 2205 additions and 3057 deletions
+1 -70
View File
@@ -7,9 +7,7 @@ TINYAUTH_APPURL=
# database config
# The database driver to use. Valid values: sqlite, postgres, memory.
TINYAUTH_DATABASE_DRIVER="sqlite"
# The path to the SQLite database file, or connection URL when driver is postgres.
# The path to the database, including file name.
TINYAUTH_DATABASE_PATH="./tinyauth.db"
# analytics config
@@ -32,8 +30,6 @@ TINYAUTH_SERVER_PORT=3000
TINYAUTH_SERVER_ADDRESS="0.0.0.0"
# The path to the Unix socket.
TINYAUTH_SERVER_SOCKETPATH=
# Enable listening on both TCP and Unix socket at the same time.
TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
# auth config
@@ -41,52 +37,8 @@ TINYAUTH_SERVER_CONCURRENTLISTENERSENABLED=false
TINYAUTH_AUTH_IP_ALLOW=
# List of blocked IPs or CIDR ranges.
TINYAUTH_AUTH_IP_BLOCK=
# List of IPs or CIDR ranges that bypass authentication entirely.
TINYAUTH_AUTH_IP_BYPASS=
# Comma-separated list of users (username:hashed_password).
TINYAUTH_AUTH_USERS=
# Enable subdomains support.
TINYAUTH_AUTH_SUBDOMAINSENABLED=true
# Full name of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_NAME=
# Given (first) name of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_GIVENNAME=
# Family (last) name of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_FAMILYNAME=
# Middle name of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_MIDDLENAME=
# Nickname of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_NICKNAME=
# URL of the user's profile page.
TINYAUTH_AUTH_USERATTRIBUTES_name_PROFILE=
# URL of the user's profile picture.
TINYAUTH_AUTH_USERATTRIBUTES_name_PICTURE=
# URL of the user's website.
TINYAUTH_AUTH_USERATTRIBUTES_name_WEBSITE=
# Email address of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_EMAIL=
# Gender of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_GENDER=
# Birthdate of the user (YYYY-MM-DD).
TINYAUTH_AUTH_USERATTRIBUTES_name_BIRTHDATE=
# Time zone of the user (e.g. Europe/Athens).
TINYAUTH_AUTH_USERATTRIBUTES_name_ZONEINFO=
# Locale of the user (e.g. en-US).
TINYAUTH_AUTH_USERATTRIBUTES_name_LOCALE=
# Phone number of the user.
TINYAUTH_AUTH_USERATTRIBUTES_name_PHONENUMBER=
# Full mailing address, formatted for display.
TINYAUTH_AUTH_USERATTRIBUTES_name_ADDRESS_FORMATTED=
# Street address.
TINYAUTH_AUTH_USERATTRIBUTES_name_ADDRESS_STREETADDRESS=
# City or locality.
TINYAUTH_AUTH_USERATTRIBUTES_name_ADDRESS_LOCALITY=
# State, province, or region.
TINYAUTH_AUTH_USERATTRIBUTES_name_ADDRESS_REGION=
# Zip or postal code.
TINYAUTH_AUTH_USERATTRIBUTES_name_ADDRESS_POSTALCODE=
# Country.
TINYAUTH_AUTH_USERATTRIBUTES_name_ADDRESS_COUNTRY=
# Path to the users file.
TINYAUTH_AUTH_USERSFILE=
# Enable secure cookies.
@@ -101,8 +53,6 @@ TINYAUTH_AUTH_LOGINTIMEOUT=300
TINYAUTH_AUTH_LOGINMAXRETRIES=3
# Comma-separated list of trusted proxy addresses.
TINYAUTH_AUTH_TRUSTEDPROXIES=
# ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow.
TINYAUTH_AUTH_ACLS_POLICY="allow"
# apps config
@@ -151,10 +101,6 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
# Path to the file containing the OAuth client secret.
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
# Comma-separated list of allowed OAuth domains for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
# Path to the OAuth whitelist file for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
# OAuth scopes.
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
# OAuth redirect URL.
@@ -218,8 +164,6 @@ TINYAUTH_LDAP_AUTHCERT=
TINYAUTH_LDAP_AUTHKEY=
# Cache duration for LDAP group membership in seconds.
TINYAUTH_LDAP_GROUPCACHETTL=900
# Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment.
TINYAUTH_LABELPROVIDER="auto"
# log config
@@ -239,16 +183,3 @@ TINYAUTH_LOG_STREAMS_APP_LEVEL=
TINYAUTH_LOG_STREAMS_AUDIT_ENABLED=false
# Log level for this stream. Use global if empty.
TINYAUTH_LOG_STREAMS_AUDIT_LEVEL=
# tailscale config
# Enable Tailscale integration.
TINYAUTH_TAILSCALE_ENABLED=false
# Tailscale state directory.
TINYAUTH_TAILSCALE_DIR="./tailscale_state"
# Tailscale hostname.
TINYAUTH_TAILSCALE_HOSTNAME=
# Tailscale auth key.
TINYAUTH_TAILSCALE_AUTHKEY=
# Use ephemeral Tailscale node.
TINYAUTH_TAILSCALE_EPHEMERAL=false
+2 -2
View File
@@ -29,7 +29,7 @@ jobs:
run: go mod download
- name: Setup sqlc
uses: sqlc-dev/setup-sqlc@v5
uses: sqlc-dev/setup-sqlc@v4
with:
sqlc-version: "1.31.1"
@@ -62,6 +62,6 @@ jobs:
run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
+6 -6
View File
@@ -83,7 +83,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-amd64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -128,7 +128,7 @@ jobs:
- name: Build
run: |
cp -r frontend/dist internal/assets/dist
go build -ldflags "-X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
go build -ldflags "-s -w -X github.com/tinyauthapp/tinyauth/internal/model.Version=${{ needs.generate-metadata.outputs.VERSION }} -X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${{ needs.generate-metadata.outputs.COMMIT_HASH }} -X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}" -o tinyauth-arm64 ./cmd/tinyauth
env:
CGO_ENABLED: 0
@@ -166,7 +166,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -224,7 +224,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -282,7 +282,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -340,7 +340,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
+4 -8
View File
@@ -136,7 +136,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -150,7 +150,6 @@ jobs:
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
- name: Export digest
run: |
@@ -192,7 +191,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/amd64
@@ -207,7 +206,6 @@ jobs:
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
- name: Export digest
run: |
@@ -248,7 +246,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -262,7 +260,6 @@ jobs:
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
- name: Export digest
run: |
@@ -304,7 +301,7 @@ jobs:
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4
- name: Build and push
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7
id: build
with:
platforms: linux/arm64
@@ -319,7 +316,6 @@ jobs:
VERSION=${{ needs.generate-metadata.outputs.VERSION }}
COMMIT_HASH=${{ needs.generate-metadata.outputs.COMMIT_HASH }}
BUILD_TIMESTAMP=${{ needs.generate-metadata.outputs.BUILD_TIMESTAMP }}
LDFLAGS="-s -w"
- name: Export digest
run: |
+1 -1
View File
@@ -38,6 +38,6 @@ jobs:
retention-days: 5
- name: Upload to code-scanning
uses: github/codeql-action/upload-sarif@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4
uses: github/codeql-action/upload-sarif@68bde559dea0fdcac2102bfdf6230c5f70eb485e # v4
with:
sarif_file: results.sarif
+2 -3
View File
@@ -1,5 +1,5 @@
# Site builder
FROM node:26.2-alpine3.23 AS frontend-builder
FROM node:26.1-alpine3.23 AS frontend-builder
WORKDIR /frontend
@@ -27,7 +27,6 @@ FROM golang:1.26-alpine3.23 AS builder
ARG VERSION
ARG COMMIT_HASH
ARG BUILD_TIMESTAMP
ARG LDFLAGS
WORKDIR /tinyauth
@@ -40,7 +39,7 @@ COPY ./cmd ./cmd
COPY ./internal ./internal
COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
+2 -3
View File
@@ -1,5 +1,5 @@
# Site builder
FROM node:26.2-alpine3.23 AS frontend-builder
FROM node:26.1-alpine3.23 AS frontend-builder
WORKDIR /frontend
@@ -27,7 +27,6 @@ FROM golang:1.26-alpine3.23 AS builder
ARG VERSION
ARG COMMIT_HASH
ARG BUILD_TIMESTAMP
ARG LDFLAGS
WORKDIR /tinyauth
@@ -42,7 +41,7 @@ COPY --from=frontend-builder /frontend/dist ./internal/assets/dist
RUN mkdir -p data
RUN CGO_ENABLED=0 go build -ldflags "${LDFLAGS} \
RUN CGO_ENABLED=0 go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${VERSION} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" ./cmd/tinyauth
+77 -64
View File
@@ -1,5 +1,5 @@
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
@@ -7,15 +7,17 @@
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
@@ -24,34 +26,44 @@ them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
@@ -60,7 +72,7 @@ modification follow.
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
@@ -537,45 +549,35 @@ to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
@@ -633,29 +635,40 @@ the "copyright" line and a pointer to where the full notice is found.
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
GNU General Public License for more details.
You should have received a copy of the GNU Affero General Public License
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
+1 -11
View File
@@ -8,7 +8,6 @@ TAG_NAME := $(shell git describe --abbrev=0 --exact-match 2> /dev/null || echo "
COMMIT_HASH := $(shell git rev-parse HEAD)
BUILD_TIMESTAMP := $(shell date '+%Y-%m-%dT%H:%M:%S')
BIN_NAME := tinyauth-$(GOARCH)
LDFLAGS := -s -w
# Development vars
DEV_COMPOSE := $(shell test -f "docker-compose.test.yml" && echo "docker-compose.test.yml" || echo "docker-compose.dev.yml" )
@@ -37,7 +36,7 @@ webui: clean-webui
# Build the binary
binary: webui
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "${LDFLAGS} \
CGO_ENABLED=$(CGO_ENABLED) go build -ldflags "-s -w \
-X github.com/tinyauthapp/tinyauth/internal/model.Version=${TAG_NAME} \
-X github.com/tinyauthapp/tinyauth/internal/model.CommitHash=${COMMIT_HASH} \
-X github.com/tinyauthapp/tinyauth/internal/model.BuildTimestamp=${BUILD_TIMESTAMP}" \
@@ -62,15 +61,6 @@ binary-linux-arm64:
test:
go test -v ./...
# Go vet
.PHONY: vet
vet:
go vet ./...
# Go race
test-race:
go test -race ./...
# Development
dev:
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
+4 -1
View File
@@ -28,6 +28,9 @@ Tinyauth is the simplest and tiniest authentication and authorization server you
> [!NOTE]
> This is the main development branch. For the latest stable release, see the [documentation](https://tinyauth.app) or the latest stable tag.
> [!NOTE]
> Tinyauth is in the process of migrating to the new [tinyauthapp](https://github.com/tinyauthapp) organization. The organization **is official** and it will host all of the Tinyauth related repositories in the future.
## Getting Started
You can get started with Tinyauth by following the guide in the [documentation](https://tinyauth.app/docs/getting-started). There is also an available [docker-compose](./docker-compose.example.yml) file that has Traefik, Whoami and Tinyauth to demonstrate its capabilities (keep in mind that this file lives in the development branch so it may have updates that are not yet released).
@@ -56,7 +59,7 @@ If you like, you can help translate Tinyauth into more languages by visiting the
## License
Tinyauth is licensed under the GNU Affero General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) AGPL-licensed code must also be made available under the AGPL along with build & install instructions. If you run a modified version over a network, you must also make the source available to the users of that service. For more information about the license check the [license](LICENSE) file.
Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You may copy, distribute and modify the software as long as you track changes/dates in source files. Any modifications to or software including (via compiler) GPL-licensed code must also be made available under the GPL along with build & install instructions. For more information about the license check the [license](./LICENSE) file.
## Sponsors
+2 -2
View File
@@ -58,8 +58,8 @@
"invalidInput": "Input non valido",
"domainWarningTitle": "Dominio non valido",
"domainWarningSubtitle": "Stai accedendo a questa istanza da un dominio errato. Scegliendo di procedere, potresti incontrare problemi con l'autenticazione.",
"domainWarningCurrent": "Attuale:",
"domainWarningExpected": "Previsto:",
"domainWarningCurrent": "Current:",
"domainWarningExpected": "Expected:",
"ignoreTitle": "Ignora",
"goToCorrectDomainTitle": "Vai al dominio corretto",
"authorizeTitle": "Autorizza",
+1 -1
View File
@@ -57,7 +57,7 @@
"fieldRequired": "Ово поље је неопходно",
"invalidInput": "Неисправан унос",
"domainWarningTitle": "Неисправан домен",
"domainWarningSubtitle": "Приступате овој инстанци са неисправног домена. Ако наставите, можете наићи на проблеме са аутентификацијом.",
"domainWarningSubtitle": "You are accessing this instance from an incorrect domain. If you proceed, you may encounter issues with authentication.",
"domainWarningCurrent": "Тренутни:",
"domainWarningExpected": "Очекивани:",
"ignoreTitle": "Игнориши",
+1 -1
View File
@@ -46,7 +46,7 @@ export const LoginPage = () => {
const { t } = useTranslation();
const [showRedirectButton, setShowRedirectButton] = useState(false);
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== undefined);
const [useTailscale, setUseTailscale] = useState(tailscale.nodeName !== "");
const hasAutoRedirectedRef = useRef(false);
+1 -1
View File
@@ -18,7 +18,7 @@ const totpSchema = z.object({
});
const tailscaleSchema = z.object({
nodeName: z.string().optional(),
nodeName: z.string(),
});
export const userContextSchema = z.object({
+4 -10
View File
@@ -12,21 +12,19 @@ require (
github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/go-querystring v1.2.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.9.2
github.com/mdp/qrterminal/v3 v3.2.1
github.com/pquerna/otp v1.5.0
github.com/rs/zerolog v1.35.1
github.com/steveiliop56/ding v0.2.0
github.com/stretchr/testify v1.11.1
github.com/tinyauthapp/paerser v0.0.0-20260410140347-85c3740d6298
github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.52.0
golang.org/x/crypto v0.51.0
golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.44.0
k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.1
modernc.org/sqlite v1.50.1
tailscale.com v1.98.3
tailscale.com v1.98.2
)
require (
@@ -92,10 +90,6 @@ require (
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.3.0 // indirect
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.5 // indirect
@@ -157,9 +151,9 @@ require (
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.35.0 // indirect
golang.org/x/net v0.54.0 // indirect
golang.org/x/net v0.53.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/sys v0.44.0 // indirect
golang.org/x/term v0.43.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/time v0.14.0 // indirect
+8 -23
View File
@@ -143,8 +143,6 @@ github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbww
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ=
github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc h1:8WFBn63wegobsYAX0YjD+8suexZDga5CctH4CCTx2+8=
github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw=
github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4=
github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU=
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q=
github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -253,16 +251,6 @@ github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeV
github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U=
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA=
github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI=
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw=
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
@@ -402,15 +390,12 @@ github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/steveiliop56/ding v0.2.0 h1:m/Fj99wBpVVLHlpqb2RDJkWubOc5cWJ11ZYCHya3Sk0=
github.com/steveiliop56/ding v0.2.0/go.mod h1:bE2u2XH7CjhPzbb/0Ems+D8YZlf2Ae+eKhj00UR1iAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
@@ -493,8 +478,8 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
@@ -503,8 +488,8 @@ golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w=
golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -512,8 +497,8 @@ golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
@@ -605,5 +590,5 @@ sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
tailscale.com v1.98.3 h1:caAbG4UfkKfKPE6b1fj5t4ep5qrwEis5AJu91ruvePw=
tailscale.com v1.98.3/go.mod h1:U23ZwbZlKJMNU7CScy+lCVVlece/S5n09q0nyudncBI=
tailscale.com v1.98.2 h1:HP5gt0qyLKtJoDV7PMUvPpiXjMFk4nzXMbm7JdjttMY=
tailscale.com v1.98.2/go.mod h1:U23ZwbZlKJMNU7CScy+lCVVlece/S5n09q0nyudncBI=
+1 -1
View File
@@ -11,5 +11,5 @@ var FrontendAssets embed.FS
// Migrations
//
//go:embed migrations/sqlite/*.sql migrations/postgres/*.sql
//go:embed migrations/sqlite/*.sql
var Migrations embed.FS
@@ -1,4 +0,0 @@
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";
DROP TABLE IF EXISTS "sessions";
@@ -1,60 +0,0 @@
CREATE TABLE "sessions" (
"uuid" TEXT NOT NULL PRIMARY KEY,
"username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"name" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"totp_pending" BOOLEAN NOT NULL,
"oauth_groups" TEXT NOT NULL DEFAULT '',
"expiry" BIGINT NOT NULL,
"created_at" BIGINT NOT NULL,
"oauth_name" TEXT NOT NULL DEFAULT '',
"oauth_sub" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"code_challenge" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_userinfo" (
"sub" TEXT NOT NULL PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" BIGINT NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
CREATE INDEX idx_sessions_expiry ON "sessions" ("expiry");
@@ -1,46 +0,0 @@
DROP TABLE IF EXISTS "oidc_sessions";
CREATE TABLE "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"code_challenge" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT ''
);
CREATE TABLE "oidc_userinfo" (
"sub" TEXT NOT NULL PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" BIGINT NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
@@ -1,28 +0,0 @@
/*
This migration will nuke the entire setup of OIDC sessions and merge everything
into one table.
*/
/*
Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal
*/
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";
/*
Create a new simple OIDC sessions table that will hold tokens + userinfo.
*/
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"access_token_hash" TEXT NOT NULL UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL
);
@@ -1,46 +0,0 @@
DROP TABLE IF EXISTS "oidc_sessions";
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
@@ -1,28 +0,0 @@
/*
This migration will nuke the entire setup of OIDC sessions and merge everything
into one table.
*/
/*
Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal
*/
DROP TABLE IF EXISTS "oidc_tokens";
DROP TABLE IF EXISTS "oidc_userinfo";
DROP TABLE IF EXISTS "oidc_codes";
/*
Create a new simple OIDC sessions table that will hold tokens + userinfo.
*/
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"access_token_hash" TEXT NOT NULL UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"userinfo_json" TEXT NOT NULL
);
+18 -43
View File
@@ -13,11 +13,11 @@ import (
"os/signal"
"sort"
"strings"
"sync"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
@@ -26,12 +26,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
// Shutdown order for go routines
// 1. Janitor routines (e.g. database cleanup, heartbeat) - ding.RingMinor
// 2. HTTP server listeners - ding.RingNormal
// 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor
// 4. Database connection - ding.RingCritical
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
@@ -54,7 +48,7 @@ type BootstrapApp struct {
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
wg sync.WaitGroup
listeners []Listener
}
@@ -70,10 +64,6 @@ func (app *BootstrapApp) Setup() error {
app.ctx = ctx
app.cancel = cancel
// Create a ding instance
dg := ding.New(ctx)
app.ding = dg
// setup logger
log := logger.NewLogger().WithConfig(app.config.Log)
log.Init()
@@ -107,12 +97,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to load users: %w", err)
}
if users != nil {
app.runtime.LocalUsers = *users
} else {
log.App.Debug().Msg("No local users found, local authentication will not be available")
app.runtime.LocalUsers = []model.LocalUser{}
}
app.runtime.LocalUsers = *users
// load oauth whitelist
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
@@ -127,13 +112,6 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders {
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
}
provider.Whitelist = providerWhitelist
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
@@ -196,17 +174,15 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to setup database: %w", err)
}
app.ding.Go(func(ctx context.Context) {
<-ctx.Done()
app.log.App.Debug().Msg("Shutting down database connection")
if app.db == nil {
// using memory store, no db instance
return
// after this point, we start initializing dependencies so it's a good time to setup a defer
// to ensure that resources are cleaned up properly in case of an error during initialization
defer func() {
app.cancel()
app.wg.Wait()
if app.db != nil {
app.db.Close()
}
if err := app.db.Close(); err != nil {
app.log.App.Error().Err(err).Msg("Failed to close database connection")
}
}, ding.RingCritical)
}()
// store
app.queries = store
@@ -273,12 +249,12 @@ func (app *BootstrapApp) Setup() error {
// start db cleanup routine
app.log.App.Debug().Msg("Starting database cleanup routine")
app.ding.Go(app.dbCleanupRoutine, ding.RingMinor)
app.wg.Go(app.dbCleanupRoutine)
// if analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
app.log.App.Debug().Msg("Starting heartbeat routine")
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
app.wg.Go(app.heartbeatRoutine)
}
// setup listeners
@@ -299,7 +275,6 @@ func (app *BootstrapApp) Setup() error {
for {
select {
case <-app.ctx.Done():
app.ding.Wait()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil
case err := <-lec:
@@ -310,7 +285,7 @@ func (app *BootstrapApp) Setup() error {
}
}
func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop()
@@ -363,7 +338,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
if res.StatusCode != 200 && res.StatusCode != 201 {
app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status")
}
case <-ctx.Done():
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping heartbeat routine")
ticker.Stop()
return
@@ -371,7 +346,7 @@ func (app *BootstrapApp) heartbeatRoutine(ctx context.Context) {
}
}
func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
func (app *BootstrapApp) dbCleanupRoutine() {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
@@ -380,14 +355,14 @@ func (app *BootstrapApp) dbCleanupRoutine(ctx context.Context) {
case <-ticker.C:
app.log.App.Debug().Msg("Running database cleanup")
err := app.queries.DeleteExpiredSessions(ctx, time.Now().Unix())
err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix())
if err != nil {
app.log.App.Error().Err(err).Msg("Failed to delete expired sessions")
}
app.log.App.Debug().Msg("Database cleanup completed")
case <-ctx.Done():
case <-app.ctx.Done():
app.log.App.Debug().Msg("Stopping database cleanup routine")
ticker.Stop()
return
+9 -57
View File
@@ -6,18 +6,15 @@ import (
"os"
"path/filepath"
"github.com/golang-migrate/migrate/v4"
pgxmigrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "github.com/jackc/pgx/v5/stdlib"
_ "modernc.org/sqlite"
"github.com/tinyauthapp/tinyauth/internal/assets"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/repository/postgres"
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "modernc.org/sqlite"
)
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
@@ -26,10 +23,8 @@ func (app *BootstrapApp) SetupStore() (repository.Store, error) {
return memory.New(), nil
case "sqlite", "":
return app.setupSQLite(app.config.Database.Path)
case "postgres":
return app.setupPostgres(app.config.Database.Path)
default:
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, postgres, memory", app.config.Database.Driver)
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
}
}
@@ -46,9 +41,9 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
return nil, fmt.Errorf("failed to open database: %w", err)
}
cleanup := true
// Close the database if there is an error during migration
defer func() {
if cleanup {
if err != nil {
db.Close()
}
}()
@@ -75,54 +70,11 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err
return nil, fmt.Errorf("failed to create migrator: %w", err)
}
if err = migrator.Up(); err != nil && err != migrate.ErrNoChange {
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
cleanup = false
app.db = db
return sqlite.NewStore(sqlite.New(db)), nil
}
func (app *BootstrapApp) setupPostgres(databaseURL string) (repository.Store, error) {
db, err := sql.Open("pgx", databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
cleanup := true
defer func() {
if cleanup {
db.Close()
}
}()
migrations, err := iofs.New(assets.Migrations, "migrations/postgres")
if err != nil {
return nil, fmt.Errorf("failed to create migrations: %w", err)
}
target, err := pgxmigrate.WithInstance(db, &pgxmigrate.Config{})
if err != nil {
return nil, fmt.Errorf("failed to create postgres instance: %w", err)
}
migrator, err := migrate.NewWithInstance("iofs", migrations, "pgx", target)
if err != nil {
return nil, fmt.Errorf("failed to create migrator: %w", err)
}
if err = migrator.Up(); err != nil && err != migrate.ErrNoChange {
return nil, fmt.Errorf("failed to migrate database: %w", err)
}
cleanup = false
app.db = db
return postgres.NewStore(postgres.New(db)), nil
}
+20 -17
View File
@@ -9,7 +9,6 @@ import (
"os"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
@@ -81,9 +80,9 @@ func (app *BootstrapApp) runListeners() (chan error, error) {
return nil, fmt.Errorf("failed to get listener function: %w", err)
}
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
app.wg.Go(func() {
lec <- listenerFunc()
})
}
return lec, nil
@@ -126,7 +125,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
return l
}
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
func (app *BootstrapApp) listenerFromType(listenerType Listener) (func() error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
@@ -139,7 +138,7 @@ func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx conte
}
}
func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
func (app *BootstrapApp) serveHTTP() error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
app.log.App.Info().Msgf("Starting server on %s", address)
@@ -155,10 +154,10 @@ func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
Handler: app.router.Handler(),
}
return app.serve(listener, server, ctx, "http")
return app.serve(listener, server, "http")
}
func (app *BootstrapApp) serveUnix(ctx context.Context) error {
func (app *BootstrapApp) serveUnix() error {
_, err := os.Stat(app.config.Server.SocketPath)
if err == nil {
@@ -182,10 +181,10 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
Handler: app.router.Handler(),
}
return app.serve(listener, server, ctx, "unix socket")
return app.serve(listener, server, "unix socket")
}
func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
func (app *BootstrapApp) serveTailscale() error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener()
@@ -198,23 +197,27 @@ func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
Handler: app.router.Handler(),
}
return app.serve(listener, server, ctx, "tailscale")
return app.serve(listener, server, "tailscale")
}
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, ctx context.Context, name string) error {
func (app *BootstrapApp) serve(listener net.Listener, server *http.Server, name string) error {
shutdown := func() {
// we use a new context for the shutdown since the main one is cancelled
sctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), model.GracefulShutdownTimeout*time.Second)
defer cancel()
err := server.Shutdown(sctx)
if err != nil {
err := server.Shutdown(ctx)
if err != nil &&
// With tailscale, the goroutine for shutting down the tailscale connection
// runs first and causes the connection the tailscale listener is running on to close
// first so, the shutdown fails
// TODO: add priority to the goroutine shutdowns
!errors.Is(err, net.ErrClosed) {
app.log.App.Error().Err(err).Msgf("Failed to shutdown %s listener gracefully", name)
}
listener.Close()
}
go func() {
<-ctx.Done()
<-app.ctx.Done()
app.log.App.Debug().Msgf("Shutting down %s listener", name)
shutdown()
}()
+8 -8
View File
@@ -2,13 +2,14 @@ package bootstrap
import (
"fmt"
"os"
"github.com/tinyauthapp/tinyauth/internal/service"
)
func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ding)
ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg)
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
@@ -22,7 +23,7 @@ func (app *BootstrapApp) setupServices() error {
return fmt.Errorf("failed to initialize label provider: %w", err)
}
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding)
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, &app.wg)
if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
@@ -42,10 +43,10 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err)
@@ -69,7 +70,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding)
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg)
if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
@@ -81,7 +82,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding)
dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg)
if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err)
@@ -125,8 +126,7 @@ func (app *BootstrapApp) setupPolicyEngine() error {
Config: app.config,
})
policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{
Log: app.log,
Config: app.config,
Log: app.log,
})
app.services.policyEngine = policyEngine
+16 -16
View File
@@ -183,23 +183,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
if !controller.auth.IsEmailWhitelisted(user.Email) {
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
@@ -240,6 +226,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1)
}
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
sessionCookie := repository.Session{
Username: username,
Name: name,
+74 -110
View File
@@ -1,7 +1,6 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -13,18 +12,10 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type authorizeErrorParams struct {
err error
reason string
reasonPublic string
callback string
callbackError string
state string
}
type OIDCController struct {
log *logger.Logger
oidc *service.OIDCService
@@ -128,55 +119,34 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
func (controller *OIDCController) Authorize(c *gin.Context) {
if controller.oidc == nil {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err_oidc_not_configured"),
reason: "OIDC not configured",
reasonPublic: "This instance is not configured for OIDC",
})
controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "")
return
}
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to get user context",
reasonPublic: "User is not logged in or the session is invalid",
})
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
if !userContext.Authenticated {
controller.authorizeError(c, authorizeErrorParams{
err: errors.New("err user not logged in"),
reason: "User not logged in",
reasonPublic: "The user is not logged in",
})
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
return
}
var req service.AuthorizeRequest
err = c.Bind(&req)
err = c.BindJSON(&req)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to bind JSON",
reasonPublic: "The client provided an invalid authorization request",
})
controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "")
return
}
_, ok := controller.oidc.GetClient(req.ClientID)
client, ok := controller.oidc.GetClient(req.ClientID)
if !ok {
controller.authorizeError(c, authorizeErrorParams{
err: fmt.Errorf("client not found: %s", req.ClientID),
reason: "Client not found",
reasonPublic: "The client ID is invalid",
})
controller.authorizeError(c, fmt.Errorf("client not found: %s", req.ClientID), "Client not found", "The client ID is invalid", "", "", "")
return
}
@@ -185,43 +155,41 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
if err != nil {
controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params")
if err.Error() != "invalid_request_uri" {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed validate authorize params",
reasonPublic: "Invalid request parameters",
callback: req.RedirectURI,
callbackError: err.Error(),
state: req.State,
})
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
return
}
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Redirect URI not trusted",
reasonPublic: "The provided redirect URI is not trusted",
})
controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "")
return
}
// Create the sub to find and delete old sessions
sub := controller.oidc.CreateSub(*userContext, req.ClientID)
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
code := utils.GenerateString(32)
// Before storing the code, delete old session
err = controller.oidc.DeleteOldSession(c, sub)
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to delete old sessions",
reasonPublic: "Failed to delete old sessions",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State)
return
}
// Create the authorization code
code := controller.oidc.CreateCode(req, *userContext)
err = controller.oidc.StoreCode(c, sub, code, req)
if err != nil {
controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State)
return
}
// We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to store user info")
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
return
}
}
queries, err := query.Values(AuthorizeCallback{
Code: code,
@@ -229,14 +197,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
})
if err != nil {
controller.authorizeError(c, authorizeErrorParams{
err: err,
reason: "Failed to build query",
reasonPublic: "Failed to build query",
callback: req.RedirectURI,
callbackError: "server_error",
state: req.State,
})
controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State)
return
}
@@ -324,34 +285,39 @@ func (controller *OIDCController) Token(c *gin.Context) {
switch req.GrantType {
case "authorization_code":
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
if !ok {
// ensure no code reuse
usedCodeSub, ok := controller.oidc.IsCodeUsed(controller.oidc.Hash(req.Code))
if ok {
controller.log.App.Warn().Msg("Code reuse detected")
err := controller.oidc.DeleteSessionBySub(c, usedCodeSub)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete session for reused code")
}
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
}
if errors.Is(err, service.ErrCodeNotFound) {
controller.log.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
controller.log.App.Warn().Msg("Code not found")
if errors.Is(err, service.ErrCodeExpired) {
controller.log.App.Warn().Msg("Code expired")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}
if errors.Is(err, service.ErrInvalidClient) {
controller.log.App.Warn().Msg("Code does not belong to client")
c.JSON(400, gin.H{
"error": "invalid_client",
})
return
}
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
c.JSON(400, gin.H{
"error": "invalid_grant",
"error": "server_error",
})
return
}
// mark code as used to prevent reuse
controller.oidc.MarkCodeAsUsed(controller.oidc.Hash(req.Code), entry.Userinfo.Sub)
if entry.RedirectURI != req.RedirectURI {
controller.log.App.Warn().Msg("Redirect URI does not match")
c.JSON(400, gin.H{
@@ -360,7 +326,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
ok = controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
if !ok {
controller.log.App.Warn().Msg("PKCE validation failed")
@@ -370,7 +336,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
@@ -380,7 +346,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
tokenResponse = *tokenRes
tokenResponse = tokenRes
case "refresh_token":
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID)
@@ -408,7 +374,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}
tokenResponse = *tokenRes
tokenResponse = tokenRes
}
c.Header("cache-control", "no-store")
@@ -472,7 +438,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return
}
entry, err := controller.oidc.GetSessionByToken(c, controller.oidc.Hash(token))
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
if err != nil {
if errors.Is(err, service.ErrTokenNotFound) {
@@ -491,17 +457,15 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
}
// If we don't have the openid scope, return an error
if !slices.Contains(strings.Split(entry.Scope, " "), "openid") {
controller.log.App.Warn().Msg("OIDC userinfo accessed with missing openid scope")
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
c.JSON(401, gin.H{
"error": "invalid_scope",
})
return
}
var userinfo service.UserinfoResponse
err = json.Unmarshal([]byte(entry.UserinfoJson), &userinfo)
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get user info")
@@ -511,23 +475,23 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return
}
c.JSON(200, controller.oidc.CompileUserinfo(userinfo, entry.Scope))
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
}
func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) {
controller.log.App.Error().Err(params.err).Str("reason", params.reason).Msg("Authorization error")
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error")
if params.callback != "" {
if callback != "" {
errorQueries := CallbackError{
Error: params.callbackError,
Error: callbackError,
}
if params.reasonPublic != "" {
errorQueries.ErrorDescription = params.reasonPublic
if reasonUser != "" {
errorQueries.ErrorDescription = reasonUser
}
if params.state != "" {
errorQueries.State = params.state
if state != "" {
errorQueries.State = state
}
queries, err := query.Values(errorQueries)
@@ -539,13 +503,13 @@ func (controller *OIDCController) authorizeError(c *gin.Context, params authoriz
c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", params.callback, queries.Encode()),
"redirect_uri": fmt.Sprintf("%s?%s", callback, queries.Encode()),
})
return
}
errorQueries := ErrorScreen{
Error: params.reasonPublic,
Error: reasonUser,
}
queries, err := query.Values(errorQueries)
+3 -3
View File
@@ -8,11 +8,11 @@ import (
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -840,9 +840,9 @@ func TestOIDCController(t *testing.T) {
store := memory.New()
dg := ding.New(context.TODO())
wg := &sync.WaitGroup{}
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, context.TODO(), wg)
require.NoError(t, err)
for _, test := range tests {
+1 -4
View File
@@ -160,10 +160,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
// No user context found is not an issue
if !errors.Is(err, model.ErrUserContextNotFound) {
controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
}
controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated")
userContext = &model.UserContext{
Authenticated: false,
}
+3 -4
View File
@@ -3,10 +3,10 @@ package controller_test
import (
"context"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -353,10 +353,11 @@ func TestProxyController(t *testing.T) {
store := memory.New()
wg := &sync.WaitGroup{}
ctx := context.TODO()
dg := ding.New(ctx)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
aclsService := service.NewAccessControlsService(log, cfg, nil)
policyEngine, err := service.NewPolicyEngine(cfg, log)
@@ -382,8 +383,6 @@ func TestProxyController(t *testing.T) {
Log: log,
})
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
router := gin.Default()
+4 -7
View File
@@ -6,12 +6,12 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -412,17 +412,14 @@ func TestUserController(t *testing.T) {
}
ctx := context.TODO()
dg := ding.New(ctx)
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
wg := &sync.WaitGroup{}
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
beforeEach := func() {
// Clear failed login attempts before each test
authService.ClearLoginAttempts()
authService.ClearRateLimitsTestingOnly()
}
for _, test := range tests {
@@ -5,10 +5,10 @@ import (
"encoding/json"
"fmt"
"net/http/httptest"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller"
@@ -89,11 +89,11 @@ func TestWellKnownController(t *testing.T) {
}
ctx := context.TODO()
dg := ding.New(ctx)
wg := &sync.WaitGroup{}
store := memory.New()
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, dg)
oidcService, err := service.NewOIDCService(log, cfg, runtime, store, ctx, wg)
require.NoError(t, err)
for _, test := range tests {
+6 -5
View File
@@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
}
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
}
@@ -251,10 +251,6 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model.
case model.UserLocal:
user := m.auth.GetLocalUser(username)
if user == nil {
return nil, nil, fmt.Errorf("user not found locally: %s", username)
}
if user.TOTPSecret != "" {
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
}
@@ -326,6 +322,11 @@ func (m *ContextMiddleware) tailscaleWhois(ctx context.Context, ip string) (*mod
Name: whois.DisplayName,
},
UserID: whois.UserID,
Tags: whois.Tags,
}
if !strings.ContainsAny(uctx.Email, "@") {
uctx.Email = utils.CompileUserEmail(uctx.Email+"-tailscale", m.runtime.CookieDomain)
}
return &uctx, nil
@@ -5,11 +5,11 @@ import (
"encoding/base64"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/middleware"
@@ -250,20 +250,17 @@ func TestContextMiddleware(t *testing.T) {
}
ctx := context.TODO()
dg := ding.New(ctx)
wg := &sync.WaitGroup{}
store := memory.New()
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
for _, test := range tests {
authService.ClearLoginAttempts()
authService.ClearRateLimitsTestingOnly()
t.Run(test.description, func(t *testing.T) {
gin.SetMode(gin.TestMode)
+10 -10
View File
@@ -62,6 +62,9 @@ func NewDefaultConfiguration() *Config {
PrivateKeyPath: "./tinyauth_oidc_key",
PublicKeyPath: "./tinyauth_oidc_key.pub",
},
Experimental: ExperimentalConfig{
ConfigFile: "",
},
Tailscale: TailscaleConfig{
Dir: "./tailscale_state",
},
@@ -85,12 +88,11 @@ type Config struct {
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment." yaml:"labelProvider"`
Log LogConfig `description:"Logging configuration." yaml:"log"`
Tailscale TailscaleConfig `description:"Tailscale configuration." yaml:"tailscale"`
ConfigFile string `description:"Path to config file." yaml:"-"`
}
type DatabaseConfig struct {
Driver string `description:"The database driver to use. Valid values: sqlite, postgres, memory." yaml:"driver"`
Path string `description:"The path to the SQLite database file, or connection URL when driver is postgres." yaml:"path"`
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
}
type AnalyticsConfig struct {
@@ -152,9 +154,8 @@ type AddressClaim struct {
}
type IPConfig struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"`
Bypass []string `description:"List of IPs or CIDR ranges that bypass authentication entirely." yaml:"bypass"`
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"`
}
type OAuthConfig struct {
@@ -206,8 +207,9 @@ type LogStreamConfig struct {
Level string `description:"Log level for this stream. Use global if empty." yaml:"level"`
}
// no experimental features
type ExperimentalConfig struct{}
type ExperimentalConfig struct {
ConfigFile string `description:"Path to config file." yaml:"-"`
}
type TailscaleConfig struct {
Enabled bool `description:"Enable Tailscale integration." yaml:"enabled"`
@@ -223,8 +225,6 @@ type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
+2
View File
@@ -59,6 +59,8 @@ type LDAPContext struct {
type TailscaleContext struct {
BaseContext
UserID string
// for future use
Tags []string
}
func (c *UserContext) IsAuthenticated() bool {
+288 -104
View File
@@ -101,182 +101,366 @@ func TestMemoryStore(t *testing.T) {
},
},
{
description: "Create and get OIDC session",
description: "Create and get OIDC code",
run: func(t *testing.T, s repository.Store) {
sess, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
Scope: "openid",
code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{
Sub: "sub-1",
CodeHash: "hash-1",
Scope: "openid",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", sess.Sub)
assert.Equal(t, "sub-1", code.Sub)
got, err := s.GetOIDCSessionBySub(ctx, "sub-1")
// destructive read removes the record
got, err := s.GetOidcCode(ctx, "hash-1")
require.NoError(t, err)
assert.Equal(t, sess, got)
},
},
{
description: "Get OIDC session by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionBySub(ctx, "missing")
assert.Equal(t, code, got)
_, err = s.GetOidcCode(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC session by access token hash",
description: "Get OIDC code not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
_, err := s.GetOidcCode(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// destructive — gone after read
_, err = s.GetOidcCodeBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code unsafe",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeUnsafe(ctx, "hash-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
// non-destructive — still present
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.NoError(t, err)
},
},
{
description: "Get OIDC code unsafe not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC code by sub unsafe",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "hash-1", got.CodeHash)
// non-destructive — still present
_, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1")
assert.NoError(t, err)
},
},
{
description: "Get OIDC code by sub unsafe not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcCodeBySubUnsafe(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC code unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub")
},
},
{
description: "Delete OIDC code",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCode(ctx, "hash-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC code by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1"))
_, err = s.GetOidcCodeUnsafe(ctx, "hash-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC codes",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10})
require.NoError(t, err)
_, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100})
require.NoError(t, err)
deleted, err := s.DeleteExpiredOidcCodes(ctx, 50)
require.NoError(t, err)
require.Len(t, deleted, 1)
assert.Equal(t, "hash-1", deleted[0].CodeHash)
_, err = s.GetOidcCodeUnsafe(ctx, "hash-2")
assert.NoError(t, err)
},
},
{
description: "Create and get OIDC token",
run: func(t *testing.T, s repository.Store) {
tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-hash-1",
CodeHash: "code-hash-1",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", tok.Sub)
got, err := s.GetOidcToken(ctx, "at-hash-1")
require.NoError(t, err)
assert.Equal(t, tok, got)
},
},
{
description: "Get OIDC token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC token unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub")
},
},
{
description: "Get OIDC token by refresh token",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
},
},
{
description: "Get OIDC session by access token hash not found",
description: "Get OIDC token by refresh token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionByAccessTokenHash(ctx, "missing")
_, err := s.GetOidcTokenByRefreshToken(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Get OIDC session by refresh token hash",
description: "Get OIDC token by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
})
require.NoError(t, err)
got, err := s.GetOidcTokenBySub(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, "at-1", got.AccessTokenHash)
},
},
{
description: "Get OIDC token by sub not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcTokenBySub(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Update OIDC token by refresh token",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
got, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "rt-1")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
},
},
{
description: "Get OIDC session by refresh token hash not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCSessionByRefreshTokenHash(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC session unique sub constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.sub")
},
},
{
description: "Create OIDC session unique access token hash constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-1", RefreshTokenHash: "rt-2"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.access_token_hash")
},
},
{
description: "Create OIDC session unique refresh token hash constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
require.NoError(t, err)
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-1"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
},
},
{
description: "Update OIDC session",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
RefreshTokenHash: "rt-1",
})
require.NoError(t, err)
updated, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
Sub: "sub-1",
updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
RefreshTokenHash_2: "rt-1",
AccessTokenHash: "at-2",
RefreshTokenHash: "rt-2",
Scope: "openid profile",
TokenExpiresAt: 200,
RefreshTokenExpiresAt: 400,
})
require.NoError(t, err)
assert.Equal(t, "at-2", updated.AccessTokenHash)
assert.Equal(t, "rt-2", updated.RefreshTokenHash)
assert.Equal(t, "openid profile", updated.Scope)
// updated token hashes are now queryable, old ones are gone
got, err := s.GetOIDCSessionByAccessTokenHash(ctx, "at-2")
// old key gone, new key present
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
got, err := s.GetOidcToken(ctx, "at-2")
require.NoError(t, err)
assert.Equal(t, "sub-1", got.Sub)
_, err = s.GetOIDCSessionByAccessTokenHash(ctx, "at-1")
},
},
{
description: "Update OIDC token by refresh token not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{
RefreshTokenHash_2: "missing",
})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Update OIDC session not found",
description: "Delete OIDC token",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{Sub: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC session by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1"})
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOIDCSessionBySub(ctx, "sub-1"))
require.NoError(t, s.DeleteOidcToken(ctx, "at-1"))
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC sessions",
description: "Delete OIDC token by sub",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC token by code hash",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1",
AccessTokenHash: "at-1",
CodeHash: "code-1",
})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1"))
_, err = s.GetOidcToken(ctx, "at-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete expired OIDC tokens",
run: func(t *testing.T, s repository.Store) {
// both expiries past
_, err := s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-1", AccessTokenHash: "at-1", RefreshTokenHash: "rt-1",
_, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-1", AccessTokenHash: "at-1",
TokenExpiresAt: 10, RefreshTokenExpiresAt: 10,
})
require.NoError(t, err)
// valid
_, err = s.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: "sub-2", AccessTokenHash: "at-2", RefreshTokenHash: "rt-2",
_, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{
Sub: "sub-3", AccessTokenHash: "at-3",
TokenExpiresAt: 100, RefreshTokenExpiresAt: 100,
})
require.NoError(t, err)
require.NoError(t, s.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: 50,
RefreshTokenExpiresAt: 50,
}))
})
require.NoError(t, err)
assert.Len(t, deleted, 1)
_, err = s.GetOIDCSessionBySub(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
_, err = s.GetOIDCSessionBySub(ctx, "sub-2")
_, err = s.GetOidcToken(ctx, "at-3")
assert.NoError(t, err)
},
},
{
description: "Create and get OIDC user info",
run: func(t *testing.T, s repository.Store) {
u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{
Sub: "sub-1",
Name: "Alice",
Email: "alice@example.com",
})
require.NoError(t, err)
assert.Equal(t, "sub-1", u.Sub)
got, err := s.GetOidcUserInfo(ctx, "sub-1")
require.NoError(t, err)
assert.Equal(t, u, got)
},
},
{
description: "Get OIDC user info not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOidcUserInfo(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC user info",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"})
require.NoError(t, err)
require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1"))
_, err = s.GetOidcUserInfo(ctx, "sub-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
}
for _, test := range tests {
+205 -60
View File
@@ -7,90 +7,235 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
)
func (s *Store) CreateOIDCSession(_ context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce UNIQUE constraints (sub is the primary key, access/refresh token hashes are unique).
for _, sess := range s.oidcSessions {
switch {
case sess.Sub == arg.Sub:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.sub")
case sess.AccessTokenHash == arg.AccessTokenHash:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.access_token_hash")
case sess.RefreshTokenHash == arg.RefreshTokenHash:
return repository.OidcSession{}, fmt.Errorf("UNIQUE constraint failed: oidc_sessions.refresh_token_hash")
// Enforce sub UNIQUE constraint
for _, c := range s.oidcCodes {
if c.Sub == arg.Sub {
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
}
}
sess := repository.OidcSession(arg)
s.oidcSessions[arg.Sub] = sess
return sess, nil
code := repository.OidcCode(arg)
s.oidcCodes[arg.CodeHash] = code
return code, nil
}
func (s *Store) GetOIDCSessionBySub(_ context.Context, sub string) (repository.OidcSession, error) {
s.mu.RLock()
defer s.mu.RUnlock()
sess, ok := s.oidcSessions[sub]
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcSession{}, repository.ErrNotFound
return repository.OidcCode{}, repository.ErrNotFound
}
return sess, nil
delete(s.oidcCodes, codeHash)
return c, nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(_ context.Context, accessTokenHash string) (repository.OidcSession, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, sess := range s.oidcSessions {
if sess.AccessTokenHash == accessTokenHash {
return sess, nil
}
}
return repository.OidcSession{}, repository.ErrNotFound
}
func (s *Store) GetOIDCSessionByRefreshTokenHash(_ context.Context, refreshTokenHash string) (repository.OidcSession, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, sess := range s.oidcSessions {
if sess.RefreshTokenHash == refreshTokenHash {
return sess, nil
}
}
return repository.OidcSession{}, repository.ErrNotFound
}
func (s *Store) UpdateOIDCSession(_ context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
sess, ok := s.oidcSessions[arg.Sub]
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
c, ok := s.oidcCodes[codeHash]
if !ok {
return repository.OidcSession{}, repository.ErrNotFound
return repository.OidcCode{}, repository.ErrNotFound
}
sess.AccessTokenHash = arg.AccessTokenHash
sess.RefreshTokenHash = arg.RefreshTokenHash
sess.Scope = arg.Scope
sess.ClientID = arg.ClientID
sess.TokenExpiresAt = arg.TokenExpiresAt
sess.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
sess.Nonce = arg.Nonce
sess.UserinfoJson = arg.UserinfoJson
s.oidcSessions[arg.Sub] = sess
return sess, nil
return c, nil
}
func (s *Store) DeleteOIDCSessionBySub(_ context.Context, sub string) error {
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, c := range s.oidcCodes {
if c.Sub == sub {
return c, nil
}
}
return repository.OidcCode{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcSessions, sub)
delete(s.oidcCodes, codeHash)
return nil
}
func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, sess := range s.oidcSessions {
if sess.TokenExpiresAt < arg.TokenExpiresAt && sess.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
delete(s.oidcSessions, k)
for k, c := range s.oidcCodes {
if c.Sub == sub {
delete(s.oidcCodes, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcCode
for k, c := range s.oidcCodes {
if c.ExpiresAt < expiresAt {
deleted = append(deleted, c)
delete(s.oidcCodes, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Enforce sub UNIQUE constraint
for _, t := range s.oidcTokens {
if t.Sub == arg.Sub {
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
}
}
tok := repository.OidcToken{
Sub: arg.Sub,
AccessTokenHash: arg.AccessTokenHash,
RefreshTokenHash: arg.RefreshTokenHash,
CodeHash: arg.CodeHash,
Scope: arg.Scope,
ClientID: arg.ClientID,
TokenExpiresAt: arg.TokenExpiresAt,
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
Nonce: arg.Nonce,
}
s.oidcTokens[arg.AccessTokenHash] = tok
return tok, nil
}
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
t, ok := s.oidcTokens[accessTokenHash]
if !ok {
return repository.OidcToken{}, repository.ErrNotFound
}
return t, nil
}
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.RefreshTokenHash == refreshTokenHash {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, t := range s.oidcTokens {
if t.Sub == sub {
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
delete(s.oidcTokens, k)
t.AccessTokenHash = arg.AccessTokenHash
t.RefreshTokenHash = arg.RefreshTokenHash
t.TokenExpiresAt = arg.TokenExpiresAt
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
s.oidcTokens[arg.AccessTokenHash] = t
return t, nil
}
}
return repository.OidcToken{}, repository.ErrNotFound
}
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcTokens, accessTokenHash)
return nil
}
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.Sub == sub {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
for k, t := range s.oidcTokens {
if t.CodeHash == codeHash {
delete(s.oidcTokens, k)
}
}
return nil
}
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
s.mu.Lock()
defer s.mu.Unlock()
var deleted []repository.OidcToken
for k, t := range s.oidcTokens {
if t.TokenExpiresAt < arg.TokenExpiresAt && t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
deleted = append(deleted, t)
delete(s.oidcTokens, k)
}
}
return deleted, nil
}
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
u := repository.OidcUserinfo(arg)
s.oidcUsers[arg.Sub] = u
return u, nil
}
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
u, ok := s.oidcUsers[sub]
if !ok {
return repository.OidcUserinfo{}, repository.ErrNotFound
}
return u, nil
}
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcUsers, sub)
return nil
}
+9 -5
View File
@@ -9,15 +9,19 @@ import (
// Store is a thread-safe in-memory implementation of repository.Store.
type Store struct {
mu sync.RWMutex
sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession
mu sync.RWMutex
sessions map[string]repository.Session
oidcCodes map[string]repository.OidcCode
oidcTokens map[string]repository.OidcToken
oidcUsers map[string]repository.OidcUserinfo
}
// New returns a new empty in-memory Store.
func New() repository.Store {
return &Store{
sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession),
sessions: make(map[string]repository.Session),
oidcCodes: make(map[string]repository.OidcCode),
oidcTokens: make(map[string]repository.OidcToken),
oidcUsers: make(map[string]repository.OidcUserinfo),
}
}
+73 -11
View File
@@ -17,16 +17,49 @@ type Session struct {
OAuthSub string
}
type OidcSession struct {
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type CreateSessionParams struct {
@@ -56,7 +89,18 @@ type UpdateSessionParams struct {
UUID string
}
type CreateOIDCSessionParams struct {
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
@@ -64,23 +108,41 @@ type CreateOIDCSessionParams struct {
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
CodeHash string
Nonce string
UserinfoJson string
}
type UpdateOIDCSessionParams struct {
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
Sub string
RefreshTokenHash_2 string
}
type DeleteExpiredOIDCSessionsParams struct {
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
-31
View File
@@ -1,31 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
package postgres
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}
type Queries struct {
db DBTX
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
db: tx,
}
}
-3
View File
@@ -1,3 +0,0 @@
package postgres
//go:generate go run github.com/tinyauthapp/tinyauth/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/postgres
-31
View File
@@ -1,31 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
package postgres
type OidcSession struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
}
type Session struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
@@ -1,210 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: oidc_queries.sql
package postgres
import (
"context"
)
const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" (
"sub",
"access_token_hash",
"refresh_token_hash",
"scope",
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"nonce",
"userinfo_json"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9
)
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
`
type CreateOIDCSessionParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
}
func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, createOIDCSession,
arg.Sub,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope,
arg.ClientID,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.Nonce,
arg.UserinfoJson,
)
var i OidcSession
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec
DELETE FROM "oidc_sessions"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2
`
type DeleteExpiredOIDCSessionsParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error {
_, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions"
WHERE "sub" = $1
`
func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub)
return err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = $1
`
func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash)
var i OidcSession
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "refresh_token_hash" = $1
`
func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash)
var i OidcSession
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "sub" = $1
`
func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub)
var i OidcSession
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET
"access_token_hash" = $1,
"refresh_token_hash" = $2,
"scope" = $3,
"client_id" = $4,
"token_expires_at" = $5,
"refresh_token_expires_at" = $6,
"nonce" = $7,
"userinfo_json" = $8
WHERE "sub" = $9
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
`
type UpdateOIDCSessionParams struct {
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
Sub string
}
func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, updateOIDCSession,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope,
arg.ClientID,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.Nonce,
arg.UserinfoJson,
arg.Sub,
)
var i OidcSession
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
@@ -1,176 +0,0 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.31.1
// source: session_queries.sql
package postgres
import (
"context"
)
const createSession = `-- name: CreateSession :one
INSERT INTO "sessions" (
"uuid",
"username",
"email",
"name",
"provider",
"totp_pending",
"oauth_groups",
"expiry",
"created_at",
"oauth_name",
"oauth_sub"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
)
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type CreateSessionParams struct {
UUID string
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
CreatedAt int64
OAuthName string
OAuthSub string
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, createSession,
arg.UUID,
arg.Username,
arg.Email,
arg.Name,
arg.Provider,
arg.TotpPending,
arg.OAuthGroups,
arg.Expiry,
arg.CreatedAt,
arg.OAuthName,
arg.OAuthSub,
)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
DELETE FROM "sessions"
WHERE "expiry" < $1
`
func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
_, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry)
return err
}
const deleteSession = `-- name: DeleteSession :exec
DELETE FROM "sessions"
WHERE "uuid" = $1
`
func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteSession, uuid)
return err
}
const getSession = `-- name: GetSession :one
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
WHERE "uuid" = $1
`
func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) {
row := q.db.QueryRowContext(ctx, getSession, uuid)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
const updateSession = `-- name: UpdateSession :one
UPDATE "sessions" SET
"username" = $1,
"email" = $2,
"name" = $3,
"provider" = $4,
"totp_pending" = $5,
"oauth_groups" = $6,
"expiry" = $7,
"oauth_name" = $8,
"oauth_sub" = $9
WHERE "uuid" = $10
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type UpdateSessionParams struct {
Username string
Email string
Name string
Provider string
TotpPending bool
OAuthGroups string
Expiry int64
OAuthName string
OAuthSub string
UUID string
}
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
row := q.db.QueryRowContext(ctx, updateSession,
arg.Username,
arg.Email,
arg.Name,
arg.Provider,
arg.TotpPending,
arg.OAuthGroups,
arg.Expiry,
arg.OAuthName,
arg.OAuthSub,
arg.UUID,
)
var i Session
err := row.Scan(
&i.UUID,
&i.Username,
&i.Email,
&i.Name,
&i.Provider,
&i.TotpPending,
&i.OAuthGroups,
&i.Expiry,
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
)
return i, err
}
-113
View File
@@ -1,113 +0,0 @@
// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT.
package postgres
import (
"context"
"database/sql"
"errors"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
// Store wraps *Queries and implements repository.Store.
type Store struct {
q *Queries
}
// NewStore wraps a *Queries to satisfy repository.Store.
func NewStore(q *Queries) repository.Store {
return &Store{q: q}
}
var errorMap = map[error]error{
sql.ErrNoRows: repository.ErrNotFound,
}
func mapErr(err error) error {
for from, to := range errorMap {
if errors.Is(err, from) {
return to
}
}
return err
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil {
return repository.OidcSession{}, mapErr(err)
}
return repository.OidcSession(r), nil
}
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return repository.Session(r), nil
}
func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg)))
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil {
return repository.OidcSession{}, mapErr(err)
}
return repository.OidcSession(r), nil
}
func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash)
if err != nil {
return repository.OidcSession{}, mapErr(err)
}
return repository.OidcSession(r), nil
}
func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionBySub(ctx, sub)
if err != nil {
return repository.OidcSession{}, mapErr(err)
}
return repository.OidcSession(r), nil
}
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
r, err := s.q.GetSession(ctx, uuid)
if err != nil {
return repository.Session{}, mapErr(err)
}
return repository.Session(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil {
return repository.OidcSession{}, mapErr(err)
}
return repository.OidcSession(r), nil
}
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
if err != nil {
return repository.Session{}, mapErr(err)
}
return repository.Session(r), nil
}
+35 -2
View File
@@ -4,16 +4,49 @@
package sqlite
type OidcSession struct {
type OidcCode struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
type OidcToken struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
CodeHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
}
type OidcUserinfo struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
type Session struct {
+435 -64
View File
@@ -9,8 +9,60 @@ import (
"context"
)
const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" (
const createOidcCode = `-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
type CreateOidcCodeParams struct {
Sub string
CodeHash string
Scope string
RedirectURI string
ClientID string
ExpiresAt int64
Nonce string
CodeChallenge string
}
func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, createOidcCode,
arg.Sub,
arg.CodeHash,
arg.Scope,
arg.RedirectURI,
arg.ClientID,
arg.ExpiresAt,
arg.Nonce,
arg.CodeChallenge,
)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const createOidcToken = `-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token_hash",
"refresh_token_hash",
@@ -18,15 +70,15 @@ INSERT INTO "oidc_sessions" (
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"nonce",
"userinfo_json"
"code_hash",
"nonce"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type CreateOIDCSessionParams struct {
type CreateOidcTokenParams struct {
Sub string
AccessTokenHash string
RefreshTokenHash string
@@ -34,12 +86,12 @@ type CreateOIDCSessionParams struct {
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
CodeHash string
Nonce string
UserinfoJson string
}
func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, createOIDCSession,
func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, createOidcToken,
arg.Sub,
arg.AccessTokenHash,
arg.RefreshTokenHash,
@@ -47,164 +99,483 @@ func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionPa
arg.ClientID,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.CodeHash,
arg.Nonce,
arg.UserinfoJson,
)
var i OidcSession
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec
DELETE FROM "oidc_sessions"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
const createOidcUserInfo = `-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address
`
type DeleteExpiredOIDCSessionsParams struct {
type CreateOidcUserInfoParams struct {
Sub string
Name string
PreferredUsername string
Email string
Groups string
UpdatedAt int64
GivenName string
FamilyName string
MiddleName string
Nickname string
Profile string
Picture string
Website string
Gender string
Birthdate string
Zoneinfo string
Locale string
PhoneNumber string
Address string
}
func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, createOidcUserInfo,
arg.Sub,
arg.Name,
arg.PreferredUsername,
arg.Email,
arg.Groups,
arg.UpdatedAt,
arg.GivenName,
arg.FamilyName,
arg.MiddleName,
arg.Nickname,
arg.Profile,
arg.Picture,
arg.Website,
arg.Gender,
arg.Birthdate,
arg.Zoneinfo,
arg.Locale,
arg.PhoneNumber,
arg.Address,
)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcCode
for rows.Next() {
var i OidcCode
if err := rows.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type DeleteExpiredOidcTokensParams struct {
TokenExpiresAt int64
RefreshTokenExpiresAt int64
}
func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error {
_, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OidcToken
for rows.Next() {
var i OidcToken
if err := rows.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteOidcCode = `-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions"
const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub)
func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub)
return err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
const deleteOidcToken = `-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?
`
func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash)
var i OidcSession
func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash)
return err
}
const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?
`
func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash)
return err
}
const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub)
return err
}
const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error {
_, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub)
return err
}
const getOidcCode = `-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCode, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge
`
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "sub" = ?
`
func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes"
WHERE "code_hash" = ?
`
func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) {
row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash)
var i OidcCode
err := row.Scan(
&i.Sub,
&i.CodeHash,
&i.Scope,
&i.RedirectURI,
&i.ClientID,
&i.ExpiresAt,
&i.Nonce,
&i.CodeChallenge,
)
return i, err
}
const getOidcToken = `-- name: GetOidcToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "access_token_hash" = ?
`
func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?
`
func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash)
var i OidcSession
func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens"
WHERE "sub" = ?
`
func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub)
var i OidcSession
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
SELECT sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo"
WHERE "sub" = ?
`
func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) {
row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub)
var i OidcUserinfo
err := row.Scan(
&i.Sub,
&i.Name,
&i.PreferredUsername,
&i.Email,
&i.Groups,
&i.UpdatedAt,
&i.GivenName,
&i.FamilyName,
&i.MiddleName,
&i.Nickname,
&i.Profile,
&i.Picture,
&i.Website,
&i.Gender,
&i.Birthdate,
&i.Zoneinfo,
&i.Locale,
&i.PhoneNumber,
&i.Address,
)
return i, err
}
const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"scope" = ?,
"client_id" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?,
"nonce" = ?,
"userinfo_json" = ?
WHERE "sub" = ?
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce
`
type UpdateOIDCSessionParams struct {
type UpdateOidcTokenByRefreshTokenParams struct {
AccessTokenHash string
RefreshTokenHash string
Scope string
ClientID string
TokenExpiresAt int64
RefreshTokenExpiresAt int64
Nonce string
UserinfoJson string
Sub string
RefreshTokenHash_2 string
}
func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) {
row := q.db.QueryRowContext(ctx, updateOIDCSession,
func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) {
row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken,
arg.AccessTokenHash,
arg.RefreshTokenHash,
arg.Scope,
arg.ClientID,
arg.TokenExpiresAt,
arg.RefreshTokenExpiresAt,
arg.Nonce,
arg.UserinfoJson,
arg.Sub,
arg.RefreshTokenHash_2,
)
var i OidcSession
var i OidcToken
err := row.Scan(
&i.Sub,
&i.AccessTokenHash,
&i.RefreshTokenHash,
&i.CodeHash,
&i.Scope,
&i.ClientID,
&i.TokenExpiresAt,
&i.RefreshTokenExpiresAt,
&i.Nonce,
&i.UserinfoJson,
)
return i, err
}
+120 -24
View File
@@ -32,12 +32,28 @@ func mapErr(err error) error {
return err
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
if err != nil {
return repository.OidcSession{}, mapErr(err)
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcSession(r), nil
return repository.OidcCode(r), nil
}
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
}
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
@@ -48,44 +64,124 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
return repository.Session(r), nil
}
func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error {
return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg)))
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcCode, len(rows))
for i, row := range rows {
out[i] = repository.OidcCode(row)
}
return out, nil
}
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
if err != nil {
return nil, mapErr(err)
}
out := make([]repository.OidcToken, len(rows))
for i, row := range rows {
out[i] = repository.OidcToken(row)
}
return out, nil
}
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
}
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
}
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
}
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
}
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
}
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
}
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid))
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCode(ctx, codeHash)
if err != nil {
return repository.OidcSession{}, mapErr(err)
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcSession(r), nil
return repository.OidcCode(r), nil
}
func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash)
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySub(ctx, sub)
if err != nil {
return repository.OidcSession{}, mapErr(err)
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcSession(r), nil
return repository.OidcCode(r), nil
}
func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionBySub(ctx, sub)
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
if err != nil {
return repository.OidcSession{}, mapErr(err)
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcSession(r), nil
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
if err != nil {
return repository.OidcCode{}, mapErr(err)
}
return repository.OidcCode(r), nil
}
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
r, err := s.q.GetOidcTokenBySub(ctx, sub)
if err != nil {
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcToken(r), nil
}
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
r, err := s.q.GetOidcUserInfo(ctx, sub)
if err != nil {
return repository.OidcUserinfo{}, mapErr(err)
}
return repository.OidcUserinfo(r), nil
}
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
@@ -96,12 +192,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
if err != nil {
return repository.OidcSession{}, mapErr(err)
return repository.OidcToken{}, mapErr(err)
}
return repository.OidcSession(r), nil
return repository.OidcToken(r), nil
}
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
+25 -8
View File
@@ -19,12 +19,29 @@ type Store interface {
DeleteSession(ctx context.Context, uuid string) error
DeleteExpiredSessions(ctx context.Context, expiry int64) error
// OIDC sessions
CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error)
DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error
DeleteOIDCSessionBySub(ctx context.Context, sub string) error
GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error)
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
// OIDC codes
CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error)
GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error)
GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error)
GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error)
DeleteOidcCode(ctx context.Context, codeHash string) error
DeleteOidcCodeBySub(ctx context.Context, sub string) error
DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error)
// OIDC tokens
CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error)
GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error)
GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error)
GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error)
UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error)
DeleteOidcToken(ctx context.Context, accessTokenHash string) error
DeleteOidcTokenBySub(ctx context.Context, sub string) error
DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error
DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error)
// OIDC userinfo
CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error)
GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error)
DeleteOidcUserInfo(ctx context.Context, sub string) error
}
+22 -59
View File
@@ -9,12 +9,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
// For LDAP and OAuth groups and IP allow/deny, we default to allow even with a deny policy.
// This is because we can't force the user to use groups in LDAP and OAuth if they would like to use
// a deny policy. As for IP checks, we can't reliably get the client IP (most of Tinyauth instances are
// behind a Docker bridge network) so to make it easier for users to use a deny policy without
// issues with IPs we allow by default.
type RuleName string
const (
@@ -31,11 +25,7 @@ type UserAllowedRule struct {
}
func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
if ctx.UserContext == nil {
return EffectDeny
}
if ctx.ACLs == nil {
if ctx.ACLs == nil || ctx.UserContext == nil {
return EffectAbstain
}
@@ -44,7 +34,7 @@ func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Whitelist, ctx.UserContext.OAuth.Email)
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.OAuth.Email).Msg("Invalid entry in OAuth whitelist")
return EffectDeny
return EffectAbstain
}
if match {
rule.Log.App.Debug().Str("email", ctx.UserContext.OAuth.Email).Msg("User is in OAuth whitelist, allowing access")
@@ -58,7 +48,7 @@ func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
match, err := utils.CheckFilter(ctx.ACLs.Users.Block, ctx.UserContext.GetUsername())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users block list")
return EffectDeny
return EffectAbstain
}
if match {
rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users block list, denying access")
@@ -72,11 +62,8 @@ func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect {
match, err := utils.CheckFilter(ctx.ACLs.Users.Allow, ctx.UserContext.GetUsername())
if err != nil {
if err == utils.ErrFilterEmpty {
return EffectAbstain
}
rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users allow list")
return EffectDeny
return EffectAbstain
}
if match {
@@ -93,22 +80,13 @@ type OAuthGroupRule struct {
}
func (rule *OAuthGroupRule) Evaluate(ctx *ACLContext) Effect {
if ctx.UserContext == nil {
return EffectDeny
}
if ctx.ACLs == nil {
return EffectAllow
if ctx.ACLs == nil || ctx.UserContext == nil {
return EffectAbstain
}
if !ctx.UserContext.IsOAuth() {
rule.Log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return EffectAllow
}
if len(ctx.ACLs.OAuth.Groups) == 0 {
rule.Log.App.Debug().Msg("No OAuth groups specified in ACLs, allowing access")
return EffectAllow
return EffectAbstain
}
if _, ok := model.OverrideProviders[ctx.UserContext.OAuth.ID]; ok {
@@ -119,8 +97,7 @@ func (rule *OAuthGroupRule) Evaluate(ctx *ACLContext) Effect {
for _, group := range ctx.UserContext.OAuth.Groups {
match, err := utils.CheckFilter(ctx.ACLs.OAuth.Groups, strings.TrimSpace(group))
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", group).Msg("Invalid entry in OAuth groups ACL")
return EffectDeny
return EffectAbstain
}
if match {
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.OAuth.Groups).Msg("User group matched, allowing access")
@@ -137,29 +114,19 @@ type LDAPGroupRule struct {
}
func (rule *LDAPGroupRule) Evaluate(ctx *ACLContext) Effect {
if ctx.UserContext == nil {
return EffectDeny
}
if ctx.ACLs == nil {
return EffectAllow
if ctx == nil || ctx.UserContext == nil {
return EffectAbstain
}
if !ctx.UserContext.IsLDAP() {
rule.Log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return EffectAllow
}
if len(ctx.ACLs.LDAP.Groups) == 0 {
rule.Log.App.Debug().Msg("No LDAP groups specified in ACLs, allowing access")
return EffectAllow
return EffectAbstain
}
for _, group := range ctx.UserContext.LDAP.Groups {
match, err := utils.CheckFilter(ctx.ACLs.LDAP.Groups, strings.TrimSpace(group))
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", group).Msg("Invalid entry in LDAP groups ACL")
return EffectDeny
return EffectAbstain
}
if match {
rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.LDAP.Groups).Msg("User group matched, allowing access")
@@ -215,15 +182,14 @@ type IPAllowedRule struct {
}
func (rule *IPAllowedRule) Evaluate(ctx *ACLContext) Effect {
// merge global and per-app block/allow lists
blockedIps := append([]string{}, rule.Config.Auth.IP.Block...)
allowedIPs := append([]string{}, rule.Config.Auth.IP.Allow...)
if ctx.ACLs != nil {
blockedIps = append(blockedIps, ctx.ACLs.IP.Block...)
allowedIPs = append(allowedIPs, ctx.ACLs.IP.Allow...)
if ctx.ACLs == nil {
return EffectAbstain
}
// Merge the global and app IP filter
blockedIps := append(ctx.ACLs.IP.Block, rule.Config.Auth.IP.Block...)
allowedIPs := append(ctx.ACLs.IP.Allow, rule.Config.Auth.IP.Allow...)
for _, blocked := range blockedIps {
match, err := utils.CheckIPFilter(blocked, ctx.IP.String())
if err != nil {
@@ -258,18 +224,15 @@ func (rule *IPAllowedRule) Evaluate(ctx *ACLContext) Effect {
}
type IPBypassedRule struct {
Log *logger.Logger
Config model.Config
Log *logger.Logger
}
func (rule *IPBypassedRule) Evaluate(ctx *ACLContext) Effect {
// merge global and per-app bypass lists
bypassList := append([]string{}, rule.Config.Auth.IP.Bypass...)
if ctx.ACLs != nil {
bypassList = append(bypassList, ctx.ACLs.IP.Bypass...)
if ctx.ACLs == nil {
return EffectDeny
}
for _, bypassed := range bypassList {
for _, bypassed := range ctx.ACLs.IP.Bypass {
match, err := utils.CheckIPFilter(bypassed, ctx.IP.String())
if err != nil {
rule.Log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
+47 -151
View File
@@ -5,7 +5,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -21,16 +20,6 @@ func TestUserAllowedRule(t *testing.T) {
ctx *ACLContext
expected Effect
}{
{
name: "denies when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectDeny,
},
{
name: "abstains when ACLs are nil",
ctx: &ACLContext{
@@ -44,6 +33,16 @@ func TestUserAllowedRule(t *testing.T) {
},
expected: EffectAbstain,
},
{
name: "abstains when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectAbstain,
},
{
name: "allows OAuth user when email matches whitelist",
ctx: &ACLContext{
@@ -78,7 +77,7 @@ func TestUserAllowedRule(t *testing.T) {
expected: EffectDeny,
},
{
name: "denies for OAuth user when whitelist filter is invalid",
name: "abstains for OAuth user when whitelist filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "/[/"},
@@ -90,7 +89,7 @@ func TestUserAllowedRule(t *testing.T) {
},
},
},
expected: EffectDeny,
expected: EffectAbstain,
},
{
name: "denies local user when username matches block list",
@@ -123,7 +122,7 @@ func TestUserAllowedRule(t *testing.T) {
expected: EffectAllow,
},
{
name: "denies when block list filter is invalid",
name: "abstains when block list filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Block: "/[/"},
@@ -135,21 +134,6 @@ func TestUserAllowedRule(t *testing.T) {
},
},
},
expected: EffectDeny,
},
{
name: "abstains when allow list is empty",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Allow: ""},
},
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAbstain,
},
{
@@ -183,7 +167,7 @@ func TestUserAllowedRule(t *testing.T) {
expected: EffectDeny,
},
{
name: "denies when allow list filter is invalid",
name: "abstains when allow list filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
Users: model.AppUsers{Allow: "/[/"},
@@ -195,7 +179,7 @@ func TestUserAllowedRule(t *testing.T) {
},
},
},
expected: EffectDeny,
expected: EffectAbstain,
},
}
@@ -218,17 +202,7 @@ func TestOAuthGroupRule(t *testing.T) {
expected Effect
}{
{
name: "denies when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectDeny,
},
{
name: "allows when ACLs are nil",
name: "abstains when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
UserContext: &model.UserContext{
@@ -238,10 +212,20 @@ func TestOAuthGroupRule(t *testing.T) {
},
},
},
expected: EffectAllow,
expected: EffectAbstain,
},
{
name: "allows when user is not OAuth",
name: "abstains when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectAbstain,
},
{
name: "abstains when user is not OAuth",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "admins"},
@@ -253,22 +237,7 @@ func TestOAuthGroupRule(t *testing.T) {
},
},
},
expected: EffectAllow,
},
{
name: "allows when group filter is empty",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: ""},
},
UserContext: &model.UserContext{
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAllow,
expected: EffectAbstain,
},
{
name: "allows when provider is an override provider regardless of groups",
@@ -335,7 +304,7 @@ func TestOAuthGroupRule(t *testing.T) {
expected: EffectDeny,
},
{
name: "denies when groups filter is invalid",
name: "abstains when groups filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Groups: "/[/"},
@@ -348,7 +317,7 @@ func TestOAuthGroupRule(t *testing.T) {
},
},
},
expected: EffectDeny,
expected: EffectAbstain,
},
}
@@ -371,30 +340,22 @@ func TestLDAPGroupRule(t *testing.T) {
expected Effect
}{
{
name: "denies when user context is nil",
name: "abstains when context is nil",
ctx: nil,
expected: EffectAbstain,
},
{
name: "abstains when user context is nil",
ctx: &ACLContext{
ACLs: &model.App{
OAuth: model.AppOAuth{Whitelist: "alice"},
},
UserContext: nil,
},
expected: EffectDeny,
expected: EffectAbstain,
},
{
name: "allows when acls are nil",
ctx: &ACLContext{
ACLs: nil,
UserContext: &model.UserContext{
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAllow,
},
{
name: "allows when user is not LDAP",
name: "abstains when user is not LDAP",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "admins"},
@@ -406,22 +367,7 @@ func TestLDAPGroupRule(t *testing.T) {
},
},
},
expected: EffectAllow,
},
{
name: "allows when group filter is empty",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: ""},
},
UserContext: &model.UserContext{
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{
BaseContext: model.BaseContext{Username: "alice"},
},
},
},
expected: EffectAllow,
expected: EffectAbstain,
},
{
name: "allows LDAP user when a group matches",
@@ -469,7 +415,7 @@ func TestLDAPGroupRule(t *testing.T) {
expected: EffectDeny,
},
{
name: "denies when groups filter is invalid",
name: "abstains when groups filter is invalid",
ctx: &ACLContext{
ACLs: &model.App{
LDAP: model.AppLDAP{Groups: "/[/"},
@@ -481,7 +427,7 @@ func TestLDAPGroupRule(t *testing.T) {
},
},
},
expected: EffectDeny,
expected: EffectAbstain,
},
}
@@ -612,12 +558,12 @@ func TestIPAllowedRule(t *testing.T) {
expected Effect
}{
{
name: "allows when ACLs are nil and no global lists configured",
name: "abstains when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectAllow,
expected: EffectAbstain,
},
{
name: "denies when IP matches app block list",
@@ -723,70 +669,23 @@ func TestIPBypassedRule(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
defaultIPBR := &IPBypassedRule{Log: log}
globBypassIPBR := &IPBypassedRule{
Log: log,
Config: model.Config{Auth: model.AuthConfig{IP: model.IPConfig{Bypass: []string{"10.0.0.0/24"}}}},
}
rule := &IPBypassedRule{Log: log}
tests := []struct {
name string
rule *IPBypassedRule
ctx *ACLContext
expected Effect
}{
{
name: "deny when ACLs are nil and no global bypass",
rule: defaultIPBR,
name: "deny when ACLs are nil",
ctx: &ACLContext{
ACLs: nil,
IP: net.ParseIP("10.0.0.1"),
},
expected: EffectDeny,
},
{
name: "allows when ACLs are nil but IP matches global bypass",
rule: globBypassIPBR,
ctx: &ACLContext{
ACLs: nil,
IP: net.ParseIP("10.0.0.5"),
},
expected: EffectAllow,
},
{
name: "denies when ACLs are nil and IP does not match global bypass",
rule: globBypassIPBR,
ctx: &ACLContext{
ACLs: nil,
IP: net.ParseIP("192.168.1.1"),
},
expected: EffectDeny,
},
{
name: "allows when IP matches per-app bypass but not global bypass",
rule: defaultIPBR,
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
},
IP: net.ParseIP("10.0.0.5"),
},
expected: EffectAllow,
},
{
name: "allows when IP matches global bypass but not per-app bypass",
rule: globBypassIPBR,
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"172.16.0.0/24"}},
},
IP: net.ParseIP("10.0.0.5"),
},
expected: EffectAllow,
},
{
name: "allows when IP matches bypass list",
rule: defaultIPBR,
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
@@ -797,7 +696,6 @@ func TestIPBypassedRule(t *testing.T) {
},
{
name: "denies when IP does not match bypass list",
rule: defaultIPBR,
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}},
@@ -808,7 +706,6 @@ func TestIPBypassedRule(t *testing.T) {
},
{
name: "denies when bypass list is empty",
rule: defaultIPBR,
ctx: &ACLContext{
ACLs: &model.App{},
IP: net.ParseIP("10.0.0.1"),
@@ -817,7 +714,6 @@ func TestIPBypassedRule(t *testing.T) {
},
{
name: "skips invalid bypass entries and allows on later match",
rule: defaultIPBR,
ctx: &ACLContext{
ACLs: &model.App{
IP: model.AppIP{Bypass: []string{"not-an-ip", "10.0.0.1"}},
@@ -830,7 +726,7 @@ func TestIPBypassedRule(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.rule.Evaluate(tt.ctx))
assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx))
})
}
}
+202 -159
View File
@@ -9,12 +9,13 @@ import (
"sync"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
@@ -52,37 +53,42 @@ type OAuthPendingSession struct {
CallbackParams OAuthURLParams
}
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type Lockdown struct {
Active bool
ActiveUntil time.Time
}
type AuthService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
ctx context.Context
context context.Context
ldap *LdapService
queries repository.Store
oauthBroker *OAuthBrokerService
tailscale *TailscaleService
policyEngine *PolicyEngine
ldap *LdapService
queries repository.Store
oauthBroker *OAuthBrokerService
tailscale *TailscaleService
lockdown struct {
active bool
until time.Time
ctx context.Context
cancelFunc context.CancelFunc
mu sync.RWMutex
}
caches struct {
login *CacheStore[LoginAttempt]
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(
@@ -90,49 +96,27 @@ func NewAuthService(
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
dg *ding.Ding,
wg *sync.WaitGroup,
ldap *LdapService,
queries repository.Store,
oauthBroker *OAuthBrokerService,
tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
ctx: ctx,
config: config,
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
policyEngine: policy,
log: log,
runtime: runtime,
context: ctx,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
}
// caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.oauth.Sweep()
service.caches.login.Sweep()
service.caches.ldap.Sweep()
case <-ctx.Done():
return
}
}
}, ding.RingMinor)
wg.Go(service.CleanupOAuthSessionsRoutine)
return service
}
@@ -207,12 +191,14 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, errors.New("ldap service not configured")
}
entry, exists := auth.caches.ldap.Get(userDN)
auth.ldapGroupsMutex.RLock()
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.RUnlock()
if exists {
if exists && time.Now().Before(entry.Expires) {
return &model.LDAPUser{
DN: userDN,
Groups: entry,
Groups: entry.Groups,
}, nil
}
@@ -222,7 +208,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
}
auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second)
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return &model.LDAPUser{
DN: userDN,
@@ -231,7 +222,11 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
if locked, remaining := auth.IsInLockdown(); locked {
auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
if auth.lockdown != nil && auth.lockdown.Active {
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
return true, remaining
}
@@ -239,7 +234,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return false, 0
}
attempt, exists := auth.caches.login.Get(identifier)
attempt, exists := auth.loginAttempts[identifier]
if !exists {
return false, 0
}
@@ -257,72 +252,46 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return
}
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
if locked, _ := auth.IsInLockdown(); locked {
auth.loginMutex.Lock()
defer auth.loginMutex.Unlock()
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
if auth.lockdown != nil && auth.lockdown.Active {
return
}
go auth.lockdownMode()
return
}
auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) {
entry, ok := actions.Get(identifier)
attempt, exists := auth.loginAttempts[identifier]
if !exists {
attempt = &LoginAttempt{}
auth.loginAttempts[identifier] = attempt
}
if !ok {
attempt := LoginAttempt{
LastAttempt: time.Now(),
}
if !success {
attempt.FailedAttempts = 1
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
// match current tinyauth behavior which doesn't expire rate limits
actions.Set(identifier, attempt, 0)
return
}
attempt.LastAttempt = time.Now()
entry.LastAttempt = time.Now()
if success {
attempt.FailedAttempts = 0
attempt.LockedUntil = time.Time{} // Reset lock time
return
}
if success {
entry.FailedAttempts = 0
entry.LockedUntil = time.Time{}
} else {
entry.FailedAttempts++
attempt.FailedAttempts++
if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
actions.Set(identifier, entry, 0)
})
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
// We could also directly access the policyEngine.effectToAccess but
// I believe it's better to use the exported functions instead
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
return auth.policyEngine.EvaluateFunc(func() Effect {
whitelist := auth.runtime.OAuthWhitelist
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
whitelist = providerConfig.Whitelist
}
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
if err != nil {
if err == utils.ErrFilterEmpty {
return EffectAbstain
}
auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
return EffectDeny
}
if match {
return EffectAllow
}
return EffectDeny
})
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
if err != nil {
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
return false
}
return match
}
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -517,6 +486,8 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
}
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
@@ -540,7 +511,9 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
CallbackParams: params,
}
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
return sessionId.String(), session, nil
}
@@ -556,10 +529,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
session, ok := auth.caches.oauth.Get(sessionId)
session, err := auth.GetOAuthPendingSession(sessionId)
if !ok {
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
if err != nil {
return nil, err
}
token, err := (*session.Service).GetToken(code, session.Verifier)
@@ -568,14 +541,9 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
auth.oauthMutex.Lock()
session.Token = token
// ttl 0 means keep current expiration
ok = auth.caches.oauth.Update(sessionId, session, 0)
if !ok {
return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId)
}
auth.oauthMutex.Unlock()
return token, nil
}
@@ -611,39 +579,123 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.caches.oauth.Delete(sessionId)
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
auth.oauthMutex.Lock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
}
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
session, exists := auth.caches.oauth.Get(sessionId)
auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
return &session, nil
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}
func (auth *AuthService) lockdownMode() {
auth.lockdown.mu.Lock()
func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if auth.lockdown.active {
auth.lockdown.mu.Unlock()
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
return
}
type entry struct {
id string
expiresAt int64
}
entries := make([]entry, 0, len(auth.oauthPendingSessions))
for id, session := range auth.oauthPendingSessions {
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
}
slices.SortFunc(entries, func(a, b entry) int {
if a.expiresAt < b.expiresAt {
return -1
}
if a.expiresAt > b.expiresAt {
return 1
}
return 0
})
for _, e := range entries[:OAuthCleanupCount] {
delete(auth.oauthPendingSessions, e.id)
}
}
func (auth *AuthService) lockdownMode() {
ctx, cancel := context.WithCancel(context.Background())
auth.loginMutex.Lock()
if auth.lockdown != nil && auth.lockdown.Active {
auth.loginMutex.Unlock()
cancel()
return
}
auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true
auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.lockdown = &Lockdown{
Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
timer := time.NewTimer(time.Until(auth.lockdown.until))
// At this point all login attemps will also expire so,
// we might as well clear them to free up memory
auth.loginAttempts = make(map[string]*LoginAttempt)
auth.lockdown.mu.Unlock()
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
auth.loginMutex.Unlock()
defer cancel()
defer timer.Stop()
@@ -653,33 +705,24 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.ctx.Done():
case <-auth.context.Done():
// Service is shutting down, end lockdown
}
auth.lockdown.mu.Lock()
auth.loginMutex.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown.active = false
auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
auth.lockdown.cancelFunc = nil
auth.lockdown.mu.Unlock()
auth.lockdown = nil
auth.loginMutex.Unlock()
}
func (auth *AuthService) IsInLockdown() (bool, int) {
auth.lockdown.mu.RLock()
defer auth.lockdown.mu.RUnlock()
if auth.lockdown.active {
remaining := int(time.Until(auth.lockdown.until).Seconds())
return true, remaining
// Function only used for testing - do not use in prod!
func (auth *AuthService) ClearRateLimitsTestingOnly() {
auth.loginMutex.Lock()
auth.loginAttempts = make(map[string]*LoginAttempt)
if auth.lockdown != nil {
auth.lockdownCancelFunc()
}
return false, 0
}
// mostly a testing function, not useful for anything else
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
auth.loginMutex.Unlock()
}
-39
View File
@@ -1,39 +0,0 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
auth := &AuthService{
log: log,
runtime: model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
Whitelist: []string{"github@example.com"},
},
"pocketid": {
Whitelist: []string{"pocket@example.com"},
},
"gitlab": {
Whitelist: []string{},
},
},
},
}
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com"))
assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com"))
assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com"))
}
-197
View File
@@ -1,197 +0,0 @@
package service
import (
"slices"
"sync"
"time"
)
type CacheStoreActions[T any] struct {
Set func(key string, value T, ttl time.Duration)
Get func(key string) (T, bool)
Delete func(key string)
Update func(key string, value T, ttl time.Duration) bool
}
type cacheEntry[T any] struct {
value T
expiresAt *time.Time
}
type CacheStore[T any] struct {
cache map[string]cacheEntry[T]
order []string
mu sync.RWMutex
maxSize int
}
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
return &CacheStore[T]{
cache: make(map[string]cacheEntry[T]),
order: make([]string, 0),
maxSize: maxSize,
}
}
// With lock allows performing multiple operations on the cache store atomically.
// The provided mutate function receives a set of actions (Set, Get, Delete) that
// can be used to manipulate the cache store within the locked context.
func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) {
cs.mu.Lock()
defer cs.mu.Unlock()
actions := CacheStoreActions[T]{
Set: cs.setCallback,
Get: cs.getCallback,
Delete: cs.deleteCallback,
Update: cs.updateCallback,
}
mutate(actions)
}
func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool {
if currentEntry, exists := cs.cache[key]; exists {
if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) {
return false
}
entry := cacheEntry[T]{
value: value,
expiresAt: currentEntry.expiresAt,
}
if ttl > 0 {
expiration := time.Now().Add(ttl)
entry.expiresAt = &expiration
}
cs.cache[key] = entry
return true
}
return false
}
func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool {
cs.mu.Lock()
defer cs.mu.Unlock()
return cs.updateCallback(key, value, ttl)
}
func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) {
if cs.maxSize > 0 {
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
cs.evictOne()
}
}
var expiresAt *time.Time
if ttl > 0 {
expiration := time.Now().Add(ttl)
expiresAt = &expiration
}
cs.cache[key] = cacheEntry[T]{
value: value,
expiresAt: expiresAt,
}
if !slices.Contains(cs.order, key) {
cs.order = append(cs.order, key)
}
}
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.setCallback(key, value, ttl)
}
func (cs *CacheStore[T]) getCallback(key string) (T, bool) {
entry, exists := cs.cache[key]
if !exists {
var zero T
return zero, false
}
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
var zero T
return zero, false
}
return entry.value, true
}
func (cs *CacheStore[T]) Get(key string) (T, bool) {
cs.mu.RLock()
defer cs.mu.RUnlock()
return cs.getCallback(key)
}
func (cs *CacheStore[T]) deleteCallback(key string) {
delete(cs.cache, key)
keyIdx := slices.Index(cs.order, key)
if keyIdx != -1 {
cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...)
}
}
func (cs *CacheStore[T]) Delete(key string) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.deleteCallback(key)
}
func (cs *CacheStore[T]) Sweep() {
cs.mu.Lock()
for key, entry := range cs.cache {
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
cs.deleteCallback(key)
}
}
cs.mu.Unlock()
}
func (cs *CacheStore[T]) evictOne() bool {
now := time.Now()
var oldestKey string
var oldestExp *time.Time
for k, e := range cs.cache {
if e.expiresAt != nil && now.After(*e.expiresAt) {
cs.deleteCallback(k)
return true
}
if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) {
oldestKey, oldestExp = k, e.expiresAt
}
}
// If we found an oldest key, evict it else we delete the first key in the order list
if oldestKey != "" {
cs.deleteCallback(oldestKey)
return true
} else {
if len(cs.order) > 0 {
cs.deleteCallback(cs.order[0])
return true
}
}
return false
}
func (cs *CacheStore[T]) Size() int {
cs.mu.RLock()
defer cs.mu.RUnlock()
return len(cs.cache)
}
func (cs *CacheStore[T]) Clear() {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.cache = make(map[string]cacheEntry[T])
cs.order = make([]string, 0)
}
-383
View File
@@ -1,383 +0,0 @@
package service
import (
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCacheStoreGet(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
wantValue string
wantOk bool
}{
{
name: "returns a stored value",
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) },
wantValue: "value",
wantOk: true,
},
{
name: "reports a missing key",
setup: func(cs *CacheStore[string]) {},
wantOk: false,
},
{
name: "returns the latest value after an overwrite",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "first", 0)
cs.Set("key", "second", 0)
},
wantValue: "second",
wantOk: true,
},
{
name: "returns a non-expired entry",
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) },
wantValue: "value",
wantOk: true,
},
{
name: "treats an expired entry as missing",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "value", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
value, ok := cs.Get("key")
assert.Equal(t, tt.wantOk, ok)
if tt.wantOk {
assert.Equal(t, tt.wantValue, value)
}
})
}
}
func TestCacheStoreUpdate(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
ttl time.Duration
wantOk bool
afterWait time.Duration
wantPresent bool
wantValue string
}{
{
name: "updates an existing entry",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) },
ttl: 0,
wantOk: true,
wantPresent: true,
wantValue: "new",
},
{
name: "does not create a missing entry",
setup: func(cs *CacheStore[string]) {},
ttl: 0,
wantOk: false,
wantPresent: false,
},
{
name: "preserves the existing expiry when ttl is zero",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) },
ttl: 0,
wantOk: true,
afterWait: 40 * time.Millisecond,
wantPresent: false,
},
{
name: "refreshes the expiry when ttl is provided",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) },
ttl: time.Minute,
wantOk: true,
afterWait: 20 * time.Millisecond,
wantPresent: true,
wantValue: "new",
},
{
name: "does not update an expired entry",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "old", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
ttl: time.Minute,
wantOk: false,
wantPresent: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
ok := cs.Update("key", "new", tt.ttl)
assert.Equal(t, tt.wantOk, ok)
time.Sleep(tt.afterWait)
value, present := cs.Get("key")
assert.Equal(t, tt.wantPresent, present)
if tt.wantPresent {
assert.Equal(t, tt.wantValue, value)
}
})
}
}
func TestCacheStoreDelete(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
key string
wantSize int
}{
{
name: "removes an existing key",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
},
key: "a",
wantSize: 1,
},
{
name: "is a no-op for a missing key",
setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) },
key: "missing",
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
cs.Delete(tt.key)
_, ok := cs.Get(tt.key)
assert.False(t, ok)
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreSweep(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
present []string
absent []string
wantSize int
}{
{
name: "removes expired entries and keeps the rest",
setup: func(cs *CacheStore[string]) {
cs.Set("permanent", "value", 0)
cs.Set("expired", "value", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
present: []string{"permanent"},
absent: []string{"expired"},
wantSize: 1,
},
{
name: "keeps all live entries",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "value", 0)
cs.Set("b", "value", time.Minute)
},
present: []string{"a", "b"},
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
cs.Sweep()
for _, key := range tt.present {
_, ok := cs.Get(key)
assert.True(t, ok)
}
for _, key := range tt.absent {
_, ok := cs.Get(key)
assert.False(t, ok)
}
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreEviction(t *testing.T) {
// Every case uses a cache with maxSize 2; the final Set in setup is the
// insertion that overflows the cache and triggers an eviction.
tests := []struct {
name string
setup func(cs *CacheStore[string])
present []string
absent []string
wantSize int
}{
{
name: "evicts an already expired entry first",
setup: func(cs *CacheStore[string]) {
cs.Set("expired", "value", 10*time.Millisecond)
cs.Set("fresh", "value", time.Minute)
time.Sleep(20 * time.Millisecond)
cs.Set("new", "value", time.Minute)
},
present: []string{"fresh", "new"},
absent: []string{"expired"},
wantSize: 2,
},
{
name: "evicts the entry expiring soonest",
setup: func(cs *CacheStore[string]) {
cs.Set("soon", "value", 50*time.Millisecond)
cs.Set("later", "value", time.Hour)
cs.Set("new", "value", time.Hour)
},
present: []string{"later", "new"},
absent: []string{"soon"},
wantSize: 2,
},
{
name: "evicts the oldest inserted entry when none have a ttl",
setup: func(cs *CacheStore[string]) {
cs.Set("first", "value", 0)
cs.Set("second", "value", 0)
cs.Set("third", "value", 0)
},
present: []string{"second", "third"},
absent: []string{"first"},
wantSize: 2,
},
{
name: "overwriting an existing key does not trigger eviction",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
cs.Set("a", "updated", 0)
},
present: []string{"a", "b"},
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](2)
tt.setup(cs)
for _, key := range tt.present {
_, ok := cs.Get(key)
assert.True(t, ok)
}
for _, key := range tt.absent {
_, ok := cs.Get(key)
assert.False(t, ok)
}
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreSizeAndClear(t *testing.T) {
cs := NewCacheStore[string](0)
assert.Equal(t, 0, cs.Size())
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
assert.Equal(t, 2, cs.Size())
cs.Clear()
assert.Equal(t, 0, cs.Size())
_, ok := cs.Get("a")
assert.False(t, ok)
}
func TestCacheStoreWithLock(t *testing.T) {
cs := NewCacheStore[int](0)
cs.Set("counter", 1, 0)
// All four actions run atomically under a single lock.
cs.WithLock(func(actions CacheStoreActions[int]) {
current, ok := actions.Get("counter")
assert.True(t, ok)
actions.Set("counter", current+1, 0)
actions.Set("other", 100, 0)
actions.Delete("counter")
updated := actions.Update("other", 200, 0)
assert.True(t, updated)
})
_, ok := cs.Get("counter")
assert.False(t, ok)
value, ok := cs.Get("other")
assert.True(t, ok)
assert.Equal(t, 200, value)
}
// TestCacheStoreConcurrency exercises every locking path concurrently so the
// race detector (go test -race) can flag unsynchronised access.
func TestCacheStoreConcurrency(t *testing.T) {
cs := NewCacheStore[int](64)
const goroutines = 16
const iterations = 200
var wg sync.WaitGroup
wg.Add(goroutines)
for g := range goroutines {
go func(g int) {
defer wg.Done()
for i := range iterations {
key := strconv.Itoa((g*iterations + i) % 32)
switch i % 6 {
case 0:
cs.Set(key, i, time.Minute)
case 1:
cs.Get(key)
case 2:
cs.Update(key, i, time.Minute)
case 3:
cs.Delete(key)
case 4:
cs.Size()
case 5:
cs.WithLock(func(actions CacheStoreActions[int]) {
if v, ok := actions.Get(key); ok {
actions.Set(key, v+1, time.Minute)
}
})
}
}
}(g)
}
wg.Wait()
}
+5 -5
View File
@@ -3,8 +3,8 @@ package service
import (
"context"
"strings"
"sync"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -24,7 +24,7 @@ type DockerService struct {
func NewDockerService(
log *logger.Logger,
ctx context.Context,
dg *ding.Ding,
wg *sync.WaitGroup,
) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv)
@@ -50,7 +50,7 @@ func NewDockerService(
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
dg.Go(service.watchAndClose, ding.RingMajor)
wg.Go(service.watchAndClose)
return service, nil
}
@@ -108,8 +108,8 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
return nil, nil
}
func (docker *DockerService) watchAndClose(ctx context.Context) {
<-ctx.Done()
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
+17 -88
View File
@@ -3,12 +3,10 @@ package service
import (
"context"
"fmt"
"slices"
"strings"
"sync"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
@@ -39,6 +37,7 @@ type ingressApp struct {
type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface
started bool
@@ -51,7 +50,7 @@ type KubernetesService struct {
func NewKubernetesService(
log *logger.Logger,
ctx context.Context,
dg *ding.Ding,
wg *sync.WaitGroup,
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
@@ -82,15 +81,16 @@ func NewKubernetesService(
service := &KubernetesService{
log: log,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
dg.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx)
}, ding.RingMajor)
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
@@ -167,68 +167,6 @@ func (k *KubernetesService) getByAppName(appName string) *model.App {
return nil
}
func (k *KubernetesService) extractPaths(rule map[string]any) ([]string, error) {
http, found, err := unstructured.NestedMap(rule, "http")
if err != nil {
return nil, fmt.Errorf("reading http from rule: %w", err)
}
if !found {
return nil, nil
}
paths, found, err := unstructured.NestedSlice(http, "paths")
if err != nil {
return nil, fmt.Errorf("reading http.paths: %w", err)
}
if !found {
return nil, nil
}
var result []string
for _, p := range paths {
path, ok := p.(map[string]any)
if !ok {
continue
}
if p, ok := path["path"].(string); ok && p != "" {
result = append(result, p)
}
}
return result, nil
}
func (k *KubernetesService) extractHosts(item *unstructured.Unstructured) ([]string, error) {
rules, found, err := unstructured.NestedSlice(item.Object, "spec", "rules")
if err != nil {
return nil, fmt.Errorf("reading spec.rules: %w", err)
}
if !found {
return nil, nil
}
var hosts []string
for _, r := range rules {
rule, ok := r.(map[string]any)
if !ok {
continue
}
if host, ok := rule["host"].(string); ok && host != "" {
hosts = append(hosts, host)
}
paths, err := k.extractPaths(rule)
if err != nil {
// This is purely to warn users, it doesn't affect our ability to extract hosts so we won't fail the whole operation
k.log.App.Warn().Err(err).Str("namespace", item.GetNamespace()).Str("name", item.GetName()).Msg("Failed to extract paths from ingress rule")
continue
}
if len(paths) == 0 {
continue
}
if !slices.Contains(paths, "/") {
k.log.App.Warn().Str("namespace", item.GetNamespace()).Str("name", item.GetName()).Strs("paths", paths).Msg("Ingress rule does not contain a catch-all path, another ingress may be able to bypass auth checks if it routes the same host with a different path. Consider adding a catch-all path to this rule to ensure auth checks are applied to all paths for this host.")
}
}
k.log.App.Trace().Strs("hosts", hosts).Msg("Extracted hosts from ingress rules")
return hosts, nil
}
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
namespace := item.GetNamespace()
name := item.GetName()
@@ -237,11 +175,6 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
k.removeIngress(namespace, name)
return
}
hosts, err := k.extractHosts(item)
if err != nil {
k.removeIngress(namespace, name)
return
}
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil {
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
@@ -253,10 +186,6 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
if appLabels.Config.Domain == "" {
continue
}
if len(hosts) > 0 && !slices.Contains(hosts, appLabels.Config.Domain) {
k.log.App.Warn().Str("namespace", namespace).Str("name", name).Str("appName", appName).Str("domain", appLabels.Config.Domain).Msg("App domain does not match any hosts defined in ingress rules, skipping")
continue
}
apps = append(apps, ingressApp{
domain: appLabels.Config.Domain,
appName: appName,
@@ -270,8 +199,8 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
}
}
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
defer cancel()
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
@@ -288,10 +217,10 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource, ctx conte
// runWatcher drains events from an active watcher until it closes or the context is done.
// Returns true if the caller should restart the watcher, false if it should exit.
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker, ctx context.Context) bool {
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
for {
select {
case <-ctx.Done():
case <-k.ctx.Done():
w.Stop()
return false
case event, ok := <-w.ResultChan():
@@ -313,33 +242,33 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
k.removeIngress(item.GetNamespace(), item.GetName())
}
case <-resyncTicker.C:
if err := k.resyncGVR(gvr, ctx); err != nil {
if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
}
}
}
}
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx context.Context) {
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
resyncTicker := time.NewTicker(5 * time.Minute)
defer resyncTicker.Stop()
if err := k.resyncGVR(gvr, ctx); err != nil {
if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second)
}
for {
select {
case <-ctx.Done():
case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr, ctx); err != nil {
if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
}
default:
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
@@ -348,7 +277,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource, ctx contex
continue
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker, ctx) {
if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel()
return
}
+11 -9
View File
@@ -9,14 +9,14 @@ import (
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LdapService struct {
log *logger.Logger
config model.Config
log *logger.Logger
config model.Config
context context.Context
conn *ldapgo.Conn
mutex sync.RWMutex
@@ -26,15 +26,17 @@ type LdapService struct {
func NewLdapService(
log *logger.Logger,
config model.Config,
dg *ding.Ding,
ctx context.Context,
wg *sync.WaitGroup,
) (*LdapService, error) {
if config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
log: log,
config: config,
log: log,
config: config,
context: ctx,
}
// Check whether authentication with client certificate is possible
@@ -67,7 +69,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
dg.Go(func(ctx context.Context) {
wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -85,12 +87,12 @@ func NewLdapService(
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
}
case <-ctx.Done():
case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
}
}
}, ding.RingMajor)
})
return ldap, nil
}
+2 -2
View File
@@ -17,7 +17,7 @@ type GithubEmailResponse []struct {
Verified bool `json:"verified"`
}
type GithubUserinfoResponse struct {
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
@@ -30,7 +30,7 @@ func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
var user model.Claims
userInfo, err := simpleReq[GithubUserinfoResponse](client, "https://api.github.com/user", map[string]string{
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
+3 -3
View File
@@ -10,13 +10,13 @@ import (
"golang.org/x/oauth2"
)
type OAuthUserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type OAuthService struct {
serviceCfg model.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor OAuthUserinfoExtractor
userinfoExtractor UserinfoExtractor
id string
}
@@ -50,7 +50,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
}
}
func (s *OAuthService) WithUserinfoExtractor(extractor OAuthUserinfoExtractor) *OAuthService {
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
s.userinfoExtractor = extractor
return s
}
+189 -191
View File
@@ -15,12 +15,13 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"slices"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -41,10 +42,6 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
// for better compatibility with existing apps
type ClaimSet struct {
Iss string `json:"iss"`
Aud string `json:"aud"`
@@ -70,8 +67,6 @@ type ClaimSet struct {
Nonce string `json:"nonce,omitempty"`
}
// We use this struct as both a response struct and a struct to store userinfo
// in the database
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
@@ -116,35 +111,17 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"`
}
type AuthorizeCodeEntry struct {
CodeHash string
Scope string
RedirectURI string
ClientID string
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
}
type UsedCodeEntry struct {
Sub string
}
type OIDCService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
queries repository.Store
context context.Context
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
issuer string
caches struct {
code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry]
}
}
func NewOIDCService(
@@ -152,7 +129,8 @@ func NewOIDCService(
config model.Config,
runtime model.RuntimeConfig,
queries repository.Store,
dg *ding.Ding) (*OIDCService, error) {
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
if len(runtime.OIDCClients) == 0 {
return nil, nil
@@ -298,6 +276,7 @@ func NewOIDCService(
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
@@ -306,29 +285,7 @@ func NewOIDCService(
}
// Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256)
service.caches.code = codeCash
service.caches.usedCode = usedCode
// Start cache cleanup routine
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.code.Sweep()
service.caches.usedCode.Sweep()
case <-ctx.Done():
return
}
}
}, ding.RingMinor)
wg.Go(service.cleanupRoutine)
return service, nil
}
@@ -391,17 +348,19 @@ func (service *OIDCService) filterScopes(scopes []string) []string {
})
}
func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string {
code := utils.GenerateString(32)
sub := service.CreateSub(userContext, req.ClientID)
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
// Fixed 10 minutes
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
entry := AuthorizeCodeEntry{
CodeHash: service.Hash(code),
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "),
entry := repository.CreateOidcCodeParams{
Sub: sub,
CodeHash: service.Hash(code),
// Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
RedirectURI: req.RedirectURI,
ClientID: req.ClientID,
ExpiresAt: expiresAt,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
}
if req.CodeChallenge != "" {
@@ -413,14 +372,14 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
}
}
// Store the code in the cache
service.caches.code.Set(entry.CodeHash, entry, 1*time.Minute)
// Insert the code into the database
_, err := service.queries.CreateOidcCode(c, entry)
return code
return err
}
func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse {
userInfo := UserinfoResponse{
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
Name: userContext.GetName(),
Email: userContext.GetEmail(),
@@ -429,31 +388,37 @@ func (service *OIDCService) userinfoFromContext(userContext model.UserContext, s
}
if userContext.IsLocal() {
userInfo.GivenName = userContext.Local.Attributes.GivenName
userInfo.FamilyName = userContext.Local.Attributes.FamilyName
userInfo.MiddleName = userContext.Local.Attributes.MiddleName
userInfo.Nickname = userContext.Local.Attributes.Nickname
userInfo.Profile = userContext.Local.Attributes.Profile
userInfo.Picture = userContext.Local.Attributes.Picture
userInfo.Website = userContext.Local.Attributes.Website
userInfo.Gender = userContext.Local.Attributes.Gender
userInfo.Birthdate = userContext.Local.Attributes.Birthdate
userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfo.Locale = userContext.Local.Attributes.Locale
userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfo.Address = &userContext.Local.Attributes.Address
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
if err != nil {
return err
}
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
userInfoParams.Profile = userContext.Local.Attributes.Profile
userInfoParams.Picture = userContext.Local.Attributes.Picture
userInfoParams.Website = userContext.Local.Attributes.Website
userInfoParams.Gender = userContext.Local.Attributes.Gender
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfoParams.Locale = userContext.Local.Attributes.Locale
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfoParams.Address = string(addressJSON)
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.IsLDAP() {
userInfo.Groups = userContext.LDAP.Groups
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
}
if userContext.IsOAuth() {
userInfo.Groups = userContext.OAuth.Groups
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
}
return userInfo
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
func (service *OIDCService) ValidateGrantType(grantType string) error {
@@ -464,34 +429,36 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
return nil
}
func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) {
var entry AuthorizeCodeEntry
var ok bool
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
service.caches.code.WithLock(func(actions CacheStoreActions[AuthorizeCodeEntry]) {
entry, ok = actions.Get(codeHash)
if !ok {
return
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcCode{}, ErrCodeNotFound
}
if entry.ClientID != clientId {
ok = false
return
}
// Since the code can only be used once, we delete it from the cache after retrieving it
actions.Delete(codeHash)
})
if !ok {
return nil, false
return repository.OidcCode{}, err
}
return &entry, true
if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil {
return repository.OidcCode{}, err
}
err = service.DeleteUserinfo(c, oidcCode.Sub)
if err != nil {
return repository.OidcCode{}, err
}
return repository.OidcCode{}, ErrCodeExpired
}
if oidcCode.ClientID != clientId {
return repository.OidcCode{}, ErrInvalidClient
}
return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -557,11 +524,17 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil {
return nil, err
return TokenResponse{}, err
}
idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce)
if err != nil {
return TokenResponse{}, err
}
accessToken := utils.GenerateString(32)
@@ -581,68 +554,56 @@ func (service *OIDCService) GenerateAccessToken(ctx context.Context, client mode
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
}
var userInfoJson []byte
userInfoJson, err = json.Marshal(codeEntry.Userinfo)
if err != nil {
return nil, err
}
_, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: codeEntry.Userinfo.Sub,
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: codeEntry.Sub,
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken),
Scope: codeEntry.Scope,
ClientID: client.ClientID,
Scope: codeEntry.Scope,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: codeEntry.Nonce,
UserinfoJson: string(userInfoJson),
CodeHash: codeEntry.CodeHash,
})
if err != nil {
return nil, err
return TokenResponse{}, err
}
return &tokenResponse, nil
return tokenResponse, nil
}
func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) {
entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken))
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, ErrTokenNotFound
return TokenResponse{}, ErrTokenNotFound
}
return nil, err
return TokenResponse{}, err
}
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return nil, ErrTokenExpired
return TokenResponse{}, ErrTokenExpired
}
// Ensure the client ID in the request matches the client ID in the token
if entry.ClientID != clientId {
return nil, ErrInvalidClient
if entry.ClientID != reqClientId {
return TokenResponse{}, ErrInvalidClient
}
// we need to unmarshal the userinfo from the database to include it in the new ID token,
// since the ID token includes user claims for better compatibility with existing apps
var userInfo UserinfoResponse
err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo)
user, err := service.GetUserinfo(c, entry.Sub)
if err != nil {
return nil, err
return TokenResponse{}, err
}
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce)
}, user, entry.Scope, entry.Nonce)
if err != nil {
return nil, err
return TokenResponse{}, err
}
accessToken := utils.GenerateString(32)
@@ -660,54 +621,71 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
_, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
Sub: entry.Sub,
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
Scope: entry.Scope,
ClientID: entry.ClientID,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: entry.Nonce,
UserinfoJson: entry.UserinfoJson,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
})
if err != nil {
return nil, err
return TokenResponse{}, err
}
return &tokenResponse, nil
return tokenResponse, nil
}
func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) {
entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash)
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcCode(c, codeHash)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, tokenHash)
}
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, ErrTokenNotFound
return repository.OidcToken{}, ErrTokenNotFound
}
return nil, err
return repository.OidcToken{}, err
}
if entry.TokenExpiresAt < time.Now().Unix() {
// If refresh token is expired, delete the session
// since there is no way for the client to access anything anymore
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
// Deletes by sub
err := service.queries.DeleteOIDCSessionBySub(ctx, entry.Sub)
err := service.DeleteToken(c, tokenHash)
if err != nil {
return nil, err
return repository.OidcToken{}, err
}
err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
}
return nil, ErrTokenExpired
}
return nil, ErrTokenExpired
return repository.OidcToken{}, ErrTokenExpired
}
return &entry, nil
return entry, nil
}
func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse {
scopes := strings.Split(scope, " ")
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
return service.queries.GetOidcUserInfo(c, sub)
}
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
userInfo := UserinfoResponse{
Sub: user.Sub,
UpdatedAt: user.UpdatedAt,
@@ -735,7 +713,11 @@ func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string)
}
if slices.Contains(scopes, "groups") {
userInfo.Groups = user.Groups
if user.Groups != "" {
userInfo.Groups = strings.Split(user.Groups, ",")
} else {
userInfo.Groups = []string{}
}
}
if slices.Contains(scopes, "phone") {
@@ -745,7 +727,10 @@ func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string)
}
if slices.Contains(scopes, "address") {
userInfo.Address = user.Address
var addr model.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr
}
}
return userInfo
@@ -758,16 +743,25 @@ func (service *OIDCService) Hash(token string) string {
}
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOIDCSessionBySub(ctx, sub)
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
return nil
}
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
for {
@@ -777,18 +771,50 @@ func (service *OIDCService) cleanupRoutine(ctx context.Context) {
currentTime := time.Now().Unix()
// Limitation of sqlc, meaning we need to specify a timestamp for both token and refresh token expiry
err := service.queries.DeleteExpiredOIDCSessions(ctx, repository.DeleteExpiredOIDCSessionsParams{
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired OIDC sessions")
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
if !errors.Is(err, repository.ErrNotFound) {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
continue
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-ctx.Done():
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
@@ -828,31 +854,3 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher.Write([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes.
// We will just create a uuid out of the username and client name which remains stable,
// but if username or client name changes then sub changes too.
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
}
func (service *OIDCService) IsCodeUsed(codeHash string) (string, bool) {
entry, ok := service.caches.usedCode.Get(codeHash)
if !ok {
return "", false
}
return entry.Sub, true
}
func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
entry := UsedCodeEntry{
Sub: sub,
}
service.caches.usedCode.Set(codeHash, entry, 2*time.Minute)
}
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
}
+46 -25
View File
@@ -2,24 +2,36 @@ package service_test
import (
"context"
"encoding/json"
"sync"
"testing"
"github.com/steveiliop56/ding"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() service.UserinfoResponse {
return service.UserinfoResponse{
func newTestUser() repository.OidcUserinfo {
addr := model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
Region: "IL",
PostalCode: "62701",
Country: "US",
}
addrJSON, _ := json.Marshal(addr)
return repository.OidcUserinfo{
Sub: "test-sub",
Name: "Test User",
PreferredUsername: "testuser",
Email: "test@example.com",
Groups: []string{"admins", "users"},
Groups: "admins,users",
UpdatedAt: 1234567890,
GivenName: "Test",
FamilyName: "User",
@@ -33,14 +45,7 @@ func newTestUser() service.UserinfoResponse {
Zoneinfo: "America/Chicago",
Locale: "en-US",
PhoneNumber: "+15555550100",
Address: &model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
Region: "IL",
PostalCode: "62701",
Country: "US",
},
Address: string(addrJSON),
}
}
@@ -65,14 +70,14 @@ func TestCompileUserinfo(t *testing.T) {
log.Init()
ctx := context.TODO()
dg := ding.New(ctx)
wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err)
type testCase struct {
description string
mutate func(u *service.UserinfoResponse)
mutate func(u *repository.OidcUserinfo)
scope string
run func(t *testing.T, info service.UserinfoResponse)
}
@@ -93,7 +98,7 @@ func TestCompileUserinfo(t *testing.T) {
},
{
description: "profile scope returns all profile fields",
scope: "openid profile",
scope: "openid,profile",
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername)
@@ -113,7 +118,7 @@ func TestCompileUserinfo(t *testing.T) {
},
{
description: "email scope sets email and email_verified true when email present",
scope: "openid email",
scope: "openid,email",
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified)
@@ -122,8 +127,8 @@ func TestCompileUserinfo(t *testing.T) {
},
{
description: "email scope sets email_verified false when email absent",
scope: "openid email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
scope: "openid,email",
mutate: func(u *repository.OidcUserinfo) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified)
@@ -131,7 +136,7 @@ func TestCompileUserinfo(t *testing.T) {
},
{
description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone",
scope: "openid,phone",
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified)
@@ -140,8 +145,8 @@ func TestCompileUserinfo(t *testing.T) {
},
{
description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
scope: "openid,phone",
mutate: func(u *repository.OidcUserinfo) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified)
@@ -149,7 +154,7 @@ func TestCompileUserinfo(t *testing.T) {
},
{
description: "address scope returns parsed address",
scope: "openid address",
scope: "openid,address",
run: func(t *testing.T, info service.UserinfoResponse) {
require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted)
@@ -160,16 +165,32 @@ func TestCompileUserinfo(t *testing.T) {
assert.Equal(t, "US", info.Address.Country)
},
},
{
description: "address scope with invalid JSON omits address",
scope: "openid,address",
mutate: func(u *repository.OidcUserinfo) { u.Address = "not-valid-json" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Nil(t, info.Address)
},
},
{
description: "groups scope returns split groups",
scope: "openid groups",
scope: "openid,groups",
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups)
},
},
{
description: "groups scope returns empty slice when no groups",
scope: "openid,groups",
mutate: func(u *repository.OidcUserinfo) { u.Groups = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, []string{}, info.Groups)
},
},
{
description: "all scopes return all fields",
scope: "openid profile email phone address groups",
scope: "openid,profile,email,phone,address,groups",
run: func(t *testing.T, info service.UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email)
-4
View File
@@ -108,7 +108,3 @@ func (engine *PolicyEngine) Policy() Policy {
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
return engine.rules
}
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
return engine.effectToAccess(f())
}
+9 -15
View File
@@ -9,7 +9,6 @@ import (
"sync"
"time"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"tailscale.com/client/local"
@@ -21,10 +20,12 @@ type TailscaleWhoisResponse struct {
LoginName string
DisplayName string
NodeName string
Tags []string
}
type TailscaleService struct {
log *logger.Logger
wg *sync.WaitGroup
config model.Config
ctx context.Context
@@ -34,7 +35,7 @@ type TailscaleService struct {
mu sync.Mutex
}
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, wg *sync.WaitGroup) (*TailscaleService, error) {
if !config.Tailscale.Enabled {
return nil, nil
}
@@ -66,6 +67,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
service := &TailscaleService{
log: log,
wg: wg,
config: config,
ctx: ctx,
srv: srv,
@@ -82,13 +84,13 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
}
dg.Go(service.watchAndClose, ding.RingMajor)
wg.Go(service.watchAndClose)
return service, nil
}
func (ts *TailscaleService) watchAndClose(ctx context.Context) {
<-ctx.Done()
func (ts *TailscaleService) watchAndClose() {
<-ts.ctx.Done()
ts.log.App.Debug().Msg("Shutting down Tailscale service")
ts.mu.Lock()
srv := ts.srv
@@ -114,22 +116,14 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
return nil, fmt.Errorf("failed to get client whois: %w", err)
}
if who.Node.IsTagged() {
ts.log.App.Debug().Msgf("Skipping whois for tagged node %s", who.Node.Name)
return nil, nil
}
uid := strings.TrimPrefix(who.UserProfile.ID.String(), "userid:")
res := TailscaleWhoisResponse{
UserID: uid,
UserID: who.UserProfile.ID.String(),
LoginName: who.UserProfile.LoginName,
DisplayName: who.UserProfile.DisplayName,
NodeName: strings.TrimSuffix(who.Node.Name, "."),
Tags: who.Node.Tags,
}
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil
}
+5 -2
View File
@@ -3,6 +3,7 @@ package loaders
import (
"os"
"github.com/rs/zerolog/log"
"github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/paerser/file"
"github.com/tinyauthapp/paerser/flag"
@@ -18,8 +19,8 @@ func (f *FileLoader) Load(args []string, cmd *cli.Command) (bool, error) {
}
// I guess we are using traefik as the root name (we can't change it)
configFileFlag := "traefik.configfile"
envVar := "TINYAUTH_CONFIGFILE"
configFileFlag := "traefik.experimental.configfile"
envVar := "TINYAUTH_EXPERIMENTAL_CONFIGFILE"
if _, ok := flags[configFileFlag]; !ok {
if value := os.Getenv(envVar); value != "" {
@@ -29,6 +30,8 @@ func (f *FileLoader) Load(args []string, cmd *cli.Command) (bool, error) {
}
}
log.Warn().Msg("Using experimental file config loader, this feature is experimental and may change or be removed in future releases")
err = file.Decode(flags[configFileFlag], cmd.Configuration)
if err != nil {
+1 -6
View File
@@ -3,7 +3,6 @@ package utils
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net"
"regexp"
@@ -12,10 +11,6 @@ import (
"github.com/google/uuid"
)
var (
ErrFilterEmpty = errors.New("filter is empty")
)
func GetSecret(conf string, file string) string {
if conf == "" && file == "" {
return ""
@@ -83,7 +78,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
func CheckFilter(filter string, input string) (bool, error) {
if len(strings.TrimSpace(filter)) == 0 {
return false, ErrFilterEmpty
return false, fmt.Errorf("filter is empty")
}
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
-48
View File
@@ -1,48 +0,0 @@
-- name: GetOIDCSessionBySub :one
SELECT * FROM "oidc_sessions"
WHERE "sub" = $1;
-- name: GetOIDCSessionByAccessTokenHash :one
SELECT * FROM "oidc_sessions"
WHERE "access_token_hash" = $1;
-- name: GetOIDCSessionByRefreshTokenHash :one
SELECT * FROM "oidc_sessions"
WHERE "refresh_token_hash" = $1;
-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" (
"sub",
"access_token_hash",
"refresh_token_hash",
"scope",
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"nonce",
"userinfo_json"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9
)
RETURNING *;
-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions"
WHERE "sub" = $1;
-- name: DeleteExpiredOIDCSessions :exec
DELETE FROM "oidc_sessions"
WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2;
-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET
"access_token_hash" = $1,
"refresh_token_hash" = $2,
"scope" = $3,
"client_id" = $4,
"token_expires_at" = $5,
"refresh_token_expires_at" = $6,
"nonce" = $7,
"userinfo_json" = $8
WHERE "sub" = $9
RETURNING *;
-11
View File
@@ -1,11 +0,0 @@
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"access_token_hash" TEXT NOT NULL UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" BIGINT NOT NULL,
"refresh_token_expires_at" BIGINT NOT NULL,
"nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL
);
-43
View File
@@ -1,43 +0,0 @@
-- name: CreateSession :one
INSERT INTO "sessions" (
"uuid",
"username",
"email",
"name",
"provider",
"totp_pending",
"oauth_groups",
"expiry",
"created_at",
"oauth_name",
"oauth_sub"
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11
)
RETURNING *;
-- name: GetSession :one
SELECT * FROM "sessions"
WHERE "uuid" = $1;
-- name: DeleteSession :exec
DELETE FROM "sessions"
WHERE "uuid" = $1;
-- name: UpdateSession :one
UPDATE "sessions" SET
"username" = $1,
"email" = $2,
"name" = $3,
"provider" = $4,
"totp_pending" = $5,
"oauth_groups" = $6,
"expiry" = $7,
"oauth_name" = $8,
"oauth_sub" = $9
WHERE "uuid" = $10
RETURNING *;
-- name: DeleteExpiredSessions :exec
DELETE FROM "sessions"
WHERE "expiry" < $1;
-13
View File
@@ -1,13 +0,0 @@
CREATE TABLE IF NOT EXISTS "sessions" (
"uuid" TEXT NOT NULL PRIMARY KEY,
"username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"name" TEXT NOT NULL,
"provider" TEXT NOT NULL,
"totp_pending" BOOLEAN NOT NULL,
"oauth_groups" TEXT NOT NULL DEFAULT '',
"expiry" BIGINT NOT NULL,
"created_at" BIGINT NOT NULL,
"oauth_name" TEXT NOT NULL DEFAULT '',
"oauth_sub" TEXT NOT NULL DEFAULT ''
);
+113 -28
View File
@@ -1,17 +1,46 @@
-- name: GetOIDCSessionBySub :one
SELECT * FROM "oidc_sessions"
-- name: CreateOidcCode :one
INSERT INTO "oidc_codes" (
"sub",
"code_hash",
"scope",
"redirect_uri",
"client_id",
"expires_at",
"nonce",
"code_challenge"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: GetOidcCodeUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: GetOidcCode :one
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?
RETURNING *;
-- name: GetOidcCodeBySubUnsafe :one
SELECT * FROM "oidc_codes"
WHERE "sub" = ?;
-- name: GetOIDCSessionByAccessTokenHash :one
SELECT * FROM "oidc_sessions"
WHERE "access_token_hash" = ?;
-- name: GetOidcCodeBySub :one
DELETE FROM "oidc_codes"
WHERE "sub" = ?
RETURNING *;
-- name: GetOIDCSessionByRefreshTokenHash :one
SELECT * FROM "oidc_sessions"
WHERE "refresh_token_hash" = ?;
-- name: DeleteOidcCode :exec
DELETE FROM "oidc_codes"
WHERE "code_hash" = ?;
-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" (
-- name: DeleteOidcCodeBySub :exec
DELETE FROM "oidc_codes"
WHERE "sub" = ?;
-- name: CreateOidcToken :one
INSERT INTO "oidc_tokens" (
"sub",
"access_token_hash",
"refresh_token_hash",
@@ -19,30 +48,86 @@ INSERT INTO "oidc_sessions" (
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"nonce",
"userinfo_json"
"code_hash",
"nonce"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions"
WHERE "sub" = ?;
-- name: DeleteExpiredOIDCSessions :exec
DELETE FROM "oidc_sessions"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?;
-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET
-- name: UpdateOidcTokenByRefreshToken :one
UPDATE "oidc_tokens" SET
"access_token_hash" = ?,
"refresh_token_hash" = ?,
"scope" = ?,
"client_id" = ?,
"token_expires_at" = ?,
"refresh_token_expires_at" = ?,
"nonce" = ?,
"userinfo_json" = ?
WHERE "sub" = ?
"refresh_token_expires_at" = ?
WHERE "refresh_token_hash" = ?
RETURNING *;
-- name: GetOidcToken :one
SELECT * FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: GetOidcTokenByRefreshToken :one
SELECT * FROM "oidc_tokens"
WHERE "refresh_token_hash" = ?;
-- name: GetOidcTokenBySub :one
SELECT * FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?;
-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
-- name: DeleteOidcTokenBySub :exec
DELETE FROM "oidc_tokens"
WHERE "sub" = ?;
-- name: CreateOidcUserInfo :one
INSERT INTO "oidc_userinfo" (
"sub",
"name",
"preferred_username",
"email",
"groups",
"updated_at",
"given_name",
"family_name",
"middle_name",
"nickname",
"profile",
"picture",
"website",
"gender",
"birthdate",
"zoneinfo",
"locale",
"phone_number",
"address"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
-- name: GetOidcUserInfo :one
SELECT * FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteOidcUserInfo :exec
DELETE FROM "oidc_userinfo"
WHERE "sub" = ?;
-- name: DeleteExpiredOidcCodes :many
DELETE FROM "oidc_codes"
WHERE "expires_at" < ?
RETURNING *;
-- name: DeleteExpiredOidcTokens :many
DELETE FROM "oidc_tokens"
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
RETURNING *;
+39 -6
View File
@@ -1,11 +1,44 @@
CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"access_token_hash" TEXT NOT NULL UNIQUE,
"refresh_token_hash" TEXT NOT NULL UNIQUE,
CREATE TABLE IF NOT EXISTS "oidc_codes" (
"sub" TEXT NOT NULL UNIQUE,
"code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"scope" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"expires_at" INTEGER NOT NULL,
"nonce" TEXT DEFAULT "",
"code_challenge" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
"refresh_token_expires_at" INTEGER NOT NULL,
"nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL
"nonce" TEXT DEFAULT ""
);
CREATE TABLE IF NOT EXISTS "oidc_userinfo" (
"sub" TEXT NOT NULL UNIQUE PRIMARY KEY,
"name" TEXT NOT NULL,
"preferred_username" TEXT NOT NULL,
"email" TEXT NOT NULL,
"groups" TEXT NOT NULL,
"updated_at" INTEGER NOT NULL,
"given_name" TEXT NOT NULL,
"family_name" TEXT NOT NULL,
"middle_name" TEXT NOT NULL,
"nickname" TEXT NOT NULL,
"profile" TEXT NOT NULL,
"picture" TEXT NOT NULL,
"website" TEXT NOT NULL,
"gender" TEXT NOT NULL,
"birthdate" TEXT NOT NULL,
"zoneinfo" TEXT NOT NULL,
"locale" TEXT NOT NULL,
"phone_number" TEXT NOT NULL,
"address" TEXT NOT NULL
);
+5 -14
View File
@@ -22,18 +22,9 @@ sql:
go_type: "string"
- column: "sessions.ldap_groups"
go_type: "string"
- column: "oidc_sessions.nonce"
- column: "oidc_codes.nonce"
go_type: "string"
- column: "oidc_tokens.nonce"
go_type: "string"
- column: "oidc_codes.code_challenge"
go_type: "string"
- engine: "postgresql"
queries: "sql/postgres/*_queries.sql"
schema: "sql/postgres/*_schemas.sql"
gen:
go:
package: "postgres"
out: "internal/repository/postgres"
rename:
uuid: "UUID"
oauth_groups: "OAuthGroups"
oauth_name: "OAuthName"
oauth_sub: "OAuthSub"
redirect_uri: "RedirectURI"