Compare commits

..

3 Commits

Author SHA1 Message Date
Stavros
2491d453cf feat: add oauth session impl in auth service 2026-03-21 12:55:22 +02:00
Stavros
1a1712eaeb wip 2026-03-21 12:42:05 +02:00
Stavros
dc3fa58d21 refactor: refactor proxy controller to handle proxy auth modules better (#714)
* wip

* fix: add extauthz to friendly error messages

* refactor: better module handling per proxy

* fix: get envoy host from the gin request

* tests: rework tests for proxy controller

* fix: review comments
2026-03-14 19:56:15 +02:00
10 changed files with 831 additions and 654 deletions

View File

@@ -1,9 +1,11 @@
package controller
import (
"errors"
"fmt"
"net/http"
"slices"
"net/url"
"regexp"
"strings"
"github.com/steveiliop56/tinyauth/internal/config"
@@ -15,12 +17,29 @@ import (
"github.com/google/go-querystring/query"
)
var SupportedProxies = []string{"nginx", "traefik", "caddy", "envoy"}
type AuthModuleType int
const (
AuthRequest AuthModuleType = iota
ExtAuthz
ForwardAuth
)
var BrowserUserAgentRegex = regexp.MustCompile("Chrome|Gecko|AppleWebKit|Opera|Edge")
type Proxy struct {
Proxy string `uri:"proxy" binding:"required"`
}
type ProxyContext struct {
Host string
Proto string
Path string
Method string
Type AuthModuleType
IsBrowser bool
}
type ProxyControllerConfig struct {
AppURL string
}
@@ -43,82 +62,30 @@ func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, a
func (controller *ProxyController) SetupRoutes() {
proxyGroup := controller.router.Group("/auth")
// There is a later check to control allowed methods per proxy
proxyGroup.Any("/:proxy", controller.proxyHandler)
}
func (controller *ProxyController) proxyHandler(c *gin.Context) {
var req Proxy
// Load proxy context based on the request type
proxyCtx, err := controller.getProxyContext(c)
err := c.BindUri(&req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to bind URI")
tlog.App.Warn().Err(err).Msg("Failed to get proxy context")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
"message": "Bad request",
})
return
}
if !slices.Contains(SupportedProxies, req.Proxy) {
tlog.App.Warn().Str("proxy", req.Proxy).Msg("Invalid proxy")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
// Only allow GET for non-envoy proxies.
// Envoy uses the original client method for the external auth request
// so we allow Any standard HTTP method for /api/auth/envoy
if req.Proxy != "envoy" && c.Request.Method != http.MethodGet {
tlog.App.Warn().Str("method", c.Request.Method).Msg("Invalid method for proxy")
c.Header("Allow", "GET")
c.JSON(405, gin.H{
"status": 405,
"message": "Method Not Allowed",
})
return
}
isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html")
if isBrowser {
tlog.App.Debug().Msg("Request identified as (most likely) coming from a browser")
} else {
tlog.App.Debug().Msg("Request identified as (most likely) coming from a non-browser client")
}
// We are not marking the URI as a required header because it may be missing
// and we only use it for the auth enabled check which will simply not match
// if the header is missing. For deployments like Kubernetes, we use the
// x-original-uri header instead.
uri, ok := controller.getHeader(c, "x-forwarded-uri")
if !ok {
originalUri, ok := controller.getHeader(c, "x-original-uri")
if ok {
uri = originalUri
}
}
host, ok := controller.requireHeader(c, "x-forwarded-host")
if !ok {
return
}
proto, ok := controller.requireHeader(c, "x-forwarded-proto")
if !ok {
return
}
tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context")
// Get acls
acls, err := controller.acls.GetAccessControls(host)
acls, err := controller.acls.GetAccessControls(proxyCtx.Host)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get access controls for resource")
controller.handleError(c, req, isBrowser)
controller.handleError(c, proxyCtx)
return
}
@@ -135,11 +102,11 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
authEnabled, err := controller.auth.IsAuthEnabled(uri, acls.Path)
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
controller.handleError(c, req, isBrowser)
controller.handleError(c, proxyCtx)
return
}
@@ -154,7 +121,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
if !controller.auth.CheckIP(acls.IP, clientIP) {
if req.Proxy == "nginx" || !isBrowser {
if !controller.useFriendlyError(proxyCtx) {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -163,7 +130,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP,
})
@@ -196,9 +163,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
if !userAllowed {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User not allowed to access resource")
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
if req.Proxy == "nginx" || !isBrowser {
if !controller.useFriendlyError(proxyCtx) {
c.JSON(403, gin.H{
"status": 403,
"message": "Forbidden",
@@ -207,7 +174,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
Resource: strings.Split(proxyCtx.Host, ".")[0],
})
if err != nil {
@@ -236,9 +203,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
if req.Proxy == "nginx" || !isBrowser {
if !controller.useFriendlyError(proxyCtx) {
c.JSON(403, gin.H{
"status": 403,
"message": "Forbidden",
@@ -247,7 +214,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true,
})
@@ -289,7 +256,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if req.Proxy == "nginx" || !isBrowser {
if !controller.useFriendlyError(proxyCtx) {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -298,7 +265,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
queries, err := query.Values(config.RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
if err != nil {
@@ -328,8 +295,8 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
}
}
func (controller *ProxyController) handleError(c *gin.Context, req Proxy, isBrowser bool) {
if req.Proxy == "nginx" || !isBrowser {
func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) {
if !controller.useFriendlyError(proxyCtx) {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
@@ -340,20 +307,195 @@ func (controller *ProxyController) handleError(c *gin.Context, req Proxy, isBrow
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
}
func (controller *ProxyController) requireHeader(c *gin.Context, header string) (string, bool) {
val, ok := controller.getHeader(c, header)
if !ok {
tlog.App.Error().Str("header", header).Msg("Header not found")
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return "", false
}
return val, true
}
func (controller *ProxyController) getHeader(c *gin.Context, header string) (string, bool) {
val := c.Request.Header.Get(header)
return val, strings.TrimSpace(val) != ""
}
func (controller *ProxyController) useFriendlyError(proxyCtx ProxyContext) bool {
return (proxyCtx.Type == ForwardAuth || proxyCtx.Type == ExtAuthz) && proxyCtx.IsBrowser
}
// Code below is inspired from https://github.com/authelia/authelia/blob/master/internal/handlers/handler_authz.go
// and thus it may be subject to Apache 2.0 License
func (controller *ProxyController) getForwardAuthContext(c *gin.Context) (ProxyContext, error) {
host, ok := controller.getHeader(c, "x-forwarded-host")
if !ok {
return ProxyContext{}, errors.New("x-forwarded-host not found")
}
uri, ok := controller.getHeader(c, "x-forwarded-uri")
if !ok {
return ProxyContext{}, errors.New("x-forwarded-uri not found")
}
proto, ok := controller.getHeader(c, "x-forwarded-proto")
if !ok {
return ProxyContext{}, errors.New("x-forwarded-proto not found")
}
// Normally we should only allow GET for forward auth but since it's a fallback
// for envoy we should allow everything, not a big deal
method := c.Request.Method
return ProxyContext{
Host: host,
Proto: proto,
Path: uri,
Method: method,
Type: ForwardAuth,
}, nil
}
func (controller *ProxyController) getAuthRequestContext(c *gin.Context) (ProxyContext, error) {
xOriginalUrl, ok := controller.getHeader(c, "x-original-url")
if !ok {
return ProxyContext{}, errors.New("x-original-url not found")
}
url, err := url.Parse(xOriginalUrl)
if err != nil {
return ProxyContext{}, err
}
host := url.Host
if strings.TrimSpace(host) == "" {
return ProxyContext{}, errors.New("host not found")
}
proto := url.Scheme
if strings.TrimSpace(proto) == "" {
return ProxyContext{}, errors.New("proto not found")
}
path := url.Path
method := c.Request.Method
return ProxyContext{
Host: host,
Proto: proto,
Path: path,
Method: method,
Type: AuthRequest,
}, nil
}
func (controller *ProxyController) getExtAuthzContext(c *gin.Context) (ProxyContext, error) {
// We hope for the someone to set the x-forwarded-proto header
proto, ok := controller.getHeader(c, "x-forwarded-proto")
if !ok {
return ProxyContext{}, errors.New("x-forwarded-proto not found")
}
// It sets the host to the original host, not the forwarded host
host := c.Request.Host
if strings.TrimSpace(host) == "" {
return ProxyContext{}, errors.New("host not found")
}
// We get the path from the query string
path := c.Query("path")
// For envoy we need to support every method
method := c.Request.Method
return ProxyContext{
Host: host,
Proto: proto,
Path: path,
Method: method,
Type: ExtAuthz,
}, nil
}
func (controller *ProxyController) determineAuthModules(proxy string) []AuthModuleType {
switch proxy {
case "traefik", "caddy":
return []AuthModuleType{ForwardAuth}
case "envoy":
return []AuthModuleType{ExtAuthz, ForwardAuth}
case "nginx":
return []AuthModuleType{AuthRequest, ForwardAuth}
default:
return []AuthModuleType{}
}
}
func (controller *ProxyController) getContextFromAuthModule(c *gin.Context, module AuthModuleType) (ProxyContext, error) {
switch module {
case ForwardAuth:
ctx, err := controller.getForwardAuthContext(c)
if err != nil {
return ProxyContext{}, err
}
return ctx, nil
case ExtAuthz:
ctx, err := controller.getExtAuthzContext(c)
if err != nil {
return ProxyContext{}, err
}
return ctx, nil
case AuthRequest:
ctx, err := controller.getAuthRequestContext(c)
if err != nil {
return ProxyContext{}, err
}
return ctx, nil
}
return ProxyContext{}, fmt.Errorf("unsupported auth module: %v", module)
}
func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext, error) {
var req Proxy
err := c.BindUri(&req)
if err != nil {
return ProxyContext{}, err
}
tlog.App.Debug().Msgf("Proxy: %v", req.Proxy)
authModules := controller.determineAuthModules(req.Proxy)
if len(authModules) == 0 {
return ProxyContext{}, fmt.Errorf("no auth modules supported for proxy: %v", req.Proxy)
}
var ctx ProxyContext
for _, module := range authModules {
tlog.App.Debug().Msgf("Trying auth module: %v", module)
ctx, err = controller.getContextFromAuthModule(c, module)
if err == nil {
tlog.App.Debug().Msgf("Auth module %v succeeded", module)
break
}
tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module)
}
if err != nil {
return ProxyContext{}, err
}
// We don't care if the header is empty, we will just assume it's not a browser
userAgent, _ := controller.getHeader(c, "user-agent")
isBrowser := BrowserUserAgentRegex.MatchString(userAgent)
if isBrowser {
tlog.App.Debug().Msg("Request identified as coming from a browser")
} else {
tlog.App.Debug().Msg("Request identified as coming from a non-browser client")
}
ctx.IsBrowser = isBrowser
return ctx, nil
}

View File

@@ -1,6 +1,7 @@
package controller_test
import (
"net/http"
"net/http/httptest"
"testing"
@@ -9,21 +10,26 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
"gotest.tools/v3/assert"
)
func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder, *service.AuthService) {
tlog.NewSimpleLogger().Init()
var loggedInCtx = config.UserContext{
Username: "test",
Name: "Test",
Email: "test@example.com",
IsLoggedIn: true,
Provider: "local",
}
func setupProxyController(t *testing.T, middlewares []gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
// Setup
gin.SetMode(gin.TestMode)
router := gin.Default()
if middlewares != nil {
for _, m := range *middlewares {
if len(middlewares) > 0 {
for _, m := range middlewares {
router.Use(m)
}
}
@@ -48,7 +54,13 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
assert.NilError(t, dockerService.Init())
// Access controls
accessControlsService := service.NewAccessControlsService(dockerService, map[string]config.App{})
accessControlsService := service.NewAccessControlsService(dockerService, map[string]config.App{
"whoami": {
Path: config.AppPath{
Allow: "/allow",
},
},
})
assert.NilError(t, accessControlsService.Init())
@@ -77,107 +89,214 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
// Controller
ctrl := controller.NewProxyController(controller.ProxyControllerConfig{
AppURL: "http://localhost:8080",
AppURL: "http://tinyauth.example.com",
}, group, accessControlsService, authService)
ctrl.SetupRoutes()
return router, recorder, authService
return router, recorder
}
// TODO: Needs tests for context middleware
func TestProxyHandler(t *testing.T) {
// Setup
router, recorder, _ := setupProxyController(t, nil)
// Test logged out user traefik/caddy (forward_auth)
router, recorder := setupProxyController(t, nil)
req, err := http.NewRequest("GET", "/api/auth/traefik", nil)
assert.NilError(t, err)
req.Header.Set("x-forwarded-host", "whoami.example.com")
req.Header.Set("x-forwarded-proto", "http")
req.Header.Set("x-forwarded-uri", "/")
// Test invalid proxy
req := httptest.NewRequest("GET", "/api/auth/invalidproxy", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusUnauthorized)
assert.Equal(t, 400, recorder.Code)
// Test logged out user nginx (auth_request)
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
assert.NilError(t, err)
req.Header.Set("x-original-url", "http://whoami.example.com/")
// Test invalid method for non-envoy proxy
recorder = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/auth/traefik", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusUnauthorized)
assert.Equal(t, 405, recorder.Code)
assert.Equal(t, "GET", recorder.Header().Get("Allow"))
// Test logged out user envoy (ext_authz)
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/envoy?path=/", nil)
assert.NilError(t, err)
req.Host = "whoami.example.com"
req.Header.Set("x-forwarded-proto", "http")
// Test logged out user (traefik/caddy)
recorder = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
req.Header.Set("X-Forwarded-Uri", "/somepath")
req.Header.Set("Accept", "text/html")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusUnauthorized)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location"))
// Test logged out user (envoy - POST method)
recorder = httptest.NewRecorder()
req = httptest.NewRequest("POST", "/api/auth/envoy", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
req.Header.Set("X-Forwarded-Uri", "/somepath")
req.Header.Set("Accept", "text/html")
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location"))
// Test logged out user (envoy - DELETE method)
recorder = httptest.NewRecorder()
req = httptest.NewRequest("DELETE", "/api/auth/envoy", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
req.Header.Set("X-Forwarded-Uri", "/somepath")
req.Header.Set("Accept", "text/html")
router.ServeHTTP(recorder, req)
assert.Equal(t, 307, recorder.Code)
assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location"))
// Test logged out user (nginx)
recorder = httptest.NewRecorder()
req = httptest.NewRequest("GET", "/api/auth/nginx", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
// we won't set X-Forwarded-Uri to test that the controller can work without it
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
// Test logged in user
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
// Test logged in user traefik/caddy (forward_auth)
router, recorder = setupProxyController(t, []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "testuser",
Name: "testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
OAuth: false,
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
})
c.Set("context", &loggedInCtx)
c.Next()
},
})
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "example.com")
req.Header.Set("X-Original-Uri", "/somepath") // Test with original URI for kubernetes ingress
req.Header.Set("Accept", "text/html")
req, err = http.NewRequest("GET", "/api/auth/traefik", nil)
assert.NilError(t, err)
req.Header.Set("x-forwarded-host", "whoami.example.com")
req.Header.Set("x-forwarded-proto", "http")
req.Header.Set("x-forwarded-uri", "/")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Equal(t, recorder.Code, http.StatusOK)
assert.Equal(t, "testuser", recorder.Header().Get("Remote-User"))
assert.Equal(t, "testuser", recorder.Header().Get("Remote-Name"))
assert.Equal(t, "testuser@example.com", recorder.Header().Get("Remote-Email"))
// Test logged in user nginx (auth_request)
router, recorder = setupProxyController(t, []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &loggedInCtx)
c.Next()
},
})
req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
assert.NilError(t, err)
req.Header.Set("x-original-url", "http://whoami.example.com/")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test logged in user envoy (ext_authz)
router, recorder = setupProxyController(t, []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &loggedInCtx)
c.Next()
},
})
req, err = http.NewRequest("GET", "/api/auth/envoy?path=/", nil)
assert.NilError(t, err)
req.Host = "whoami.example.com"
req.Header.Set("x-forwarded-proto", "http")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test ACL allow caddy/traefik (forward_auth)
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/traefik", nil)
assert.NilError(t, err)
req.Header.Set("x-forwarded-host", "whoami.example.com")
req.Header.Set("x-forwarded-proto", "http")
req.Header.Set("x-forwarded-uri", "/allow")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test ACL allow nginx
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
assert.NilError(t, err)
req.Header.Set("x-original-url", "http://whoami.example.com/allow")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test ACL allow envoy
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/envoy?path=/allow", nil)
assert.NilError(t, err)
req.Host = "whoami.example.com"
req.Header.Set("x-forwarded-proto", "http")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test traefik/caddy (forward_auth) without required headers
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/traefik", nil)
assert.NilError(t, err)
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusBadRequest)
// Test nginx (forward_auth) without required headers
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
assert.NilError(t, err)
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusBadRequest)
// Test envoy (forward_auth) without required headers
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/envoy", nil)
assert.NilError(t, err)
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusBadRequest)
// Test nginx (auth_request) with forward_auth fallback with ACLs
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/nginx", nil)
assert.NilError(t, err)
req.Header.Set("x-forwarded-host", "whoami.example.com")
req.Header.Set("x-forwarded-proto", "http")
req.Header.Set("x-forwarded-uri", "/allow")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test envoy (ext_authz) with forward_auth fallback with ACLs
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/envoy", nil)
assert.NilError(t, err)
req.Header.Set("x-forwarded-host", "whoami.example.com")
req.Header.Set("x-forwarded-proto", "http")
req.Header.Set("x-forwarded-uri", "/allow")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
// Test envoy (ext_authz) with empty path
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/envoy", nil)
assert.NilError(t, err)
req.Host = "whoami.example.com"
req.Header.Set("x-forwarded-proto", "http")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusUnauthorized)
// Ensure forward_auth fallback works with path (should ignore)
router, recorder = setupProxyController(t, nil)
req, err = http.NewRequest("GET", "/api/auth/traefik?path=/allow", nil)
assert.NilError(t, err)
req.Header.Set("x-forwarded-proto", "http")
req.Header.Set("x-forwarded-host", "whoami.example.com")
req.Header.Set("x-forwarded-uri", "/allow")
router.ServeHTTP(recorder, req)
assert.Equal(t, recorder.Code, http.StatusOK)
}

