Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
436e6b60e9 chore(deps): bump oven/bun from 1.3.10-alpine to 1.3.11-alpine
Bumps oven/bun from 1.3.10-alpine to 1.3.11-alpine.

---
updated-dependencies:
- dependency-name: oven/bun
  dependency-version: 1.3.11-alpine
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-18 08:15:03 +00:00
24 changed files with 949 additions and 1147 deletions

View File

@@ -15,6 +15,8 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
ref: nightly
- name: Generate metadata - name: Generate metadata
id: metadata id: metadata

View File

@@ -1,5 +1,5 @@
# Site builder # Site builder
FROM oven/bun:1.3.10-alpine AS frontend-builder FROM oven/bun:1.3.11-alpine AS frontend-builder
WORKDIR /frontend WORKDIR /frontend

View File

@@ -1,5 +1,5 @@
# Site builder # Site builder
FROM oven/bun:1.3.10-alpine AS frontend-builder FROM oven/bun:1.3.11-alpine AS frontend-builder
WORKDIR /frontend WORKDIR /frontend

3
go.mod
View File

@@ -17,7 +17,6 @@ require (
github.com/mdp/qrterminal/v3 v3.2.1 github.com/mdp/qrterminal/v3 v3.2.1
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.11.1
github.com/traefik/paerser v0.2.2 github.com/traefik/paerser v0.2.2
github.com/weppos/publicsuffix-go v0.50.3 github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.49.0 golang.org/x/crypto v0.49.0
@@ -53,7 +52,6 @@ require (
github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/distribution/reference v0.6.0 // indirect github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
@@ -98,7 +96,6 @@ require (
github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect

View File

@@ -22,17 +22,16 @@ import (
type BootstrapApp struct { type BootstrapApp struct {
config config.Config config config.Config
context struct { context struct {
appUrl string appUrl string
uuid string uuid string
cookieDomain string cookieDomain string
sessionCookieName string sessionCookieName string
csrfCookieName string csrfCookieName string
redirectCookieName string redirectCookieName string
oauthSessionCookieName string users []config.User
users []config.User oauthProviders map[string]config.OAuthServiceConfig
oauthProviders map[string]config.OAuthServiceConfig configuredProviders []controller.Provider
configuredProviders []controller.Provider oidcClients []config.OIDCClientConfig
oidcClients []config.OIDCClientConfig
} }
services Services services Services
} }
@@ -114,7 +113,6 @@ func (app *BootstrapApp) Setup() error {
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
// Dumps // Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump") tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
@@ -192,12 +190,12 @@ func (app *BootstrapApp) Setup() error {
// Start db cleanup routine // Start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine") tlog.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanupRoutine(queries) go app.dbCleanup(queries)
// If analytics are not disabled, start heartbeat // If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
tlog.App.Debug().Msg("Starting heartbeat routine") tlog.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeatRoutine() go app.heartbeat()
} }
// If we have an socket path, bind to it // If we have an socket path, bind to it
@@ -228,7 +226,7 @@ func (app *BootstrapApp) Setup() error {
return nil return nil
} }
func (app *BootstrapApp) heartbeatRoutine() { func (app *BootstrapApp) heartbeat() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -282,7 +280,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
} }
} }
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background() ctx := context.Background()

View File

@@ -77,13 +77,12 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
contextController.SetupRoutes() contextController.SetupRoutes()
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL, AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie, SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName, CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName, RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain, CookieDomain: app.context.cookieDomain,
OAuthSessionCookieName: app.context.oauthSessionCookieName, }, apiRouter, app.services.authService, app.services.oauthBrokerService)
}, apiRouter, app.services.authService)
oauthController.SetupRoutes() oauthController.SetupRoutes()

View File

