Compare commits

..

2 Commits

Author SHA1 Message Date
Stavros a7f5374acc refactor: use one struct for service deps 2026-06-13 17:14:47 +03:00
Stavros a0e74cd5f2 refactor: move oidc handling to backend and add support for oidc post (#923)
Co-authored-by: Claude <noreply@anthropic.com>
2026-06-13 16:45:12 +03:00
50 changed files with 289 additions and 945 deletions
-2
View File
@@ -206,8 +206,6 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN= TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication. # Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD= TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches. # Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN= TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections. # Allow insecure LDAP connections.
+1 -1
View File
@@ -15,7 +15,7 @@ export const useRedirectUri = (
let isAllowedProto = false; let isAllowedProto = false;
let isHttpsDowngrade = false; let isHttpsDowngrade = false;
if (redirect_uri === undefined) { if (!redirect_uri) {
return { return {
valid: isValid, valid: isValid,
trusted: isTrusted, trusted: isTrusted,
+1 -5
View File
@@ -110,11 +110,7 @@ export const AuthorizePage = () => {
}, },
}); });
if ( if (!isOidc || !screenParams.oidc_ticket || !screenParams.oidc_scope) {
!isOidc ||
screenParams.oidc_ticket === undefined ||
screenParams.oidc_scope === undefined
) {
return ( return (
<Navigate <Navigate
to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`} to={`/error?error=${encodeURIComponent(t("authorizeErrorInvalidParams"))}`}
+1 -1
View File
@@ -11,7 +11,7 @@ export const ErrorPage = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const { search } = useLocation(); const { search } = useLocation();
const searchParams = new URLSearchParams(search); const searchParams = new URLSearchParams(search);
const error = searchParams.get("error") ?? ""; const error = searchParams.get("error") || "";
return ( return (
<Card> <Card>
+3 -1
View File
@@ -168,7 +168,8 @@ export const LoginPage = () => {
!auth.authenticated && !auth.authenticated &&
isOauthAutoRedirect && isOauthAutoRedirect &&
!hasAutoRedirectedRef.current && !hasAutoRedirectedRef.current &&
screenParams.login_for !== undefined screenParams.redirect_uri &&
screenParams.login_for
) { ) {
hasAutoRedirectedRef.current = true; hasAutoRedirectedRef.current = true;
oauthMutate(oauth.autoRedirect); oauthMutate(oauth.autoRedirect);
@@ -180,6 +181,7 @@ export const LoginPage = () => {
oauth.autoRedirect, oauth.autoRedirect,
isOauthAutoRedirect, isOauthAutoRedirect,
screenParams.login_for, screenParams.login_for,
screenParams.redirect_uri,
]); ]);
useEffect(() => { useEffect(() => {
+21 -30
View File
@@ -67,24 +67,15 @@ func run() error {
Overlay: map[string][]byte{outPath: stub}, Overlay: map[string][]byte{outPath: stub},
} }
repoPkgPath := parentPkg(*driverPkg) driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)
if err != nil { if err != nil {
return fmt.Errorf("load packages: %w", err) return fmt.Errorf("load driver package: %w", err)
} }
driverTypePkg, ok := pkgs[*driverPkg] repoPkgPath := parentPkg(*driverPkg)
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if !ok { if err != nil {
return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg) return fmt.Errorf("load repo package: %w", err)
}
repoTypePkg, ok := pkgs[repoPkgPath]
if !ok {
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
} }
if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
@@ -115,25 +106,25 @@ func run() error {
return nil return nil
} }
// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package, // loadOnePkg loads a single package via cfg and returns its *types.Package,
// or an error if any package fails to load or has type errors. // or an error if the package fails to load or has type errors.
func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) { func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
pkgs, err := packages.Load(cfg, importPaths...) pkgs, err := packages.Load(cfg, importPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("load %v: %w", importPaths, err) return nil, fmt.Errorf("load %s: %w", importPath, err)
} }
out := make(map[string]*types.Package) if len(pkgs) != 1 {
for _, pkg := range pkgs { return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
if len(pkg.Errors) > 0 { }
msgs := make([]string, len(pkg.Errors)) pkg := pkgs[0]
for i, e := range pkg.Errors { if len(pkg.Errors) > 0 {
msgs[i] = e.Error() msgs := make([]string, len(pkg.Errors))
} for i, e := range pkg.Errors {
return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n ")) msgs[i] = e.Error()
} }
out[pkg.PkgPath] = pkg.Types return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
} }
return out, nil return pkg.Types, nil
} }
// parentPkg returns the parent import path (everything before the last /). // parentPkg returns the parent import path (everything before the last /).
@@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
@@ -1 +0,0 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -1 +0,0 @@
DROP TABLE IF EXISTS "oidc_consent";
@@ -1,7 +0,0 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
+10 -23
View File
@@ -31,24 +31,10 @@ import (
// 2. HTTP server listeners - ding.RingNormal // 2. HTTP server listeners - ding.RingNormal
// 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor // 3. Networking layers, user and label providers (e.g. ailscale service, kubernetes service) - ding.RingMajor
// 4. Database connection - ding.RingCritical // 4. Database connection - ding.RingCritical
type Services struct {
accessControlService *service.AccessControlsService
authService *service.AuthService
dockerService *service.DockerService
kubernetesService *service.KubernetesService
ldapService *service.LdapService
oauthBrokerService *service.OAuthBrokerService
oidcService *service.OIDCService
tailscaleService *service.TailscaleService
policyEngine *service.PolicyEngine
}
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers model.RuntimeHelpers services service.Services
services Services
log *logger.Logger log *logger.Logger
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@@ -57,6 +43,9 @@ type BootstrapApp struct {
db *sql.DB db *sql.DB
ding *ding.Ding ding *ding.Ding
listeners []Listener listeners []Listener
deps struct {
service *service.ServiceDependencies
}
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config model.Config) *BootstrapApp {
@@ -186,8 +175,9 @@ func (app *BootstrapApp) Setup() error {
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
// database // database
store, err := app.SetupStore() store, err := app.SetupStore()
@@ -233,7 +223,7 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name return configuredProviders[i].Name < configuredProviders[j].Name
}) })
if app.services.authService.LocalAuthConfigured() { if app.services.AuthService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "Local", Name: "Local",
ID: "local", ID: "local",
@@ -241,7 +231,7 @@ func (app *BootstrapApp) Setup() error {
}) })
} }
if app.services.authService.LDAPAuthConfigured() { if app.services.AuthService.LDAPAuthConfigured() {
configuredProviders = append(configuredProviders, model.Provider{ configuredProviders = append(configuredProviders, model.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
@@ -260,13 +250,10 @@ func (app *BootstrapApp) Setup() error {
app.runtime.ConfiguredProviders = configuredProviders app.runtime.ConfiguredProviders = configuredProviders
// throw in tailscale if it's configured just before setting up the controllers // throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil { if app.services.TailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.TailscaleService.GetHostname())
} }
// runtime helpers
app.helpers.GetCookieDomain = app.getCookieDomain
// setup router // setup router
err = app.setupRouter() err = app.setupRouter()
-55
View File
@@ -1,55 +0,0 @@
package bootstrap
import (
"context"
"errors"
"fmt"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
// Not really the best place for the helpers to be but it works because bootstrap app provides
// them with everything they need
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
cookieDomain := app.runtime.CookieDomain
if app.isTailscaleRequest(ctx, ip) {
if app.services.tailscaleService == nil {
return "", errors.New("tailscale service is not configured")
}
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
if err != nil {
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
cookieDomain = tsCookieDomain
}
if app.config.Auth.SubdomainsEnabled {
cookieDomain = "." + cookieDomain
}
return cookieDomain, nil
}
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
if app.services.tailscaleService == nil {
return false
}
whois, err := app.services.tailscaleService.Whois(ctx, ip)
if err != nil {
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
return false
}
if whois == nil {
return false
}
return true
}
+10 -10
View File
@@ -40,7 +40,7 @@ func (app *BootstrapApp) setupRouter() error {
} }
} }
contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService, app.services.tailscaleService) contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.AuthService, app.services.OAuthBrokerService, app.services.TailscaleService)
engine.Use(contextMiddleware.Middleware()) engine.Use(contextMiddleware.Middleware())
uiMiddleware, err := middleware.NewUIMiddleware() uiMiddleware, err := middleware.NewUIMiddleware()
@@ -58,13 +58,13 @@ func (app *BootstrapApp) setupRouter() error {
apiRouter := engine.Group("/api") apiRouter := engine.Group("/api")
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, app.helpers, apiRouter, app.services.authService) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.AuthService)
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, app.helpers, app.config, apiRouter, &engine.RouterGroup) controller.NewOIDCController(app.log, app.services.OIDCService, app.runtime, apiRouter, &engine.RouterGroup)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.AccessControlService, app.services.AuthService, app.services.PolicyEngine)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.AuthService)
controller.NewResourcesController(app.config, &engine.RouterGroup) controller.NewResourcesController(app.config, &engine.RouterGroup)
controller.NewHealthController(apiRouter) controller.NewHealthController(apiRouter)
controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) controller.NewWellKnownController(app.services.OIDCService, &engine.RouterGroup)
app.router = engine app.router = engine
return nil return nil
@@ -99,7 +99,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{} l := []Listener{}
if !app.config.Server.ConcurrentListenersEnabled { if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil { if app.services.TailscaleService != nil {
l = append(l, ListenerTailscale) l = append(l, ListenerTailscale)
return l return l
} }
@@ -117,7 +117,7 @@ func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l = append(l, ListenerUnix) l = append(l, ListenerUnix)
} }
if app.services.tailscaleService != nil { if app.services.TailscaleService != nil {
l = append(l, ListenerTailscale) l = append(l, ListenerTailscale)
} }
@@ -186,9 +186,9 @@ func (app *BootstrapApp) serveUnix(ctx context.Context) error {
} }
func (app *BootstrapApp) serveTailscale(ctx context.Context) error { func (app *BootstrapApp) serveTailscale(ctx context.Context) error {
app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) app.log.App.Info().Msgf("Starting Tailscale server on %s", fmt.Sprintf("https://%s", app.services.TailscaleService.GetHostname()))
listener, err := app.services.tailscaleService.CreateListener() listener, err := app.services.TailscaleService.CreateListener()
if err != nil { if err != nil {
return fmt.Errorf("failed to create tailscale listener: %w", err) return fmt.Errorf("failed to create tailscale listener: %w", err)
+30 -18
View File
@@ -8,13 +8,23 @@ import (
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService, err := service.NewLdapService(app.log, app.config, app.ding) app.deps.service = &service.ServiceDependencies{
Log: app.log,
StaticConfig: &app.config,
RuntimeConfig: &app.runtime,
Ctx: app.ctx,
Ding: app.ding,
Services: &app.services,
Queries: &app.queries,
}
ldap, err := service.NewLdapService(app.deps.service)
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it")
} }
app.services.ldapService = ldapService app.services.LDAPService = ldap
labelProvider, err := app.getLabelProvider() labelProvider, err := app.getLabelProvider()
@@ -22,16 +32,18 @@ func (app *BootstrapApp) setupServices() error {
return fmt.Errorf("failed to initialize label provider: %w", err) return fmt.Errorf("failed to initialize label provider: %w", err)
} }
tailscaleService, err := service.NewTailscaleService(app.log, app.config, app.ctx, app.ding) app.deps.service.LabelProvider = labelProvider
tailscaleService, err := service.NewTailscaleService(app.deps.service)
if err != nil { if err != nil {
app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it") app.log.App.Warn().Err(err).Msg("Failed to initialize Tailscale connection, will continue without it")
} }
app.services.tailscaleService = tailscaleService app.services.TailscaleService = tailscaleService
accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider) accessControlsService := service.NewAccessControlsService(app.deps.service)
app.services.accessControlService = accessControlsService app.services.AccessControlService = accessControlsService
err = app.setupPolicyEngine() err = app.setupPolicyEngine()
@@ -39,19 +51,19 @@ func (app *BootstrapApp) setupServices() error {
return fmt.Errorf("failed to initialize policy engine: %w", err) return fmt.Errorf("failed to initialize policy engine: %w", err)
} }
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) oauthBrokerService := service.NewOAuthBrokerService(app.deps.service)
app.services.oauthBrokerService = oauthBrokerService app.services.OAuthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.helpers, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine) authService := service.NewAuthService(app.deps.service)
app.services.authService = authService app.services.AuthService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding) oidcService, err := service.NewOIDCService(app.deps.service)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize oidc service: %w", err) return fmt.Errorf("failed to initialize oidc service: %w", err)
} }
app.services.oidcService = oidcService app.services.OIDCService = oidcService
return nil return nil
} }
@@ -69,19 +81,19 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, app.ding) kubernetesService, err := service.NewKubernetesService(app.deps.service)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err) return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err)
} }
app.services.kubernetesService = kubernetesService app.services.KubernetesService = kubernetesService
return kubernetesService, nil return kubernetesService, nil
} }
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService, err := service.NewDockerService(app.log, app.ctx, app.ding) dockerService, err := service.NewDockerService(app.deps.service)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize docker service: %w", err) return nil, fmt.Errorf("failed to initialize docker service: %w", err)
@@ -94,7 +106,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
return nil, nil return nil, nil
} }
app.services.dockerService = dockerService app.services.DockerService = dockerService
return dockerService, nil return dockerService, nil
default: default:
return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider) return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider)
@@ -102,7 +114,7 @@ func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) {
} }
func (app *BootstrapApp) setupPolicyEngine() error { func (app *BootstrapApp) setupPolicyEngine() error {
policyEngine, err := service.NewPolicyEngine(app.config, app.log) policyEngine, err := service.NewPolicyEngine(app.deps.service)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize policy engine: %w", err) return fmt.Errorf("failed to initialize policy engine: %w", err)
@@ -129,6 +141,6 @@ func (app *BootstrapApp) setupPolicyEngine() error {
Config: app.config, Config: app.config,
}) })
app.services.policyEngine = policyEngine app.services.PolicyEngine = policyEngine
return nil return nil
} }
+10 -25
View File
@@ -24,7 +24,6 @@ type OAuthController struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers model.RuntimeHelpers
auth *service.AuthService auth *service.AuthService
} }
@@ -32,7 +31,6 @@ func NewOAuthController(
log *logger.Logger, log *logger.Logger,
config model.Config, config model.Config,
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
helpers model.RuntimeHelpers,
router *gin.RouterGroup, router *gin.RouterGroup,
auth *service.AuthService, auth *service.AuthService,
) *OAuthController { ) *OAuthController {
@@ -40,7 +38,6 @@ func NewOAuthController(
log: log, log: log,
config: config, config: config,
runtime: runtimeConfig, runtime: runtimeConfig,
helpers: helpers,
auth: auth, auth: auth,
} }
@@ -108,18 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP()) c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -149,15 +135,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP()) c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
@@ -274,7 +252,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
controller.log.App.Debug().Msg("Creating session cookie for user") controller.log.App.Debug().Msg("Creating session cookie for user")
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -320,3 +298,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
return params.LoginFor == string(FrontendLoginForOIDC) return params.LoginFor == string(FrontendLoginForOIDC)
} }
func (controller *OAuthController) getCookieDomain() string {
if controller.config.Auth.SubdomainsEnabled {
return "." + controller.runtime.CookieDomain
}
return controller.runtime.CookieDomain
}
+8 -64
View File
@@ -1,14 +1,12 @@
package controller package controller
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"slices" "slices"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
@@ -33,8 +31,6 @@ type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime model.RuntimeConfig runtime model.RuntimeConfig
helpers model.RuntimeHelpers
config model.Config
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -72,11 +68,10 @@ type ClientCredentials struct {
} }
type AuthorizeScreenParams struct { type AuthorizeScreenParams struct {
LoginFor FrontendLoginFor `url:"login_for"` LoginFor FrontendLoginFor `url:"login_for"`
OIDCTicket string `url:"oidc_ticket"` OIDCTicket string `url:"oidc_ticket"`
OIDCScope string `url:"oidc_scope"` OIDCScope string `url:"oidc_scope"`
OIDCName string `url:"oidc_name"` OIDCName string `url:"oidc_name"`
OIDCShowConsent bool `url:"oidc_show_consent"`
} }
type AuthorizeCompleteRequest struct { type AuthorizeCompleteRequest struct {
@@ -87,16 +82,12 @@ func NewOIDCController(
log *logger.Logger, log *logger.Logger,
oidcService *service.OIDCService, oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig, runtimeConfig model.RuntimeConfig,
helpers model.RuntimeHelpers,
config model.Config,
router *gin.RouterGroup, router *gin.RouterGroup,
mainRouter *gin.RouterGroup) *OIDCController { mainRouter *gin.RouterGroup) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: log, log: log,
oidc: oidcService, oidc: oidcService,
runtime: runtimeConfig, runtime: runtimeConfig,
helpers: helpers,
config: config,
} }
mainRouter.POST("/authorize", controller.authorize) mainRouter.POST("/authorize", controller.authorize)
@@ -172,31 +163,11 @@ func (controller *OIDCController) authorize(c *gin.Context) {
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req) ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
// Check if we have consented before for this client and scope
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
showConsent := true
if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
showConsent = false
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
queries, err := query.Values(AuthorizeScreenParams{ queries, err := query.Values(AuthorizeScreenParams{
LoginFor: FrontendLoginForOIDC, LoginFor: FrontendLoginForOIDC,
OIDCTicket: ticket, OIDCTicket: ticket,
OIDCScope: req.Scope, OIDCScope: req.Scope,
OIDCName: client.Name, OIDCName: client.Name,
OIDCShowConsent: showConsent,
}) })
if err != nil { if err != nil {
@@ -318,33 +289,6 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
return return
} }
// Just before returning let's set the consent cookie
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
// If we fail to create the consent entry, we don't want to block the authorization flow,
// but we log the error and move on without setting the cookie
if err == nil {
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
if err == nil {
cookie := &http.Cookie{
Name: controller.runtime.ConsentCookieName,
Value: consnetUUID,
Path: "/",
Domain: cookieDomain,
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
Secure: controller.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(c.Writer, cookie)
} else {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
}
} else {
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
+1 -3
View File
@@ -30,8 +30,6 @@ func TestOIDCController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
ctx := context.TODO() ctx := context.TODO()
dg := ding.New(ctx) dg := ding.New(ctx)
@@ -833,7 +831,7 @@ func TestOIDCController(t *testing.T) {
svc = nil svc = nil
} }
controller.NewOIDCController(log, svc, runtime, helpers, cfg, group, &router.RouterGroup) controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+1 -3
View File
@@ -24,8 +24,6 @@ func TestProxyController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
const browserUserAgent = ` const browserUserAgent = `
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
@@ -397,7 +395,7 @@ func TestProxyController(t *testing.T) {
Log: log, Log: log,
}) })
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
+6 -6
View File
@@ -150,7 +150,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Email: email, Email: email,
Provider: "local", Provider: "local",
TotpPending: true, TotpPending: true,
}, c.RemoteIP()) })
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session")
@@ -195,7 +195,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login")
@@ -246,7 +246,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return return
} }
cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP()) cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Error deleting session on logout") controller.log.App.Error().Err(err).Msg("Error deleting session on logout")
@@ -350,7 +350,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
uuid, err := c.Cookie(controller.runtime.SessionCookieName) uuid, err := c.Cookie(controller.runtime.SessionCookieName)
if err == nil { if err == nil {
_, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP()) _, err = controller.auth.DeleteSession(c, uuid)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification")
} }
@@ -374,7 +374,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
sessionCookie.Email = user.Attributes.Email sessionCookie.Email = user.Attributes.Email
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
@@ -424,7 +424,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) {
Provider: "tailscale", Provider: "tailscale",
} }
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login")
+1 -3
View File
@@ -29,8 +29,6 @@ func TestUserController(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
Authenticated: false, Authenticated: false,
@@ -420,7 +418,7 @@ func TestUserController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
+2 -2
View File
@@ -206,12 +206,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid, ip) m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
} }
} }
cookie, err := m.auth.RefreshSession(ctx, uuid, ip) cookie, err := m.auth.RefreshSession(ctx, uuid)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err) return nil, nil, fmt.Errorf("error refreshing session: %w", err)
@@ -27,8 +27,6 @@ func TestContextMiddleware(t *testing.T) {
cfg, runtime := test.CreateTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
helpers := test.CreateTestHelpers()
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
} }
@@ -260,7 +258,7 @@ func TestContextMiddleware(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine) authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+2 -1
View File
@@ -18,7 +18,8 @@ var OverrideProviders = map[string]string{
} }
const SessionCookieName = "tinyauth-session" const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth" const OAuthSessionCookieName = "tinyauth-oauth"
const ConsentCookieName = "tinyauth-consent"
const GracefulShutdownTimeout = 5 // seconds const GracefulShutdownTimeout = 5 // seconds
+2 -7
View File
@@ -1,14 +1,13 @@
package model package model
import "context"
type RuntimeConfig struct { type RuntimeConfig struct {
AppURL string AppURL string
UUID string UUID string
CookieDomain string CookieDomain string
SessionCookieName string SessionCookieName string
CSRFCookieName string
RedirectCookieName string
OAuthSessionCookieName string OAuthSessionCookieName string
ConsentCookieName string
LocalUsers []LocalUser LocalUsers []LocalUser
OAuthProviders map[string]OAuthServiceConfig OAuthProviders map[string]OAuthServiceConfig
OAuthWhitelist []string OAuthWhitelist []string
@@ -17,10 +16,6 @@ type RuntimeConfig struct {
TrustedDomains []string TrustedDomains []string
} }
type RuntimeHelpers struct {
GetCookieDomain func(ctx context.Context, ip string) (string, error)
}
type Provider struct { type Provider struct {
Name string `json:"name"` Name string `json:"name"`
ID string `json:"id"` ID string `json:"id"`
-72
View File
@@ -277,78 +277,6 @@ func TestMemoryStore(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}, },
}, },
{
description: "Create and get OIDC consent",
run: func(t *testing.T, s repository.Store) {
consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{
UUID: "uuid-1",
ClientID: "client-1",
Scopes: "openid profile",
})
require.NoError(t, err)
assert.Equal(t, "uuid-1", consent.UUID)
assert.Equal(t, "client-1", consent.ClientID)
assert.Equal(t, "openid profile", consent.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, consent, got)
},
},
{
description: "Get OIDC consent by UUID not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.GetOIDCConsentByUUID(ctx, "missing")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Create OIDC consent unique UUID constraint",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
_, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"})
assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid")
},
},
{
description: "Update OIDC consent",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: "uuid-1",
Scopes: "profile email",
})
require.NoError(t, err)
assert.Equal(t, "profile email", updated.Scopes)
got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1")
require.NoError(t, err)
assert.Equal(t, updated, got)
},
},
{
description: "Update OIDC consent not found",
run: func(t *testing.T, s repository.Store) {
_, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"})
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
{
description: "Delete OIDC consent by UUID",
run: func(t *testing.T, s repository.Store) {
_, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"})
require.NoError(t, err)
require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1"))
_, err = s.GetOIDCConsentByUUID(ctx, "uuid-1")
assert.ErrorIs(t, err, repository.ErrNotFound)
},
},
} }
for _, test := range tests { for _, test := range tests {
@@ -94,47 +94,3 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele
} }
return nil return nil
} }
func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.oidcConsent[arg.UUID]; ok {
return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid")
}
consent := repository.OidcConsent{
UUID: arg.UUID,
ClientID: arg.ClientID,
Scopes: arg.Scopes,
}
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) {
s.mu.RLock()
defer s.mu.RUnlock()
consent, ok := s.oidcConsent[uuid]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
return consent, nil
}
func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
s.mu.Lock()
defer s.mu.Unlock()
consent, ok := s.oidcConsent[arg.UUID]
if !ok {
return repository.OidcConsent{}, repository.ErrNotFound
}
consent.Scopes = arg.Scopes
s.oidcConsent[arg.UUID] = consent
return consent, nil
}
func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.oidcConsent, uuid)
return nil
}
-2
View File
@@ -12,7 +12,6 @@ type Store struct {
mu sync.RWMutex mu sync.RWMutex
sessions map[string]repository.Session sessions map[string]repository.Session
oidcSessions map[string]repository.OidcSession oidcSessions map[string]repository.OidcSession
oidcConsent map[string]repository.OidcConsent
} }
// New returns a new empty in-memory Store. // New returns a new empty in-memory Store.
@@ -20,6 +19,5 @@ func New() repository.Store {
return &Store{ return &Store{
sessions: make(map[string]repository.Session), sessions: make(map[string]repository.Session),
oidcSessions: make(map[string]repository.OidcSession), oidcSessions: make(map[string]repository.OidcSession),
oidcConsent: make(map[string]repository.OidcConsent),
} }
} }
-21
View File
@@ -1,18 +1,8 @@
package repository package repository
import "time"
// Shared model and parameter types for all storage drivers. // Shared model and parameter types for all storage drivers.
// sqlc-generated driver packages use these via the conversion layer in their store.go. // sqlc-generated driver packages use these via the conversion layer in their store.go.
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type Session struct { type Session struct {
UUID string UUID string
Username string Username string
@@ -94,14 +84,3 @@ type DeleteExpiredOIDCSessionsParams struct {
TokenExpiresAt int64 TokenExpiresAt int64
RefreshTokenExpiresAt int64 RefreshTokenExpiresAt int64
} }
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
-12
View File
@@ -4,18 +4,6 @@
package postgres package postgres
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,36 +9,6 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -110,16 +80,6 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = $1 WHERE "sub" = $1
@@ -130,24 +90,6 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = $1
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = $1 WHERE "access_token_hash" = $1
@@ -214,32 +156,6 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = $1, "access_token_hash" = $1,
-28
View File
@@ -32,14 +32,6 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -64,10 +56,6 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -76,14 +64,6 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -116,14 +96,6 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
-12
View File
@@ -4,18 +4,6 @@
package sqlite package sqlite
import (
"time"
)
type OidcConsent struct {
UUID string
ClientID string
Scopes string
CreatedAt time.Time
UpdatedAt time.Time
}
type OidcSession struct { type OidcSession struct {
Sub string Sub string
AccessTokenHash string AccessTokenHash string
@@ -9,36 +9,6 @@ import (
"context" "context"
) )
const createOIDCConsent = `-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type CreateOIDCConsentParams struct {
UUID string
ClientID string
Scopes string
}
func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const createOIDCSession = `-- name: CreateOIDCSession :one const createOIDCSession = `-- name: CreateOIDCSession :one
INSERT INTO "oidc_sessions" ( INSERT INTO "oidc_sessions" (
"sub", "sub",
@@ -110,16 +80,6 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir
return err return err
} }
const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
_, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid)
return err
}
const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec
DELETE FROM "oidc_sessions" DELETE FROM "oidc_sessions"
WHERE "sub" = ? WHERE "sub" = ?
@@ -130,24 +90,6 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error
return err return err
} }
const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one
SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent"
WHERE "uuid" = ?
`
func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions"
WHERE "access_token_hash" = ? WHERE "access_token_hash" = ?
@@ -214,32 +156,6 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess
return i, err return i, err
} }
const updateOIDCConsent = `-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING uuid, client_id, scopes, created_at, updated_at
`
type UpdateOIDCConsentParams struct {
Scopes string
UUID string
}
func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) {
row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID)
var i OidcConsent
err := row.Scan(
&i.UUID,
&i.ClientID,
&i.Scopes,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const updateOIDCSession = `-- name: UpdateOIDCSession :one const updateOIDCSession = `-- name: UpdateOIDCSession :one
UPDATE "oidc_sessions" SET UPDATE "oidc_sessions" SET
"access_token_hash" = ?, "access_token_hash" = ?,
-28
View File
@@ -32,14 +32,6 @@ func mapErr(err error) error {
return err return err
} }
func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg))
if err != nil { if err != nil {
@@ -64,10 +56,6 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
} }
func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid))
}
func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error {
return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub))
} }
@@ -76,14 +64,6 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
return mapErr(s.q.DeleteSession(ctx, uuid)) return mapErr(s.q.DeleteSession(ctx, uuid))
} }
func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) {
r, err := s.q.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) {
r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash)
if err != nil { if err != nil {
@@ -116,14 +96,6 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
return repository.Session(r), nil return repository.Session(r), nil
} }
func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) {
r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg))
if err != nil {
return repository.OidcConsent{}, mapErr(err)
}
return repository.OidcConsent(r), nil
}
func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) {
r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg))
if err != nil { if err != nil {
-6
View File
@@ -27,10 +27,4 @@ type Store interface {
GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error)
GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error)
UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error)
// OIDC consents
CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error)
DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error
GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error)
UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error)
} }
+6 -7
View File
@@ -13,19 +13,18 @@ type LabelProvider interface {
type AccessControlsService struct { type AccessControlsService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
labelProvider *LabelProvider labelProvider *LabelProvider
} }
func NewAccessControlsService( func NewAccessControlsService(
log *logger.Logger, deps *ServiceDependencies,
config model.Config, ) *AccessControlsService {
labelProvider *LabelProvider) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log, log: deps.Log,
config: config, config: deps.StaticConfig,
labelProvider: labelProvider, labelProvider: &deps.LabelProvider,
} }
} }
+38 -50
View File
@@ -57,9 +57,8 @@ type LoginAttempt struct {
type AuthService struct { type AuthService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
runtime model.RuntimeConfig runtime *model.RuntimeConfig
helpers model.RuntimeHelpers
ctx context.Context ctx context.Context
ldap *LdapService ldap *LdapService
@@ -84,29 +83,18 @@ type AuthService struct {
} }
func NewAuthService( func NewAuthService(
log *logger.Logger, deps *ServiceDependencies,
config model.Config,
runtime model.RuntimeConfig,
helpers model.RuntimeHelpers,
ctx context.Context,
dg *ding.Ding,
ldap *LdapService,
queries repository.Store,
oauthBroker *OAuthBrokerService,
tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService { ) *AuthService {
service := &AuthService{ service := &AuthService{
log: log, log: deps.Log,
runtime: runtime, runtime: deps.RuntimeConfig,
helpers: helpers, ctx: deps.Ctx,
ctx: ctx, config: deps.StaticConfig,
config: config, ldap: deps.Services.LDAPService,
ldap: ldap, queries: *deps.Queries,
queries: queries, oauthBroker: deps.Services.OAuthBrokerService,
oauthBroker: oauthBroker, tailscale: deps.Services.TailscaleService,
tailscale: tailscale, policyEngine: deps.Services.PolicyEngine,
policyEngine: policy,
} }
// caches setup // caches setup
@@ -118,7 +106,7 @@ func NewAuthService(
service.caches.login = loginCache service.caches.login = loginCache
service.caches.ldap = ldapCache service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -325,7 +313,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool
}) })
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
if data.Provider == "tailscale" && auth.tailscale == nil { if data.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
} }
@@ -366,17 +354,33 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err) return nil, fmt.Errorf("failed to create session entry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip) if data.Provider == "tailscale" {
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
if err != nil { tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: cookieDomain, Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -385,17 +389,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}, nil }, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) { func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid) session, err := auth.queries.GetSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return nil, fmt.Errorf("failed to retrieve session: %w", err)
} }
if session.Provider == "tailscale" && auth.tailscale == nil {
return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user")
}
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
var refreshThreshold int64 var refreshThreshold int64
@@ -429,17 +429,11 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip str
return nil, fmt.Errorf("failed to update session expiry: %w", err) return nil, fmt.Errorf("failed to update session expiry: %w", err)
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: cookieDomain, Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
@@ -449,24 +443,18 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip str
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) { func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
} }
cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip)
if err != nil {
return nil, fmt.Errorf("failed to determine cookie domain: %w", err)
}
return &http.Cookie{ return &http.Cookie{
Name: auth.runtime.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: cookieDomain, Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.Auth.SecureCookie, Secure: auth.config.Auth.SecureCookie,
+7 -9
View File
@@ -22,9 +22,7 @@ type DockerService struct {
} }
func NewDockerService( func NewDockerService(
log *logger.Logger, deps *ServiceDependencies,
ctx context.Context,
dg *ding.Ding,
) (*DockerService, error) { ) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
@@ -32,25 +30,25 @@ func NewDockerService(
return nil, err return nil, err
} }
client.NegotiateAPIVersion(ctx) client.NegotiateAPIVersion(deps.Ctx)
_, err = client.Ping(ctx) _, err = client.Ping(deps.Ctx)
if err != nil { if err != nil {
log.App.Debug().Err(err).Msg("Docker not connected") deps.Log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil return nil, nil
} }
service := &DockerService{ service := &DockerService{
log: log, log: deps.Log,
client: client, client: client,
context: ctx, context: deps.Ctx,
} }
service.isConnected = true service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully") service.log.App.Debug().Msg("Docker connected successfully")
dg.Go(service.watchAndClose, ding.RingMajor) deps.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
+7 -9
View File
@@ -49,9 +49,7 @@ type KubernetesService struct {
} }
func NewKubernetesService( func NewKubernetesService(
log *logger.Logger, deps *ServiceDependencies,
ctx context.Context,
dg *ding.Ding,
) (*KubernetesService, error) { ) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig() cfg, err := rest.InClusterConfig()
if err != nil { if err != nil {
@@ -69,31 +67,31 @@ func NewKubernetesService(
Resource: "ingresses", Resource: "ingresses",
} }
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) accessCtx, accessCancel := context.WithTimeout(deps.Ctx, 5*time.Second)
defer accessCancel() defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil { if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") deps.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err) return nil, fmt.Errorf("failed to access ingress api: %w", err)
} }
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") deps.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{ service := &KubernetesService{
log: log, log: deps.Log,
client: client, client: client,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
} }
dg.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx) service.watchGVR(gvr, ctx)
}, ding.RingMajor) }, ding.RingMajor)
service.started = true service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully") deps.Log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil return service, nil
} }
+15 -17
View File
@@ -17,40 +17,38 @@ import (
type LdapService struct { type LdapService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
ldapBindPw string
} }
func NewLdapService( func NewLdapService(
log *logger.Logger, deps *ServiceDependencies,
config model.Config,
dg *ding.Ding,
) (*LdapService, error) { ) (*LdapService, error) {
if config.LDAP.Address == "" { if deps.StaticConfig.LDAP.Address == "" {
return nil, nil return nil, nil
} }
secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile) ldapBindPw := utils.GetSecret(deps.StaticConfig.LDAP.BindPassword, deps.StaticConfig.LDAP.BindPasswordFile)
config.LDAP.BindPassword = secret
config.LDAP.BindPasswordFile = ""
ldap := &LdapService{ ldap := &LdapService{
log: log, log: deps.Log,
config: config, config: deps.StaticConfig,
ldapBindPw: ldapBindPw,
} }
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { if deps.StaticConfig.LDAP.AuthCert != "" && deps.StaticConfig.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) cert, err := tls.LoadX509KeyPair(deps.StaticConfig.LDAP.AuthCert, deps.StaticConfig.LDAP.AuthKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
log.App.Info().Msg("LDAP mTLS authentication configured successfully") ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert ldap.cert = &cert
@@ -72,7 +70,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err) return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
} }
dg.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
+6 -8
View File
@@ -33,22 +33,20 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
} }
func NewOAuthBrokerService( func NewOAuthBrokerService(
log *logger.Logger, deps *ServiceDependencies,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService { ) *OAuthBrokerService {
service := &OAuthBrokerService{ service := &OAuthBrokerService{
log: log, log: deps.Log,
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: deps.RuntimeConfig.OAuthProviders,
} }
for name, cfg := range configs { for name, cfg := range service.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
service.services[name] = presetFunc(cfg, ctx) service.services[name] = presetFunc(cfg, deps.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
service.services[name] = NewOAuthService(cfg, name, ctx) service.services[name] = NewOAuthService(cfg, name, deps.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
} }
} }
+42 -78
View File
@@ -21,7 +21,6 @@ import (
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/steveiliop56/ding" "github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
@@ -108,7 +107,6 @@ type TokenResponse struct {
} }
type AuthorizeRequest struct { type AuthorizeRequest struct {
jwt.Claims
Scope string `form:"scope" json:"scope" url:"scope"` Scope string `form:"scope" json:"scope" url:"scope"`
ResponseType string `form:"response_type" json:"response_type" url:"response_type"` ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
ClientID string `form:"client_id" json:"client_id" url:"client_id"` ClientID string `form:"client_id" json:"client_id" url:"client_id"`
@@ -135,8 +133,8 @@ type UsedCodeEntry struct {
type OIDCService struct { type OIDCService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
runtime model.RuntimeConfig runtime *model.RuntimeConfig
queries repository.Store queries repository.Store
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
@@ -152,18 +150,15 @@ type OIDCService struct {
} }
func NewOIDCService( func NewOIDCService(
log *logger.Logger, deps *ServiceDependencies,
config model.Config, ) (*OIDCService, error) {
runtime model.RuntimeConfig,
queries repository.Store,
dg *ding.Ding) (*OIDCService, error) {
// If not configured, skip init // If not configured, skip init
if len(runtime.OIDCClients) == 0 { if len(deps.RuntimeConfig.OIDCClients) == 0 {
return nil, nil return nil, nil
} }
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(runtime.AppURL) uissuer, err := url.Parse(deps.RuntimeConfig.AppURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err) return nil, fmt.Errorf("failed to parse app url: %w", err)
@@ -176,14 +171,14 @@ func NewOIDCService(
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || if strings.TrimSpace(deps.StaticConfig.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { strings.TrimSpace(deps.StaticConfig.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required") return nil, errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) fprivateKey, err := os.ReadFile(deps.StaticConfig.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err return nil, err
@@ -202,8 +197,8 @@ func NewOIDCService(
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") deps.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) err = os.WriteFile(deps.StaticConfig.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err) return nil, fmt.Errorf("failed to write private key to file: %w", err)
} }
@@ -212,7 +207,7 @@ func NewOIDCService(
if block == nil { if block == nil {
return nil, errors.New("failed to decode private key") return nil, errors.New("failed to decode private key")
} }
log.App.Trace().Str("type", block.Type).Msg("Loaded private key") deps.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err) return nil, fmt.Errorf("failed to parse private key: %w", err)
@@ -221,7 +216,7 @@ func NewOIDCService(
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) fpublicKey, err := os.ReadFile(deps.StaticConfig.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err) return nil, fmt.Errorf("failed to read public key: %w", err)
@@ -237,8 +232,8 @@ func NewOIDCService(
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") deps.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) err = os.WriteFile(deps.StaticConfig.OIDC.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -247,7 +242,7 @@ func NewOIDCService(
if block == nil { if block == nil {
return nil, errors.New("failed to decode public key") return nil, errors.New("failed to decode public key")
} }
log.App.Trace().Str("type", block.Type).Msg("Loaded public key") deps.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
@@ -277,7 +272,7 @@ func NewOIDCService(
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig) clients := make(map[string]model.OIDCClientConfig)
for id, client := range config.OIDC.Clients { for id, client := range deps.StaticConfig.OIDC.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
@@ -293,15 +288,15 @@ func NewOIDCService(
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
clients[id] = client clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") deps.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
} }
// Initialize the service // Initialize the service
service := &OIDCService{ service := &OIDCService{
log: log, log: deps.Log,
config: config, config: deps.StaticConfig,
runtime: runtime, runtime: deps.RuntimeConfig,
queries: queries, queries: *deps.Queries,
clients: clients, clients: clients,
privateKey: privateKey, privateKey: privateKey,
@@ -310,7 +305,7 @@ func NewOIDCService(
} }
// Start cleanup routine // Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor) deps.Ding.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches // Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256) codeCash := NewCacheStore[AuthorizeCodeEntry](256)
@@ -322,7 +317,7 @@ func NewOIDCService(
service.caches.authorize = authorize service.caches.authorize = authorize
// Start cache cleanup routine // Start cache cleanup routine
dg.Go(func(ctx context.Context) { deps.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
@@ -889,63 +884,32 @@ func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
// TODO: support signed request objects in the future // TODO: support signed request objects in the future
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) { func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
var req AuthorizeRequest var claims jwt.MapClaims
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &req)
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err) return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
} }
claims, ok := token.Claims.(*AuthorizeRequest) alg, ok := token.Header["alg"].(string)
if !ok { if !ok || alg != "none" || string(token.Signature) != "" {
return nil, errors.New("failed to parse claims from authorize request jwt") return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
} }
return claims, nil get := func(k string) string {
} v, _ := claims[k].(string)
return v
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
u := uuid.New()
entry := repository.CreateOIDCConsentParams{
UUID: u.String(),
ClientID: clientId,
Scopes: scope,
} }
_, err := service.queries.CreateOIDCConsent(ctx, entry) return &AuthorizeRequest{
Scope: get("scope"),
if err != nil { ResponseType: get("response_type"),
return "", err ClientID: get("client_id"),
} RedirectURI: get("redirect_uri"),
State: get("state"),
return entry.UUID, nil Nonce: get("nonce"),
} CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) { }, nil
entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return nil, nil
}
return nil, err
}
return &entry, nil
}
func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error {
return service.queries.DeleteOIDCConsentByUUID(ctx, uuid)
}
func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error {
_, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{
UUID: uuid,
Scopes: scopes,
})
return err
} }
+8 -6
View File
@@ -40,21 +40,23 @@ type PolicyEngine struct {
policy Policy policy Policy
} }
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) { func NewPolicyEngine(
deps *ServiceDependencies,
) (*PolicyEngine, error) {
engine := PolicyEngine{ engine := PolicyEngine{
log: log, log: deps.Log,
rules: make(map[RuleName]Rule), rules: make(map[RuleName]Rule),
} }
switch config.Auth.ACLs.Policy { switch deps.StaticConfig.Auth.ACLs.Policy {
case string(PolicyAllow): case string(PolicyAllow):
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked") deps.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow engine.policy = PolicyAllow
case string(PolicyDeny): case string(PolicyDeny):
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed") deps.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny engine.policy = PolicyDeny
default: default:
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy) return nil, fmt.Errorf("invalid acl policy: %s", deps.StaticConfig.Auth.ACLs.Policy)
} }
return &engine, nil return &engine, nil
+33
View File
@@ -0,0 +1,33 @@
package service
import (
"context"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type Services struct {
AccessControlService *AccessControlsService
AuthService *AuthService
DockerService *DockerService
KubernetesService *KubernetesService
LDAPService *LdapService
OAuthBrokerService *OAuthBrokerService
OIDCService *OIDCService
TailscaleService *TailscaleService
PolicyEngine *PolicyEngine
}
type ServiceDependencies struct {
Log *logger.Logger
StaticConfig *model.Config
RuntimeConfig *model.RuntimeConfig
Ctx context.Context
Ding *ding.Ding
Services *Services
LabelProvider LabelProvider
Queries *repository.Store
}
+16 -14
View File
@@ -25,7 +25,7 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct { type TailscaleService struct {
log *logger.Logger log *logger.Logger
config model.Config config *model.Config
ctx context.Context ctx context.Context
srv *tsnet.Server srv *tsnet.Server
@@ -34,22 +34,24 @@ type TailscaleService struct {
mu sync.Mutex mu sync.Mutex
} }
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) { func NewTailscaleService(
if !config.Tailscale.Enabled { deps *ServiceDependencies,
) (*TailscaleService, error) {
if !deps.StaticConfig.Tailscale.Enabled {
return nil, nil return nil, nil
} }
srv := new(tsnet.Server) srv := new(tsnet.Server)
// node options // node options
srv.Dir = config.Tailscale.Dir srv.Dir = deps.StaticConfig.Tailscale.Dir
srv.Hostname = config.Tailscale.Hostname srv.Hostname = deps.StaticConfig.Tailscale.Hostname
srv.AuthKey = config.Tailscale.AuthKey srv.AuthKey = deps.StaticConfig.Tailscale.AuthKey
srv.Ephemeral = config.Tailscale.Ephemeral srv.Ephemeral = deps.StaticConfig.Tailscale.Ephemeral
// redirect logs to zerolog // redirect logs to zerolog
srv.Logf = log.App.Printf srv.Logf = deps.Log.App.Printf
srv.UserLogf = log.App.Printf srv.UserLogf = deps.Log.App.Printf
err := srv.Start() err := srv.Start()
@@ -65,14 +67,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
} }
service := &TailscaleService{ service := &TailscaleService{
log: log, log: deps.Log,
config: config, config: deps.StaticConfig,
ctx: ctx, ctx: deps.Ctx,
srv: srv, srv: srv,
lc: lc, lc: lc,
} }
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed connectCtx, cancel := context.WithTimeout(deps.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
defer cancel() defer cancel()
err = service.waitForConn(connectCtx) err = service.waitForConn(connectCtx)
@@ -82,7 +84,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err) return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
} }
dg.Go(service.watchAndClose, ding.RingMajor) deps.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil return service, nil
} }
-9
View File
@@ -1,7 +1,6 @@
package test package test
import ( import (
"context"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -134,11 +133,3 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
return config, runtime return config, runtime
} }
func CreateTestHelpers() model.RuntimeHelpers {
return model.RuntimeHelpers{
GetCookieDomain: func(ctx context.Context, ip string) (string, error) {
return "example.com", nil
},
}
}
-25
View File
@@ -46,28 +46,3 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = $8 "userinfo_json" = $8
WHERE "sub" = $9 WHERE "sub" = $9
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
$1, $2, $3
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = $1;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = $1,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = $2
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = $1;
-8
View File
@@ -9,11 +9,3 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT '', "nonce" TEXT NOT NULL DEFAULT '',
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-25
View File
@@ -46,28 +46,3 @@ UPDATE "oidc_sessions" SET
"userinfo_json" = ? "userinfo_json" = ?
WHERE "sub" = ? WHERE "sub" = ?
RETURNING *; RETURNING *;
-- name: CreateOIDCConsent :one
INSERT INTO "oidc_consent" (
"uuid",
"client_id",
"scopes"
) VALUES (
?, ?, ?
)
RETURNING *;
-- name: GetOIDCConsentByUUID :one
SELECT * FROM "oidc_consent"
WHERE "uuid" = ?;
-- name: UpdateOIDCConsent :one
UPDATE "oidc_consent" SET
"scopes" = ?,
"updated_at" = CURRENT_TIMESTAMP
WHERE "uuid" = ?
RETURNING *;
-- name: DeleteOIDCConsentByUUID :exec
DELETE FROM "oidc_consent"
WHERE "uuid" = ?;
-8
View File
@@ -9,11 +9,3 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" (
"nonce" TEXT NOT NULL DEFAULT "", "nonce" TEXT NOT NULL DEFAULT "",
"userinfo_json" TEXT NOT NULL "userinfo_json" TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);