View File

@@ -17,8 +17,17 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
)
type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
Service *OAuthServiceImpl
ExpiresAt time.Time
}
type LdapGroupsCache struct {
Groups []string
Expires time.Time
@@ -45,17 +54,20 @@ type AuthServiceConfig struct {
}
type AuthService struct {
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{
config: config,
docker: docker,
@@ -63,10 +75,12 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
ldapGroupsCache: make(map[string]*LdapGroupsCache),
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
}
}
func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine()
return nil
}
@@ -553,3 +567,135 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
return false
}
func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) {
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
return "", fmt.Errorf("oauth service not found: %s", serviceName)
}
sessionId, err := uuid.NewRandom()
if err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
state := uuid.New().String()
verifier := uuid.New().String()
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
auth.oauthMutex.Unlock()
return sessionId.String(), nil
}
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
auth.oauthMutex.RLock()
defer auth.oauthMutex.RUnlock()
session, exists := auth.oauthPendingSessions[sessionId]
if !exists {
return "", fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
return "", fmt.Errorf("oauth session expired: %s", sessionId)
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return nil, fmt.Errorf("oauth session expired: %s", sessionId)
}
token, err := (*session.Service).GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
auth.oauthMutex.Lock()
session.Token = token
auth.oauthMutex.Unlock()
return token, nil
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
if session.Token == nil {
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
}
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return userinfo, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
}
}