@@ -58,16 +58,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.accessControlService = accessControlsService services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(service.AuthServiceConfig{
Users: app.context.users, Users: app.context.users,
OauthWhitelist: app.config.OAuth.Whitelist, OauthWhitelist: app.config.OAuth.Whitelist,
@@ -80,7 +70,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
SessionCookieName: app.context.sessionCookieName, SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP, IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL, LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, dockerService, services.ldapService, queries, services.oauthBrokerService) }, dockerService, services.ldapService, queries)
err = authService.Init() err = authService.Init()
@@ -90,6 +80,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.authService = authService services.authService = authService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{ oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients, Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath, PrivateKeyPath: app.config.OIDC.PrivateKeyPath,

View File

@@ -73,7 +73,6 @@ var BuildTimestamp = "0000-00-00T00:00:00Z"
var SessionCookieName = "tinyauth-session" var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf" var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect" var RedirectCookieName = "tinyauth-redirect"
var OAuthSessionCookieName = "tinyauth-oauth"
// Main app config // Main app config

View File

@@ -2,131 +2,152 @@ package controller_test
import ( import (
"encoding/json" "encoding/json"
"net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/gin-gonic/gin"
"gotest.tools/v3/assert"
) )
func TestContextController(t *testing.T) { var contextControllerCfg = controller.ContextControllerConfig{
controllerConfig := controller.ContextControllerConfig{ Providers: []controller.Provider{
Providers: []controller.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
Title: "Tinyauth",
AppURL: "https://tinyauth.example.com",
CookieDomain: "example.com",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
OAuthAutoRedirect: "none",
WarningsEnabled: true,
}
tests := []struct {
description string
middlewares []gin.HandlerFunc
expected string
path string
}{
{ {
description: "Ensure context controller returns app context", Name: "Local",
middlewares: []gin.HandlerFunc{}, ID: "local",
path: "/api/context/app", OAuth: false,
expected: func() string {
expectedAppContextResponse := controller.AppContextResponse{
Status: 200,
Message: "Success",
Providers: controllerConfig.Providers,
Title: controllerConfig.Title,
AppURL: controllerConfig.AppURL,
CookieDomain: controllerConfig.CookieDomain,
ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage,
BackgroundImage: controllerConfig.BackgroundImage,
OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect,
WarningsEnabled: controllerConfig.WarningsEnabled,
}
bytes, err := json.Marshal(expectedAppContextResponse)
assert.NoError(t, err)
return string(bytes)
}(),
}, },
{ {
description: "Ensure user context returns 401 when unauthorized", Name: "Google",
middlewares: []gin.HandlerFunc{}, ID: "google",
path: "/api/context/user", OAuth: true,
expected: func() string {
expectedUserContextResponse := controller.UserContextResponse{
Status: 401,
Message: "Unauthorized",
}
bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err)
return string(bytes)
}(),
}, },
{ },
description: "Ensure user context returns when authorized", Title: "Test App",
middlewares: []gin.HandlerFunc{ AppURL: "http://localhost:8080",
func(c *gin.Context) { CookieDomain: "localhost",
c.Set("context", &config.UserContext{ ForgotPasswordMessage: "Contact admin to reset your password.",
Username: "johndoe", BackgroundImage: "/assets/bg.jpg",
Name: "John Doe", OAuthAutoRedirect: "google",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), WarningsEnabled: true,
Provider: "local", }
IsLoggedIn: true,
}) var contextCtrlTestContext = config.UserContext{
}, Username: "testuser",
}, Name: "testuser",
path: "/api/context/user", Email: "test@example.com",
expected: func() string { IsLoggedIn: true,
expectedUserContextResponse := controller.UserContextResponse{ IsBasicAuth: false,
Status: 200, OAuth: false,
Message: "Success", Provider: "local",
IsLoggedIn: true, TotpPending: false,
Username: "johndoe", OAuthGroups: "",
Name: "John Doe", TotpEnabled: false,
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), OAuthSub: "",
Provider: "local", }
}
bytes, err := json.Marshal(expectedUserContextResponse) func setupContextController(middlewares *[]gin.HandlerFunc) (*gin.Engine, *httptest.ResponseRecorder) {
assert.NoError(t, err) tlog.NewSimpleLogger().Init()
return string(bytes)
}(), // Setup
}, gin.SetMode(gin.TestMode)
} router := gin.Default()
recorder := httptest.NewRecorder()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { if middlewares != nil {
router := gin.Default() for _, m := range *middlewares {
router.Use(m)
for _, middleware := range test.middlewares { }
router.Use(middleware) }
}
group := router.Group("/api")
group := router.Group("/api")
gin.SetMode(gin.TestMode) ctrl := controller.NewContextController(contextControllerCfg, group)
ctrl.SetupRoutes()
contextController := controller.NewContextController(controllerConfig, group)
contextController.SetupRoutes() return router, recorder
}
recorder := httptest.NewRecorder()
func TestAppContextHandler(t *testing.T) {
request, err := http.NewRequest("GET", test.path, nil) expectedRes := controller.AppContextResponse{
assert.NoError(t, err) Status: 200,
Message: "Success",
router.ServeHTTP(recorder, request) Providers: contextControllerCfg.Providers,
Title: contextControllerCfg.Title,
assert.Equal(t, recorder.Result().StatusCode, http.StatusOK) AppURL: contextControllerCfg.AppURL,
assert.Equal(t, test.expected, recorder.Body.String()) CookieDomain: contextControllerCfg.CookieDomain,
}) ForgotPasswordMessage: contextControllerCfg.ForgotPasswordMessage,
} BackgroundImage: contextControllerCfg.BackgroundImage,
OAuthAutoRedirect: contextControllerCfg.OAuthAutoRedirect,
WarningsEnabled: contextControllerCfg.WarningsEnabled,
}
router, recorder := setupContextController(nil)
req := httptest.NewRequest("GET", "/api/context/app", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var ctrlRes controller.AppContextResponse
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
assert.NilError(t, err)
assert.DeepEqual(t, expectedRes, ctrlRes)
}
func TestUserContextHandler(t *testing.T) {
expectedRes := controller.UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: contextCtrlTestContext.IsLoggedIn,
Username: contextCtrlTestContext.Username,
Name: contextCtrlTestContext.Name,
Email: contextCtrlTestContext.Email,
Provider: contextCtrlTestContext.Provider,
OAuth: contextCtrlTestContext.OAuth,
TotpPending: contextCtrlTestContext.TotpPending,
OAuthName: contextCtrlTestContext.OAuthName,
}
// Test with context
router, recorder := setupContextController(&[]gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &contextCtrlTestContext)
c.Next()
},
})
req := httptest.NewRequest("GET", "/api/context/user", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var ctrlRes controller.UserContextResponse
err := json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
assert.NilError(t, err)
assert.DeepEqual(t, expectedRes, ctrlRes)
// Test no context
expectedRes = controller.UserContextResponse{
Status: 401,
Message: "Unauthorized",
IsLoggedIn: false,
}
router, recorder = setupContextController(nil)
req = httptest.NewRequest("GET", "/api/context/user", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &ctrlRes)
assert.NilError(t, err)
assert.DeepEqual(t, expectedRes, ctrlRes)
} }

