Compare commits

..

9 Commits

Author SHA1 Message Date
Stavros
407ca9a047 chore: set permissions for integration workflow 2026-03-14 20:28:05 +02:00
Stavros
a3d310667b feat: add workflow for ci 2026-03-14 20:24:37 +02:00
Stavros
bd57c480f3 chore: fix merge typo 2026-03-14 19:57:50 +02:00
Stavros
fcf0df348a Merge branch 'main' into feat/proxy-integration-tests 2026-03-14 19:57:13 +02:00
Stavros
1feffaa4d2 fix: fix envoy tests 2026-03-14 12:09:18 +02:00
Stavros
c2a57b57b8 Merge branch 'refactor/proxy-controller' into feat/proxy-integration-tests 2026-03-13 23:08:18 +02:00
Stavros
5b2fa47c0e wip 2026-03-13 23:04:47 +02:00
Stavros
13bb3ca682 fix: add extauthz to friendly error messages 2026-03-13 23:03:11 +02:00
Stavros
81cb695d0b wip 2026-03-13 17:45:11 +02:00
31 changed files with 1272 additions and 750 deletions

82
.github/workflows/integration.yml vendored Normal file
View File

@@ -0,0 +1,82 @@
name: Proxy Integration Tests
on:
workflow_dispatch:
inputs:
build:
type: boolean
description: Build from current tag before running tests
permissions:
contents: read
jobs:
determine-tag:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.metadata.outputs.VERSION }}
commit_hash: ${{ steps.metadata.outputs.COMMIT_HASH }}
build_timestamp: ${{ steps.metadata.outputs.BUILD_TIMESTAMP }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Determine metadata
id: metadata
run: |
# Generate static metadata
echo "COMMIT_HASH=$(git rev-parse HEAD)" >> "$GITHUB_OUTPUT"
echo "BUILD_TIMESTAMP=$(date '+%Y-%m-%dT%H:%M:%S')" >> "$GITHUB_OUTPUT"
# Determine version
if [ "${{ inputs.build }}" == "true" ]; then
echo "VERSION=development" >> "$GITHUB_OUTPUT"
else
echo "VERSION=${{ github.ref_name }}" >> "$GITHUB_OUTPUT"
fi
integration:
runs-on: ubuntu-latest
needs: determine-tag
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version: "1.25.0"
- name: Initialize submodules
run: |
git submodule init
git submodule update
- name: Apply patches
run: |
git apply --directory paerser/ patches/nested_maps.diff
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
if: ${{ inputs.build == 'true' }}
- name: Build
uses: docker/build-push-action@v6
if: ${{ inputs.build == 'true' }}
with:
platforms: linux/amd64
tags: ghcr.io/${{ github.repository_owner }}/tinyauth:${{ needs.determine-tag.outputs.version }}
outputs: type=image,push=false
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: |
VERSION=${{ needs.determine-tag.outputs.version }}
COMMIT_HASH=${{ needs.determine-tag.outputs.commit_hash }}
BUILD_TIMESTAMP=${{ needs.determine-tag.outputs.build_timestamp }}
- name: Set tinyauth version
run: |
sed -i "s/TINYAUTH_VERSION=.*/TINYAUTH_VERSION=${{ needs.determine-tag.outputs.version }}/" integration/.env
- name: Test
run: make integration

View File

@@ -81,5 +81,11 @@ sql:
sqlc generate
# Go gen
.PHONY: generate
generate:
go run ./gen
# Proxy integration tests
.PHONY: integration
integration:
go run ./integration -- --log=false

View File

@@ -37,7 +37,7 @@
"devDependencies": {
"@eslint/js": "^10.0.1",
"@tanstack/eslint-plugin-query": "^5.91.4",
"@types/node": "^25.5.0",
"@types/node": "^25.4.0",
"@types/react": "^19.2.14",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.4",
@@ -417,7 +417,7 @@
"@types/ms": ["@types/ms@2.1.0", "", {}, "sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA=="],
"@types/node": ["@types/node@25.5.0", "", { "dependencies": { "undici-types": "~7.18.0" } }, "sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw=="],
"@types/node": ["@types/node@25.4.0", "", { "dependencies": { "undici-types": "~7.18.0" } }, "sha512-9wLpoeWuBlcbBpOY3XmzSTG3oscB6xjBEEtn+pYXTfhyXhIxC5FsBer2KTopBlvKEiW9l13po9fq+SJY/5lkhw=="],
"@types/react": ["@types/react@19.2.14", "", { "dependencies": { "csstype": "^3.2.2" } }, "sha512-ilcTH/UniCkMdtexkoCN0bI7pMcJDvmQFPvuPvmEaYA/NSfFTAgdUSLAoVjaRJm7+6PvcM+q1zYOwS4wTYMF9w=="],

View File

@@ -43,7 +43,7 @@
"devDependencies": {
"@eslint/js": "^10.0.1",
"@tanstack/eslint-plugin-query": "^5.91.4",
"@types/node": "^25.5.0",
"@types/node": "^25.4.0",
"@types/react": "^19.2.14",
"@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.4",

16
integration/.env Normal file
View File

@@ -0,0 +1,16 @@
# Nginx configuration
NGINX_VERSION=1.29
# Whoami configuration
WHOAMI_VERSION=latest
# Traefik configuration
TRAEFIK_VERSION=v3.6
TINYAUTH_HOST=tinyauth.127.0.0.1.sslip.io
WHOAMI_HOST=whoami.127.0.0.1.sslip.io
# Envoy configuration
ENVOY_VERSION=v1.33-latest
# Tinyauth configuration
TINYAUTH_VERSION=latest

15
integration/config.yml Normal file
View File

@@ -0,0 +1,15 @@
appUrl: http://tinyauth.127.0.0.1.sslip.io
auth:
users: test:$2a$10$eG88kFow83l5YRSlTSL2o.sZimjxFHrpiKdaSUZqpLBGX7Y2.4PZG
log:
json: true
level: trace
apps:
whoami:
config:
domain: whoami.127.0.0.1.sslip.io
path:
allow: /allow

View File

@@ -0,0 +1,16 @@
services:
envoy:
image: envoyproxy/envoy:${ENVOY_VERSION}
ports:
- 80:80
volumes:
- ./envoy.yml:/etc/envoy/envoy.yaml:ro
whoami:
image: traefik/whoami:${WHOAMI_VERSION}
tinyauth:
image: ghcr.io/steveiliop56/tinyauth:${TINYAUTH_VERSION}
command: --experimental.configfile=/data/config.yml
volumes:
- ./config.yml:/data/config.yml:ro

View File

@@ -0,0 +1,16 @@
services:
nginx:
image: nginx:${NGINX_VERSION}
ports:
- 80:80
volumes:
- ./nginx.conf:/etc/nginx/conf.d/default.conf:ro
whoami:
image: traefik/whoami:${WHOAMI_VERSION}
tinyauth:
image: ghcr.io/steveiliop56/tinyauth:${TINYAUTH_VERSION}
command: --experimental.configfile=/data/config.yml
volumes:
- ./config.yml:/data/config.yml:ro

View File

@@ -0,0 +1,28 @@
services:
traefik:
image: traefik:${TRAEFIK_VERSION}
command: |
--api.insecure=true
--providers.docker
--entryPoints.web.address=:80
ports:
- 80:80
volumes:
- /var/run/docker.sock:/var/run/docker.sock
whoami:
image: traefik/whoami:${WHOAMI_VERSION}
labels:
traefik.enable: true
traefik.http.routers.whoami.rule: Host(`${WHOAMI_HOST}`)
traefik.http.routers.whoami.middlewares: tinyauth
tinyauth:
image: ghcr.io/steveiliop56/tinyauth:${TINYAUTH_VERSION}
command: --experimental.configfile=/data/config.yml
volumes:
- ./config.yml:/data/config.yml:ro
labels:
traefik.enable: true
traefik.http.routers.tinyauth.rule: Host(`${TINYAUTH_HOST}`)
traefik.http.middlewares.tinyauth.forwardauth.address: http://tinyauth:3000/api/auth/traefik

105
integration/envoy.yml Normal file
View File

@@ -0,0 +1,105 @@
static_resources:
listeners:
- name: "listener_http"
address:
socket_address:
address: "0.0.0.0"
port_value: 80
filter_chains:
- filters:
- name: "envoy.filters.network.http_connection_manager"
typed_config:
"@type": "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager"
stat_prefix: "ingress_http"
use_remote_address: true
skip_xff_append: false
route_config:
name: "local_route"
virtual_hosts:
- name: "whoami_service"
domains: ["whoami.127.0.0.1.sslip.io"]
routes:
- match:
prefix: "/"
route:
cluster: "whoami"
- name: "tinyauth_service"
domains: ["tinyauth.127.0.0.1.sslip.io"]
typed_per_filter_config:
envoy.filters.http.ext_authz:
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute"
disabled: true
routes:
- match:
prefix: "/"
route:
cluster: "tinyauth"
http_filters:
- name: "envoy.filters.http.ext_authz"
typed_config:
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz"
transport_api_version: "v3"
http_service:
path_prefix: "/api/auth/envoy?path="
server_uri:
uri: "tinyauth:3000"
cluster: "tinyauth"
timeout: "0.25s"
authorization_request:
allowed_headers:
patterns:
- exact: "authorization"
- exact: "accept"
- exact: "cookie"
headers_to_add:
- key: "x-forwarded-proto"
value: "%REQ(:SCHEME)%"
- key: "x-forwarded-host"
value: "%REQ(:AUTHORITY)%"
- key: "x-forwarded-uri"
value: "%REQ(:PATH)%"
authorization_response:
allowed_upstream_headers:
patterns:
- prefix: "remote-"
allowed_client_headers:
patterns:
- exact: "set-cookie"
- exact: "location"
allowed_client_headers_on_success:
patterns:
- exact: "set-cookie"
- exact: "location"
failure_mode_allow: false
- name: "envoy.filters.http.router"
typed_config:
"@type": "type.googleapis.com/envoy.extensions.filters.http.router.v3.Router"
clusters:
- name: "whoami"
connect_timeout: "0.25s"
type: "logical_dns"
dns_lookup_family: "v4_only"
lb_policy: "round_robin"
load_assignment:
cluster_name: "whoami"
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: "whoami"
port_value: 80
- name: "tinyauth"
connect_timeout: "0.25s"
type: "logical_dns"
dns_lookup_family: "v4_only"
lb_policy: "round_robin"
load_assignment:
cluster_name: "tinyauth"
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: "tinyauth"
port_value: 3000

View File

@@ -0,0 +1,62 @@
package main
import (
"fmt"
"io"
"net/http"
"strings"
)
func testUnauthorized(client *http.Client) error {
req, err := http.NewRequest("GET", WhoamiURL, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
// nginx and envoy will throw us at the frontend
if resp.StatusCode != http.StatusUnauthorized && !strings.Contains(string(body), "<div id=\"root\"></div>") {
return fmt.Errorf("expected status code %d or to to contain '<div id=\"root\"></div>', got %d", http.StatusUnauthorized, resp.StatusCode)
}
return nil
}
func testLoggedIn(client *http.Client) error {
req, err := http.NewRequest("GET", WhoamiURL, nil)
if err != nil {
return err
}
req.SetBasicAuth(DefaultUsername, DefaultPassword)
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
return nil
}
func testACLAllowed(client *http.Client) error {
req, err := http.NewRequest("GET", WhoamiURL+"/allow", nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
return nil
}

191
integration/integration.go Normal file
View File

@@ -0,0 +1,191 @@
package main
import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"log/slog"
"net/http"
"os"
"os/exec"
"path"
"time"
)
var ProxiesToTest = []string{"traefik", "nginx", "envoy"}
const (
EnvFile = ".env"
ConfigFile = "config.yml"
)
const (
TinyauthURL = "http://tinyauth.127.0.0.1.sslip.io"
WhoamiURL = "http://whoami.127.0.0.1.sslip.io"
)
const (
DefaultUsername = "test"
DefaultPassword = "password"
)
func main() {
logFlag := flag.Bool("log", false, "enable stack logging")
flag.Parse()
rootFolder, err := os.Getwd()
if err != nil {
slog.Error("fail", "error", err)
os.Exit(1)
}
slog.Info("root folder", "folder", rootFolder)
integrationRoot := rootFolder
if _, err := os.Stat(path.Join(rootFolder, ".git")); err != nil {
if !errors.Is(err, os.ErrNotExist) {
slog.Error("fail", "error", err)
os.Exit(1)
}
} else {
integrationRoot = path.Join(rootFolder, "integration")
}
slog.Info("integration root", "folder", integrationRoot)
for _, proxy := range ProxiesToTest {
slog.Info("begin", "proxy", proxy)
compose := fmt.Sprintf("docker-compose.%s.yml", proxy)
if _, err := os.Stat(path.Join(integrationRoot, compose)); err != nil {
slog.Error("fail", "proxy", proxy, "error", err)
os.Exit(1)
}
if err := createInstanceAndRunTests(compose, *logFlag, proxy, integrationRoot); err != nil {
slog.Error("fail", "proxy", proxy, "error", err)
os.Exit(1)
}
slog.Info("end", "proxy", proxy)
}
}
func runTests(client *http.Client, name string) error {
if err := testUnauthorized(client); err != nil {
slog.Error("fail unauthorized test", "name", name)
return err
}
slog.Info("unauthorized test passed", "name", name)
if err := testLoggedIn(client); err != nil {
slog.Error("fail logged in test", "name", name)
return err
}
slog.Info("logged in test passed", "name", name)
if err := testACLAllowed(client); err != nil {
slog.Error("fail acl test", "name", name)
return err
}
slog.Info("acl test passed", "name", name)
return nil
}
func createInstanceAndRunTests(compose string, log bool, name string, integrationDir string) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
composeFile := path.Join(integrationDir, compose)
envFile := path.Join(integrationDir, EnvFile)
cmdArgs := []string{"compose", "-f", composeFile, "--env-file", envFile, "up", "--build", "--force-recreate", "--remove-orphans"}
cmd := exec.CommandContext(ctx, "docker", cmdArgs...)
if log {
setupCmdLogging(cmd)
}
if err := cmd.Start(); err != nil {
return err
}
defer func() {
cmd.Process.Signal(os.Interrupt)
cmd.Wait()
}()
slog.Info("instance created", "name", name)
if err := waitForHealthy(); err != nil {
return err
}
slog.Info("instance healthy", "name", name)
client := &http.Client{}
if err := runTests(client, name); err != nil {
return err
}
slog.Info("tests passed", "name", name)
cmd.Process.Signal(os.Interrupt)
if err := cmd.Wait(); cmd.ProcessState.ExitCode() != 130 && err != nil {
return err
}
return nil
}
func setupCmdLogging(cmd *exec.Cmd) {
stdout, _ := cmd.StdoutPipe()
stderr, _ := cmd.StderrPipe()
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
slog.Info("docker out", "stdout", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
slog.Error("docker out", "stderr", scanner.Text())
}
}()
}
func waitForHealthy() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
ticker := time.NewTicker(10 * time.Second)
client := http.Client{}
for {
select {
case <-ctx.Done():
return errors.New("tinyauth not healthy after 5 minutes")
case <-ticker.C:
req, err := http.NewRequestWithContext(ctx, "GET", TinyauthURL+"/api/healthz", nil)
if err != nil {
return err
}
res, err := client.Do(req)
if err != nil {
continue
}
if res.StatusCode == http.StatusOK {
return nil
}
}
}
}

36
integration/nginx.conf Normal file
View File

@@ -0,0 +1,36 @@
server {
listen 80;
listen [::]:80;
server_name tinyauth.127.0.0.1.sslip.io;
location / {
proxy_pass http://tinyauth:3000;
}
}
server {
listen 80;
listen [::]:80;
server_name whoami.127.0.0.1.sslip.io;
location / {
proxy_pass http://whoami;
auth_request /tinyauth;
error_page 401 = @tinyauth_login;
}
location /tinyauth {
internal;
proxy_pass http://tinyauth:3000/api/auth/nginx;
proxy_set_header x-original-url $scheme://$http_host$request_uri;
proxy_set_header x-forwarded-proto $scheme;
proxy_set_header x-forwarded-host $host;
proxy_set_header x-forwarded-uri $request_uri;
}
location @tinyauth_login {
return 302 http://tinyauth.127.0.0.1.sslip.io/login?redirect_uri=$scheme://$http_host$request_uri;
}
}

View File

@@ -28,7 +28,6 @@ type BootstrapApp struct {
sessionCookieName string
csrfCookieName string
redirectCookieName string
oauthSessionCookieName string
users []config.User
oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider
@@ -114,7 +113,6 @@ func (app *BootstrapApp) Setup() error {
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
// Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
@@ -192,12 +190,12 @@ func (app *BootstrapApp) Setup() error {
// Start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(queries)
go app.dbCleanup(queries)
// If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled {
tlog.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeatRoutine()
go app.heartbeat()
}
// If we have an socket path, bind to it
@@ -228,7 +226,7 @@ func (app *BootstrapApp) Setup() error {
return nil
}
func (app *BootstrapApp) heartbeatRoutine() {
func (app *BootstrapApp) heartbeat() {
ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop()
@@ -282,7 +280,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
}
}
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
ctx := context.Background()

View File

@@ -82,8 +82,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName,
}, apiRouter, app.services.authService)
}, apiRouter, app.services.authService, app.services.oauthBrokerService)
oauthController.SetupRoutes()