View File

@@ -1,132 +0,0 @@
package service
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/oauth2"
)
type GenericOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
insecureSkipVerify bool
userinfoUrl string
name string
}
func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService {
return &GenericOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
insecureSkipVerify: config.Insecure,
userinfoUrl: config.UserinfoURL,
name: config.Name,
}
}
func (generic *GenericOAuthService) Init() error {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: generic.insecureSkipVerify,
MinVersion: tls.VersionTLS12,
},
}
httpClient := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
generic.context = ctx
return nil
}
func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (generic *GenericOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
generic.verifier = verifier
return verifier
}
func (generic *GenericOAuthService) GetAuthURL(state string) string {
return generic.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.verifier))
}
func (generic *GenericOAuthService) VerifyCode(code string) error {
token, err := generic.config.Exchange(generic.context, code, oauth2.VerifierOption(generic.verifier))
if err != nil {
return err
}
generic.token = token
return nil
}
func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := generic.config.Client(generic.context, generic.token)
res, err := client.Get(generic.userinfoUrl)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
err = json.Unmarshal(body, &user)
if err != nil {
return user, err
}
return user, nil
}
func (generic *GenericOAuthService) GetName() string {
return generic.name
}

View File

@@ -1,184 +0,0 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
var GithubOAuthScopes = []string{"user:email", "read:user"}
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
}
type GithubOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
name string
}
func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService {
return &GithubOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GithubOAuthScopes,
Endpoint: endpoints.GitHub,
},
name: config.Name,
}
}
func (github *GithubOAuthService) Init() error {
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
github.context = ctx
return nil
}
func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (github *GithubOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
github.verifier = verifier
return verifier
}
func (github *GithubOAuthService) GetAuthURL(state string) string {
return github.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.verifier))
}
func (github *GithubOAuthService) VerifyCode(code string) error {
token, err := github.config.Exchange(github.context, code, oauth2.VerifierOption(github.verifier))
if err != nil {
return err
}
github.token = token
return nil
}
func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := github.config.Client(github.context, github.token)
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
var userInfo GithubUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return user, err
}
req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err = client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err = io.ReadAll(res.Body)
if err != nil {
return user, err
}
var emails GithubEmailResponse
err = json.Unmarshal(body, &emails)
if err != nil {
return user, err
}
for _, email := range emails {
if email.Primary {
user.Email = email.Email
break
}
}
if len(emails) == 0 {
return user, errors.New("no emails found")
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = emails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
}
func (github *GithubOAuthService) GetName() string {
return github.name
}