View File

@@ -19,7 +19,7 @@ func (controller *HealthController) SetupRoutes() {
func (controller *HealthController) healthHandler(c *gin.Context) { func (controller *HealthController) healthHandler(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": "ok",
"message": "Healthy", "message": "Healthy",
}) })
} }

View File

@@ -1,71 +0,0 @@
package controller_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/stretchr/testify/assert"
)
func TestHealthController(t *testing.T) {
tests := []struct {
description string
path string
method string
expected string
}{
{
description: "Ensure health endpoint returns 200 OK",
path: "/api/healthz",
method: "GET",
expected: func() string {
expectedHealthResponse := map[string]any{
"status": 200,
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err)
return string(bytes)
}(),
},
{
description: "Ensure health endpoint returns 200 OK for HEAD request",
path: "/api/healthz",
method: "HEAD",
expected: func() string {
expectedHealthResponse := map[string]any{
"status": 200,
"message": "Healthy",
}
bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err)
return string(bytes)
}(),
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
router := gin.Default()
group := router.Group("/api")
gin.SetMode(gin.TestMode)
healthController := controller.NewHealthController(group)
healthController.SetupRoutes()
recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil)
assert.NoError(t, err)
router.ServeHTTP(recorder, request)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, test.expected, recorder.Body.String())
})
}
}

View File

@@ -21,25 +21,26 @@ type OAuthRequest struct {
} }
type OAuthControllerConfig struct { type OAuthControllerConfig struct {
CSRFCookieName string CSRFCookieName string
OAuthSessionCookieName string RedirectCookieName string
RedirectCookieName string SecureCookie bool
SecureCookie bool AppURL string
AppURL string CookieDomain string
CookieDomain string
} }
type OAuthController struct { type OAuthController struct {
config OAuthControllerConfig config OAuthControllerConfig
router *gin.RouterGroup router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
broker *service.OAuthBrokerService
} }
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
return &OAuthController{ return &OAuthController{
config: config, config: config,
router: router, router: router,
auth: auth, auth: auth,
broker: broker,
} }
} }
@@ -62,30 +63,21 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
sessionId, session, err := controller.auth.NewOAuthSession(req.Provider) service, exists := controller.broker.GetService(req.Provider)
if err != nil { if !exists {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session") tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.JSON(500, gin.H{ c.JSON(404, gin.H{
"status": 500, "status": 404,
"message": "Internal Server Error", "message": "Not Found",
}) })
return return
} }
authUrl, err := controller.auth.GetOAuthURL(sessionId) service.GenerateVerifier()
state := service.GenerateState()
if err != nil { authURL := service.GetAuthURL(state)
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
@@ -103,7 +95,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "OK", "message": "OK",
"url": authUrl, "url": authURL,
}) })
} }
@@ -120,17 +112,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
defer controller.auth.EndOAuthSession(sessionIdCookie)
state := c.Query("state") state := c.Query("state")
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
@@ -144,15 +125,28 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
code := c.Query("code") code := c.Query("code")
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code) service, exists := controller.broker.GetService(req.Provider)
if err != nil { if !exists {
tlog.App.Error().Err(err).Msg("Failed to exchange code for token") tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) err = service.VerifyCode(code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to verify OAuth code")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
user, err := controller.broker.GetUser(req.Provider)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
if user.Email == "" { if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email") tlog.App.Error().Msg("OAuth provider did not return an email")
@@ -198,21 +192,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
service, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: username, Username: username,
Name: name, Name: name,
Email: user.Email, Email: user.Email,
Provider: req.Provider, Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups), OAuthGroups: utils.CoalesceToString(user.Groups),
OAuthName: service.Name(), OAuthName: service.GetName(),
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }

View File

@@ -235,7 +235,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
if !ok { if !ok {
tlog.App.Error().Msg("Missing authorization header") tlog.App.Error().Msg("Missing authorization header")
c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.Header("www-authenticate", "basic")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"error": "invalid_client", "error": "invalid_client",
}) })