View File

@@ -58,16 +58,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{
Users: app.context.users,
OauthWhitelist: app.config.OAuth.Whitelist,
@@ -80,7 +70,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, dockerService, services.ldapService, queries, services.oauthBrokerService)
}, dockerService, services.ldapService, queries)
err = authService.Init()
@@ -90,6 +80,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.authService = authService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,

View File

@@ -73,7 +73,6 @@ var BuildTimestamp = "0000-00-00T00:00:00Z"
var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"
var OAuthSessionCookieName = "tinyauth-oauth"
// Main app config

View File

@@ -2,153 +2,152 @@ package controller_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
"gotest.tools/v3/assert"
)
func TestContextController(t *testing.T) {
controllerConfig := controller.ContextControllerConfig{
var contextControllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
{
Name: "Google",
ID: "google",
OAuth: true,
},
Title: "Tinyauth",
AppURL: "https://tinyauth.example.com",
CookieDomain: "example.com",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
OAuthAutoRedirect: "none",
},
Title: "Test App",
AppURL: "http://localhost:8080",
CookieDomain: "localhost",
ForgotPasswordMessage: "Contact admin to reset your password.",
BackgroundImage: "/assets/bg.jpg",
OAuthAutoRedirect: "google",
WarningsEnabled: true,
}
}
tests := []struct {
description string
middlewares []gin.HandlerFunc
expected string
path string
}{
{
description: "Ensure context controller returns app context",
middlewares: []gin.HandlerFunc{},
path: "/api/context/app",
expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{
Status: 200,
Message: "Success",
Providers: controllerConfig.Providers,
Title: controllerConfig.Title,
AppURL: controllerConfig.AppURL,
CookieDomain: controllerConfig.CookieDomain,
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
BackgroundImage: controllerConfig.BackgroundImage,
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
WarningsEnabled: controllerConfig.WarningsEnabled,
}
bytes, err := json.Marshal(expectedAppContextResponse)
if err != nil {
t.Fatalf("Failed to marshal expected response: %v", err)
}
return string(bytes)
}(),
},
{
description: "Ensure user context returns 401 when unauthorized",
middlewares: []gin.HandlerFunc{},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
Status: 401,
Message: "Unauthorized",
}
bytes, err := json.Marshal(expectedUserContextResponse)
if err != nil {
t.Fatalf("Failed to marshal expected response: %v", err)
}
return string(bytes)
}(),
},
{
description: "Ensure user context returns when authorized",
middlewares: []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
Provider: "local",
var contextCtrlTestContext = config.UserContext{
Username: "testuser",
Name: "testuser",
Email: "test@example.com",
IsLoggedIn: true,
})
},
},
path: "/api/context/user",
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: true,
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
IsBasicAuth: false,
OAuth: false,
Provider: "local",
}
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
OAuthSub: "",
}
bytes, err := json.Marshal(expectedUserContextResponse)
func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
tlog.NewSimpleLogger().Init()
if err != nil {
t.Fatalf("Failed to marshal expected response: %v", err)
}
return string(bytes)
}(),
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
// Setup
gin.SetMode(gin.TestMode)
router := gin.Default()
recorder := httptest.NewRecorder()
for _, middleware := range test.middlewares {
router.Use(middleware)
if middlewares != nil {
for _, m := range *middlewares {
router.Use(m)
}
}
group := router.Group("/api")
gin.SetMode(gin.TestMode)
contextController := controller.NewContextController(controllerConfig, group)
contextController.SetupRoutes()
ctrl := controller.NewContextController(contextControllerCfg, group)
ctrl.SetupRoutes()
recorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("Expected status code 200, got %d", recorder.Code)
}
if recorder.Body.String() != test.expected {
t.Fatalf("Expected response body %s, got %s", test.expected, recorder.Body.String())
}
})
}
return router, recorder
}
func TestAppContextHandler(t *testing.T) {
expectedRes := controller.AppContextResponse{
Status: 200,
Message: "Success",
Providers: contextControllerCfg.Providers,
Title: contextControllerCfg.Title,
AppURL: contextControllerCfg.AppURL,
CookieDomain: contextControllerCfg.CookieDomain,
ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
BackgroundImage: contextControllerCfg.BackgroundImage,
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
WarningsEnabled: contextControllerCfg.WarningsEnabled,
}
router, recorder := setupContextController(nil)
req := httptest.NewRequest("GET", "/api/context/app", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var ctrlRes controller.AppContextResponse
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
assert.NilError(t, err)
assert.DeepEqual(t, expectedRes, ctrlRes)
}
func TestUserContextHandler(t *testing.T) {
expectedRes := controller.UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
Username: contextCtrlTestContext.Username,
Name: contextCtrlTestContext.Name,
Email: contextCtrlTestContext.Email,
Provider: contextCtrlTestContext.Provider,
OAuth: contextCtrlTestContext.OAuth,
TotpPending: contextCtrlTestContext.TotpPending,
OAuthName: contextCtrlTestContext.OAuthName,
}
// Test with context
router, recorder := setupContextController(&[]gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &contextCtrlTestContext)
c.Next()
},
})
req := httptest.NewRequest("GET", "/api/context/user", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var ctrlRes controller.UserContextResponse
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
assert.NilError(t, err)
assert.DeepEqual(t, expectedRes, ctrlRes)
// Test no context
expectedRes = controller.UserContextResponse{
Status: 401,
Message: "Unauthorized",
IsLoggedIn: false,
}
router, recorder = setupContextController(nil)
req = httptest.NewRequest("GET", "/api/context/user", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
assert.NilError(t, err)
assert.DeepEqual(t, expectedRes, ctrlRes)
}

View File

@@ -19,7 +19,7 @@ func (controller *HealthController) SetupRoutes() {
func (controller *HealthController) healthHandler(c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"status": "ok",
"message": "Healthy",
})
}