View File

@@ -1,116 +0,0 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
var GoogleOAuthScopes = []string{"openid", "email", "profile"}
type GoogleOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
name string
}
func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService {
return &GoogleOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GoogleOAuthScopes,
Endpoint: endpoints.Google,
},
name: config.Name,
}
}
func (google *GoogleOAuthService) Init() error {
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
google.context = ctx
return nil
}
func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (google *GoogleOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
google.verifier = verifier
return verifier
}
func (google *GoogleOAuthService) GetAuthURL(state string) string {
return google.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.verifier))
}
func (google *GoogleOAuthService) VerifyCode(code string) error {
token, err := google.config.Exchange(google.context, code, oauth2.VerifierOption(google.verifier))
if err != nil {
return err
}
google.token = token
return nil
}
func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := google.config.Client(google.context, google.token)
res, err := client.Get("https://openidconnect.googleapis.com/v1/userinfo")
if err != nil {
return config.Claims{}, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err
}
err = json.Unmarshal(body, &user)
if err != nil {
return config.Claims{}, err
}
user.PreferredUsername = strings.SplitN(user.Email, "@", 2)[0]
return user, nil
}
func (google *GoogleOAuthService) GetName() string {
return google.name
}

View File

@@ -1,60 +1,48 @@
package service
import (
"errors"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
type OAuthService interface {
Init() error
GenerateState() string
GenerateVerifier() string
GetAuthURL(state string) string
VerifyCode(code string) error
Userinfo() (config.Claims, error)
GetName() string
type OAuthServiceImpl interface {
Name() string
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (config.Claims, error)
}
type OAuthBrokerService struct {
services map[string]OAuthService
services map[string]OAuthServiceImpl
configs map[string]config.OAuthServiceConfig
}
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
services: make(map[string]OAuthService),
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
switch name {
case "github":
service := NewGithubOAuthService(cfg)
broker.services[name] = service
case "google":
service := NewGoogleOAuthService(cfg)
broker.services[name] = service
default:
service := NewGenericOAuthService(cfg)
broker.services[name] = service
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
}
}
for name, service := range broker.services {
err := service.Init()
if err != nil {
tlog.App.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
return err
}
tlog.App.Info().Str("service", name).Msg("Initialized OAuth service")
}
return nil
}
@@ -67,15 +55,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
service, exists := broker.services[name]
return service, exists
}
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
oauthService, exists := broker.services[service]
if !exists {
return config.Claims{}, errors.New("oauth service not found")
}
return oauthService.Userinfo()
}