View File

@@ -2,465 +2,280 @@ package controller_test
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/steveiliop56/tinyauth/internal/bootstrap" "github.com/steveiliop56/tinyauth/internal/bootstrap"
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository" "github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service" "github.com/steveiliop56/tinyauth/internal/service"
"github.com/stretchr/testify/assert" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
"gotest.tools/v3/assert"
) )
var oidcServiceConfig = service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
"client1": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
ClientSecretFile: "",
TrustedRedirectURIs: []string{
"https://example.com/oauth/callback",
},
Name: "Client 1",
},
},
PrivateKeyPath: "/tmp/tinyauth_oidc_key",
PublicKeyPath: "/tmp/tinyauth_oidc_key.pub",
Issuer: "https://example.com",
SessionExpiry: 3600,
}
var oidcCtrlTestContext = config.UserContext{
Username: "test",
Name: "Test",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "ldap", // ldap in order to test the groups
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
OAuthName: "",
OAuthSub: "",
LdapGroups: "test1,test2",
}
// Test is not amazing, but it will confirm the OIDC server works
func TestOIDCController(t *testing.T) { func TestOIDCController(t *testing.T) {
oidcServiceCfg := service.OIDCServiceConfig{ tlog.NewSimpleLogger().Init()
Clients: map[string]config.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: "/tmp/tinyauth_testing_key.pem",
PublicKeyPath: "/tmp/tinyauth_testing_key.pub",
Issuer: "https://tinyauth.example.com",
SessionExpiry: 500,
}
controllerCfg := controller.OIDCControllerConfig{}
simpleCtx := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "test",
Name: "Test User",
Email: "test@example.com",
IsLoggedIn: true,
Provider: "local",
})
c.Next()
}
type testCase struct {
description string
middlewares []gin.HandlerFunc
run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder)
}
var tests []testCase
getTestByDescription := func(description string) (func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder), bool) {
for _, test := range tests {
if test.description == description {
return test.run, true
}
}
return nil, false
}
tests = []testCase{
{
description: "Ensure we can fetch the client",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/clients/some-client-id", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure API fails on non-existent client ID",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/clients/non-existent-client-id", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 404, recorder.Code)
},
},
{
description: "Ensure authorize fails with empty context",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/authorize", nil)
router.ServeHTTP(recorder, req)
var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
},
},
{
description: "Ensure authorize fails with an invalid param",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "some_unsupported_response_type",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
},
},
{
description: "Ensure authorize succeeds with valid params",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")
code := queryParams.Get("code")
assert.NotEmpty(t, code)
},
},
{
description: "Ensure token request fails with invalid grant",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := controller.TokenRequest{
GrantType: "invalid_grant",
Code: "",
RedirectURI: "https://test.example.com/callback",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, res["error"], "unsupported_grant_type")
},
},
{
description: "Ensure token endpoint accepts basic auth",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: "some-code",
RedirectURI: "https://test.example.com/callback",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Empty(t, recorder.Header().Get("www-authenticate"))
},
},
{
description: "Ensure token endpoint accepts form auth",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", "some-code")
form.Set("redirect_uri", "https://test.example.com/callback")
form.Set("client_id", "some-client-id")
form.Set("client_secret", "some-client-secret")
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Empty(t, recorder.Header().Get("www-authenticate"))
},
},
{
description: "Ensure token endpoint sets authenticate header when no auth is available",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: "some-code",
RedirectURI: "https://test.example.com/callback",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
authHeader := recorder.Header().Get("www-authenticate")
assert.Contains(t, authHeader, "Basic")
},
},
{
description: "Ensure we can get a token with a valid request",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
assert.True(t, found, "Authorize test not found")
authorizeTestRecorder := httptest.NewRecorder()
authorizeCodeTest(t, router, authorizeTestRecorder)
var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
assert.NotEmpty(t, code)
reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure we can renew the access token with the refresh token",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
assert.True(t, found, "Token test not found")
tokenRecorder := httptest.NewRecorder()
tokenTest(t, router, tokenRecorder)
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
_, ok := tokenRes["refresh_token"]
assert.True(t, ok, "Expected refresh token in response")
refreshToken := tokenRes["refresh_token"].(string)
assert.NotEmpty(t, refreshToken)
reqBody := controller.TokenRequest{
GrantType: "refresh_token",
RefreshToken: refreshToken,
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.NotEmpty(t, recorder.Header().Get("cache-control"))
assert.NotEmpty(t, recorder.Header().Get("pragma"))
assert.Equal(t, 200, recorder.Code)
var refreshRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
assert.NoError(t, err)
_, ok = refreshRes["access_token"]
assert.True(t, ok, "Expected access token in refresh response")
assert.NotEqual(t, tokenRes["refresh_token"].(string), refreshRes["access_token"].(string))
assert.NotEqual(t, tokenRes["access_token"].(string), refreshRes["access_token"].(string))
},
},
{
description: "Ensure token endpoint deletes code afer use",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
assert.True(t, found, "Authorize test not found")
authorizeTestRecorder := httptest.NewRecorder()
authorizeCodeTest(t, router, authorizeTestRecorder)
var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)
queryParams := url.Query()
code := queryParams.Get("code")
assert.NotEmpty(t, code)
reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
// Try to use the same code again
secondReq := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(string(reqBodyBytes)))
secondReq.Header.Set("Content-Type", "application/json")
secondReq.SetBasicAuth("some-client-id", "some-client-secret")
secondRecorder := httptest.NewRecorder()
router.ServeHTTP(secondRecorder, secondReq)
assert.Equal(t, 400, secondRecorder.Code)
var secondRes map[string]any
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
assert.NoError(t, err)
assert.Equal(t, secondRes["error"], "invalid_grant")
},
},
{
description: "Ensure userinfo forbids access with invalid access token",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer invalid-access-token")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)
},
},
{
description: "Ensure access token can be used to access protected resources",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
assert.True(t, found, "Token test not found")
tokenRecorder := httptest.NewRecorder()
tokenTest(t, router, tokenRecorder)
var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)
accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)
protectedReq := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
protectedReq.Header.Set("Authorization", "Bearer "+accessToken)
router.ServeHTTP(recorder, protectedReq)
assert.Equal(t, 200, recorder.Code)
var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)
_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
// We should not have an email claim since we didn't request it in the scope
_, ok = userInfoRes["email"]
assert.False(t, ok, "Did not expect email claim in userinfo response")
},
},
}
// Create an app instance
app := bootstrap.NewBootstrapApp(config.Config{}) app := bootstrap.NewBootstrapApp(config.Config{})
db, err := app.SetupDatabase("/tmp/tinyauth_test.db") // Get db
db, err := app.SetupDatabase("/tmp/tinyauth.db")
if err != nil { assert.NilError(t, err)
t.Fatalf("Failed to set up database: %v", err)
}
// Create queries
queries := repository.New(db) queries := repository.New(db)
oidcService := service.NewOIDCService(oidcServiceCfg, queries)
// Create a new OIDC Servicee
oidcService := service.NewOIDCService(oidcServiceConfig, queries)
err = oidcService.Init() err = oidcService.Init()
assert.NilError(t, err)
if err != nil { // Create test router
t.Fatalf("Failed to initialize OIDC service: %v", err) gin.SetMode(gin.TestMode)
router := gin.Default()
router.Use(func(c *gin.Context) {
c.Set("context", &oidcCtrlTestContext)
c.Next()
})
group := router.Group("/api")
// Register oidc controller
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, oidcService, group)
oidcController.SetupRoutes()
// Get redirect URL test
recorder := httptest.NewRecorder()
marshalled, err := json.Marshal(service.AuthorizeRequest{
Scope: "openid profile email groups",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://example.com/oauth/callback",
State: "some-state",
})
assert.NilError(t, err)
req, err := http.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(marshalled)))
assert.NilError(t, err)
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson := map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
redirect_uri, ok := resJson["redirect_uri"].(string)
assert.Assert(t, ok)
u, err := url.Parse(redirect_uri)
assert.NilError(t, err)
m, err := url.ParseQuery(u.RawQuery)
assert.NilError(t, err)
assert.Equal(t, m["state"][0], "some-state")
code := m["code"][0]
// Exchange code for token
recorder = httptest.NewRecorder()
params, err := query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
accessToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
_, ok = resJson["id_token"].(string)
assert.Assert(t, ok)
refreshToken, ok := resJson["refresh_token"].(string)
assert.Assert(t, ok)
expires_in, ok := resJson["expires_in"].(float64)
assert.Assert(t, ok)
assert.Equal(t, expires_in, float64(oidcServiceConfig.SessionExpiry))
// Ensure code is expired
recorder = httptest.NewRecorder()
params, err = query.Values(controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://example.com/oauth/callback",
})
assert.NilError(t, err)
req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
assert.NilError(t, err)
req.Header.Set("content-type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
// Test userinfo
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
_, ok = resJson["sub"].(string)
assert.Assert(t, ok)
name, ok := resJson["name"].(string)
assert.Assert(t, ok)
assert.Equal(t, name, oidcCtrlTestContext.Name)
email, ok := resJson["email"].(string)
assert.Assert(t, ok)
assert.Equal(t, email, oidcCtrlTestContext.Email)
preferred_username, ok := resJson["preferred_username"].(string)
assert.Assert(t, ok)
assert.Equal(t, preferred_username, oidcCtrlTestContext.Username)
// Not sure why this is failing, will look into it later
igroups, ok := resJson["groups"].([]any)
assert.Assert(t, ok)
groups := make([]string, len(igroups))
for i, group := range igroups {
groups[i], ok = group.(string)
assert.Assert(t, ok)
} }
for _, test := range tests { assert.DeepEqual(t, strings.Split(oidcCtrlTestContext.LdapGroups, ","), groups)
t.Run(test.description, func(t *testing.T) {
router := gin.Default()
for _, middleware := range test.middlewares { // Test refresh token
router.Use(middleware) recorder = httptest.NewRecorder()
}
group := router.Group("/api") params, err = query.Values(controller.TokenRequest{
gin.SetMode(gin.TestMode) GrantType: "refresh_token",
RefreshToken: refreshToken,
})
oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) assert.NilError(t, err)
oidcController.SetupRoutes()
recorder := httptest.NewRecorder() req, err = http.NewRequest("POST", "/api/oidc/token", strings.NewReader(params.Encode()))
test.run(t, router, recorder) assert.NilError(t, err)
})
} req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
resJson = map[string]any{}
err = json.Unmarshal(recorder.Body.Bytes(), &resJson)
assert.NilError(t, err)
newToken, ok := resJson["access_token"].(string)
assert.Assert(t, ok)
assert.Assert(t, newToken != accessToken)
// Ensure old token is invalid
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", accessToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
// Test new token
recorder = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/api/oidc/userinfo", nil)
assert.NilError(t, err)
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", newToken))
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
} }

