Merge branch 'main' into feat/ldap-reconnect

This commit is contained in:
Stavros
2026-06-28 18:08:12 +03:00
81 changed files with 4306 additions and 2279 deletions
+17 -11
View File
@@ -5,6 +5,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
type LabelProvider interface {
@@ -13,19 +14,24 @@ type LabelProvider interface {
type AccessControlsService struct {
log *logger.Logger
config model.Config
labelProvider *LabelProvider
config *model.Config
labelProvider LabelProvider
}
func NewAccessControlsService(
log *logger.Logger,
config model.Config,
labelProvider *LabelProvider) *AccessControlsService {
type AccessControlServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
LabelProvider LabelProvider `optional:"true"`
}
func NewAccessControlsService(i AccessControlServiceInput) *AccessControlsService {
return &AccessControlsService{
log: log,
config: config,
labelProvider: labelProvider,
log: i.Log,
config: i.Config,
labelProvider: i.LabelProvider,
}
}
@@ -57,8 +63,8 @@ func (service *AccessControlsService) GetAccessControls(domain string) (*model.A
}
// If we have a label provider configured, try to get ACLs from it
if service.labelProvider != nil && *service.labelProvider != nil {
return (*service.labelProvider).GetLabels(domain)
if service.labelProvider != nil {
return service.labelProvider.GetLabels(domain)
}
// no labels
@@ -87,7 +87,11 @@ func TestLookupStaticACLs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{Apps: tt.apps},
LabelProvider: nil,
})
got := svc.lookupStaticACLs(tt.domain)
if tt.expectNil {
assert.Nil(t, got)
@@ -112,7 +116,11 @@ func TestGetAccessControls(t *testing.T) {
},
},
}
svc := NewAccessControlsService(log, config, nil)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &config,
LabelProvider: nil,
})
got, err := svc.GetAccessControls("foo.example.com")
@@ -123,7 +131,11 @@ func TestGetAccessControls(t *testing.T) {
})
t.Run("returns nil when no static match and no label provider", func(t *testing.T) {
svc := NewAccessControlsService(log, model.Config{}, nil)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: nil,
})
got, err := svc.GetAccessControls("unknown.example.com")
@@ -133,7 +145,11 @@ func TestGetAccessControls(t *testing.T) {
t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) {
var provider LabelProvider
svc := NewAccessControlsService(log, model.Config{}, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider, // nil provider
})
got, err := svc.GetAccessControls("unknown.example.com")
@@ -152,7 +168,11 @@ func TestGetAccessControls(t *testing.T) {
},
}
var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com")
@@ -170,7 +190,11 @@ func TestGetAccessControls(t *testing.T) {
"foo": {Config: model.AppConfig{Domain: "foo.example.com"}},
},
}
svc := NewAccessControlsService(log, config, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &config,
LabelProvider: provider,
})
got, err := svc.GetAccessControls("foo.example.com")
@@ -188,7 +212,11 @@ func TestGetAccessControls(t *testing.T) {
},
}
var provider LabelProvider = mock
svc := NewAccessControlsService(log, model.Config{}, &provider)
svc := NewAccessControlsService(AccessControlServiceInput{
Log: log,
Config: &model.Config{},
LabelProvider: provider,
})
got, err := svc.GetAccessControls("dynamic.example.com")
+108 -80
View File
@@ -2,8 +2,10 @@ package service
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"sync"
@@ -14,6 +16,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
@@ -24,32 +27,28 @@ import (
// but for now these are just safety limits to prevent unbounded memory usage
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed
type OAuthURLParams struct {
Scope string `form:"scope" url:"scope"`
ResponseType string `form:"response_type" url:"response_type"`
ClientID string `form:"client_id" url:"client_id"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
State string `form:"state" url:"state"`
Nonce string `form:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
// We either store params for redirecting to an app after OAuth login,
// or for redirecting back to the authorize screen to continue OIDC
type OAuthCallbackParams struct {
LoginFor string `form:"login_for" url:"login_for"`
OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
OIDCName string `form:"oidc_name" url:"oidc_name"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
}
type OAuthPendingSession struct {
State string
Verifier string
Token *oauth2.Token
Service *OAuthServiceImpl
Service IOAuthService
ExpiresAt time.Time
CallbackParams OAuthURLParams
CallbackParams OAuthCallbackParams
}
type LoginAttempt struct {
@@ -60,8 +59,8 @@ type LoginAttempt struct {
type AuthService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
config *model.Config
runtime *model.RuntimeConfig
ctx context.Context
ldap *LdapService
@@ -83,42 +82,57 @@ type AuthService struct {
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
maxLoginLimits int
}
func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
dg *ding.Ding,
ldap *LdapService,
queries repository.Store,
oauthBroker *OAuthBrokerService,
tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService {
type AuthServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Ctx context.Context
Ding *ding.Ding
LDAP *LdapService `optional:"true"`
Queries repository.Store
OAuthBroker *OAuthBrokerService
Tailscale *TailscaleService `optional:"true"`
PolicyEngine *PolicyEngine
}
func NewAuthService(i AuthServiceInput) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
ctx: ctx,
config: config,
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
policyEngine: policy,
log: i.Log,
runtime: i.Runtime,
ctx: i.Ctx,
config: i.Config,
ldap: i.LDAP,
queries: i.Queries,
oauthBroker: i.OAuthBroker,
tailscale: i.Tailscale,
policyEngine: i.PolicyEngine,
}
// get the max login limits based on the number of users and the configured max retries
service.maxLoginLimits = service.calculateLockdownLimit()
loginCacheSize := 0
if !service.config.Auth.LockdownEnabled {
loginCacheSize = service.maxLoginLimits
}
// caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
loginCache := NewCacheStore[LoginAttempt](loginCacheSize)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -257,7 +271,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return
}
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
if !success && auth.config.Auth.LockdownEnabled && auth.caches.login.Size() >= auth.maxLoginLimits {
if locked, _ := auth.IsInLockdown(); locked {
return
}
@@ -366,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
if data.Provider == "tailscale" {
auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname")
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname()))
if err != nil {
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
@@ -445,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie,
@@ -466,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Domain: auth.getCookieDomain(),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.Auth.SecureCookie,
@@ -516,17 +508,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
return "", fmt.Errorf("oauth service not found: %s", serviceName)
}
sessionId, err := uuid.NewRandom()
if err != nil {
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
state := service.NewRandom()
@@ -535,14 +527,14 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
session := OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
Service: service,
ExpiresAt: time.Now().Add(1 * time.Hour),
CallbackParams: params,
}
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
return sessionId.String(), session, nil
return sessionId.String(), nil
}
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
@@ -552,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
return "", err
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
return session.Service.GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
@@ -562,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
token, err := (*session.Service).GetToken(code, session.Verifier)
token, err := session.Service.GetToken(code, session.Verifier)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
@@ -591,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
userinfo, err := session.Service.GetUserinfo(session.Token)
if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err)
@@ -600,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro
return userinfo, nil
}
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
return session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
@@ -632,16 +624,17 @@ func (auth *AuthService) lockdownMode() {
return
}
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(auth.ctx)
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown.active = true
auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
timer := time.NewTimer(time.Until(auth.lockdown.until))
d := time.Duration(auth.config.Auth.LoginTimeout) * time.Second
auth.lockdown.until = time.Now().Add(d)
timer := time.NewTimer(d)
auth.lockdown.mu.Unlock()
@@ -653,14 +646,13 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
}
auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.caches.login.Clear()
auth.lockdown.active = false
auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
@@ -683,3 +675,39 @@ func (auth *AuthService) IsInLockdown() (bool, int) {
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
}
func (auth *AuthService) calculateLockdownLimit() int {
userCount := len(auth.runtime.LocalUsers)
if auth.ldap != nil {
ldapUsers, err := auth.ldap.GetUserCount()
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to get LDAP user count")
} else {
userCount += ldapUsers
}
}
limit := userCount * auth.config.Auth.LoginMaxRetries
jitter, err := rand.Int(rand.Reader, big.NewInt(64))
if err != nil {
auth.log.App.Warn().Err(err).Msg("Failed to generate jitter for lockdown limit")
} else {
limit += int(jitter.Int64())
}
if limit < 256 {
limit = 256
}
return limit
}
func (auth *AuthService) getCookieDomain() string {
if !auth.config.Auth.SubdomainsEnabled {
return ""
}
return auth.runtime.CookieDomain
}
+16 -1
View File
@@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,9 +13,22 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
policyEngine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &model.Config{
Auth: model.AuthConfig{
ACLs: model.ACLsConfig{
Policy: string(PolicyAllow),
},
},
},
})
require.NoError(t, err)
auth := &AuthService{
log: log,
runtime: model.RuntimeConfig{
runtime: &model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
@@ -28,6 +42,7 @@ func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
},
},
},
policyEngine: policyEngine,
}
assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
+16 -11
View File
@@ -8,6 +8,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
@@ -21,36 +22,40 @@ type DockerService struct {
isConnected bool
}
func NewDockerService(
log *logger.Logger,
ctx context.Context,
dg *ding.Ding,
) (*DockerService, error) {
type DockerServiceInput struct {
dig.In
Log *logger.Logger
Ctx context.Context
Ding *ding.Ding
}
func NewDockerService(i DockerServiceInput) (*DockerService, error) {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return nil, err
}
client.NegotiateAPIVersion(ctx)
client.NegotiateAPIVersion(i.Ctx)
_, err = client.Ping(ctx)
_, err = client.Ping(i.Ctx)
if err != nil {
log.App.Debug().Err(err).Msg("Docker not connected")
i.Log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
service := &DockerService{
log: log,
log: i.Log,
client: client,
context: ctx,
context: i.Ctx,
}
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
dg.Go(service.watchAndClose, ding.RingMajor)
i.Ding.Go(service.watchAndClose, ding.RingMajor)
return service, nil
}
+16 -11
View File
@@ -12,6 +12,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -48,11 +49,15 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey
}
func NewKubernetesService(
log *logger.Logger,
ctx context.Context,
dg *ding.Ding,
) (*KubernetesService, error) {
type KubernetesServiceInput struct {
dig.In
Log *logger.Logger
Ctx context.Context
Ding *ding.Ding
}
func NewKubernetesService(i KubernetesServiceInput) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
@@ -69,31 +74,31 @@ func NewKubernetesService(
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
accessCtx, accessCancel := context.WithTimeout(i.Ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
i.Log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
i.Log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
log: i.Log,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
service.watchGVR(gvr, ctx)
}, ding.RingMajor)
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
i.Log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
+48 -19
View File
@@ -11,44 +11,53 @@ import (
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
type LdapService struct {
log *logger.Logger
config model.Config
ctx context.Context
config *model.Config
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
bindPw string
}
func NewLdapService(
log *logger.Logger,
config model.Config,
ctx context.Context,
dg *ding.Ding,
) (*LdapService, error) {
if config.LDAP.Address == "" {
type LdapServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Ding *ding.Ding
Ctx context.Context
}
func NewLdapService(i LdapServiceInput) (*LdapService, error) {
if i.Config.LDAP.Address == "" {
return nil, nil
}
ldap := &LdapService{
log: log,
config: config,
ctx: ctx,
log: i.Log,
config: i.Config,
ctx: i.Ctx,
}
ldap.bindPw = utils.GetSecret(i.Config.LDAP.BindPassword, i.Config.LDAP.BindPasswordFile)
// Check whether authentication with client certificate is possible
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if i.Config.LDAP.AuthCert != "" && i.Config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(i.Config.LDAP.AuthCert, i.Config.LDAP.AuthKey)
if err != nil {
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
i.Log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
@@ -72,7 +81,7 @@ func NewLdapService(
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
@@ -165,6 +174,26 @@ func (ldap *LdapService) GetUserInfo(username string) (dn string, email string,
return entry.DN, entry.GetAttributeValue("mail"), nil
}
func (ldap *LdapService) GetUserCount() (int, error) {
searchRequest := ldapgo.NewSearchRequest(
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectClass=person)",
[]string{"dn"},
nil,
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return 0, err
}
return len(searchResult.Entries), nil
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
@@ -217,7 +246,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.bindPw)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
+23 -16
View File
@@ -5,25 +5,28 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"slices"
"golang.org/x/oauth2"
)
type OAuthServiceImpl interface {
type IOAuthService interface {
Name() string
ID() string
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetAuthURL(state, verifier string) string
GetToken(code, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
GetConfig() model.OAuthServiceConfig
UpdateConfig(config model.OAuthServiceConfig)
}
type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl
services map[string]IOAuthService
configs map[string]model.OAuthServiceConfig
}
@@ -32,23 +35,27 @@ var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Conte
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
type OAuthBrokerServiceInput struct {
dig.In
Log *logger.Logger
Runtime *model.RuntimeConfig
Ctx context.Context
}
func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService {
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
log: i.Log,
services: make(map[string]IOAuthService),
configs: i.Runtime.OAuthProviders,
}
for name, cfg := range configs {
for name, cfg := range service.configs {
if presetFunc, exists := presets[name]; exists {
service.services[name] = presetFunc(cfg, ctx)
service.services[name] = presetFunc(cfg, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
service.services[name] = NewOAuthService(cfg, name, ctx)
service.services[name] = NewOAuthService(cfg, name, i.Ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
@@ -65,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) {
func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) {
service, exists := broker.services[name]
return service, exists
}
+15 -1
View File
@@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string {
return random
}
func (s *OAuthService) GetAuthURL(state string, verifier string) string {
func (s *OAuthService) GetAuthURL(state, verifier string) string {
return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
}
@@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
func (s *OAuthService) GetConfig() model.OAuthServiceConfig {
return s.serviceCfg
}
func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) {
s.serviceCfg = config
s.config.ClientID = config.ClientID
s.config.ClientSecret = config.ClientSecret
s.config.Scopes = config.Scopes
s.config.Endpoint.AuthURL = config.AuthURL
s.config.Endpoint.TokenURL = config.TokenURL
s.config.RedirectURL = config.RedirectURL
}
+155 -42
View File
@@ -14,17 +14,20 @@ import (
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"slices"
"github.com/go-jose/go-jose/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
var (
@@ -41,6 +44,15 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)
type OIDCPrompt string
const (
OIDCPromptLogin OIDCPrompt = "login"
OIDCPromptNone OIDCPrompt = "none"
)
var SupportedPrompts = []string{string(OIDCPromptLogin), string(OIDCPromptNone)}
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
@@ -51,6 +63,7 @@ type ClaimSet struct {
Sub string `json:"sub"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
AuthTime int64 `json:"auth_time,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
@@ -106,14 +119,16 @@ type TokenResponse struct {
}
type AuthorizeRequest struct {
Scope string `json:"scope" binding:"required"`
ResponseType string `json:"response_type" binding:"required"`
ClientID string `json:"client_id" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
State string `json:"state"`
Nonce string `json:"nonce"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
Scope string `form:"scope" json:"scope" url:"scope"`
ResponseType string `form:"response_type" json:"response_type" url:"response_type"`
ClientID string `form:"client_id" json:"client_id" url:"client_id"`
RedirectURI string `form:"redirect_uri" json:"redirect_uri" url:"redirect_uri"`
State string `form:"state" json:"state" url:"state"`
Nonce string `form:"nonce" json:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" json:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" url:"code_challenge_method"`
Prompt string `form:"prompt" json:"prompt" url:"prompt"`
MaxAge string `form:"max_age" json:"max_age" url:"max_age"`
}
type AuthorizeCodeEntry struct {
@@ -124,6 +139,7 @@ type AuthorizeCodeEntry struct {
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
AuthTime int64
}
type UsedCodeEntry struct {
@@ -132,8 +148,8 @@ type UsedCodeEntry struct {
type OIDCService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
config *model.Config
runtime *model.RuntimeConfig
queries repository.Store
clients map[string]model.OIDCClientConfig
@@ -142,24 +158,30 @@ type OIDCService struct {
issuer string
caches struct {
code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry]
code *CacheStore[AuthorizeCodeEntry]
usedCode *CacheStore[UsedCodeEntry]
authorize *CacheStore[AuthorizeRequest]
}
}
func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries repository.Store,
dg *ding.Ding) (*OIDCService, error) {
type OIDCServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Runtime *model.RuntimeConfig
Queries repository.Store
Ding *ding.Ding
}
func NewOIDCService(i OIDCServiceInput) (*OIDCService, error) {
// If not configured, skip init
if len(runtime.OIDCClients) == 0 {
if len(i.Config.OIDC.Clients) == 0 {
return nil, nil
}
// Ensure issuer is https
uissuer, err := url.Parse(runtime.AppURL)
uissuer, err := url.Parse(i.Runtime.AppURL)
if err != nil {
return nil, fmt.Errorf("failed to parse app url: %w", err)
@@ -172,14 +194,14 @@ func NewOIDCService(
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
if strings.TrimSpace(i.Config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(i.Config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
fprivateKey, err := os.ReadFile(i.Config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err
@@ -198,8 +220,12 @@ func NewOIDCService(
Type: "RSA PRIVATE KEY",
Bytes: der,
})
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
i.Log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PrivateKeyPath), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create directory for private key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
@@ -208,7 +234,7 @@ func NewOIDCService(
if block == nil {
return nil, errors.New("failed to decode private key")
}
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
@@ -217,7 +243,7 @@ func NewOIDCService(
var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
fpublicKey, err := os.ReadFile(i.Config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err)
@@ -233,8 +259,12 @@ func NewOIDCService(
Type: "RSA PUBLIC KEY",
Bytes: der,
})
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
i.Log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err := os.MkdirAll(filepath.Dir(i.Config.OIDC.PublicKeyPath), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create directory for public key: %w", err)
}
err = os.WriteFile(i.Config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return nil, err
}
@@ -243,7 +273,7 @@ func NewOIDCService(
if block == nil {
return nil, errors.New("failed to decode public key")
}
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
i.Log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
@@ -273,7 +303,7 @@ func NewOIDCService(
// We will reorganize the client into a map with the client ID as the key
clients := make(map[string]model.OIDCClientConfig)
for id, client := range config.OIDC.Clients {
for id, client := range i.Config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
@@ -289,15 +319,15 @@ func NewOIDCService(
}
client.ClientSecretFile = ""
clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
i.Log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
log: i.Log,
config: i.Config,
runtime: i.Runtime,
queries: i.Queries,
clients: clients,
privateKey: privateKey,
@@ -306,16 +336,19 @@ func NewOIDCService(
}
// Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor)
i.Ding.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
usedCode := NewCacheStore[UsedCodeEntry](256)
authorize := NewCacheStore[AuthorizeRequest](256)
service.caches.code = codeCash
service.caches.usedCode = usedCode
service.caches.authorize = authorize
// Start cache cleanup routine
dg.Go(func(ctx context.Context) {
i.Ding.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -324,6 +357,7 @@ func NewOIDCService(
case <-ticker.C:
service.caches.code.Sweep()
service.caches.usedCode.Sweep()
service.caches.authorize.Sweep()
case <-ctx.Done():
return
}
@@ -402,6 +436,7 @@ func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.U
ClientID: req.ClientID,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
AuthTime: userContext.AuthTime,
}
if req.CodeChallenge != "" {
@@ -491,7 +526,7 @@ func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*Aut
return &entry, true
}
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string, authTime *int64) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -536,6 +571,10 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
Nonce: nonce,
}
if authTime != nil {
claims.AuthTime = *authTime
}
payload, err := json.Marshal(claims)
if err != nil {
@@ -557,8 +596,8 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry, authTime int64) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce, &authTime)
if err != nil {
return nil, err
@@ -637,9 +676,10 @@ func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken
return nil, err
}
// TODO: store auth time in the database so we can include it in the new ID token, for now we omit it
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, userInfo, entry.Scope, entry.Nonce)
}, userInfo, entry.Scope, entry.Nonce, nil)
if err != nil {
return nil, err
@@ -856,3 +896,76 @@ func (service *OIDCService) MarkCodeAsUsed(codeHash string, sub string) {
func (service *OIDCService) DeleteSessionBySub(ctx context.Context, sub string) error {
return service.queries.DeleteOIDCSessionBySub(ctx, sub)
}
func (service *OIDCService) CreateAuthorizeRequestTicket(req AuthorizeRequest) string {
ticket := utils.GenerateString(32)
service.caches.authorize.Set(ticket, req, 10*time.Minute)
return ticket
}
func (service *OIDCService) GetAuthorizeRequestByTicket(ticket string) (*AuthorizeRequest, bool) {
entry, ok := service.caches.authorize.Get(ticket)
if !ok {
return nil, false
}
return &entry, true
}
func (service *OIDCService) DeleteAuthorizeRequestTicket(ticket string) {
service.caches.authorize.Delete(ticket)
}
// TODO: support signed request objects in the future
func (service *OIDCService) DecodeAuthorizeJWT(tokenString string) (*AuthorizeRequest, error) {
var claims jwt.MapClaims
token, _, err := jwt.NewParser().ParseUnverified(tokenString, &claims)
if err != nil {
return nil, fmt.Errorf("failed to parse authorize request jwt: %w", err)
}
alg, ok := token.Header["alg"].(string)
if !ok || alg != "none" || string(token.Signature) != "" {
return nil, fmt.Errorf("only unsigned jwts are supported for authorize requests")
}
get := func(k string) string {
v, _ := claims[k].(string)
return v
}
return &AuthorizeRequest{
Scope: get("scope"),
ResponseType: get("response_type"),
ClientID: get("client_id"),
RedirectURI: get("redirect_uri"),
State: get("state"),
Nonce: get("nonce"),
CodeChallenge: get("code_challenge"),
CodeChallengeMethod: get("code_challenge_method"),
Prompt: get("prompt"),
}, nil
}
func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt {
if prompt == "" {
return []OIDCPrompt{}
}
parsedPromps := make([]OIDCPrompt, 0)
prompts := strings.SplitSeq(prompt, " ")
for p := range prompts {
if !slices.Contains(SupportedPrompts, p) {
continue
}
parsedPromps = append(parsedPromps, OIDCPrompt(p))
}
return parsedPromps
}
+26 -18
View File
@@ -1,4 +1,4 @@
package service_test
package service
import (
"context"
@@ -9,12 +9,12 @@ import (
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() service.UserinfoResponse {
return service.UserinfoResponse{
func newTestUser() UserinfoResponse {
return UserinfoResponse{
Sub: "test-sub",
Name: "Test User",
PreferredUsername: "testuser",
@@ -67,21 +67,29 @@ func TestCompileUserinfo(t *testing.T) {
ctx := context.TODO()
dg := ding.New(ctx)
svc, err := service.NewOIDCService(log, cfg, runtime, nil, dg)
store := memory.New()
svc, err := NewOIDCService(OIDCServiceInput{
Log: log,
Config: &cfg,
Runtime: &runtime,
Queries: store,
Ding: dg,
})
require.NoError(t, err)
type testCase struct {
description string
mutate func(u *service.UserinfoResponse)
mutate func(u *UserinfoResponse)
scope string
run func(t *testing.T, info service.UserinfoResponse)
run func(t *testing.T, info UserinfoResponse)
}
tests := []testCase{
{
description: "openid scope only returns sub and updated_at",
scope: "openid",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test-sub", info.Sub)
assert.Equal(t, int64(1234567890), info.UpdatedAt)
assert.Empty(t, info.Name)
@@ -94,7 +102,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "profile scope returns all profile fields",
scope: "openid profile",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "testuser", info.PreferredUsername)
assert.Equal(t, "Test", info.GivenName)
@@ -114,7 +122,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email and email_verified true when email present",
scope: "openid email",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "test@example.com", info.Email)
assert.True(t, info.EmailVerified)
assert.Empty(t, info.Name)
@@ -123,8 +131,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "email scope sets email_verified false when email absent",
scope: "openid email",
mutate: func(u *service.UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
mutate: func(u *UserinfoResponse) { u.Email = "" },
run: func(t *testing.T, info UserinfoResponse) {
assert.Empty(t, info.Email)
assert.False(t, info.EmailVerified)
},
@@ -132,7 +140,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified true when phone present",
scope: "openid phone",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "+15555550100", info.PhoneNumber)
require.NotNil(t, info.PhoneNumberVerified)
assert.True(t, *info.PhoneNumberVerified)
@@ -141,8 +149,8 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "phone scope sets phone_number_verified false when phone absent",
scope: "openid phone",
mutate: func(u *service.UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info service.UserinfoResponse) {
mutate: func(u *UserinfoResponse) { u.PhoneNumber = "" },
run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.PhoneNumberVerified)
assert.False(t, *info.PhoneNumberVerified)
},
@@ -150,7 +158,7 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "address scope returns parsed address",
scope: "openid address",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
require.NotNil(t, info.Address)
assert.Equal(t, "123 Main St", info.Address.Formatted)
assert.Equal(t, "123 Main St", info.Address.StreetAddress)
@@ -163,14 +171,14 @@ func TestCompileUserinfo(t *testing.T) {
{
description: "groups scope returns split groups",
scope: "openid groups",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, []string{"admins", "users"}, info.Groups)
},
},
{
description: "all scopes return all fields",
scope: "openid profile email phone address groups",
run: func(t *testing.T, info service.UserinfoResponse) {
run: func(t *testing.T, info UserinfoResponse) {
assert.Equal(t, "Test User", info.Name)
assert.Equal(t, "test@example.com", info.Email)
assert.Equal(t, "+15555550100", info.PhoneNumber)
+14 -6
View File
@@ -6,6 +6,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
)
type Policy string
@@ -40,21 +41,28 @@ type PolicyEngine struct {
policy Policy
}
func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) {
type PolicyEngineInput struct {
dig.In
Log *logger.Logger
Config *model.Config
}
func NewPolicyEngine(i PolicyEngineInput) (*PolicyEngine, error) {
engine := PolicyEngine{
log: log,
log: i.Log,
rules: make(map[RuleName]Rule),
}
switch config.Auth.ACLs.Policy {
switch i.Config.Auth.ACLs.Policy {
case string(PolicyAllow):
log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
i.Log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked")
engine.policy = PolicyAllow
case string(PolicyDeny):
log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
i.Log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed")
engine.policy = PolicyDeny
default:
return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy)
return nil, fmt.Errorf("invalid acl policy: %s", i.Config.Auth.ACLs.Policy)
}
return &engine, nil
+36 -19
View File
@@ -1,10 +1,9 @@
package service_test
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
@@ -12,14 +11,14 @@ import (
// Create test rule
type TestRule struct{}
func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect {
func (rule *TestRule) Evaluate(ctx *ACLContext) Effect {
switch ctx.Path {
case "/allowed":
return service.EffectAllow
return EffectAllow
case "/denied":
return service.EffectDeny
return EffectDeny
default:
return service.EffectAbstain
return EffectAbstain
}
}
@@ -33,36 +32,51 @@ func TestPolicyEngine(t *testing.T) {
// Engine should fail with invalid policy
cfg.Auth.ACLs.Policy = "invalid_policy"
_, err := service.NewPolicyEngine(cfg, log)
_, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.Error(t, err)
// Engine should initialize with 'allow' policy
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err := service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err := NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
assert.Equal(t, service.PolicyAllow, engine.Policy())
assert.Equal(t, PolicyAllow, engine.Policy())
// Engine should initialize with 'deny' policy
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
assert.Equal(t, service.PolicyDeny, engine.Policy())
assert.Equal(t, PolicyDeny, engine.Policy())
// Engine should allow adding rules
engine, err = service.NewPolicyEngine(cfg, log)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
_, ok := engine.Rules()["test-rule"]
assert.True(t, ok)
// Begin allow policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyAllow)
engine, err = service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyAllow)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
// With allow policy, if rule allows, access should be allowed
ctx := &service.ACLContext{Path: "/allowed"}
ctx := &ACLContext{Path: "/allowed"}
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// With allow policy, if rule denies, access should be denied
@@ -74,8 +88,11 @@ func TestPolicyEngine(t *testing.T) {
assert.Equal(t, true, engine.Evaluate("test-rule", ctx))
// Begin deny policy tests
cfg.Auth.ACLs.Policy = string(service.PolicyDeny)
engine, err = service.NewPolicyEngine(cfg, log)
cfg.Auth.ACLs.Policy = string(PolicyDeny)
engine, err = NewPolicyEngine(PolicyEngineInput{
Log: log,
Config: &cfg,
})
assert.NoError(t, err)
engine.RegisterRule("test-rule", testRule)
+38 -16
View File
@@ -12,6 +12,7 @@ import (
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"go.uber.org/dig"
"tailscale.com/client/local"
"tailscale.com/tsnet"
)
@@ -25,7 +26,7 @@ type TailscaleWhoisResponse struct {
type TailscaleService struct {
log *logger.Logger
config model.Config
config *model.Config
ctx context.Context
srv *tsnet.Server
@@ -34,22 +35,31 @@ type TailscaleService struct {
mu sync.Mutex
}
func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Context, dg *ding.Ding) (*TailscaleService, error) {
if !config.Tailscale.Enabled {
type TailscaleServiceInput struct {
dig.In
Log *logger.Logger
Config *model.Config
Ctx context.Context
Ding *ding.Ding
}
func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) {
if !i.Config.Tailscale.Enabled {
return nil, nil
}
srv := new(tsnet.Server)
// node options
srv.Dir = config.Tailscale.Dir
srv.Hostname = config.Tailscale.Hostname
srv.AuthKey = config.Tailscale.AuthKey
srv.Ephemeral = config.Tailscale.Ephemeral
srv.Dir = i.Config.Tailscale.Dir
srv.Hostname = i.Config.Tailscale.Hostname
srv.AuthKey = i.Config.Tailscale.AuthKey
srv.Ephemeral = i.Config.Tailscale.Ephemeral
// redirect logs to zerolog
srv.Logf = log.App.Printf
srv.UserLogf = log.App.Printf
srv.Logf = i.Log.App.Printf
srv.UserLogf = i.Log.App.Printf
err := srv.Start()
@@ -65,14 +75,14 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
}
service := &TailscaleService{
log: log,
config: config,
ctx: ctx,
log: i.Log,
config: i.Config,
ctx: i.Ctx,
srv: srv,
lc: lc,
}
connectCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
connectCtx, cancel := context.WithTimeout(i.Ctx, 2*time.Minute) // large enough timeout to allow for user to manually authenticate with link if needed
defer cancel()
err = service.waitForConn(connectCtx)
@@ -82,7 +92,11 @@ func NewTailscaleService(log *logger.Logger, config model.Config, ctx context.Co
return nil, fmt.Errorf("failed to connect to tailscale network: %w", err)
}
dg.Go(service.watchAndClose, ding.RingMajor)
i.Ding.Go(service.watchAndClose, ding.RingMajor)
if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen {
service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.")
}
return service, nil
}
@@ -128,8 +142,6 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
NodeName: strings.TrimSuffix(who.Node.Name, "."),
}
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil
}
@@ -140,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) {
if ts.ln != nil {
return *ts.ln, nil
}
if ts.config.Tailscale.Funnel {
ln, err := ts.srv.ListenFunnel("tcp", ":443")
if err != nil {
return nil, err
}
ts.ln = &ln
return ln, nil
}
ln, err := ts.srv.ListenTLS("tcp", ":443")
if err != nil {
return nil, err