View File

@@ -0,0 +1,121 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
)
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
}
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
var claims config.Claims
res, err := client.Get(url)
if err != nil {
return config.Claims{}, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return config.Claims{}, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err
}
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
err = json.Unmarshal(body, &claims)
if err != nil {
return config.Claims{}, err
}
return claims, nil
}
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user config.Claims
userInfo, err := githubRequest[GithubUserInfoResponse](client, "https://api.github.com/user")
if err != nil {
return config.Claims{}, err
}
userEmails, err := githubRequest[GithubEmailResponse](client, "https://api.github.com/user/emails")
if err != nil {
return config.Claims{}, err
}
if len(userEmails) == 0 {
return user, errors.New("no emails found")
}
for _, email := range userEmails {
if email.Primary {
user.Email = email.Email
break
}
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = userEmails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
}
func githubRequest[T any](client *http.Client, url string) (T, error) {
var githubRes T
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return githubRes, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err := client.Do(req)
if err != nil {
return githubRes, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return githubRes, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return githubRes, err
}
err = json.Unmarshal(body, &githubRes)
if err != nil {
return githubRes, err
}
return githubRes, nil
}

View File

@@ -0,0 +1,23 @@
package service
import (
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config)
}
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config).WithUserinfoExtractor(githubExtractor)
}

View File

@@ -0,0 +1,78 @@
package service
import (
"context"
"crypto/tls"
"net/http"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
)
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type OAuthService struct {
serviceCfg config.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor UserinfoExtractor
}
func NewOAuthService(config config.OAuthServiceConfig) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Insecure,
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
config: &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
ctx: ctx,
userinfoExtractor: defaultExtractor,
}
}
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
s.userinfoExtractor = extractor
return s
}
func (s *OAuthService) Name() string {
return s.serviceCfg.Name
}
func (s *OAuthService) NewRandom() string {
// The generate verifier function just creates a random string,
// so we can use it to generate a random state as well
random := oauth2.GenerateVerifier()
return random
}
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
}
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}