View File

@@ -85,7 +85,7 @@ func setupProxyController(t *testing.T, middlewares []gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300, LoginTimeout: 300,
LoginMaxRetries: 3, LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session", SessionCookieName: "tinyauth-session",
}, dockerService, nil, queries, &service.OAuthBrokerService{}) }, dockerService, nil, queries)
// Controller // Controller
ctrl := controller.NewProxyController(controller.ProxyControllerConfig{ ctrl := controller.NewProxyController(controller.ProxyControllerConfig{

View File

@@ -71,7 +71,7 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300, LoginTimeout: 300,
LoginMaxRetries: 3, LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session", SessionCookieName: "tinyauth-session",
}, nil, nil, queries, &service.OAuthBrokerService{}) }, nil, nil, queries)
// Controller // Controller
ctrl := controller.NewUserController(controller.UserControllerConfig{ ctrl := controller.NewUserController(controller.UserControllerConfig{

View File

@@ -17,21 +17,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
) )
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
Service *OAuthServiceImpl
ExpiresAt time.Time
}
type LdapGroupsCache struct { type LdapGroupsCache struct {
Groups []string Groups []string
Expires time.Time Expires time.Time
@@ -58,34 +45,28 @@ type AuthServiceConfig struct {
} }
type AuthService struct { type AuthService struct {
config AuthServiceConfig config AuthServiceConfig
docker *DockerService docker *DockerService
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession loginMutex sync.RWMutex
oauthMutex sync.RWMutex ldapGroupsMutex sync.RWMutex
loginMutex sync.RWMutex ldap *LdapService
ldapGroupsMutex sync.RWMutex queries *repository.Queries
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
} }
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
return &AuthService{ return &AuthService{
config: config, config: config,
docker: docker, docker: docker,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession), ldap: ldap,
ldap: ldap, queries: queries,
queries: queries,
oauthBroker: oauthBroker,
} }
} }
func (auth *AuthService) Init() error { func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine()
return nil return nil
} }
@@ -572,177 +553,3 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
return false return false
} }
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
}
sessionId, err := uuid.NewRandom()
if err != nil {
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
}
state := service.NewRandom()
verifier := service.NewRandom()
session := OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
return sessionId.String(), session, nil
}
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return "", err
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
token, err := (*session.Service).GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
auth.oauthMutex.Lock()
session.Token = token
auth.oauthMutex.Unlock()
return token, nil
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return config.Claims{}, err
}
if session.Token == nil {
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
}
return userinfo, nil
}
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
auth.oauthMutex.Unlock()
}
}
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}
func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
cleanupIds := make([]string, 0, OAuthCleanupCount)
for range OAuthCleanupCount {
oldestId := ""
oldestTime := int64(0)
for id, session := range auth.oauthPendingSessions {
if oldestTime == 0 {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
continue
}
if slices.Contains(cleanupIds, id) {
continue
}
if session.ExpiresAt.Unix() < oldestTime {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
}
}
cleanupIds = append(cleanupIds, oldestId)
}
for _, id := range cleanupIds {
delete(auth.oauthPendingSessions, id)
}
}
}

