mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-22 03:10:16 +00:00
Merge branch 'main' into feat/oidc-preserve-consent
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type LabelProvider interface {
|
||||
@@ -13,19 +14,24 @@ type LabelProvider interface {
|
||||
|
||||
type AccessControlsService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
labelProvider *LabelProvider
|
||||
config *model.Config
|
||||
labelProvider LabelProvider
|
||||
}
|
||||
|
||||
func NewAccessControlsService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
labelProvider *LabelProvider) *AccessControlsService {
|
||||
type AccessControlServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
LabelProvider LabelProvider `optional:"true"`
|
||||
}
|
||||
|
||||
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
|
||||
|
||||
return &AccessControlsService{
|
||||
log: log,
|
||||
config: config,
|
||||
labelProvider: labelProvider,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
labelProvider: i.LabelProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
|
||||
}
|
||||
|
||||
// If we have a label provider configured, try to get ACLs from it
|
||||
if service.labelProvider != nil && *service.labelProvider != nil {
|
||||
return (*service.labelProvider).GetLabels(domain)
|
||||
if service.labelProvider != nil {
|
||||
return service.labelProvider.GetLabels(domain)
|
||||
}
|
||||
|
||||
// no labels
|
||||
|
||||
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{Apps: tt.apps},
|
||||
LabelProvider: nil,
|
||||
})
|
||||
got := svc.lookupStaticACLs(tt.domain)
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, got)
|
||||
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(log, config, nil)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &config,
|
||||
LabelProvider: nil,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
|
||||
svc := NewAccessControlsService(log, model.Config{}, nil)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: nil,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
|
||||
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
|
||||
var provider LabelProvider
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider, // nil provider
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("unknown.example.com")
|
||||
|
||||
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
|
||||
},
|
||||
}
|
||||
svc := NewAccessControlsService(log, config, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &config,
|
||||
LabelProvider: provider,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("foo.example.com")
|
||||
|
||||
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
|
||||
},
|
||||
}
|
||||
var provider LabelProvider = mock
|
||||
svc := NewAccessControlsService(log, model.Config{}, &provider)
|
||||
svc := NewAccessControlsService(AccessControlServiceInput{
|
||||
Log: log,
|
||||
Config: &model.Config{},
|
||||
LabelProvider: provider,
|
||||
})
|
||||
|
||||
got, err := svc.GetAccessControls("dynamic.example.com")
|
||||
|
||||
|
||||
@@ -2,8 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -14,6 +16,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -24,7 +27,6 @@ import (
|
||||
// but for now these are just safety limits to prevent unbounded memory usage
|
||||
const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
const MaxLoginAttemptRecords = 256
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
@@ -57,9 +59,8 @@ type LoginAttempt struct {
|
||||
|
||||
type AuthService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
helpers *model.RuntimeHelpers
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
ctx context.Context
|
||||
|
||||
ldap *LdapService
|
||||
@@ -81,44 +82,57 @@ type AuthService struct {
|
||||
oauth *CacheStore[OAuthPendingSession]
|
||||
ldap *CacheStore[[]string]
|
||||
}
|
||||
|
||||
maxLoginLimits int
|
||||
}
|
||||
|
||||
func NewAuthService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
helpers *model.RuntimeHelpers,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
ldap *LdapService,
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
tailscale *TailscaleService,
|
||||
policy *PolicyEngine,
|
||||
) *AuthService {
|
||||
type AuthServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
LDAP *LdapService `optional:"true"`
|
||||
Queries repository.Store
|
||||
OAuthBroker *OAuthBrokerService
|
||||
Tailscale *TailscaleService `optional:"true"`
|
||||
PolicyEngine *PolicyEngine
|
||||
}
|
||||
|
||||
func NewAuthService(i AuthServiceInput) *AuthService {
|
||||
service := &AuthService{
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
helpers: helpers,
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
ldap: ldap,
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
tailscale: tailscale,
|
||||
policyEngine: policy,
|
||||
log: i.Log,
|
||||
runtime: i.Runtime,
|
||||
ctx: i.Ctx,
|
||||
config: i.Config,
|
||||
ldap: i.LDAP,
|
||||
queries: i.Queries,
|
||||
oauthBroker: i.OAuthBroker,
|
||||
tailscale: i.Tailscale,
|
||||
policyEngine: i.PolicyEngine,
|
||||
}
|
||||
|
||||
// get the max login limits based on the number of users and the configured max retries
|
||||
service.maxLoginLimits = service.calculateLockdownLimit()
|
||||
|
||||
loginCacheSize := 0
|
||||
|
||||
if !service.config.Auth.LockdownEnabled {
|
||||
loginCacheSize = service.maxLoginLimits
|
||||
}
|
||||
|
||||
// caches setup
|
||||
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||
loginCache := NewCacheStore[LoginAttempt](1024)
|
||||
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
|
||||
ldapCache := NewCacheStore[[]string](1024)
|
||||
|
||||
service.caches.oauth = oauthCache
|
||||
service.caches.login = loginCache
|
||||
service.caches.ldap = ldapCache
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -257,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
||||
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
|
||||
if locked, _ := auth.IsInLockdown(); locked {
|
||||
return
|
||||
}
|
||||
@@ -632,16 +646,17 @@ func (auth *AuthService) lockdownMode() {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(auth.ctx)
|
||||
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown.active = true
|
||||
auth.lockdown.ctx = ctx
|
||||
auth.lockdown.cancelFunc = cancel
|
||||
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||
|
||||
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
||||
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
|
||||
auth.lockdown.until = time.Now().Add(d)
|
||||
timer := time.NewTimer(d)
|
||||
|
||||
auth.lockdown.mu.Unlock()
|
||||
|
||||
@@ -653,14 +668,13 @@ func (auth *AuthService) lockdownMode() {
|
||||
// Timer expired, end lockdown
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, end lockdown
|
||||
case <-auth.ctx.Done():
|
||||
// Service is shutting down, end lockdown
|
||||
}
|
||||
|
||||
auth.lockdown.mu.Lock()
|
||||
|
||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||
|
||||
auth.caches.login.Clear()
|
||||
auth.lockdown.active = false
|
||||
auth.lockdown.until = time.Time{}
|
||||
auth.lockdown.ctx = nil
|
||||
@@ -683,3 +697,32 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
|
||||
func (auth *AuthService) ClearLoginAttempts() {
|
||||
auth.caches.login.Clear()
|
||||
}
|
||||
|
||||
func (auth *AuthService) calculateLockdownLimit() int {
|
||||
userCount := len(auth.runtime.LocalUsers)
|
||||
|
||||
if auth.ldap != nil {
|
||||
ldapUsers, err := auth.ldap.GetUserCount()
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
|
||||
} else {
|
||||
userCount += ldapUsers
|
||||
}
|
||||
}
|
||||
|
||||
limit := userCount * auth.config.Auth.LoginMaxRetries
|
||||
|
||||
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
|
||||
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
|
||||
} else {
|
||||
limit += int(jitter.Int64())
|
||||
}
|
||||
|
||||
if limit < 256 {
|
||||
limit = 256
|
||||
}
|
||||
|
||||
return limit
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &model.Config{
|
||||
Auth: model.AuthConfig{
|
||||
ACLs: model.ACLsConfig{
|
||||
Policy: string(PolicyAllow),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
auth := &AuthService{
|
||||
log: log,
|
||||
runtime: model.RuntimeConfig{
|
||||
runtime: &model.RuntimeConfig{
|
||||
OAuthWhitelist: []string{"global@example.com"},
|
||||
OAuthProviders: map[string]model.OAuthServiceConfig{
|
||||
"github": {
|
||||
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
policyEngine: policyEngine,
|
||||
}
|
||||
|
||||
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
container "github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/client"
|
||||
@@ -21,36 +22,40 @@ type DockerService struct {
|
||||
isConnected bool
|
||||
}
|
||||
|
||||
func NewDockerService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*DockerService, error) {
|
||||
type DockerServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
|
||||
|
||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.NegotiateAPIVersion(ctx)
|
||||
client.NegotiateAPIVersion(i.Ctx)
|
||||
|
||||
_, err = client.Ping(ctx)
|
||||
_, err = client.Ping(i.Ctx)
|
||||
|
||||
if err != nil {
|
||||
log.App.Debug().Err(err).Msg("Docker not connected")
|
||||
i.Log.App.Debug().Err(err).Msg("Docker not connected")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
service := &DockerService{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
client: client,
|
||||
context: ctx,
|
||||
context: i.Ctx,
|
||||
}
|
||||
|
||||
service.isConnected = true
|
||||
service.log.App.Debug().Msg("Docker connected successfully")
|
||||
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||
@@ -48,11 +49,15 @@ type KubernetesService struct {
|
||||
appNameIndex map[string]ingressAppKey
|
||||
}
|
||||
|
||||
func NewKubernetesService(
|
||||
log *logger.Logger,
|
||||
ctx context.Context,
|
||||
dg *ding.Ding,
|
||||
) (*KubernetesService, error) {
|
||||
type KubernetesServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
|
||||
cfg, err := rest.InClusterConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
|
||||
@@ -69,31 +74,31 @@ func NewKubernetesService(
|
||||
Resource: "ingresses",
|
||||
}
|
||||
|
||||
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
|
||||
defer accessCancel()
|
||||
|
||||
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
|
||||
if err != nil {
|
||||
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
|
||||
return nil, fmt.Errorf("failed to access ingress api: %w", err)
|
||||
}
|
||||
|
||||
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
|
||||
|
||||
service := &KubernetesService{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
client: client,
|
||||
ingressApps: make(map[ingressKey][]ingressApp),
|
||||
domainIndex: make(map[string]ingressAppKey),
|
||||
appNameIndex: make(map[string]ingressAppKey),
|
||||
}
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
service.watchGVR(gvr, ctx)
|
||||
}, ding.RingMajor)
|
||||
|
||||
service.started = true
|
||||
log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
@@ -13,44 +13,48 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type LdapService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
config *model.Config
|
||||
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
conn *ldapgo.Conn
|
||||
mutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
bindPw string
|
||||
}
|
||||
|
||||
func NewLdapService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
dg *ding.Ding,
|
||||
) (*LdapService, error) {
|
||||
if config.LDAP.Address == "" {
|
||||
type LdapServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
|
||||
if i.Config.LDAP.Address == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
secret := utils.GetSecret(config.LDAP.BindPassword, config.LDAP.BindPasswordFile)
|
||||
config.LDAP.BindPassword = secret
|
||||
config.LDAP.BindPasswordFile = ""
|
||||
|
||||
ldap := &LdapService{
|
||||
log: log,
|
||||
config: config,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
}
|
||||
|
||||
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
|
||||
|
||||
// Check whether authentication with client certificate is possible
|
||||
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
|
||||
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
|
||||
}
|
||||
|
||||
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
|
||||
|
||||
ldap.cert = &cert
|
||||
|
||||
@@ -72,7 +76,7 @@ func NewLdapService(
|
||||
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
|
||||
}
|
||||
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
|
||||
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -165,6 +169,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
|
||||
return entry.DN, entry.GetAttributeValue("mail"), nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserCount() (int, error) {
|
||||
searchRequest := ldapgo.NewSearchRequest(
|
||||
ldap.config.LDAP.BaseDN,
|
||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||
"(objectClass=person)",
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
|
||||
ldap.mutex.Lock()
|
||||
defer ldap.mutex.Unlock()
|
||||
|
||||
searchResult, err := ldap.conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(searchResult.Entries), nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||
escapedUserDN := ldapgo.EscapeFilter(userDN)
|
||||
|
||||
@@ -217,7 +241,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
|
||||
if ldap.cert != nil {
|
||||
return ldap.conn.ExternalBind()
|
||||
}
|
||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
|
||||
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
|
||||
}
|
||||
|
||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
|
||||
"slices"
|
||||
|
||||
@@ -32,23 +33,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
|
||||
"google": newGoogleOAuthService,
|
||||
}
|
||||
|
||||
func NewOAuthBrokerService(
|
||||
log *logger.Logger,
|
||||
configs map[string]model.OAuthServiceConfig,
|
||||
ctx context.Context,
|
||||
) *OAuthBrokerService {
|
||||
type OAuthBrokerServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Runtime *model.RuntimeConfig
|
||||
Ctx context.Context
|
||||
}
|
||||
|
||||
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
|
||||
service := &OAuthBrokerService{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
services: make(map[string]OAuthServiceImpl),
|
||||
configs: configs,
|
||||
configs: i.Runtime.OAuthProviders,
|
||||
}
|
||||
|
||||
for name, cfg := range configs {
|
||||
for name, cfg := range service.configs {
|
||||
if presetFunc, exists := presets[name]; exists {
|
||||
service.services[name] = presetFunc(cfg, ctx)
|
||||
service.services[name] = presetFunc(cfg, i.Ctx)
|
||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
|
||||
} else {
|
||||
service.services[name] = NewOAuthService(cfg, name, ctx)
|
||||
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
|
||||
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -43,6 +45,15 @@ var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
)
|
||||
|
||||
type OIDCPrompt string
|
||||
|
||||
const (
|
||||
OIDCPromptLogin OIDCPrompt = "login"
|
||||
OIDCPromptNone OIDCPrompt = "none"
|
||||
)
|
||||
|
||||
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
|
||||
|
||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||
@@ -53,6 +64,7 @@ type ClaimSet struct {
|
||||
Sub string `json:"sub"`
|
||||
Iat int64 `json:"iat"`
|
||||
Exp int64 `json:"exp"`
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
@@ -108,7 +120,6 @@ type TokenResponse struct {
|
||||
}
|
||||
|
||||
type AuthorizeRequest struct {
|
||||
jwt.Claims
|
||||
Scope string `form:"scope" json:"scope" url:"scope"`
|
||||
ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
|
||||
ClientID string `form:"client_id" json:"client_id" url:"client_id"`
|
||||
@@ -117,6 +128,8 @@ type AuthorizeRequest struct {
|
||||
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
|
||||
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
|
||||
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
|
||||
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
|
||||
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
|
||||
}
|
||||
|
||||
type AuthorizeCodeEntry struct {
|
||||
@@ -127,6 +140,7 @@ type AuthorizeCodeEntry struct {
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
Userinfo UserinfoResponse
|
||||
AuthTime int64
|
||||
}
|
||||
|
||||
type UsedCodeEntry struct {
|
||||
@@ -135,8 +149,8 @@ type UsedCodeEntry struct {
|
||||
|
||||
type OIDCService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
config *model.Config
|
||||
runtime *model.RuntimeConfig
|
||||
queries repository.Store
|
||||
|
||||
clients map[string]model.OIDCClientConfig
|
||||
@@ -151,19 +165,24 @@ type OIDCService struct {
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
queries repository.Store,
|
||||
dg *ding.Ding) (*OIDCService, error) {
|
||||
type OIDCServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Runtime *model.RuntimeConfig
|
||||
Queries repository.Store
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
|
||||
// If not configured, skip init
|
||||
if len(runtime.OIDCClients) == 0 {
|
||||
if len(i.Config.OIDC.Clients) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Ensure issuer is https
|
||||
uissuer, err := url.Parse(runtime.AppURL)
|
||||
uissuer, err := url.Parse(i.Runtime.AppURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse app url: %w", err)
|
||||
@@ -176,14 +195,14 @@ func NewOIDCService(
|
||||
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||
|
||||
// Create/load private and public keys
|
||||
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
|
||||
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
|
||||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
|
||||
return nil, errors.New("private key path and public key path are required")
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
|
||||
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
|
||||
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
@@ -202,8 +221,12 @@ func NewOIDCService(
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write private key to file: %w", err)
|
||||
}
|
||||
@@ -212,7 +235,7 @@ func NewOIDCService(
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode private key")
|
||||
}
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
|
||||
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
@@ -221,7 +244,7 @@ func NewOIDCService(
|
||||
|
||||
var publicKey crypto.PublicKey
|
||||
|
||||
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
|
||||
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
|
||||
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||
@@ -237,8 +260,12 @@ func NewOIDCService(
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: der,
|
||||
})
|
||||
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
|
||||
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
|
||||
}
|
||||
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -247,7 +274,7 @@ func NewOIDCService(
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode public key")
|
||||
}
|
||||
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
|
||||
switch block.Type {
|
||||
case "RSA PUBLIC KEY":
|
||||
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
|
||||
@@ -277,7 +304,7 @@ func NewOIDCService(
|
||||
// We will reorganize the client into a map with the client ID as the key
|
||||
clients := make(map[string]model.OIDCClientConfig)
|
||||
|
||||
for id, client := range config.OIDC.Clients {
|
||||
for id, client := range i.Config.OIDC.Clients {
|
||||
client.ID = id
|
||||
if client.Name == "" {
|
||||
client.Name = utils.Capitalize(client.ID)
|
||||
@@ -293,15 +320,15 @@ func NewOIDCService(
|
||||
}
|
||||
client.ClientSecretFile = ""
|
||||
clients[id] = client
|
||||
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
|
||||
}
|
||||
|
||||
// Initialize the service
|
||||
service := &OIDCService{
|
||||
log: log,
|
||||
config: config,
|
||||
runtime: runtime,
|
||||
queries: queries,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
runtime: i.Runtime,
|
||||
queries: i.Queries,
|
||||
|
||||
clients: clients,
|
||||
privateKey: privateKey,
|
||||
@@ -310,7 +337,7 @@ func NewOIDCService(
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
|
||||
// Create caches
|
||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||
@@ -322,7 +349,7 @@ func NewOIDCService(
|
||||
service.caches.authorize = authorize
|
||||
|
||||
// Start cache cleanup routine
|
||||
dg.Go(func(ctx context.Context) {
|
||||
i.Ding.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -410,6 +437,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
|
||||
ClientID: req.ClientID,
|
||||
Nonce: req.Nonce,
|
||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||
AuthTime: userContext.AuthTime,
|
||||
}
|
||||
|
||||
if req.CodeChallenge != "" {
|
||||
@@ -499,7 +527,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
|
||||
createdAt := time.Now().Unix()
|
||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||
|
||||
@@ -544,6 +572,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
if authTime != nil {
|
||||
claims.AuthTime = *authTime
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(claims)
|
||||
|
||||
if err != nil {
|
||||
@@ -565,8 +597,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -645,9 +677,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
|
||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||
ClientID: entry.ClientID,
|
||||
}, userInfo, entry.Scope, entry.Nonce)
|
||||
}, userInfo, entry.Scope, entry.Nonce, nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -889,21 +922,53 @@ func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
|
||||
|
||||
// TODO: support signed request objects in the future
|
||||
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
|
||||
var req AuthorizeRequest
|
||||
|
||||
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &req)
|
||||
var claims jwt.MapClaims
|
||||
|
||||
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*AuthorizeRequest)
|
||||
alg, ok := token.Header["alg"].(string)
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("failed to parse claims from authorize request jwt")
|
||||
if !ok || alg != "none" || string(token.Signature) != "" {
|
||||
return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
get := func(k string) string {
|
||||
v, _ := claims[k].(string)
|
||||
return v
|
||||
}
|
||||
|
||||
return &AuthorizeRequest{
|
||||
Scope: get("scope"),
|
||||
ResponseType: get("response_type"),
|
||||
ClientID: get("client_id"),
|
||||
RedirectURI: get("redirect_uri"),
|
||||
State: get("state"),
|
||||
Nonce: get("nonce"),
|
||||
CodeChallenge: get("code_challenge"),
|
||||
CodeChallengeMethod: get("code_challenge_method"),
|
||||
Prompt: get("prompt"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
|
||||
if prompt == "" {
|
||||
return []OIDCPrompt{}
|
||||
}
|
||||
|
||||
parsedPromps := make([]OIDCPrompt, 0)
|
||||
prompts := strings.SplitSeq(prompt, " ")
|
||||
|
||||
for p := range prompts {
|
||||
if !slices.Contains(SupportedPrompts, p) {
|
||||
continue
|
||||
}
|
||||
parsedPromps = append(parsedPromps, OIDCPrompt(p))
|
||||
}
|
||||
|
||||
return parsedPromps
|
||||
}
|
||||
|
||||
func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service_test
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
func newTestUser() service.UserinfoResponse {
|
||||
return service.UserinfoResponse{
|
||||
func newTestUser() UserinfoResponse {
|
||||
return UserinfoResponse{
|
||||
Sub: "test-sub",
|
||||
Name: "Test User",
|
||||
PreferredUsername: "testuser",
|
||||
@@ -67,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
dg := ding.New(ctx)
|
||||
|
||||
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
|
||||
store := memory.New()
|
||||
|
||||
svc, err := NewOIDCService(OIDCServiceInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
Runtime: &runtime,
|
||||
Queries: store,
|
||||
Ding: dg,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
mutate func(u *service.UserinfoResponse)
|
||||
mutate func(u *UserinfoResponse)
|
||||
scope string
|
||||
run func(t *testing.T, info service.UserinfoResponse)
|
||||
run func(t *testing.T, info UserinfoResponse)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "openid scope only returns sub and updated_at",
|
||||
scope: "openid",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "test-sub", info.Sub)
|
||||
assert.Equal(t, int64(1234567890), info.UpdatedAt)
|
||||
assert.Empty(t, info.Name)
|
||||
@@ -94,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "profile scope returns all profile fields",
|
||||
scope: "openid profile",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "Test User", info.Name)
|
||||
assert.Equal(t, "testuser", info.PreferredUsername)
|
||||
assert.Equal(t, "Test", info.GivenName)
|
||||
@@ -114,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "email scope sets email and email_verified true when email present",
|
||||
scope: "openid email",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "test@example.com", info.Email)
|
||||
assert.True(t, info.EmailVerified)
|
||||
assert.Empty(t, info.Name)
|
||||
@@ -123,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "email scope sets email_verified false when email absent",
|
||||
scope: "openid email",
|
||||
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
mutate: func(u *UserinfoResponse) { u.Email = "" },
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Empty(t, info.Email)
|
||||
assert.False(t, info.EmailVerified)
|
||||
},
|
||||
@@ -132,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "phone scope sets phone_number_verified true when phone present",
|
||||
scope: "openid phone",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||
require.NotNil(t, info.PhoneNumberVerified)
|
||||
assert.True(t, *info.PhoneNumberVerified)
|
||||
@@ -141,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "phone scope sets phone_number_verified false when phone absent",
|
||||
scope: "openid phone",
|
||||
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
require.NotNil(t, info.PhoneNumberVerified)
|
||||
assert.False(t, *info.PhoneNumberVerified)
|
||||
},
|
||||
@@ -150,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "address scope returns parsed address",
|
||||
scope: "openid address",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
require.NotNil(t, info.Address)
|
||||
assert.Equal(t, "123 Main St", info.Address.Formatted)
|
||||
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
|
||||
@@ -163,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
|
||||
{
|
||||
description: "groups scope returns split groups",
|
||||
scope: "openid groups",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, []string{"admins", "users"}, info.Groups)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "all scopes return all fields",
|
||||
scope: "openid profile email phone address groups",
|
||||
run: func(t *testing.T, info service.UserinfoResponse) {
|
||||
run: func(t *testing.T, info UserinfoResponse) {
|
||||
assert.Equal(t, "Test User", info.Name)
|
||||
assert.Equal(t, "test@example.com", info.Email)
|
||||
assert.Equal(t, "+15555550100", info.PhoneNumber)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
)
|
||||
|
||||
type Policy string
|
||||
@@ -40,21 +41,28 @@ type PolicyEngine struct {
|
||||
policy Policy
|
||||
}
|
||||
|
||||
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
|
||||
type PolicyEngineInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
}
|
||||
|
||||
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
|
||||
engine := PolicyEngine{
|
||||
log: log,
|
||||
log: i.Log,
|
||||
rules: make(map[RuleName]Rule),
|
||||
}
|
||||
|
||||
switch config.Auth.ACLs.Policy {
|
||||
switch i.Config.Auth.ACLs.Policy {
|
||||
case string(PolicyAllow):
|
||||
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
|
||||
engine.policy = PolicyAllow
|
||||
case string(PolicyDeny):
|
||||
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
|
||||
engine.policy = PolicyDeny
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
|
||||
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
|
||||
}
|
||||
|
||||
return &engine, nil
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package service_test
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -12,14 +11,14 @@ import (
|
||||
// Create test rule
|
||||
type TestRule struct{}
|
||||
|
||||
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
|
||||
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
|
||||
switch ctx.Path {
|
||||
case "/allowed":
|
||||
return service.EffectAllow
|
||||
return EffectAllow
|
||||
case "/denied":
|
||||
return service.EffectDeny
|
||||
return EffectDeny
|
||||
default:
|
||||
return service.EffectAbstain
|
||||
return EffectAbstain
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
|
||||
|
||||
// Engine should fail with invalid policy
|
||||
cfg.Auth.ACLs.Policy = "invalid_policy"
|
||||
_, err := service.NewPolicyEngine(cfg, log)
|
||||
_, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Engine should initialize with 'allow' policy
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err := service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||
engine, err := NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, service.PolicyAllow, engine.Policy())
|
||||
assert.Equal(t, PolicyAllow, engine.Policy())
|
||||
|
||||
// Engine should initialize with 'deny' policy
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, service.PolicyDeny, engine.Policy())
|
||||
assert.Equal(t, PolicyDeny, engine.Policy())
|
||||
|
||||
// Engine should allow adding rules
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
_, ok := engine.Rules()["test-rule"]
|
||||
assert.True(t, ok)
|
||||
|
||||
// Begin allow policy tests
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyAllow)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
// With allow policy, if rule allows, access should be allowed
|
||||
ctx := &service.ACLContext{Path: "/allowed"}
|
||||
ctx := &ACLContext{Path: "/allowed"}
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// With allow policy, if rule denies, access should be denied
|
||||
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
|
||||
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
|
||||
|
||||
// Begin deny policy tests
|
||||
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
|
||||
engine, err = service.NewPolicyEngine(cfg, log)
|
||||
cfg.Auth.ACLs.Policy = string(PolicyDeny)
|
||||
engine, err = NewPolicyEngine(PolicyEngineInput{
|
||||
Log: log,
|
||||
Config: &cfg,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
engine.RegisterRule("test-rule", testRule)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
"go.uber.org/dig"
|
||||
"tailscale.com/client/local"
|
||||
"tailscale.com/tsnet"
|
||||
)
|
||||
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
|
||||
|
||||
type TailscaleService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
config *model.Config
|
||||
ctx context.Context
|
||||
|
||||
srv *tsnet.Server
|
||||
@@ -34,22 +35,31 @@ type TailscaleService struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
|
||||
if !config.Tailscale.Enabled {
|
||||
type TailscaleServiceInput struct {
|
||||
dig.In
|
||||
|
||||
Log *logger.Logger
|
||||
Config *model.Config
|
||||
Ctx context.Context
|
||||
Ding *ding.Ding
|
||||
}
|
||||
|
||||
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
|
||||
if !i.Config.Tailscale.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
srv := new(tsnet.Server)
|
||||
|
||||
// node options
|
||||
srv.Dir = config.Tailscale.Dir
|
||||
srv.Hostname = config.Tailscale.Hostname
|
||||
srv.AuthKey = config.Tailscale.AuthKey
|
||||
srv.Ephemeral = config.Tailscale.Ephemeral
|
||||
srv.Dir = i.Config.Tailscale.Dir
|
||||
srv.Hostname = i.Config.Tailscale.Hostname
|
||||
srv.AuthKey = i.Config.Tailscale.AuthKey
|
||||
srv.Ephemeral = i.Config.Tailscale.Ephemeral
|
||||
|
||||
// redirect logs to zerolog
|
||||
srv.Logf = log.App.Printf
|
||||
srv.UserLogf = log.App.Printf
|
||||
srv.Logf = i.Log.App.Printf
|
||||
srv.UserLogf = i.Log.App.Printf
|
||||
|
||||
err := srv.Start()
|
||||
|
||||
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
||||
}
|
||||
|
||||
service := &TailscaleService{
|
||||
log: log,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
log: i.Log,
|
||||
config: i.Config,
|
||||
ctx: i.Ctx,
|
||||
srv: srv,
|
||||
lc: lc,
|
||||
}
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
|
||||
defer cancel()
|
||||
|
||||
err = service.waitForConn(connectCtx)
|
||||
@@ -82,7 +92,7 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
|
||||
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
|
||||
}
|
||||
|
||||
dg.Go(service.watchAndClose, ding.RingMajor)
|
||||
i.Ding.Go(service.watchAndClose, ding.RingMajor)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -128,8 +138,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
|
||||
NodeName: strings.TrimSuffix(who.Node.Name, "."),
|
||||
}
|
||||
|
||||
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user