View File

@@ -1,88 +0,0 @@
package controller_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
)
func TestHealthController(t *testing.T) {
tests := []struct {
description string
path string
method string
expected string
}{
{
description: "Ensure health endpoint returns 200 OK",
path: "/api/healthz",
method: "GET",
expected: func() string {
expectedHealthResponse := map[string]any{
"status": 200,
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
if err != nil {
t.Fatalf("Failed to marshal expected response: %v", err)
}
return string(bytes)
}(),
},
{
description: "Ensure health endpoint returns 200 OK for HEAD request",
path: "/api/healthz",
method: "HEAD",
expected: func() string {
expectedHealthResponse := map[string]any{
"status": 200,
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
if err != nil {
t.Fatalf("Failed to marshal expected response: %v", err)
}
return string(bytes)
}(),
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
healthController := controller.NewHealthController(group)
healthController.SetupRoutes()
recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("Expected status code 200, got %d", recorder.Code)
}
if recorder.Body.String() != test.expected {
t.Fatalf("Expected response body %s, got %s", test.expected, recorder.Body.String())
}
})
}
}

View File

@@ -22,7 +22,6 @@ type OAuthRequest struct {
type OAuthControllerConfig struct {
CSRFCookieName string
OAuthSessionCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
@@ -33,13 +32,15 @@ type OAuthController struct {
config OAuthControllerConfig
router *gin.RouterGroup
auth *service.AuthService
broker *service.OAuthBrokerService
}
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
return &OAuthController{
config: config,
router: router,
auth: auth,
broker: broker,
}
}
@@ -62,30 +63,21 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
service, exists := controller.broker.GetService(req.Provider)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
if !exists {
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.JSON(404, gin.H{
"status": 404,
"message": "Not Found",
})
return
}
authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
service.GenerateVerifier()
state := service.GenerateState()
authURL := service.GetAuthURL(state)
c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
redirectURI := c.Query("redirect_uri")
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
@@ -103,7 +95,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"message": "OK",
"url": authUrl,
"url": authURL,
})
}
@@ -120,17 +112,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
defer controller.auth.EndOAuthSession(sessionIdCookie)
state := c.Query("state")
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
@@ -144,15 +125,28 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
code := c.Query("code")
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
service, exists := controller.broker.GetService(req.Provider)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
if !exists {
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
err = service.VerifyCode(code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to verify OAuth code")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
user, err := controller.broker.GetUser(req.Provider)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email")
@@ -198,21 +192,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1)
}
service, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
sessionCookie := repository.Session{
Username: username,
Name: name,
Email: user.Email,
Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups),
OAuthName: service.Name(),
OAuthName: service.GetName(),
OAuthSub: user.Sub,
}