View File

@@ -0,0 +1,132 @@
package service
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/oauth2"
)
type GenericOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
insecureSkipVerify bool
userinfoUrl string
name string
}
func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService {
return &GenericOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
insecureSkipVerify: config.Insecure,
userinfoUrl: config.UserinfoURL,
name: config.Name,
}
}
func (generic *GenericOAuthService) Init() error {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: generic.insecureSkipVerify,
MinVersion: tls.VersionTLS12,
},
}
httpClient := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
generic.context = ctx
return nil
}
func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (generic *GenericOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
generic.verifier = verifier
return verifier
}
func (generic *GenericOAuthService) GetAuthURL(state string) string {
return generic.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.verifier))
}
func (generic *GenericOAuthService) VerifyCode(code string) error {
token, err := generic.config.Exchange(generic.context, code, oauth2.VerifierOption(generic.verifier))
if err != nil {
return err
}
generic.token = token
return nil
}
func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := generic.config.Client(generic.context, generic.token)
res, err := client.Get(generic.userinfoUrl)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
err = json.Unmarshal(body, &user)
if err != nil {
return user, err
}
return user, nil
}
func (generic *GenericOAuthService) GetName() string {
return generic.name
}

View File

@@ -0,0 +1,184 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
var GithubOAuthScopes = []string{"user:email", "read:user"}
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
}
type GithubOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
name string
}
func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService {
return &GithubOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GithubOAuthScopes,
Endpoint: endpoints.GitHub,
},
name: config.Name,
}
}
func (github *GithubOAuthService) Init() error {
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
github.context = ctx
return nil
}
func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (github *GithubOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
github.verifier = verifier
return verifier
}
func (github *GithubOAuthService) GetAuthURL(state string) string {
return github.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.verifier))
}
func (github *GithubOAuthService) VerifyCode(code string) error {
token, err := github.config.Exchange(github.context, code, oauth2.VerifierOption(github.verifier))
if err != nil {
return err
}
github.token = token
return nil
}
func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := github.config.Client(github.context, github.token)
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
var userInfo GithubUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return user, err
}
req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err = client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err = io.ReadAll(res.Body)
if err != nil {
return user, err
}
var emails GithubEmailResponse
err = json.Unmarshal(body, &emails)
if err != nil {
return user, err
}
for _, email := range emails {
if email.Primary {
user.Email = email.Email
break
}
}
if len(emails) == 0 {
return user, errors.New("no emails found")
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = emails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
}
func (github *GithubOAuthService) GetName() string {
return github.name
}

