Merge branch 'main' into feat/tailscale

This commit is contained in:
Stavros
2026-05-10 16:44:39 +03:00
81 changed files with 4474 additions and 2655 deletions
+37 -26
View File
@@ -1,54 +1,65 @@
package service
import (
"errors"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type AccessControlsService struct {
docker *DockerService
static map[string]config.App
type LabelProvider interface {
GetLabels(appDomain string) (*model.App, error)
}
func NewAccessControlsService(docker *DockerService, static map[string]config.App) *AccessControlsService {
type AccessControlsService struct {
log *logger.Logger
labelProvider *LabelProvider
static map[string]model.App
}
func NewAccessControlsService(
log *logger.Logger,
labelProvider *LabelProvider,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{
docker: docker,
static: static,
log: log,
labelProvider: labelProvider,
static: static,
}
}
func (acls *AccessControlsService) Init() error {
return nil // No initialization needed
}
func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App
for app, config := range acls.static {
if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
return config, nil
acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config
break // If we find a match by domain, we can stop searching
}
if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
return config, nil
acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config
break // If we find a match by app name, we can stop searching
}
}
return config.App{}, errors.New("no results")
return appAcls
}
func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
// First check in the static config
app, err := acls.lookupStaticACLs(domain)
app := acls.lookupStaticACLs(domain)
if err == nil {
tlog.App.Debug().Msg("Using ACls from static configuration")
if app != nil {
acls.log.App.Debug().Msg("Using static ACLs for app")
return app, nil
}
// Fallback to Docker labels
tlog.App.Debug().Msg("Falling back to Docker labels for ACLs")
return acls.docker.GetLabels(domain)
// If we have a label provider configured, try to get ACLs from it
if acls.labelProvider != nil {
return (*acls.labelProvider).GetLabels(domain)
}
// no labels
return nil, nil
}
+262 -233
View File
@@ -5,20 +5,22 @@ import (
"database/sql"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
@@ -28,6 +30,10 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed
type OAuthURLParams struct {
@@ -66,41 +72,42 @@ type Lockdown struct {
ActiveUntil time.Time
}
type AuthServiceConfig struct {
Users []config.User
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
LDAPGroupsCacheTTL int
}
type AuthService struct {
config AuthServiceConfig
docker *DockerService
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{
func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
ctx context.Context,
wg *sync.WaitGroup,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
context: ctx,
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
@@ -108,86 +115,79 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
queries: queries,
oauthBroker: oauthBroker,
}
wg.Go(service.CleanupOAuthSessionsRoutine)
return service
}
func (auth *AuthService) Init() error {
go auth.CleanupOAuthSessionsRoutine()
return nil
}
func (auth *AuthService) SearchUser(username string) config.UserSearch {
if auth.GetLocalUser(username).Username != "" {
return config.UserSearch{
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
if auth.GetLocalUser(username) != nil {
return &model.UserSearch{
Username: username,
Type: "local",
}
Type: model.UserLocal,
}, nil
}
if auth.ldap.IsConfigured() {
if auth.ldap != nil {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "unknown",
}
return nil, fmt.Errorf("failed to get ldap user: %w", err)
}
return config.UserSearch{
return &model.UserSearch{
Username: userDN,
Type: "ldap",
}
Type: model.UserLDAP,
}, nil
}
return config.UserSearch{
Type: "unknown",
}
return nil, ErrUserNotFound
}
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
switch search.Type {
case "local":
case model.UserLocal:
user := auth.GetLocalUser(search.Username)
return auth.CheckPassword(user, password)
case "ldap":
if auth.ldap.IsConfigured() {
if user == nil {
return ErrUserNotFound
}
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap != nil {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
return false
return fmt.Errorf("failed to bind to ldap user: %w", err)
}
err = auth.ldap.BindService(true)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
return false
return fmt.Errorf("failed to bind to ldap service account: %w", err)
}
return true
return nil
}
default:
tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
return false
return errors.New("unknown user search type")
}
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
return false
return errors.New("user authentication failed")
}
func (auth *AuthService) GetLocalUser(username string) config.User {
for _, user := range auth.config.Users {
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.runtime.LocalUsers == nil {
return nil
}
for _, user := range auth.runtime.LocalUsers {
if user.Username == username {
return user
return &user
}
}
tlog.App.Warn().Str("username", username).Msg("Local user not found")
return config.User{}
return nil
}
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
if !auth.ldap.IsConfigured() {
return config.LdapUser{}, errors.New("LDAP service not initialized")
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if auth.ldap == nil {
return nil, errors.New("ldap service not configured")
}
auth.ldapGroupsMutex.RLock()
@@ -195,7 +195,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) {
return config.LdapUser{
return &model.LDAPUser{
DN: userDN,
Groups: entry.Groups,
}, nil
@@ -204,26 +204,22 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
}
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return config.LdapUser{
return &model.LDAPUser{
DN: userDN,
Groups: groups,
}, nil
}
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
@@ -233,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining
}
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0
}
@@ -251,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
}
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return
}
@@ -282,21 +278,21 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
uuid, err := uuid.NewRandom()
if err != nil {
return err
return nil, fmt.Errorf("failed to generate session uuid: %w", err)
}
var expiry int
@@ -304,9 +300,11 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
if data.TotpPending {
expiry = 3600
} else {
expiry = auth.config.SessionExpiry
expiry = auth.config.Auth.SessionExpiry
}
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
session := repository.CreateSessionParams{
UUID: uuid.String(),
Username: data.Username,
@@ -315,63 +313,74 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
Provider: data.Provider,
TotpPending: data.TotpPending,
OAuthGroups: data.OAuthGroups,
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
Expiry: expiresAt.Unix(),
CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
}
_, err = auth.queries.CreateSession(c, session)
_, err = auth.queries.CreateSession(ctx, session)
if err != nil {
return err
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
if data.Provider == "tailscale" {
// TODO: use domain from tailscale to set cookie, this is mostly a hack for now
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", c.Request.Host))
if err != nil {
return err
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", tsCookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", tsCookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
return err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
return err
return nil, fmt.Errorf("failed to retrieve session: %w", err)
}
currentTime := time.Now().Unix()
var refreshThreshold int64
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.SessionExpiry / 2)
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else {
refreshThreshold = int64(time.Hour.Seconds())
}
if session.Expiry-currentTime > refreshThreshold {
return nil
return nil, nil
}
newExpiry := session.Expiry + refreshThreshold
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
Username: session.Username,
Email: session.Email,
Name: session.Name,
@@ -385,150 +394,160 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
})
if err != nil {
return err
return nil, fmt.Errorf("failed to update session expiry: %w", err)
}
c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
return nil
}
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return err
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
err = auth.queries.DeleteSession(c, cookie)
if err != nil {
return err
}
c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil {
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.Session{}, fmt.Errorf("session not found")
return nil, errors.New("session not found")
}
return repository.Session{}, err
return nil, err
}
currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
err = auth.queries.DeleteSession(c, cookie)
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
return nil, fmt.Errorf("session max lifetime exceeded")
}
}
if currentTime > session.Expiry {
err = auth.queries.DeleteSession(c, cookie)
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
return repository.Session{}, fmt.Errorf("session expired")
return nil, fmt.Errorf("session expired")
}
return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
}, nil
return &session, nil
}
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
return len(auth.runtime.LocalUsers) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap.IsConfigured()
func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
if context.OAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
if context.Provider == model.ProviderOAuth {
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users")
if utils.CheckFilter(acls.Users.Block, context.Username) {
auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
tlog.App.Debug().Msg("Checking users")
return utils.CheckFilter(acls.Users.Allow, context.Username)
auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
for id := range config.OverrideProviders {
if context.Provider == id {
tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
return true
}
if !context.IsOAuth() {
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
if acls == nil {
return true
}
if !context.IsLDAP() {
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
auth.log.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
if acls == nil {
return true, nil
}
// Check for block list
if path.Block != "" {
regex, err := regexp.Compile(path.Block)
if acls.Path.Block != "" {
regex, err := regexp.Compile(acls.Path.Block)
if err != nil {
return true, err
@@ -540,8 +559,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
}
// Check for allow list
if path.Allow != "" {
regex, err := regexp.Compile(path.Allow)
if acls.Path.Allow != "" {
regex, err := regexp.Compile(acls.Path.Allow)
if err != nil {
return true, err
@@ -555,31 +574,23 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
return true, nil
}
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
username, password, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Debug().Msg("No basic auth provided")
return nil
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
if acls == nil {
return true
}
return &config.User{
Username: username,
Password: password,
}
}
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false
}
}
@@ -587,38 +598,42 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true
}
}
if len(allowedIPs) > 0 {
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true
}
func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
for _, bypassed := range acls.Bypass {
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
if acls == nil {
return false
}
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true
}
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false
}
@@ -685,21 +700,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return token, nil
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return config.Claims{}, err
return nil, err
}
if session.Token == nil {
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
return nil, fmt.Errorf("failed to get userinfo: %w", err)
}
return userinfo, nil
@@ -722,21 +737,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
}
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
now := time.Now()
auth.oauthMutex.Lock()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
}
auth.oauthMutex.Unlock()
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-auth.context.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
}
@@ -805,11 +831,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock()
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
// At this point all login attemps will also expire so,
@@ -826,11 +852,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
}
auth.loginMutex.Lock()
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil
auth.loginMutex.Unlock()
}
+63 -55
View File
@@ -3,104 +3,112 @@ package service
import (
"context"
"strings"
"sync"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
)
type DockerService struct {
client *client.Client
context context.Context
log *logger.Logger
client *client.Client
context context.Context
isConnected bool
}
func NewDockerService() *DockerService {
return &DockerService{}
}
func NewDockerService(
log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
) (*DockerService, error) {
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return err
}
ctx := context.Background()
client.NegotiateAPIVersion(ctx)
docker.client = client
docker.context = ctx
_, err = docker.client.Ping(docker.context)
if err != nil {
tlog.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false
docker.client = nil
docker.context = nil
return nil
}
docker.isConnected = true
tlog.App.Debug().Msg("Docker connected")
return nil
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
containers, err := docker.client.ContainerList(docker.context, container.ListOptions{})
if err != nil {
return nil, err
}
return containers, nil
client.NegotiateAPIVersion(ctx)
_, err = client.Ping(ctx)
if err != nil {
log.App.Debug().Err(err).Msg("Docker not connected")
return nil, nil
}
service := &DockerService{
log: log,
client: client,
context: ctx,
}
service.isConnected = true
service.log.App.Debug().Msg("Docker connected successfully")
wg.Go(service.watchAndClose)
return service, nil
}
func (docker *DockerService) getContainers() ([]container.Summary, error) {
return docker.client.ContainerList(docker.context, container.ListOptions{})
}
func (docker *DockerService) inspectContainer(containerId string) (container.InspectResponse, error) {
inspect, err := docker.client.ContainerInspect(docker.context, containerId)
if err != nil {
return container.InspectResponse{}, err
}
return inspect, nil
return docker.client.ContainerInspect(docker.context, containerId)
}
func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
return config.App{}, nil
docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
return nil, nil
}
containers, err := docker.getContainers()
if err != nil {
return config.App{}, err
return nil, err
}
for _, ctr := range containers {
inspect, err := docker.inspectContainer(ctr.ID)
if err != nil {
return config.App{}, err
return nil, err
}
labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps")
if err != nil {
return config.App{}, err
return nil, err
}
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return appLabels, nil
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil
}
if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return appLabels, nil
docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil
}
}
}
tlog.App.Debug().Msg("No matching container found, returning empty labels")
return config.App{}, nil
docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
return nil, nil
}
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
}
+310
View File
@@ -0,0 +1,310 @@
package service
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/rest"
)
type ingressKey struct {
namespace string
name string
}
type ingressAppKey struct {
ingressKey
appName string
}
type ingressApp struct {
domain string
appName string
app model.App
}
type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface
started bool
mu sync.RWMutex
ingressApps map[ingressKey][]ingressApp
domainIndex map[string]ingressAppKey
appNameIndex map[string]ingressAppKey
}
func NewKubernetesService(
log *logger.Logger,
ctx context.Context,
wg *sync.WaitGroup,
) (*KubernetesService, error) {
cfg, err := rest.InClusterConfig()
if err != nil {
return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err)
}
client, err := dynamic.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
gvr := schema.GroupVersionResource{
Group: "networking.k8s.io",
Version: "v1",
Resource: "ingresses",
}
accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second)
defer accessCancel()
_, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil {
log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
return nil, fmt.Errorf("failed to access ingress api: %w", err)
}
log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
service := &KubernetesService{
log: log,
ctx: ctx,
client: client,
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
}
wg.Go(func() {
service.watchGVR(gvr)
})
service.started = true
log.App.Debug().Msg("Kubernetes label provider started successfully")
return service, nil
}
func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) {
k.mu.Lock()
defer k.mu.Unlock()
key := ingressKey{namespace, name}
// Remove existing entries for this ingress
if existing, ok := k.ingressApps[key]; ok {
for _, app := range existing {
delete(k.domainIndex, app.domain)
delete(k.appNameIndex, app.appName)
}
}
// Add new entries
k.ingressApps[key] = apps
for _, app := range apps {
appKey := ingressAppKey{key, app.appName}
k.domainIndex[app.domain] = appKey
k.appNameIndex[app.appName] = appKey
}
}
func (k *KubernetesService) removeIngress(namespace, name string) {
k.mu.Lock()
defer k.mu.Unlock()
key := ingressKey{namespace, name}
if apps, ok := k.ingressApps[key]; ok {
for _, app := range apps {
delete(k.domainIndex, app.domain)
delete(k.appNameIndex, app.appName)
}
delete(k.ingressApps, key)
}
}
func (k *KubernetesService) getByDomain(domain string) *model.App {
k.mu.RLock()
defer k.mu.RUnlock()
if appKey, ok := k.domainIndex[domain]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for i := range apps {
app := &apps[i]
if app.domain == domain && app.appName == appKey.appName {
return &app.app
}
}
}
}
return nil
}
func (k *KubernetesService) getByAppName(appName string) *model.App {
k.mu.RLock()
defer k.mu.RUnlock()
if appKey, ok := k.appNameIndex[appName]; ok {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for i := range apps {
app := &apps[i]
if app.appName == appName {
return &app.app
}
}
}
}
return nil
}
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
namespace := item.GetNamespace()
name := item.GetName()
annotations := item.GetAnnotations()
if annotations == nil {
k.removeIngress(namespace, name)
return
}
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil {
k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
k.removeIngress(namespace, name)
return
}
var apps []ingressApp
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == "" {
continue
}
apps = append(apps, ingressApp{
domain: appLabels.Config.Domain,
appName: appName,
app: appLabels,
})
}
if len(apps) == 0 {
k.removeIngress(namespace, name)
} else {
k.addIngressApps(namespace, name, apps)
}
}
func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
ctx, cancel := context.WithTimeout(k.ctx, 30*time.Second)
defer cancel()
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
return err
}
for i := range list.Items {
k.updateFromItem(&list.Items[i])
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
return nil
}
// runWatcher drains events from an active watcher until it closes or the context is done.
// Returns true if the caller should restart the watcher, false if it should exit.
func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.Interface, resyncTicker *time.Ticker) bool {
for {
select {
case <-k.ctx.Done():
w.Stop()
return false
case event, ok := <-w.ResultChan():
if !ok {
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
w.Stop()
time.Sleep(5 * time.Second)
return true
}
item, ok := event.Object.(*unstructured.Unstructured)
if !ok {
k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
continue
}
switch event.Type {
case watch.Added, watch.Modified:
k.updateFromItem(item)
case watch.Deleted:
k.removeIngress(item.GetNamespace(), item.GetName())
}
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
}
}
}
}
func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
resyncTicker := time.NewTicker(5 * time.Minute)
defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second)
}
for {
select {
case <-k.ctx.Done():
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher")
return
case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
}
default:
ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil {
k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
cancel()
time.Sleep(10 * time.Second)
continue
}
k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel()
return
}
cancel()
}
}
}
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started {
k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
return nil, nil
}
// First check cache
app := k.getByDomain(appDomain)
if app != nil {
k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil
}
appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName)
if app != nil {
k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil
}
k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
return nil, nil
}
+191
View File
@@ -0,0 +1,191 @@
package service
import (
"testing"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func TestKubernetesService(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()
type testCase struct {
description string
run func(t *testing.T, svc *KubernetesService)
}
tests := []testCase{
{
description: "Cache by domain returns app and misses unknown domain",
run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "foo.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "foo.example.com", appName: "foo", app: app},
})
got := svc.getByDomain("foo.example.com")
require.NotNil(t, got)
assert.Equal(t, "foo.example.com", got.Config.Domain)
got = svc.getByDomain("notfound.example.com")
assert.Nil(t, got)
},
},
{
description: "Cache by app name returns app and misses unknown name",
run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "bar.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "bar.example.com", appName: "bar", app: app},
})
got := svc.getByAppName("bar")
require.NotNil(t, got)
assert.Equal(t, "bar.example.com", got.Config.Domain)
got = svc.getByAppName("notfound")
assert.Nil(t, got)
},
},
{
description: "RemoveIngress clears domain and app name entries",
run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "baz.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "baz.example.com", appName: "baz", app: app},
})
svc.removeIngress("default", "my-ingress")
got := svc.getByDomain("baz.example.com")
assert.Nil(t, got)
got = svc.getByAppName("baz")
assert.Nil(t, got)
},
},
{
description: "AddIngressApps replaces stale entries for the same ingress",
run: func(t *testing.T, svc *KubernetesService) {
old := model.App{Config: model.AppConfig{Domain: "old.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "old.example.com", appName: "old", app: old},
})
updated := model.App{Config: model.AppConfig{Domain: "new.example.com"}}
svc.addIngressApps("default", "my-ingress", []ingressApp{
{domain: "new.example.com", appName: "new", app: updated},
})
got := svc.getByDomain("old.example.com")
assert.Nil(t, got)
got = svc.getByDomain("new.example.com")
require.NotNil(t, got)
assert.Equal(t, "new.example.com", got.Config.Domain)
},
},
{
description: "GetLabels returns app from cache when started",
run: func(t *testing.T, svc *KubernetesService) {
svc.started = true
app := model.App{Config: model.AppConfig{Domain: "hit.example.com"}}
svc.addIngressApps("default", "ing", []ingressApp{
{domain: "hit.example.com", appName: "hit", app: app},
})
got, err := svc.GetLabels("hit.example.com")
require.NoError(t, err)
assert.Equal(t, "hit.example.com", got.Config.Domain)
},
},
{
description: "GetLabels returns empty app on cache miss when started",
run: func(t *testing.T, svc *KubernetesService) {
svc.started = true
got, err := svc.GetLabels("notfound.example.com")
require.NoError(t, err)
assert.Nil(t, got)
},
},
{
description: "GetLabels resolves app by app name",
run: func(t *testing.T, svc *KubernetesService) {
svc.started = true
app := model.App{Config: model.AppConfig{Domain: "myapp.internal.example.com"}}
svc.addIngressApps("default", "ing", []ingressApp{
{domain: "myapp.internal.example.com", appName: "myapp", app: app},
})
got, err := svc.GetLabels("myapp.internal.example.com")
require.NoError(t, err)
assert.Equal(t, "myapp.internal.example.com", got.Config.Domain)
},
},
{
description: "GetLabels returns empty app when service not yet started",
run: func(t *testing.T, svc *KubernetesService) {
got, err := svc.GetLabels("anything.example.com")
require.NoError(t, err)
assert.Nil(t, got)
},
},
{
description: "UpdateFromItem parses annotations and populates cache",
run: func(t *testing.T, svc *KubernetesService) {
item := unstructured.Unstructured{}
item.SetNamespace("default")
item.SetName("test-ingress")
item.SetAnnotations(map[string]string{
"tinyauth.apps.myapp.config.domain": "myapp.example.com",
"tinyauth.apps.myapp.users.allow": "alice",
})
svc.updateFromItem(&item)
got := svc.getByDomain("myapp.example.com")
require.NotNil(t, got)
assert.Equal(t, "myapp.example.com", got.Config.Domain)
assert.Equal(t, "alice", got.Users.Allow)
},
},
{
description: "UpdateFromItem with no annotations removes existing cache entries",
run: func(t *testing.T, svc *KubernetesService) {
app := model.App{Config: model.AppConfig{Domain: "todelete.example.com"}}
svc.addIngressApps("default", "test-ingress", []ingressApp{
{domain: "todelete.example.com", appName: "todelete", app: app},
})
item := unstructured.Unstructured{}
item.SetNamespace("default")
item.SetName("test-ingress")
svc.updateFromItem(&item)
got := svc.getByDomain("todelete.example.com")
assert.Nil(t, got)
},
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
svc := &KubernetesService{
ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey),
log: log,
}
test.run(t, svc)
})
}
}
+62 -71
View File
@@ -9,69 +9,47 @@ import (
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
type LdapServiceConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct {
config LdapServiceConfig
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
isConfigured bool
log *logger.Logger
config model.Config
context context.Context
conn *ldapgo.Conn
mutex sync.RWMutex
cert *tls.Certificate
}
func NewLdapService(config LdapServiceConfig) *LdapService {
return &LdapService{
config: config,
}
}
func (ldap *LdapService) IsConfigured() bool {
return ldap.isConfigured
}
func (ldap *LdapService) Unconfigure() error {
if !ldap.isConfigured {
return nil
func NewLdapService(
log *logger.Logger,
config model.Config,
ctx context.Context,
wg *sync.WaitGroup,
) (*LdapService, error) {
if config.LDAP.Address == "" {
return nil, nil
}
if ldap.conn != nil {
if err := ldap.conn.Close(); err != nil {
return fmt.Errorf("failed to close LDAP connection: %w", err)
}
ldap := &LdapService{
log: log,
config: config,
context: ctx,
}
ldap.isConfigured = false
return nil
}
func (ldap *LdapService) Init() error {
if ldap.config.Address == "" {
ldap.isConfigured = false
return nil
}
ldap.isConfigured = true
// Check whether authentication with client certificate is possible
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey)
if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey)
if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
}
log.App.Info().Msg("LDAP mTLS authentication configured successfully")
ldap.cert = &cert
tlog.App.Info().Msg("Using LDAP with mTLS authentication")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/*
@@ -84,26 +62,39 @@ func (ldap *LdapService) Init() error {
}
*/
}
_, err := ldap.connect()
if err != nil {
return fmt.Errorf("failed to connect to LDAP server: %w", err)
return nil, fmt.Errorf("failed to connect to ldap server: %w", err)
}
go func() {
for range time.Tick(time.Duration(5) * time.Minute) {
err := ldap.heartbeat()
if err != nil {
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
wg.Go(func() {
ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine")
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := ldap.heartbeat()
if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
}
tlog.App.Info().Msg("Successfully reconnected to LDAP server")
case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
}
}
}()
})
return nil
return ldap, nil
}
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
@@ -120,13 +111,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind()
if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert},
}))
} else {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.Insecure,
conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.LDAP.Insecure,
MinVersion: tls.VersionTLS12,
}))
}
@@ -146,10 +137,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
@@ -176,7 +167,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
@@ -224,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil {
return ldap.conn.ExternalBind()
}
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword)
return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
}
func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -238,7 +229,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
}
func (ldap *LdapService) heartbeat() error {
tlog.App.Debug().Msg("Performing LDAP connection heartbeat")
ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest(
"",
@@ -260,7 +251,7 @@ func (ldap *LdapService) heartbeat() error {
}
func (ldap *LdapService) reconnect() error {
tlog.App.Info().Msg("Reconnecting to LDAP server")
ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond
+25 -16
View File
@@ -1,10 +1,13 @@
package service
import (
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
)
@@ -14,37 +17,43 @@ type OAuthServiceImpl interface {
NewRandom() string
GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (config.Claims, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
}
type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl
configs map[string]config.OAuthServiceConfig
configs map[string]model.OAuthServiceConfig
}
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{
"github": newGitHubOAuthService,
"google": newGoogleOAuthService,
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
ctx context.Context,
) *OAuthBrokerService {
service := &OAuthBrokerService{
log: log,
services: make(map[string]OAuthServiceImpl),
configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs {
for name, cfg := range configs {
if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
service.services[name] = presetFunc(cfg, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else {
broker.services[name] = NewOAuthService(cfg, name)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config")
service.services[name] = NewOAuthService(cfg, name, ctx)
service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
}
}
return nil
return service
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
+32 -22
View File
@@ -8,12 +8,13 @@ import (
"net/http"
"strconv"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
)
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
type GithubUserInfoResponse struct {
@@ -22,33 +23,33 @@ type GithubUserInfoResponse struct {
ID int `json:"id"`
}
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
return simpleReq[config.Claims](client, url, nil)
func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
return simpleReq[model.Claims](client, url, nil)
}
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user config.Claims
func githubExtractor(client *http.Client, _ string) (*model.Claims, error) {
var user model.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
return nil, err
}
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil {
return config.Claims{}, err
return nil, err
}
if len(userEmails) == 0 {
return user, errors.New("no emails found")
if len(*userEmails) == 0 {
return nil, errors.New("no emails found")
}
for _, email := range userEmails {
if email.Primary {
for _, email := range *userEmails {
if email.Primary && email.Verified {
user.Email = email.Email
break
}
@@ -56,22 +57,31 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = userEmails[0].Email
for _, email := range *userEmails {
if email.Verified {
user.Email = email.Email
break
}
}
}
if user.Email == "" {
return nil, errors.New("no verified email found")
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID)
return user, nil
return &user, nil
}
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) {
var decodedRes T
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return decodedRes, err
return nil, err
}
for key, value := range headers {
@@ -80,23 +90,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
res, err := client.Do(req)
if err != nil {
return decodedRes, err
return nil, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
return nil, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return decodedRes, err
return nil, err
}
err = json.Unmarshal(body, &decodedRes)
if err != nil {
return decodedRes, err
return nil, err
}
return decodedRes, nil
return &decodedRes, nil
}
+7 -5
View File
@@ -1,23 +1,25 @@
package service
import (
"github.com/tinyauthapp/tinyauth/internal/config"
"context"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2/endpoints"
)
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL
config.TokenURL = endpoints.Google.TokenURL
config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo"
return NewOAuthService(config, "google")
return NewOAuthService(config, "google", ctx)
}
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService {
scopes := []string{"read:user", "user:email"}
config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL
config.TokenURL = endpoints.GitHub.TokenURL
return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor)
return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor)
}
+7 -8
View File
@@ -6,21 +6,21 @@ import (
"net/http"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/oauth2"
)
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
type OAuthService struct {
serviceCfg config.OAuthServiceConfig
serviceCfg model.OAuthServiceConfig
config *oauth2.Config
ctx context.Context
userinfoExtractor UserinfoExtractor
id string
}
func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService {
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
@@ -29,8 +29,7 @@ func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService
},
},
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient)
return &OAuthService{
serviceCfg: config,
@@ -44,7 +43,7 @@ func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService
TokenURL: config.TokenURL,
},
},
ctx: ctx,
ctx: vctx,
userinfoExtractor: defaultExtractor,
id: id,
}
@@ -78,7 +77,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
}
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
}
+188 -177
View File
@@ -16,15 +16,17 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"slices"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"golang.org/x/exp/slices"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
var (
@@ -67,27 +69,27 @@ type ClaimSet struct {
}
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
Email string `json:"email,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
Address *config.AddressClaim `json:"address,omitempty"`
UpdatedAt int64 `json:"updated_at"`
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
Email string `json:"email,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
Address *model.AddressClaim `json:"address,omitempty"`
UpdatedAt int64 `json:"updated_at"`
}
type TokenResponse struct {
@@ -110,179 +112,180 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"`
}
type OIDCServiceConfig struct {
Clients map[string]config.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
SessionExpiry int
}
type OIDCService struct {
config OIDCServiceConfig
queries *repository.Queries
clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
isConfigured bool
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
queries *repository.Queries
context context.Context
clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
return &OIDCService{
config: config,
queries: queries,
}
}
func (service *OIDCService) IsConfigured() bool {
return service.isConfigured
}
func (service *OIDCService) Init() error {
func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
ctx context.Context,
wg *sync.WaitGroup) (*OIDCService, error) {
// If not configured, skip init
if len(service.config.Clients) == 0 {
service.isConfigured = false
return nil
if len(runtime.OIDCClients) == 0 {
return nil, nil
}
service.isConfigured = true
// Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer)
uissuer, err := url.Parse(runtime.AppURL)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse app url: %w", err)
}
if uissuer.Scheme != "https" {
return errors.New("issuer must be https")
return nil, errors.New("issuer must be https")
}
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required")
if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(config.OIDC.PublicKeyPath) == "" {
return nil, errors.New("private key path and public key path are required")
}
var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, err
}
if errors.Is(err, os.ErrNotExist) {
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return err
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
der := x509.MarshalPKCS1PrivateKey(privateKey)
if der == nil {
return errors.New("failed to marshal private key")
return nil, errors.New("failed to marshal private key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
})
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil {
return err
return nil, fmt.Errorf("failed to write private key to file: %w", err)
}
service.privateKey = privateKey
} else {
block, _ := pem.Decode(fprivateKey)
if block == nil {
return errors.New("failed to decode private key")
return nil, errors.New("failed to decode private key")
}
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key")
log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
service.privateKey = privateKey
}
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
var publicKey crypto.PublicKey
fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
return nil, fmt.Errorf("failed to read public key: %w", err)
}
if errors.Is(err, os.ErrNotExist) {
publicKey := service.privateKey.Public()
publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return errors.New("failed to marshal public key")
return nil, errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return err
return nil, err
}
service.publicKey = publicKey
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
return errors.New("failed to decode public key")
return nil, errors.New("failed to decode public key")
}
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key")
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey
case "PUBLIC KEY":
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
service.publicKey = publicKey.(crypto.PublicKey)
default:
return fmt.Errorf("unsupported public key type: %s", block.Type)
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
}
}
// We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]config.OIDCClientConfig)
clients := make(map[string]model.OIDCClientConfig)
for id, client := range service.config.Clients {
for id, client := range config.OIDC.Clients {
client.ID = id
if client.Name == "" {
client.Name = utils.Capitalize(client.ID)
}
service.clients[client.ClientID] = client
clients[client.ClientID] = client
}
// Load the client secrets from files if they exist
for id, client := range service.clients {
for id, client := range clients {
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
if secret != "" {
client.ClientSecret = secret
}
client.ClientSecretFile = ""
service.clients[id] = client
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client")
clients[id] = client
log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
}
return nil
// Initialize the service
service := &OIDCService{
log: log,
config: config,
runtime: runtime,
queries: queries,
context: ctx,
clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}
// Start cleanup routine
wg.Go(service.cleanupRoutine)
return service, nil
}
func (service *OIDCService) GetIssuer() string {
return service.issuer
}
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) {
client, ok := service.clients[id]
return client, ok
}
@@ -306,7 +309,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope")
}
if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
}
}
@@ -356,7 +359,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge
} else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security")
service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
}
}
@@ -366,43 +369,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
addressJSON, err := json.Marshal(userContext.Attributes.Address)
if err != nil {
return err
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
Sub: sub,
Name: userContext.Name,
Email: userContext.Email,
PreferredUsername: userContext.Username,
Name: userContext.GetName(),
Email: userContext.GetEmail(),
PreferredUsername: userContext.GetUsername(),
UpdatedAt: time.Now().Unix(),
GivenName: userContext.Attributes.GivenName,
FamilyName: userContext.Attributes.FamilyName,
MiddleName: userContext.Attributes.MiddleName,
Nickname: userContext.Attributes.Nickname,
Profile: userContext.Attributes.Profile,
Picture: userContext.Attributes.Picture,
Website: userContext.Attributes.Website,
Gender: userContext.Attributes.Gender,
Birthdate: userContext.Attributes.Birthdate,
Zoneinfo: userContext.Attributes.Zoneinfo,
Locale: userContext.Attributes.Locale,
PhoneNumber: userContext.Attributes.PhoneNumber,
Address: string(addressJSON),
}
if userContext.IsLocal() {
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
if err != nil {
return err
}
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
userInfoParams.Profile = userContext.Local.Attributes.Profile
userInfoParams.Picture = userContext.Local.Attributes.Picture
userInfoParams.Website = userContext.Local.Attributes.Website
userInfoParams.Gender = userContext.Local.Attributes.Gender
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfoParams.Locale = userContext.Local.Attributes.Locale
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfoParams.Address = string(addressJSON)
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.Provider == "ldap" {
userInfoParams.Groups = userContext.LdapGroups
if userContext.IsLDAP() {
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
}
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = userContext.OAuthGroups
if userContext.IsOAuth() {
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
}
_, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
}
@@ -444,9 +449,9 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
return oidcCode, nil
}
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
hasher := sha256.New()
@@ -510,7 +515,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil {
@@ -526,16 +531,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
}
@@ -547,7 +552,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
ClientID: client.ClientID,
Scope: codeEntry.Scope,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: codeEntry.Nonce,
CodeHash: codeEntry.CodeHash,
})
@@ -563,7 +568,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return TokenResponse{}, ErrTokenNotFound
}
return TokenResponse{}, err
@@ -584,7 +589,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
return TokenResponse{}, err
}
idToken, err := service.generateIDToken(config.OIDCClientConfig{
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, user, entry.Scope, entry.Nonce)
@@ -595,14 +600,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
refrshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix()
tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{
AccessToken: accessToken,
RefreshToken: newRefreshToken,
TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry),
ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
@@ -611,7 +616,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
})
@@ -642,7 +647,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
entry, err := service.queries.GetOidcToken(c, tokenHash)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return repository.OidcToken{}, ErrTokenNotFound
}
return repository.OidcToken{}, err
@@ -713,7 +718,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
}
if slices.Contains(scopes, "address") {
var addr config.AddressClaim
var addr model.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr
}
@@ -745,56 +750,62 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() {
// We need a context for the routine
ctx := context.Background()
func (service *OIDCService) cleanupRoutine() {
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
for range ticker.C {
currentTime := time.Now().Unix()
for {
select {
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
currentTime := time.Now().Unix()
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
if err == sql.ErrNoRows {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session")
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
}
}
+28 -9
View File
@@ -1,19 +1,22 @@
package service_test
import (
"context"
"encoding/json"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)
func newTestUser() repository.OidcUserinfo {
addr := config.AddressClaim{
addr := model.AddressClaim{
Formatted: "123 Main St",
StreetAddress: "123 Main St",
Locality: "Springfield",
@@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo {
func TestCompileUserinfo(t *testing.T) {
dir := t.TempDir()
svc := service.NewOIDCService(service.OIDCServiceConfig{
PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub",
Issuer: "https://tinyauth.example.com",
SessionExpiry: 3600,
}, nil)
require.NoError(t, svc.Init())
cfg := model.Config{
OIDC: model.OIDCConfig{
PrivateKeyPath: dir + "/key.pem",
PublicKeyPath: dir + "/key.pub",
},
Auth: model.AuthConfig{
SessionExpiry: 3600,
},
}
runtime := model.RuntimeConfig{
AppURL: "https://tinyauth.example.com",
}
log := logger.NewLogger().WithTestConfig()
log.Init()
ctx := context.TODO()
wg := &sync.WaitGroup{}
svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg)
require.NoError(t, err)
type testCase struct {
description string