View File

@@ -85,7 +85,7 @@ func setupProxyController(t *testing.T, middlewares []gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300,
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}, dockerService, nil, queries, &service.OAuthBrokerService{})
}, dockerService, nil, queries)
// Controller
ctrl := controller.NewProxyController(controller.ProxyControllerConfig{

View File

@@ -71,7 +71,7 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300,
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}, nil, nil, queries, &service.OAuthBrokerService{})
}, nil, nil, queries)
// Controller
ctrl := controller.NewUserController(controller.UserControllerConfig{

View File

@@ -17,21 +17,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
Service *OAuthServiceImpl
ExpiresAt time.Time
}
type LdapGroupsCache struct {
Groups []string
Expires time.Time
@@ -62,30 +49,24 @@ type AuthService struct {
docker *DockerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
return &AuthService{
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
}
}
func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine()
return nil
}
@@ -572,177 +553,3 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
return false
}
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
}
sessionId, err := uuid.NewRandom()
if err != nil {
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
}
state := service.NewRandom()
verifier := service.NewRandom()
session := OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
return sessionId.String(), session, nil
}
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return "", err
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
token, err := (*session.Service).GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
auth.oauthMutex.Lock()
session.Token = token
auth.oauthMutex.Unlock()
return token, nil
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return config.Claims{}, err
}
if session.Token == nil {
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
}
return userinfo, nil
}
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
auth.oauthMutex.Unlock()
}
}
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}
func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
cleanupIds := make([]string, 0, OAuthCleanupCount)
for range OAuthCleanupCount {
oldestId := ""
oldestTime := int64(0)
for id, session := range auth.oauthPendingSessions {
if oldestTime == 0 {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
continue
}
if slices.Contains(cleanupIds, id) {
continue
}
if session.ExpiresAt.Unix() < oldestTime {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
}
}
cleanupIds = append(cleanupIds, oldestId)
}
for _, id := range cleanupIds {
delete(auth.oauthPendingSessions, id)
}
}
}