View File

@@ -0,0 +1,116 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
var GoogleOAuthScopes = []string{"openid", "email", "profile"}
type GoogleOAuthService struct {
config oauth2.Config
context context.Context
token *oauth2.Token
verifier string
name string
}
func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService {
return &GoogleOAuthService{
config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GoogleOAuthScopes,
Endpoint: endpoints.Google,
},
name: config.Name,
}
}
func (google *GoogleOAuthService) Init() error {
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
google.context = ctx
return nil
}
func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128)
_, err := rand.Read(b)
if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state
}
func (google *GoogleOAuthService) GenerateVerifier() string {
verifier := oauth2.GenerateVerifier()
google.verifier = verifier
return verifier
}
func (google *GoogleOAuthService) GetAuthURL(state string) string {
return google.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.verifier))
}
func (google *GoogleOAuthService) VerifyCode(code string) error {
token, err := google.config.Exchange(google.context, code, oauth2.VerifierOption(google.verifier))
if err != nil {
return err
}
google.token = token
return nil
}
func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := google.config.Client(google.context, google.token)
res, err := client.Get("https://openidconnect.googleapis.com/v1/userinfo")
if err != nil {
return config.Claims{}, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err
}
err = json.Unmarshal(body, &user)
if err != nil {
return config.Claims{}, err
}
user.PreferredUsername = strings.SplitN(user.Email, "@", 2)[0]
return user, nil
}
func (google *GoogleOAuthService) GetName() string {
return google.name
}

