mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-07-02 08:10:15 +00:00
Merge branch 'main' into feat/ldap-reconnect
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type LabelProvider interface {
|
||||
@@ -13,19 +14,24 @@ type LabelProvider interface {
|
||||
|
||||
type AccessControlsService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
labelProvider *LabelProvider
|
||||
config *model.Config
|
||||
labelProvider LabelProvider
|
||||
}
|
||||
|
||||
func NewAccessControlsService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
labelProvider *LabelProvider) *AccessControlsService {
|
||||
type AccessControlServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
LabelProvider LabelProvider `optional:"true"`
|
||||
}
|
||||
|
||||
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
|
||||
|
||||
return &AccessControlsService{
|
||||
log: log,
|
||||
config: config,
|
||||
labelProvider: labelProvider,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
labelProvider: i.LabelProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
|
||||
}
|
||||
|
||||
// If we have a label provider configured, try to get ACLs from it
|
||||
if service.labelProvider != nil && *service.labelProvider != nil {
|
||||
return (*service.labelProvider).GetLabels(domain)
|
||||
if service.labelProvider != nil {
|
||||
return service.labelProvider.GetLabels(domain)
|
||||
}
|
||||
|
||||
// no labels
|
||||
|
||||
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{Apps: tt.apps},
|
||||
LabelProvider: nil,
|
||||
})
|
||||
got := svc.lookupStaticACLs(tt.domain)
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, got)
|
||||
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(log, config, nil)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &config,
|
||||
LabelProvider: nil,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||
svc := NewAccessControlsService(log, model.Config{}, nil)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: nil,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
|
||||
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||
var provider LabelProvider
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider, // nil provider
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(log, config, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &config,
|
||||
LabelProvider: provider,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
|
||||
@@ -2,8 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -24,32 +27,28 @@ import (
|
||||
// but for now these are just safety limits to prevent unbounded memory usage
|
||||
const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
const MaxLoginAttemptRecords = 256
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
|
||||
// parameters and pass them to the authorize page if needed
|
||||
type OAuthURLParams struct {
|
||||
Scope string `form:"scope" url:"scope"`
|
||||
ResponseType string `form:"response_type" url:"response_type"`
|
||||
ClientID string `form:"client_id" url:"client_id"`
|
||||
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
|
||||
State string `form:"state" url:"state"`
|
||||
Nonce string `form:"nonce" url:"nonce"`
|
||||
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
|
||||
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
|
||||
// We either store params for redirecting to an app after OAuth login,
|
||||
// or for redirecting back to the authorize screen to continue OIDC
|
||||
type OAuthCallbackParams struct {
|
||||
LoginFor string `form:"login_for" url:"login_for"`
|
||||
OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
|
||||
OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
|
||||
OIDCName string `form:"oidc_name" url:"oidc_name"`
|
||||
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
|
||||
}
|
||||
|
||||
type OAuthPendingSession struct {
|
||||
State string
|
||||
Verifier string
|
||||
Token *oauth2.Token
|
||||
Service *OAuthServiceImpl
|
||||
Service IOAuthService
|
||||
ExpiresAt time.Time
|
||||
CallbackParams OAuthURLParams
|
||||
CallbackParams OAuthCallbackParams
|
||||
}
|
||||
|
||||
type LoginAttempt struct {
|
||||
@@ -60,8 +59,8 @@ type LoginAttempt struct {
|
||||
|
||||
type AuthService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
ctx context.Context
|
||||
|
||||
ldap *LdapService
|
||||
@@ -83,42 +82,57 @@ type AuthService struct {
|
||||
oauth *CacheStore[OAuthPendingSession]
|
||||
ldap *CacheStore[[]string]
|
||||
}
|
||||
|
||||
maxLoginLimits int
|
||||
}
|
||||
|
||||
func NewAuthService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
ldap *LdapService,
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
tailscale *TailscaleService,
|
||||
policy *PolicyEngine,
|
||||
) *AuthService {
|
||||
type AuthServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
LDAP *LdapService `optional:"true"`
|
||||
Queries repository.Store
|
||||
OAuthBroker *OAuthBrokerService
|
||||
Tailscale *TailscaleService `optional:"true"`
|
||||
PolicyEngine *PolicyEngine
|
||||
}
|
||||
|
||||
func NewAuthService(i AuthServiceInput) *AuthService {
|
||||
service := &AuthService{
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
ldap: ldap,
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
tailscale: tailscale,
|
||||
policyEngine: policy,
|
||||
log: i.Log,
|
||||
runtime: i.Runtime,
|
||||
ctx: i.Ctx,
|
||||
config: i.Config,
|
||||
ldap: i.LDAP,
|
||||
queries: i.Queries,
|
||||
oauthBroker: i.OAuthBroker,
|
||||
tailscale: i.Tailscale,
|
||||
policyEngine: i.PolicyEngine,
|
||||
}
|
||||
|
||||
// get the max login limits based on the number of users and the configured max retries
|
||||
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||
|
||||
loginCacheSize := 0
|
||||
|
||||
if !service.config.Auth.LockdownEnabled {
|
||||
loginCacheSize = service.maxLoginLimits
|
||||
}
|
||||
|
||||
// caches setup
|
||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
||||
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||
ldapCache := NewCacheStore[[]string](1024)
|
||||
|
||||
service.caches.oauth = oauthCache
|
||||
service.caches.login = loginCache
|
||||
service.caches.ldap = ldapCache
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -257,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
||||
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||
if locked, _ := auth.IsInLockdown(); locked {
|
||||
return
|
||||
}
|
||||
@@ -366,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||
}
|
||||
|
||||
if data.Provider == "tailscale" {
|
||||
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
|
||||
|
||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", tsCookieDomain),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -445,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -466,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Domain: auth.getCookieDomain(),
|
||||
Expires: time.Now(),
|
||||
MaxAge: -1,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
@@ -516,17 +508,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
return auth.ldap != nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
|
||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||
|
||||
if !ok {
|
||||
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
|
||||
return "", fmt.Errorf("oauth service not found: %s", serviceName)
|
||||
}
|
||||
|
||||
sessionId, err := uuid.NewRandom()
|
||||
|
||||
if err != nil {
|
||||
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
state := service.NewRandom()
|
||||
@@ -535,14 +527,14 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
|
||||
session := OAuthPendingSession{
|
||||
State: state,
|
||||
Verifier: verifier,
|
||||
Service: &service,
|
||||
Service: service,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
CallbackParams: params,
|
||||
}
|
||||
|
||||
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
|
||||
|
||||
return sessionId.String(), session, nil
|
||||
return sessionId.String(), nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||
@@ -552,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
||||
return session.Service.GetAuthURL(session.State, session.Verifier), nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||
@@ -562,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||
token, err := session.Service.GetToken(code, session.Verifier)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
@@ -591,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||
}
|
||||
|
||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||
userinfo, err := session.Service.GetUserinfo(session.Token)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
@@ -600,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
|
||||
return userinfo, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
|
||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return *session.Service, nil
|
||||
return session.Service, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
@@ -632,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(auth.ctx)
|
||||
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown.active = true
|
||||
auth.lockdown.ctx = ctx
|
||||
auth.lockdown.cancelFunc = cancel
|
||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||
|
||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
||||
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||
auth.lockdown.until = time.Now().Add(d)
|
||||
timer := time.NewTimer(d)
|
||||
|
||||
auth.lockdown.mu.Unlock()
|
||||
|
||||
@@ -653,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
|
||||
// Timer expired, end lockdown
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, end lockdown
|
||||
case <-auth.ctx.Done():
|
||||
// Service is shutting down, end lockdown
|
||||
}
|
||||
|
||||
auth.lockdown.mu.Lock()
|
||||
|
||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||
|
||||
auth.caches.login.Clear()
|
||||
auth.lockdown.active = false
|
||||
auth.lockdown.until = time.Time{}
|
||||
auth.lockdown.ctx = nil
|
||||
@@ -683,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
||||
func (auth *AuthService) ClearLoginAttempts() {
|
||||
auth.caches.login.Clear()
|
||||
}
|
||||
|
||||
func (auth *AuthService) calculateLockdownLimit() int {
|
||||
userCount := len(auth.runtime.LocalUsers)
|
||||
|
||||
if auth.ldap != nil {
|
||||
ldapUsers, err := auth.ldap.GetUserCount()
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||
} else {
|
||||
userCount += ldapUsers
|
||||
}
|
||||
}
|
||||
|
||||
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||
|
||||
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||
} else {
|
||||
limit += int(jitter.Int64())
|
||||
}
|
||||
|
||||
if limit < 256 {
|
||||
limit = 256
|
||||
}
|
||||
|
||||
return limit
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if !auth.config.Auth.SubdomainsEnabled {
|
||||
return ""
|
||||
}
|
||||
return auth.runtime.CookieDomain
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &model.Config{
|
||||
Auth: model.AuthConfig{
|
||||
ACLs: model.ACLsConfig{
|
||||
Policy: string(PolicyAllow),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
auth := &AuthService{
|
||||
log: log,
|
||||
runtime: model.RuntimeConfig{
|
||||
runtime: &model.RuntimeConfig{
|
||||
OAuthWhitelist: []string{"global@example.com"},
|
||||
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||
"github": {
|
||||
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
policyEngine: policyEngine,
|
||||
}
|
||||
|
||||
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
container "github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -21,36 +22,40 @@ type DockerService struct {
|
||||
isConnected bool
|
||||
}
|
||||
|
||||
func NewDockerService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*DockerService, error) {
|
||||
type DockerServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
|
||||
|
||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.NegotiateAPIVersion(ctx)
|
||||
client.NegotiateAPIVersion(i.Ctx)
|
||||
|
||||
_, err = client.Ping(ctx)
|
||||
_, err = client.Ping(i.Ctx)
|
||||
|
||||
if err != nil {
|
||||
log.App.Debug().Err(err).Msg("Docker not connected")
|
||||
i.Log.App.Debug().Err(err).Msg("Docker not connected")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
service := &DockerService{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
client: client,
|
||||
context: ctx,
|
||||
context: i.Ctx,
|
||||
}
|
||||
|
||||
service.isConnected = true
|
||||
service.log.App.Debug().Msg("Docker connected successfully")
|
||||
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||
@@ -48,11 +49,15 @@ type KubernetesService struct {
|
||||
appNameIndex map[string]ingressAppKey
|
||||
}
|
||||
|
||||
func NewKubernetesService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*KubernetesService, error) {
|
||||
type KubernetesServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
|
||||
cfg, err := rest.InClusterConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
||||
@@ -69,31 +74,31 @@ func NewKubernetesService(
|
||||
Resource: "ingresses",
|
||||
}
|
||||
|
||||
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
|
||||
defer accessCancel()
|
||||
|
||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||
if err != nil {
|
||||
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
||||
}
|
||||
|
||||
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
|
||||
service := &KubernetesService{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
client: client,
|
||||
ingressApps: make(map[ingressKey][]ingressApp),
|
||||
domainIndex: make(map[string]ingressAppKey),
|
||||
appNameIndex: make(map[string]ingressAppKey),
|
||||
}
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
service.watchGVR(gvr, ctx)
|
||||
}, ding.RingMajor)
|
||||
|
||||
service.started = true
|
||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -11,44 +11,53 @@ import (
|
||||
ldapgo "github.com/go-ldap/ldap/v3"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type LdapService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
ctx context.Context
|
||||
config *model.Config
|
||||
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
bindPw string
|
||||
}
|
||||
|
||||
func NewLdapService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*LdapService, error) {
|
||||
if config.LDAP.Address == "" {
|
||||
type LdapServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Ding *ding.Ding
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
||||
if i.Config.LDAP.Address == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ldap := &LdapService{
|
||||
log: log,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
ctx: i.Ctx,
|
||||
}
|
||||
|
||||
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
|
||||
|
||||
// Check whether authentication with client certificate is possible
|
||||
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
|
||||
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||
}
|
||||
|
||||
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||
|
||||
ldap.cert = &cert
|
||||
|
||||
@@ -72,7 +81,7 @@ func NewLdapService(
|
||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||
}
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -165,6 +174,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||
searchRequest := ldapgo.NewSearchRequest(
|
||||
ldap.config.LDAP.BaseDN,
|
||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||
"(objectClass=person)",
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
|
||||
ldap.mutex.Lock()
|
||||
defer ldap.mutex.Unlock()
|
||||
|
||||
searchResult, err := ldap.conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(searchResult.Entries), nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||
|
||||
@@ -217,7 +246,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
||||
if ldap.cert != nil {
|
||||
return ldap.conn.ExternalBind()
|
||||
}
|
||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
|
||||
}
|
||||
|
||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||
|
||||
@@ -5,25 +5,28 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"slices"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OAuthServiceImpl interface {
|
||||
type IOAuthService interface {
|
||||
Name() string
|
||||
ID() string
|
||||
NewRandom() string
|
||||
GetAuthURL(state string, verifier string) string
|
||||
GetToken(code string, verifier string) (*oauth2.Token, error)
|
||||
GetAuthURL(state, verifier string) string
|
||||
GetToken(code, verifier string) (*oauth2.Token, error)
|
||||
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
||||
GetConfig() model.OAuthServiceConfig
|
||||
UpdateConfig(config model.OAuthServiceConfig)
|
||||
}
|
||||
|
||||
type OAuthBrokerService struct {
|
||||
log *logger.Logger
|
||||
|
||||
services map[string]OAuthServiceImpl
|
||||
services map[string]IOAuthService
|
||||
configs map[string]model.OAuthServiceConfig
|
||||
}
|
||||
|
||||
@@ -32,23 +35,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
|
||||
"google": newGoogleOAuthService,
|
||||
}
|
||||
|
||||
func NewOAuthBrokerService(
|
||||
log *logger.Logger,
|
||||
configs map[string]model.OAuthServiceConfig,
|
||||
ctx context.Context,
|
||||
) *OAuthBrokerService {
|
||||
type OAuthBrokerServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Runtime *model.RuntimeConfig
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
|
||||
service := &OAuthBrokerService{
|
||||
log: log,
|
||||
services: make(map[string]OAuthServiceImpl),
|
||||
configs: configs,
|
||||
log: i.Log,
|
||||
services: make(map[string]IOAuthService),
|
||||
configs: i.Runtime.OAuthProviders,
|
||||
}
|
||||
|
||||
for name, cfg := range configs {
|
||||
for name, cfg := range service.configs {
|
||||
if presetFunc, exists := presets[name]; exists {
|
||||
service.services[name] = presetFunc(cfg, ctx)
|
||||
service.services[name] = presetFunc(cfg, i.Ctx)
|
||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||
} else {
|
||||
service.services[name] = NewOAuthService(cfg, name, ctx)
|
||||
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
|
||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
||||
}
|
||||
}
|
||||
@@ -65,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
||||
return services
|
||||
}
|
||||
|
||||
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
|
||||
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
|
||||
service, exists := broker.services[name]
|
||||
return service, exists
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
|
||||
return random
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
|
||||
func (s *OAuthService) GetAuthURL(state, verifier string) string {
|
||||
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
|
||||
}
|
||||
|
||||
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
|
||||
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
||||
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
|
||||
return s.serviceCfg
|
||||
}
|
||||
|
||||
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
|
||||
s.serviceCfg = config
|
||||
s.config.ClientID = config.ClientID
|
||||
s.config.ClientSecret = config.ClientSecret
|
||||
s.config.Scopes = config.Scopes
|
||||
s.config.Endpoint.AuthURL = config.AuthURL
|
||||
s.config.Endpoint.TokenURL = config.TokenURL
|
||||
s.config.RedirectURL = config.RedirectURL
|
||||
}
|
||||
|
||||
@@ -14,17 +14,20 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,6 +44,15 @@ var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
)
|
||||
|
||||
type OIDCPrompt string
|
||||
|
||||
const (
|
||||
OIDCPromptLogin OIDCPrompt = "login"
|
||||
OIDCPromptNone OIDCPrompt = "none"
|
||||
)
|
||||
|
||||
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||
|
||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||
@@ -51,6 +63,7 @@ type ClaimSet struct {
|
||||
Sub string `json:"sub"`
|
||||
Iat int64 `json:"iat"`
|
||||
Exp int64 `json:"exp"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
@@ -106,14 +119,16 @@ type TokenResponse struct {
|
||||
}
|
||||
|
||||
type AuthorizeRequest struct {
|
||||
Scope string `json:"scope" binding:"required"`
|
||||
ResponseType string `json:"response_type" binding:"required"`
|
||||
ClientID string `json:"client_id" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri" binding:"required"`
|
||||
State string `json:"state"`
|
||||
Nonce string `json:"nonce"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
Scope string `form:"scope" json:"scope" url:"scope"`
|
||||
ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
|
||||
ClientID string `form:"client_id" json:"client_id" url:"client_id"`
|
||||
RedirectURI string `form:"redirect_uri" json:"redirect_uri" url:"redirect_uri"`
|
||||
State string `form:"state" json:"state" url:"state"`
|
||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
||||
}
|
||||
|
||||
type AuthorizeCodeEntry struct {
|
||||
@@ -124,6 +139,7 @@ type AuthorizeCodeEntry struct {
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
Userinfo UserinfoResponse
|
||||
AuthTime int64
|
||||
}
|
||||
|
||||
type UsedCodeEntry struct {
|
||||
@@ -132,8 +148,8 @@ type UsedCodeEntry struct {
|
||||
|
||||
type OIDCService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
queries repository.Store
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
@@ -142,24 +158,30 @@ type OIDCService struct {
|
||||
issuer string
|
||||
|
||||
caches struct {
|
||||
code *CacheStore[AuthorizeCodeEntry]
|
||||
usedCode *CacheStore[UsedCodeEntry]
|
||||
code *CacheStore[AuthorizeCodeEntry]
|
||||
usedCode *CacheStore[UsedCodeEntry]
|
||||
authorize *CacheStore[AuthorizeRequest]
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries repository.Store,
|
||||
dg *ding.Ding) (*OIDCService, error) {
|
||||
type OIDCServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Queries repository.Store
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
if len(runtime.OIDCClients) == 0 {
|
||||
if len(i.Config.OIDC.Clients) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ensure issuer is https
|
||||
uissuer, err := url.Parse(runtime.AppURL)
|
||||
uissuer, err := url.Parse(i.Runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||
@@ -172,14 +194,14 @@ func NewOIDCService(
|
||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||
|
||||
// Create/load private and public keys
|
||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
||||
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
|
||||
return nil, errors.New("private key path and public key path are required")
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
|
||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
||||
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
@@ -198,8 +220,12 @@ func NewOIDCService(
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||
}
|
||||
@@ -208,7 +234,7 @@ func NewOIDCService(
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode private key")
|
||||
}
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
@@ -217,7 +243,7 @@ func NewOIDCService(
|
||||
|
||||
var publicKey crypto.PublicKey
|
||||
|
||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
||||
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||
@@ -233,8 +259,12 @@ func NewOIDCService(
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -243,7 +273,7 @@ func NewOIDCService(
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode public key")
|
||||
}
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
switch block.Type {
|
||||
case "RSA PUBLIC KEY":
|
||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
@@ -273,7 +303,7 @@ func NewOIDCService(
|
||||
// We will reorganize the client into a map with the client ID as the key
|
||||
clients := make(map[string]model.OIDCClientConfig)
|
||||
|
||||
for id, client := range config.OIDC.Clients {
|
||||
for id, client := range i.Config.OIDC.Clients {
|
||||
client.ID = id
|
||||
if client.Name == "" {
|
||||
client.Name = utils.Capitalize(client.ID)
|
||||
@@ -289,15 +319,15 @@ func NewOIDCService(
|
||||
}
|
||||
client.ClientSecretFile = ""
|
||||
clients[id] = client
|
||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
}
|
||||
|
||||
// Initialize the service
|
||||
service := &OIDCService{
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
runtime: i.Runtime,
|
||||
queries: i.Queries,
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
@@ -306,16 +336,19 @@ func NewOIDCService(
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
|
||||
// Create caches
|
||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||
usedCode := NewCacheStore[UsedCodeEntry](256)
|
||||
authorize := NewCacheStore[AuthorizeRequest](256)
|
||||
|
||||
service.caches.code = codeCash
|
||||
service.caches.usedCode = usedCode
|
||||
service.caches.authorize = authorize
|
||||
|
||||
// Start cache cleanup routine
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -324,6 +357,7 @@ func NewOIDCService(
|
||||
case <-ticker.C:
|
||||
service.caches.code.Sweep()
|
||||
service.caches.usedCode.Sweep()
|
||||
service.caches.authorize.Sweep()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
@@ -402,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
||||
ClientID: req.ClientID,
|
||||
Nonce: req.Nonce,
|
||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||
AuthTime: userContext.AuthTime,
|
||||
}
|
||||
|
||||
if req.CodeChallenge != "" {
|
||||
@@ -491,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
|
||||
createdAt := time.Now().Unix()
|
||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||
|
||||
@@ -536,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
if authTime != nil {
|
||||
claims.AuthTime = *authTime
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(claims)
|
||||
|
||||
if err != nil {
|
||||
@@ -557,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -637,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
|
||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||
ClientID: entry.ClientID,
|
||||
}, userInfo, entry.Scope, entry.Nonce)
|
||||
}, userInfo, entry.Scope, entry.Nonce, nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -856,3 +896,76 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
|
||||
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
|
||||
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
||||
}
|
||||
|
||||
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
|
||||
ticket := utils.GenerateString(32)
|
||||
|
||||
service.caches.authorize.Set(ticket, req, 10*time.Minute)
|
||||
|
||||
return ticket
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
|
||||
entry, ok := service.caches.authorize.Get(ticket)
|
||||
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
|
||||
service.caches.authorize.Delete(ticket)
|
||||
}
|
||||
|
||||
// TODO: support signed request objects in the future
|
||||
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
|
||||
var claims jwt.MapClaims
|
||||
|
||||
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
|
||||
}
|
||||
|
||||
alg, ok := token.Header["alg"].(string)
|
||||
|
||||
if !ok || alg != "none" || string(token.Signature) != "" {
|
||||
return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
|
||||
}
|
||||
|
||||
get := func(k string) string {
|
||||
v, _ := claims[k].(string)
|
||||
return v
|
||||
}
|
||||
|
||||
return &AuthorizeRequest{
|
||||
Scope: get("scope"),
|
||||
ResponseType: get("response_type"),
|
||||
ClientID: get("client_id"),
|
||||
RedirectURI: get("redirect_uri"),
|
||||
State: get("state"),
|
||||
Nonce: get("nonce"),
|
||||
CodeChallenge: get("code_challenge"),
|
||||
CodeChallengeMethod: get("code_challenge_method"),
|
||||
Prompt: get("prompt"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||
if prompt == "" {
|
||||
return []OIDCPrompt{}
|
||||
}
|
||||
|
||||
parsedPromps := make([]OIDCPrompt, 0)
|
||||
prompts := strings.SplitSeq(prompt, " ")
|
||||
|
||||
for p := range prompts {
|
||||
if !slices.Contains(SupportedPrompts, p) {
|
||||
continue
|
||||
}
|
||||
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||
}
|
||||
|
||||
return parsedPromps
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service_test
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func newTestUser() service.UserinfoResponse {
|
||||
return service.UserinfoResponse{
|
||||
func newTestUser() UserinfoResponse {
|
||||
return UserinfoResponse{
|
||||
Sub: "test-sub",
|
||||
Name: "Test User",
|
||||
PreferredUsername: "testuser",
|
||||
@@ -67,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
||||
store := memory.New()
|
||||
|
||||
svc, err := NewOIDCService(OIDCServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Queries: store,
|
||||
Ding: dg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
mutate func(u *service.UserinfoResponse)
|
||||
mutate func(u *UserinfoResponse)
|
||||
scope string
|
||||
run func(t *testing.T, info service.UserinfoResponse)
|
||||
run func(t *testing.T, info UserinfoResponse)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "openid scope only returns sub and updated_at",
|
||||
scope: "openid",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "test-sub", info.Sub)
|
||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||
assert.Empty(t, info.Name)
|
||||
@@ -94,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "profile scope returns all profile fields",
|
||||
scope: "openid profile",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "Test User", info.Name)
|
||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||
assert.Equal(t, "Test", info.GivenName)
|
||||
@@ -114,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "email scope sets email and email_verified true when email present",
|
||||
scope: "openid email",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "test@example.com", info.Email)
|
||||
assert.True(t, info.EmailVerified)
|
||||
assert.Empty(t, info.Name)
|
||||
@@ -123,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "email scope sets email_verified false when email absent",
|
||||
scope: "openid email",
|
||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Empty(t, info.Email)
|
||||
assert.False(t, info.EmailVerified)
|
||||
},
|
||||
@@ -132,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "phone scope sets phone_number_verified true when phone present",
|
||||
scope: "openid phone",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||
require.NotNil(t, info.PhoneNumberVerified)
|
||||
assert.True(t, *info.PhoneNumberVerified)
|
||||
@@ -141,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "phone scope sets phone_number_verified false when phone absent",
|
||||
scope: "openid phone",
|
||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
require.NotNil(t, info.PhoneNumberVerified)
|
||||
assert.False(t, *info.PhoneNumberVerified)
|
||||
},
|
||||
@@ -150,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "address scope returns parsed address",
|
||||
scope: "openid address",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
require.NotNil(t, info.Address)
|
||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||
@@ -163,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "groups scope returns split groups",
|
||||
scope: "openid groups",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "all scopes return all fields",
|
||||
scope: "openid profile email phone address groups",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "Test User", info.Name)
|
||||
assert.Equal(t, "test@example.com", info.Email)
|
||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type Policy string
|
||||
@@ -40,21 +41,28 @@ type PolicyEngine struct {
|
||||
policy Policy
|
||||
}
|
||||
|
||||
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
|
||||
type PolicyEngineInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
}
|
||||
|
||||
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
|
||||
engine := PolicyEngine{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
rules: make(map[RuleName]Rule),
|
||||
}
|
||||
|
||||
switch config.Auth.ACLs.Policy {
|
||||
switch i.Config.Auth.ACLs.Policy {
|
||||
case string(PolicyAllow):
|
||||
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
engine.policy = PolicyAllow
|
||||
case string(PolicyDeny):
|
||||
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
engine.policy = PolicyDeny
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
|
||||
}
|
||||
|
||||
return &engine, nil
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package service_test
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -12,14 +11,14 @@ import (
|
||||
// Create test rule
|
||||
type TestRule struct{}
|
||||
|
||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
||||
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||
switch ctx.Path {
|
||||
case "/allowed":
|
||||
return service.EffectAllow
|
||||
return EffectAllow
|
||||
case "/denied":
|
||||
return service.EffectDeny
|
||||
return EffectDeny
|
||||
default:
|
||||
return service.EffectAbstain
|
||||
return EffectAbstain
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
|
||||
|
||||
// Engine should fail with invalid policy
|
||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||
_, err := service.NewPolicyEngine(cfg, log)
|
||||
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Engine should initialize with 'allow' policy
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err := service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
||||
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||
|
||||
// Engine should initialize with 'deny' policy
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
||||
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||
|
||||
// Engine should allow adding rules
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
_, ok := engine.Rules()["test-rule"]
|
||||
assert.True(t, ok)
|
||||
|
||||
// Begin allow policy tests
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
// With allow policy, if rule allows, access should be allowed
|
||||
ctx := &service.ACLContext{Path: "/allowed"}
|
||||
ctx := &ACLContext{Path: "/allowed"}
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With allow policy, if rule denies, access should be denied
|
||||
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// Begin deny policy tests
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
"tailscale.com/client/local"
|
||||
"tailscale.com/tsnet"
|
||||
)
|
||||
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
|
||||
|
||||
type TailscaleService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
config *model.Config
|
||||
ctx context.Context
|
||||
|
||||
srv *tsnet.Server
|
||||
@@ -34,22 +35,31 @@ type TailscaleService struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
|
||||
if !config.Tailscale.Enabled {
|
||||
type TailscaleServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||
if !i.Config.Tailscale.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
srv := new(tsnet.Server)
|
||||
|
||||
// node options
|
||||
srv.Dir = config.Tailscale.Dir
|
||||
srv.Hostname = config.Tailscale.Hostname
|
||||
srv.AuthKey = config.Tailscale.AuthKey
|
||||
srv.Ephemeral = config.Tailscale.Ephemeral
|
||||
srv.Dir = i.Config.Tailscale.Dir
|
||||
srv.Hostname = i.Config.Tailscale.Hostname
|
||||
srv.AuthKey = i.Config.Tailscale.AuthKey
|
||||
srv.Ephemeral = i.Config.Tailscale.Ephemeral
|
||||
|
||||
// redirect logs to zerolog
|
||||
srv.Logf = log.App.Printf
|
||||
srv.UserLogf = log.App.Printf
|
||||
srv.Logf = i.Log.App.Printf
|
||||
srv.UserLogf = i.Log.App.Printf
|
||||
|
||||
err := srv.Start()
|
||||
|
||||
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
||||
}
|
||||
|
||||
service := &TailscaleService{
|
||||
log: log,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
ctx: i.Ctx,
|
||||
srv: srv,
|
||||
lc: lc,
|
||||
}
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||
defer cancel()
|
||||
|
||||
err = service.waitForConn(connectCtx)
|
||||
@@ -82,7 +92,11 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||
}
|
||||
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
|
||||
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -128,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||
}
|
||||
|
||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
@@ -140,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
|
||||
if ts.ln != nil {
|
||||
return *ts.ln, nil
|
||||
}
|
||||
|
||||
if ts.config.Tailscale.Funnel {
|
||||
ln, err := ts.srv.ListenFunnel("tcp", ":443")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ts.ln = &ln
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
ln, err := ts.srv.ListenTLS("tcp", ":443")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user