View File

@@ -0,0 +1,132 @@
package service
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/oauth2"
)
type GenericOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
insecureSkipVerify bool
userinfoUrl string
name string
}
func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService {
return &GenericOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
insecureSkipVerify: config.Insecure,
userinfoUrl: config.UserinfoURL,
name: config.Name,
}
}
func (generic *GenericOAuthService) Init() error {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: generic.insecureSkipVerify,
MinVersion: tls.VersionTLS12,
},
}
httpClient := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
generic.context = ctx
return nil
}
func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (generic *GenericOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
generic.verifier = verifier
return verifier
}
func (generic *GenericOAuthService) GetAuthURL(state string) string {
return generic.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.verifier))
}
func (generic *GenericOAuthService) VerifyCode(code string) error {
token, err := generic.config.Exchange(generic.context, code, oauth2.VerifierOption(generic.verifier))
if err != nil {
return err
}
generic.token = token
return nil
}
func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := generic.config.Client(generic.context, generic.token)
res, err := client.Get(generic.userinfoUrl)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
err = json.Unmarshal(body, &user)
if err != nil {
return user, err
}
return user, nil
}
func (generic *GenericOAuthService) GetName() string {
return generic.name
}

View File

@@ -0,0 +1,184 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
var GithubOAuthScopes = []string{"user:email", "read:user"}
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
}
type GithubOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
name string
}
func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService {
return &GithubOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GithubOAuthScopes,
Endpoint: endpoints.GitHub,
},
name: config.Name,
}
}
func (github *GithubOAuthService) Init() error {
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
github.context = ctx
return nil
}
func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (github *GithubOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
github.verifier = verifier
return verifier
}
func (github *GithubOAuthService) GetAuthURL(state string) string {
return github.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.verifier))
}
func (github *GithubOAuthService) VerifyCode(code string) error {
token, err := github.config.Exchange(github.context, code, oauth2.VerifierOption(github.verifier))
if err != nil {
return err
}
github.token = token
return nil
}
func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := github.config.Client(github.context, github.token)
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
var userInfo GithubUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return user, err
}
req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err = client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err = io.ReadAll(res.Body)
if err != nil {
return user, err
}
var emails GithubEmailResponse
err = json.Unmarshal(body, &emails)
if err != nil {
return user, err
}
for _, email := range emails {
if email.Primary {
user.Email = email.Email
break
}
}
if len(emails) == 0 {
return user, errors.New("no emails found")
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = emails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
}
func (github *GithubOAuthService) GetName() string {
return github.name
}