View File

@@ -1,48 +1,60 @@
package service package service
import ( import (
"errors"
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog" "github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/oauth2"
) )
type OAuthServiceImpl interface { type OAuthService interface {
Name() string Init() error
NewRandom() string GenerateState() string
GetAuthURL(state string, verifier string) string GenerateVerifier() string
GetToken(code string, verifier string) (*oauth2.Token, error) GetAuthURL(state string) string
GetUserinfo(token *oauth2.Token) (config.Claims, error) VerifyCode(code string) error
Userinfo() (config.Claims, error)
GetName() string
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
services map[string]OAuthServiceImpl services map[string]OAuthService
configs map[string]config.OAuthServiceConfig configs map[string]config.OAuthServiceConfig
} }
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{ return &OAuthBrokerService{
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthService),
configs: configs, configs: configs,
} }
} }
func (broker *OAuthBrokerService) Init() error { func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs { for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists { switch name {
broker.services[name] = presetFunc(cfg) case "github":
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") service := NewGithubOAuthService(cfg)
} else { broker.services[name] = service
broker.services[name] = NewOAuthService(cfg) case "google":
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") service := NewGoogleOAuthService(cfg)
broker.services[name] = service
default:
service := NewGenericOAuthService(cfg)
broker.services[name] = service
} }
} }
for name, service := range broker.services {
err := service.Init()
if err != nil {
tlog.App.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
return err
}
tlog.App.Info().Str("service", name).Msg("Initialized OAuth service")
}
return nil return nil
} }
@@ -55,7 +67,15 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services return services
} }
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) { func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
service, exists := broker.services[name] service, exists := broker.services[name]
return service, exists return service, exists
} }
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
oauthService, exists := broker.services[service]
if !exists {
return config.Claims{}, errors.New("oauth service not found")
}
return oauthService.Userinfo()
}

View File

@@ -1,102 +0,0 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/steveiliop56/tinyauth/internal/config"
)
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
ID int `json:"id"`
}
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
return simpleReq[config.Claims](client, url, nil)
}
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user config.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
}
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
}
if len(userEmails) == 0 {
return user, errors.New("no emails found")
}
for _, email := range userEmails {
if email.Primary {
user.Email = email.Email
break
}
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = userEmails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
}
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
var decodedRes T
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return decodedRes, err
}
for key, value := range headers {
req.Header.Add(key, value)
}
res, err := client.Do(req)
if err != nil {
return decodedRes, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return decodedRes, err
}
err = json.Unmarshal(body, &decodedRes)
if err != nil {
return decodedRes, err
}
return decodedRes, nil
}

View File

@@ -1,23 +0,0 @@
package service
import (
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config)
}
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config).WithUserinfoExtractor(githubExtractor)
}

View File

@@ -1,78 +0,0 @@
package service
import (
"context"
"crypto/tls"
"net/http"
"time"
"github.com/steveiliop56/tinyauth/internal/config"
"golang.org/x/oauth2"
)
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type OAuthService struct {
serviceCfg config.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor UserinfoExtractor
}
func NewOAuthService(config config.OAuthServiceConfig) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Insecure,
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
config: &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
ctx: ctx,
userinfoExtractor: defaultExtractor,
}
}
func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService {
s.userinfoExtractor = extractor
return s
}
func (s *OAuthService) Name() string {
return s.serviceCfg.Name
}
func (s *OAuthService) NewRandom() string {
// The generate verifier function just creates a random string,
// so we can use it to generate a random state as well
random := oauth2.GenerateVerifier()
return random
}
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
}
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}