feat: create oauth broker service

This commit is contained in:
Stavros
2025-08-25 19:33:52 +03:00
parent 44f35af3bf
commit dbadb096b4
12 changed files with 184 additions and 683 deletions

View File

@@ -6,7 +6,7 @@ import (
"strings"
"sync"
"time"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
@@ -15,8 +15,14 @@ import (
"golang.org/x/crypto/bcrypt"
)
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type AuthServiceConfig struct {
Users types.Users
Users []config.User
OauthWhitelist string
SessionExpiry int
CookieSecure bool
@@ -31,7 +37,7 @@ type AuthServiceConfig struct {
type AuthService struct {
Config AuthServiceConfig
Docker *DockerService
LoginAttempts map[string]*types.LoginAttempt
LoginAttempts map[string]*LoginAttempt
LoginMutex sync.RWMutex
Store *sessions.CookieStore
LDAP *LdapService
@@ -41,7 +47,7 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
return &AuthService{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
LoginAttempts: make(map[string]*LoginAttempt),
LDAP: ldap,
}
}
@@ -75,13 +81,13 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
return session, nil
}
func (auth *AuthService) SearchUser(username string) types.UserSearch {
func (auth *AuthService) SearchUser(username string) config.UserSearch {
log.Debug().Str("username", username).Msg("Searching for user")
// Check local users first
if auth.GetLocalUser(username).Username != "" {
log.Debug().Str("username", username).Msg("Found local user")
return types.UserSearch{
return config.UserSearch{
Username: username,
Type: "local",
}
@@ -93,20 +99,20 @@ func (auth *AuthService) SearchUser(username string) types.UserSearch {
userDN, err := auth.LDAP.Search(username)
if err != nil {
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
return types.UserSearch{}
return config.UserSearch{}
}
return types.UserSearch{
return config.UserSearch{
Username: userDN,
Type: "ldap",
}
}
return types.UserSearch{
return config.UserSearch{
Type: "unknown",
}
}
func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool {
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
// Authenticate the user based on the type
switch search.Type {
case "local":
@@ -144,7 +150,7 @@ func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bo
return false
}
func (auth *AuthService) GetLocalUser(username string) types.User {
func (auth *AuthService) GetLocalUser(username string) config.User {
// Loop through users and return the user if the username matches
log.Debug().Str("username", username).Msg("Searching for local user")
@@ -156,10 +162,10 @@ func (auth *AuthService) GetLocalUser(username string) types.User {
// If no user found, return an empty user
log.Warn().Str("username", username).Msg("Local user not found")
return types.User{}
return config.User{}
}
func (auth *AuthService) CheckPassword(user types.User, password string) bool {
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
@@ -201,7 +207,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
// Get current attempt record or create a new one
attempt, exists := auth.LoginAttempts[identifier]
if !exists {
attempt = &types.LoginAttempt{}
attempt = &LoginAttempt{}
auth.LoginAttempts[identifier] = attempt
}
@@ -229,7 +235,7 @@ func (auth *AuthService) EmailWhitelisted(email string) bool {
return utils.CheckFilter(auth.Config.OauthWhitelist, email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
log.Debug().Msg("Creating session cookie")
session, err := auth.GetSession(c)
@@ -288,13 +294,13 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
log.Debug().Msg("Getting session cookie")
session, err := auth.GetSession(c)
if err != nil {
log.Error().Err(err).Msg("Failed to get session")
return types.SessionCookie{}, err
return config.SessionCookie{}, err
}
log.Debug().Msg("Got session")
@@ -311,18 +317,18 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie,
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
log.Warn().Msg("Session cookie is invalid")
auth.DeleteSessionCookie(c)
return types.SessionCookie{}, nil
return config.SessionCookie{}, nil
}
// If the session cookie has expired, delete it
if time.Now().Unix() > expiry {
log.Warn().Msg("Session cookie expired")
auth.DeleteSessionCookie(c)
return types.SessionCookie{}, nil
return config.SessionCookie{}, nil
}
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
return types.SessionCookie{
return config.SessionCookie{
Username: username,
Name: name,
Email: email,
@@ -337,7 +343,7 @@ func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.Config.Users) > 0 || auth.LDAP != nil
}
func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool {
if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
@@ -347,7 +353,7 @@ func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserConte
return utils.CheckFilter(labels.Users, context.Username)
}
func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool {
if labels.OAuth.Groups == "" {
return true
}
@@ -374,7 +380,7 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, l
return false
}
func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) {
func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) {
// If the label is empty, auth is enabled
if labels.Allowed == "" {
return true, nil
@@ -398,18 +404,18 @@ func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, err
return true, nil
}
func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User {
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
username, password, ok := c.Request.BasicAuth()
if !ok {
return nil
}
return &types.User{
return &config.User{
Username: username,
Password: password,
}
}
func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
// Check if the IP is in block list
for _, blocked := range labels.IP.Block {
res, err := utils.FilterIP(blocked, ip)
@@ -446,7 +452,7 @@ func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
return true
}
func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool {
func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool {
// For every IP in the bypass list, check if the IP matches
for _, bypassed := range labels.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)

View File

@@ -19,7 +19,6 @@ type GenericOAuthService struct {
Token *oauth2.Token
Verifier string
InsecureSkipVerify bool
ServiceName string
UserinfoURL string
}
@@ -36,7 +35,6 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi
},
},
InsecureSkipVerify: config.InsecureSkipVerify,
ServiceName: config.Name,
UserinfoURL: config.UserinfoURL,
}
}
@@ -63,10 +61,6 @@ func (generic *GenericOAuthService) Init() error {
return nil
}
func (generic *GenericOAuthService) Name() string {
return generic.ServiceName
}
func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)

View File

@@ -54,10 +54,6 @@ func (github *GithubOAuthService) Init() error {
return nil
}
func (github *GithubOAuthService) Name() string {
return "github"
}
func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)

View File

@@ -49,10 +49,6 @@ func (google *GoogleOAuthService) Init() error {
return nil
}
func (google *GoogleOAuthService) Name() string {
return "google"
}
func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)

View File

@@ -0,0 +1,76 @@
package service
import (
"errors"
"tinyauth/internal/config"
"github.com/rs/zerolog/log"
)
type OAuthService interface {
Init() error
GenerateState() string
GetAuthURL(state string) string
VerifyCode(code string) error
Userinfo() (config.Claims, error)
}
type OAuthBrokerService struct {
Services map[string]OAuthService
Configs map[string]config.OAuthServiceConfig
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
Services: make(map[string]OAuthService),
Configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.Configs {
switch name {
case "github":
service := NewGithubOAuthService(cfg)
broker.Services[name] = service
case "google":
service := NewGoogleOAuthService(cfg)
broker.Services[name] = service
default:
service := NewGenericOAuthService(cfg)
broker.Services[name] = service
}
}
for name, service := range broker.Services {
err := service.Init()
if err != nil {
log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
return err
}
log.Info().Msgf("Initialized OAuth service: %s", name)
}
return nil
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
services := make([]string, 0, len(broker.Services))
for name := range broker.Services {
services = append(services, name)
}
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
service, exists := broker.Services[name]
return service, exists
}
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
oauthService, exists := broker.Services[service]
if !exists {
return config.Claims{}, errors.New("oauth service not found")
}
return oauthService.Userinfo()
}