View File

@@ -0,0 +1,116 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
var GoogleOAuthScopes = []string{"openid", "email", "profile"}
type GoogleOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
name string
}
func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService {
return &GoogleOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GoogleOAuthScopes,
Endpoint: endpoints.Google,
},
name: config.Name,
}
}
func (google *GoogleOAuthService) Init() error {
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
google.context = ctx
return nil
}
func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (google *GoogleOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
google.verifier = verifier
return verifier
}
func (google *GoogleOAuthService) GetAuthURL(state string) string {
return google.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.verifier))
}
func (google *GoogleOAuthService) VerifyCode(code string) error {
token, err := google.config.Exchange(google.context, code, oauth2.VerifierOption(google.verifier))
if err != nil {
return err
}
google.token = token
return nil
}
func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := google.config.Client(google.context, google.token)
res, err := client.Get("https://openidconnect.googleapis.com/v1/userinfo")
if err != nil {
return config.Claims{}, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err
}
err = json.Unmarshal(body, &user)
if err != nil {
return config.Claims{}, err
}
user.PreferredUsername = strings.SplitN(user.Email, "@", 2)[0]
return user, nil
}
func (google *GoogleOAuthService) GetName() string {
return google.name
}

View File

@@ -1,48 +1,60 @@
package service
import (
"errors"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
type OAuthServiceImpl interface {
Name() string
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (config.Claims, error)
type OAuthService interface {
Init() error
GenerateState() string
GenerateVerifier() string
GetAuthURL(state string) string
VerifyCode(code string) error
Userinfo() (config.Claims, error)
GetName() string
}
type OAuthBrokerService struct {
services map[string]OAuthServiceImpl
services map[string]OAuthService
configs map[string]config.OAuthServiceConfig
}
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
services: make(map[string]OAuthServiceImpl),
services: make(map[string]OAuthService),
configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
switch name {
case "github":
service := NewGithubOAuthService(cfg)
broker.services[name] = service
case "google":
service := NewGoogleOAuthService(cfg)
broker.services[name] = service
default:
service := NewGenericOAuthService(cfg)
broker.services[name] = service
}
}
for name, service := range broker.services {
err := service.Init()
if err != nil {
tlog.App.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
return err
}
tlog.App.Info().Str("service", name).Msg("Initialized OAuth service")
}
return nil
}
@@ -55,7 +67,15 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
service, exists := broker.services[name]
return service, exists
}
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
oauthService, exists := broker.services[service]
if !exists {
return config.Claims{}, errors.New("oauth service not found")
}
return oauthService.Userinfo()
}

View File

@@ -1,102 +0,0 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/steveiliop56/tinyauth/internal/config"
)
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
}
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
return simpleReq[config.Claims](client, url, nil)
}
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user config.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
}
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
}
if len(userEmails) == 0 {
return user, errors.New("no emails found")
}
for _, email := range userEmails {
if email.Primary {
user.Email = email.Email
break
}
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = userEmails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
}
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
var decodedRes T
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return decodedRes, err
}
for key, value := range headers {
req.Header.Add(key, value)
}
res, err := client.Do(req)
if err != nil {
return decodedRes, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return decodedRes, err
}
err = json.Unmarshal(body, &decodedRes)
if err != nil {
return decodedRes, err
}
return decodedRes, nil
}

View File

@@ -1,23 +0,0 @@
package service
import (
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config)
}
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config).WithUserinfoExtractor(githubExtractor)
}

View File

@@ -1,78 +0,0 @@
package service
import (
"context"
"crypto/tls"
"net/http"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
)
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type OAuthService struct {
serviceCfg config.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor UserinfoExtractor
}
func NewOAuthService(config config.OAuthServiceConfig) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Insecure,
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
config: &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
ctx: ctx,
userinfoExtractor: defaultExtractor,
}
}
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
s.userinfoExtractor = extractor
return s
}
func (s *OAuthService) Name() string {
return s.serviceCfg.Name
}
func (s *OAuthService) NewRandom() string {
// The generate verifier function just creates a random string,
// so we can use it to generate a random state as well
random := oauth2.GenerateVerifier()
return random
}
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
}
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}