feat: create oauth broker service

This commit is contained in:
Stavros
2025-08-25 19:33:52 +03:00
parent 44f35af3bf
commit dbadb096b4
12 changed files with 184 additions and 683 deletions

View File

@@ -1,7 +1,5 @@
package config package config
import "time"
type Claims struct { type Claims struct {
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
@@ -100,7 +98,6 @@ type OAuthServiceConfig struct {
TokenURL string TokenURL string
UserinfoURL string UserinfoURL string
InsecureSkipVerify bool InsecureSkipVerify bool
Name string
} }
type User struct { type User struct {
@@ -114,8 +111,6 @@ type UserSearch struct {
Type string // local, ldap or unknown Type string // local, ldap or unknown
} }
type Users []User
type SessionCookie struct { type SessionCookie struct {
Username string Username string
Name string Name string
@@ -137,12 +132,6 @@ type UserContext struct {
TotpEnabled bool TotpEnabled bool
} }
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type UnauthorizedQuery struct { type UnauthorizedQuery struct {
Username string `url:"username"` Username string `url:"username"`
Resource string `url:"resource"` Resource string `url:"resource"`

View File

@@ -5,9 +5,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"time" "time"
"tinyauth/internal/auth" "tinyauth/internal/config"
"tinyauth/internal/providers" "tinyauth/internal/service"
"tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -26,18 +25,18 @@ type OAuthControllerConfig struct {
} }
type OAuthController struct { type OAuthController struct {
Config OAuthControllerConfig Config OAuthControllerConfig
Router *gin.RouterGroup Router *gin.RouterGroup
Auth *auth.Auth Auth *service.AuthService
Providers *providers.Providers Broker *service.OAuthBrokerService
} }
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController { func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
return &OAuthController{ return &OAuthController{
Config: config, Config: config,
Router: router, Router: router,
Auth: auth, Auth: auth,
Providers: providers, Broker: broker,
} }
} }
@@ -59,9 +58,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
provider := controller.Providers.GetProvider(req.Provider) service, exists := controller.Broker.GetService(req.Provider)
if provider == nil { if !exists {
c.JSON(404, gin.H{ c.JSON(404, gin.H{
"status": 404, "status": 404,
"message": "Not Found", "message": "Not Found",
@@ -69,8 +68,8 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
state := provider.GenerateState() state := service.GenerateState()
authURL := provider.GetAuthURL(state) authURL := service.GetAuthURL(state)
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
@@ -109,20 +108,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
code := c.Query("code") code := c.Query("code")
provider := controller.Providers.GetProvider(req.Provider) service, exists := controller.Broker.GetService(req.Provider)
if provider == nil { if !exists {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return return
} }
_, err = provider.ExchangeToken(code) err = service.VerifyCode(code)
if err != nil { if err != nil {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return return
} }
user, err := controller.Providers.GetUser(req.Provider) user, err := controller.Broker.GetUser(req.Provider)
if err != nil { if err != nil {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
@@ -135,7 +134,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
} }
if !controller.Auth.EmailWhitelisted(user.Email) { if !controller.Auth.EmailWhitelisted(user.Email) {
queries, err := query.Values(types.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
@@ -156,7 +155,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
} }
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Email, Username: user.Email,
Name: name, Name: name,
Email: user.Email, Email: user.Email,
@@ -171,7 +170,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(types.RedirectQuery{ queries, err := query.Values(config.RedirectQuery{
RedirectURI: redirectURI, RedirectURI: redirectURI,
}) })

View File

@@ -4,9 +4,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"tinyauth/internal/auth" "tinyauth/internal/config"
"tinyauth/internal/docker" "tinyauth/internal/service"
"tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -24,11 +23,11 @@ type ProxyControllerConfig struct {
type ProxyController struct { type ProxyController struct {
Config ProxyControllerConfig Config ProxyControllerConfig
Router *gin.RouterGroup Router *gin.RouterGroup
Docker *docker.Docker Docker *service.DockerService
Auth *auth.Auth Auth *service.AuthService
} }
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController { func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController {
return &ProxyController{ return &ProxyController{
Config: config, Config: config,
Router: router, Router: router,
@@ -109,7 +108,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(types.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0], Resource: strings.Split(host, ".")[0],
IP: clientIP, IP: clientIP,
}) })
@@ -157,12 +156,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
var userContext types.UserContext var userContext config.UserContext
context, err := utils.GetContext(c) context, err := utils.GetContext(c)
if err != nil { if err != nil {
userContext = types.UserContext{ userContext = config.UserContext{
IsLoggedIn: false, IsLoggedIn: false,
} }
} else { } else {
@@ -185,7 +184,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(types.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0], Resource: strings.Split(host, ".")[0],
}) })
@@ -216,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(types.UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0], Resource: strings.Split(host, ".")[0],
GroupErr: true, GroupErr: true,
}) })
@@ -268,7 +267,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(types.RedirectQuery{ queries, err := query.Values(config.RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
}) })

View File

@@ -3,8 +3,8 @@ package controller
import ( import (
"fmt" "fmt"
"strings" "strings"
"tinyauth/internal/auth" "tinyauth/internal/config"
"tinyauth/internal/types" "tinyauth/internal/service"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -27,10 +27,10 @@ type UserControllerConfig struct {
type UserController struct { type UserController struct {
Config UserControllerConfig Config UserControllerConfig
Router *gin.RouterGroup Router *gin.RouterGroup
Auth *auth.Auth Auth *service.AuthService
} }
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController { func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
return &UserController{ return &UserController{
Config: config, Config: config,
Router: router, Router: router,
@@ -101,7 +101,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
user := controller.Auth.GetLocalUser(userSearch.Username) user := controller.Auth.GetLocalUser(userSearch.Username)
if user.TotpSecret != "" { if user.TotpSecret != "" {
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(req.Username), Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
@@ -118,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: req.Username, Username: req.Username,
Name: utils.Capitalize(req.Username), Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
@@ -202,7 +202,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
controller.Auth.RecordLoginAttempt(rateIdentifier, true) controller.Auth.RecordLoginAttempt(rateIdentifier, true)
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),

View File

@@ -3,9 +3,8 @@ package middleware
import ( import (
"fmt" "fmt"
"strings" "strings"
"tinyauth/internal/auth" "tinyauth/internal/config"
"tinyauth/internal/providers" "tinyauth/internal/service"
"tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -16,16 +15,16 @@ type ContextMiddlewareConfig struct {
} }
type ContextMiddleware struct { type ContextMiddleware struct {
Config ContextMiddlewareConfig Config ContextMiddlewareConfig
Auth *auth.Auth Auth *service.AuthService
Providers *providers.Providers Broker *service.OAuthBrokerService
} }
func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware { func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
return &ContextMiddleware{ return &ContextMiddleware{
Config: config, Config: config,
Auth: auth, Auth: auth,
Providers: providers, Broker: broker,
} }
} }
@@ -46,7 +45,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
} }
if cookie.TotpPending { if cookie.TotpPending {
c.Set("context", &types.UserContext{ c.Set("context", &config.UserContext{
Username: cookie.Username, Username: cookie.Username,
Name: cookie.Name, Name: cookie.Name,
Email: cookie.Email, Email: cookie.Email,
@@ -66,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
goto basic goto basic
} }
c.Set("context", &types.UserContext{ c.Set("context", &config.UserContext{
Username: cookie.Username, Username: cookie.Username,
Name: cookie.Name, Name: cookie.Name,
Email: cookie.Email, Email: cookie.Email,
@@ -76,9 +75,9 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Next() c.Next()
return return
default: default:
provider := m.Providers.GetProvider(cookie.Provider) _, exists := m.Broker.GetService(cookie.Provider)
if provider == nil { if !exists {
goto basic goto basic
} }
@@ -87,7 +86,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
goto basic goto basic
} }
c.Set("context", &types.UserContext{ c.Set("context", &config.UserContext{
Username: cookie.Username, Username: cookie.Username,
Name: cookie.Name, Name: cookie.Name,
Email: cookie.Email, Email: cookie.Email,
@@ -124,7 +123,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
case "local": case "local":
user := m.Auth.GetLocalUser(basic.Username) user := m.Auth.GetLocalUser(basic.Username)
c.Set("context", &types.UserContext{ c.Set("context", &config.UserContext{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain),
@@ -135,7 +134,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Next() c.Next()
return return
case "ldap": case "ldap":
c.Set("context", &types.UserContext{ c.Set("context", &config.UserContext{
Username: basic.Username, Username: basic.Username,
Name: utils.Capitalize(basic.Username), Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain), Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain),

View File

@@ -6,7 +6,7 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"tinyauth/internal/types" "tinyauth/internal/config"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -15,8 +15,14 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type AuthServiceConfig struct { type AuthServiceConfig struct {
Users types.Users Users []config.User
OauthWhitelist string OauthWhitelist string
SessionExpiry int SessionExpiry int
CookieSecure bool CookieSecure bool
@@ -31,7 +37,7 @@ type AuthServiceConfig struct {
type AuthService struct { type AuthService struct {
Config AuthServiceConfig Config AuthServiceConfig
Docker *DockerService Docker *DockerService
LoginAttempts map[string]*types.LoginAttempt LoginAttempts map[string]*LoginAttempt
LoginMutex sync.RWMutex LoginMutex sync.RWMutex
Store *sessions.CookieStore Store *sessions.CookieStore
LDAP *LdapService LDAP *LdapService
@@ -41,7 +47,7 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
return &AuthService{ return &AuthService{
Config: config, Config: config,
Docker: docker, Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt), LoginAttempts: make(map[string]*LoginAttempt),
LDAP: ldap, LDAP: ldap,
} }
} }
@@ -75,13 +81,13 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
return session, nil return session, nil
} }
func (auth *AuthService) SearchUser(username string) types.UserSearch { func (auth *AuthService) SearchUser(username string) config.UserSearch {
log.Debug().Str("username", username).Msg("Searching for user") log.Debug().Str("username", username).Msg("Searching for user")
// Check local users first // Check local users first
if auth.GetLocalUser(username).Username != "" { if auth.GetLocalUser(username).Username != "" {
log.Debug().Str("username", username).Msg("Found local user") log.Debug().Str("username", username).Msg("Found local user")
return types.UserSearch{ return config.UserSearch{
Username: username, Username: username,
Type: "local", Type: "local",
} }
@@ -93,20 +99,20 @@ func (auth *AuthService) SearchUser(username string) types.UserSearch {
userDN, err := auth.LDAP.Search(username) userDN, err := auth.LDAP.Search(username)
if err != nil { if err != nil {
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
return types.UserSearch{} return config.UserSearch{}
} }
return types.UserSearch{ return config.UserSearch{
Username: userDN, Username: userDN,
Type: "ldap", Type: "ldap",
} }
} }
return types.UserSearch{ return config.UserSearch{
Type: "unknown", Type: "unknown",
} }
} }
func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
// Authenticate the user based on the type // Authenticate the user based on the type
switch search.Type { switch search.Type {
case "local": case "local":
@@ -144,7 +150,7 @@ func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bo
return false return false
} }
func (auth *AuthService) GetLocalUser(username string) types.User { func (auth *AuthService) GetLocalUser(username string) config.User {
// Loop through users and return the user if the username matches // Loop through users and return the user if the username matches
log.Debug().Str("username", username).Msg("Searching for local user") log.Debug().Str("username", username).Msg("Searching for local user")
@@ -156,10 +162,10 @@ func (auth *AuthService) GetLocalUser(username string) types.User {
// If no user found, return an empty user // If no user found, return an empty user
log.Warn().Str("username", username).Msg("Local user not found") log.Warn().Str("username", username).Msg("Local user not found")
return types.User{} return config.User{}
} }
func (auth *AuthService) CheckPassword(user types.User, password string) bool { func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
} }
@@ -201,7 +207,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
// Get current attempt record or create a new one // Get current attempt record or create a new one
attempt, exists := auth.LoginAttempts[identifier] attempt, exists := auth.LoginAttempts[identifier]
if !exists { if !exists {
attempt = &types.LoginAttempt{} attempt = &LoginAttempt{}
auth.LoginAttempts[identifier] = attempt auth.LoginAttempts[identifier] = attempt
} }
@@ -229,7 +235,7 @@ func (auth *AuthService) EmailWhitelisted(email string) bool {
return utils.CheckFilter(auth.Config.OauthWhitelist, email) return utils.CheckFilter(auth.Config.OauthWhitelist, email)
} }
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
log.Debug().Msg("Creating session cookie") log.Debug().Msg("Creating session cookie")
session, err := auth.GetSession(c) session, err := auth.GetSession(c)
@@ -288,13 +294,13 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return nil return nil
} }
func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
log.Debug().Msg("Getting session cookie") log.Debug().Msg("Getting session cookie")
session, err := auth.GetSession(c) session, err := auth.GetSession(c)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
return types.SessionCookie{}, err return config.SessionCookie{}, err
} }
log.Debug().Msg("Got session") log.Debug().Msg("Got session")
@@ -311,18 +317,18 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie,
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
log.Warn().Msg("Session cookie is invalid") log.Warn().Msg("Session cookie is invalid")
auth.DeleteSessionCookie(c) auth.DeleteSessionCookie(c)
return types.SessionCookie{}, nil return config.SessionCookie{}, nil
} }
// If the session cookie has expired, delete it // If the session cookie has expired, delete it
if time.Now().Unix() > expiry { if time.Now().Unix() > expiry {
log.Warn().Msg("Session cookie expired") log.Warn().Msg("Session cookie expired")
auth.DeleteSessionCookie(c) auth.DeleteSessionCookie(c)
return types.SessionCookie{}, nil return config.SessionCookie{}, nil
} }
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
return types.SessionCookie{ return config.SessionCookie{
Username: username, Username: username,
Name: name, Name: name,
Email: email, Email: email,
@@ -337,7 +343,7 @@ func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.Config.Users) > 0 || auth.LDAP != nil return len(auth.Config.Users) > 0 || auth.LDAP != nil
} }
func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool {
if context.OAuth { if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist") log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
@@ -347,7 +353,7 @@ func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserConte
return utils.CheckFilter(labels.Users, context.Username) return utils.CheckFilter(labels.Users, context.Username)
} }
func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool {
if labels.OAuth.Groups == "" { if labels.OAuth.Groups == "" {
return true return true
} }
@@ -374,7 +380,7 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, l
return false return false
} }
func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) {
// If the label is empty, auth is enabled // If the label is empty, auth is enabled
if labels.Allowed == "" { if labels.Allowed == "" {
return true, nil return true, nil
@@ -398,18 +404,18 @@ func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, err
return true, nil return true, nil
} }
func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
username, password, ok := c.Request.BasicAuth() username, password, ok := c.Request.BasicAuth()
if !ok { if !ok {
return nil return nil
} }
return &types.User{ return &config.User{
Username: username, Username: username,
Password: password, Password: password,
} }
} }
func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
// Check if the IP is in block list // Check if the IP is in block list
for _, blocked := range labels.IP.Block { for _, blocked := range labels.IP.Block {
res, err := utils.FilterIP(blocked, ip) res, err := utils.FilterIP(blocked, ip)
@@ -446,7 +452,7 @@ func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
return true return true
} }
func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool {
// For every IP in the bypass list, check if the IP matches // For every IP in the bypass list, check if the IP matches
for _, bypassed := range labels.IP.Bypass { for _, bypassed := range labels.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)

View File

@@ -19,7 +19,6 @@ type GenericOAuthService struct {
Token *oauth2.Token Token *oauth2.Token
Verifier string Verifier string
InsecureSkipVerify bool InsecureSkipVerify bool
ServiceName string
UserinfoURL string UserinfoURL string
} }
@@ -36,7 +35,6 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi
}, },
}, },
InsecureSkipVerify: config.InsecureSkipVerify, InsecureSkipVerify: config.InsecureSkipVerify,
ServiceName: config.Name,
UserinfoURL: config.UserinfoURL, UserinfoURL: config.UserinfoURL,
} }
} }
@@ -63,10 +61,6 @@ func (generic *GenericOAuthService) Init() error {
return nil return nil
} }
func (generic *GenericOAuthService) Name() string {
return generic.ServiceName
}
func (generic *GenericOAuthService) GenerateState() string { func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128) b := make([]byte, 128)
rand.Read(b) rand.Read(b)

View File

@@ -54,10 +54,6 @@ func (github *GithubOAuthService) Init() error {
return nil return nil
} }
func (github *GithubOAuthService) Name() string {
return "github"
}
func (github *GithubOAuthService) GenerateState() string { func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128) b := make([]byte, 128)
rand.Read(b) rand.Read(b)

View File

@@ -49,10 +49,6 @@ func (google *GoogleOAuthService) Init() error {
return nil return nil
} }
func (google *GoogleOAuthService) Name() string {
return "google"
}
func (oauth *GoogleOAuthService) GenerateState() string { func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128) b := make([]byte, 128)
rand.Read(b) rand.Read(b)

View File

@@ -0,0 +1,76 @@
package service
import (
"errors"
"tinyauth/internal/config"
"github.com/rs/zerolog/log"
)
type OAuthService interface {
Init() error
GenerateState() string
GetAuthURL(state string) string
VerifyCode(code string) error
Userinfo() (config.Claims, error)
}
type OAuthBrokerService struct {
Services map[string]OAuthService
Configs map[string]config.OAuthServiceConfig
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
Services: make(map[string]OAuthService),
Configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.Configs {
switch name {
case "github":
service := NewGithubOAuthService(cfg)
broker.Services[name] = service
case "google":
service := NewGoogleOAuthService(cfg)
broker.Services[name] = service
default:
service := NewGenericOAuthService(cfg)
broker.Services[name] = service
}
}
for name, service := range broker.Services {
err := service.Init()
if err != nil {
log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
return err
}
log.Info().Msgf("Initialized OAuth service: %s", name)
}
return nil
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
services := make([]string, 0, len(broker.Services))
for name := range broker.Services {
services = append(services, name)
}
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
service, exists := broker.Services[name]
return service, exists
}
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
oauthService, exists := broker.Services[service]
if !exists {
return config.Claims{}, errors.New("oauth service not found")
}
return oauthService.Userinfo()
}

View File

@@ -11,7 +11,7 @@ import (
"os" "os"
"regexp" "regexp"
"strings" "strings"
"tinyauth/internal/types" "tinyauth/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/traefik/paerser/parser" "github.com/traefik/paerser/parser"
@@ -22,21 +22,21 @@ import (
) )
// Parses a list of comma separated users in a struct // Parses a list of comma separated users in a struct
func ParseUsers(users string) (types.Users, error) { func ParseUsers(users string) ([]config.User, error) {
log.Debug().Msg("Parsing users") log.Debug().Msg("Parsing users")
var usersParsed types.Users var usersParsed []config.User
userList := strings.Split(users, ",") userList := strings.Split(users, ",")
if len(userList) == 0 { if len(userList) == 0 {
return types.Users{}, errors.New("invalid user format") return []config.User{}, errors.New("invalid user format")
} }
for _, user := range userList { for _, user := range userList {
parsed, err := ParseUser(user) parsed, err := ParseUser(user)
if err != nil { if err != nil {
return types.Users{}, err return []config.User{}, err
} }
usersParsed = append(usersParsed, parsed) usersParsed = append(usersParsed, parsed)
} }
@@ -107,11 +107,11 @@ func GetSecret(conf string, file string) string {
} }
// Get the users from the config or file // Get the users from the config or file
func GetUsers(conf string, file string) (types.Users, error) { func GetUsers(conf string, file string) ([]config.User, error) {
var users string var users string
if conf == "" && file == "" { if conf == "" && file == "" {
return types.Users{}, nil return []config.User{}, nil
} }
if conf != "" { if conf != "" {
@@ -152,23 +152,18 @@ func ParseHeaders(headers []string) map[string]string {
} }
// Get labels parses a map of labels into a struct with only the needed labels // Get labels parses a map of labels into a struct with only the needed labels
func GetLabels(labels map[string]string) (types.Labels, error) { func GetLabels(labels map[string]string) (config.Labels, error) {
var labelsParsed types.Labels var labelsParsed config.Labels
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error parsing labels") log.Error().Err(err).Msg("Error parsing labels")
return types.Labels{}, err return config.Labels{}, err
} }
return labelsParsed, nil return labelsParsed, nil
} }
// Check if any of the OAuth providers are configured based on the client id and secret
func OAuthConfigured(config types.Config) bool {
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "")
}
// Filter helper function // Filter helper function
func Filter[T any](slice []T, test func(T) bool) (res []T) { func Filter[T any](slice []T, test func(T) bool) (res []T) {
for _, value := range slice { for _, value := range slice {
@@ -180,7 +175,7 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
} }
// Parse user // Parse user
func ParseUser(user string) (types.User, error) { func ParseUser(user string) (config.User, error) {
if strings.Contains(user, "$$") { if strings.Contains(user, "$$") {
user = strings.ReplaceAll(user, "$$", "$") user = strings.ReplaceAll(user, "$$", "$")
} }
@@ -188,23 +183,23 @@ func ParseUser(user string) (types.User, error) {
userSplit := strings.Split(user, ":") userSplit := strings.Split(user, ":")
if len(userSplit) < 2 || len(userSplit) > 3 { if len(userSplit) < 2 || len(userSplit) > 3 {
return types.User{}, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
for _, userPart := range userSplit { for _, userPart := range userSplit {
if strings.TrimSpace(userPart) == "" { if strings.TrimSpace(userPart) == "" {
return types.User{}, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
} }
if len(userSplit) == 2 { if len(userSplit) == 2 {
return types.User{ return config.User{
Username: strings.TrimSpace(userSplit[0]), Username: strings.TrimSpace(userSplit[0]),
Password: strings.TrimSpace(userSplit[1]), Password: strings.TrimSpace(userSplit[1]),
}, nil }, nil
} }
return types.User{ return config.User{
Username: strings.TrimSpace(userSplit[0]), Username: strings.TrimSpace(userSplit[0]),
Password: strings.TrimSpace(userSplit[1]), Password: strings.TrimSpace(userSplit[1]),
TotpSecret: strings.TrimSpace(userSplit[2]), TotpSecret: strings.TrimSpace(userSplit[2]),
@@ -350,17 +345,17 @@ func CoalesceToString(value any) string {
} }
} }
func GetContext(c *gin.Context) (types.UserContext, error) { func GetContext(c *gin.Context) (config.UserContext, error) {
userContextValue, exists := c.Get("context") userContextValue, exists := c.Get("context")
if !exists { if !exists {
return types.UserContext{}, errors.New("no user context in request") return config.UserContext{}, errors.New("no user context in request")
} }
userContext, ok := userContextValue.(*types.UserContext) userContext, ok := userContextValue.(*config.UserContext)
if !ok { if !ok {
return types.UserContext{}, errors.New("invalid user context in request") return config.UserContext{}, errors.New("invalid user context in request")
} }
return *userContext, nil return *userContext, nil

View File

@@ -1,548 +0,0 @@
package utils_test
import (
"fmt"
"os"
"reflect"
"testing"
"tinyauth/internal/types"
"tinyauth/internal/utils"
)
func TestParseUsers(t *testing.T) {
t.Log("Testing parse users with a valid string")
users := "user1:pass1,user2:pass2"
expected := types.Users{
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
result, err := utils.ParseUsers(users)
if err != nil {
t.Fatalf("Error parsing users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestGetUpperDomain(t *testing.T) {
t.Log("Testing get upper domain with a valid url")
url := "https://sub1.sub2.domain.com:8080"
expected := "sub2.domain.com"
result, err := utils.GetUpperDomain(url)
if err != nil {
t.Fatalf("Error getting root url: %v", err)
}
if expected != result {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestReadFile(t *testing.T) {
t.Log("Creating a test file")
err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644)
if err != nil {
t.Fatalf("Error creating test file: %v", err)
}
t.Log("Testing read file with a valid file")
data, err := utils.ReadFile("/tmp/test.txt")
if err != nil {
t.Fatalf("Error reading file: %v", err)
}
if data != "test" {
t.Fatalf("Expected test, got %v", data)
}
t.Log("Cleaning up test file")
err = os.Remove("/tmp/test.txt")
if err != nil {
t.Fatalf("Error cleaning up test file: %v", err)
}
}
func TestParseFileToLine(t *testing.T) {
t.Log("Testing parse file to line with a valid string")
content := "\nuser1:pass1\nuser2:pass2\n"
expected := "user1:pass1,user2:pass2"
result := utils.ParseFileToLine(content)
if expected != result {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestGetSecret(t *testing.T) {
t.Log("Testing get secret with an empty config and file")
conf := ""
file := "/tmp/test.txt"
expected := "test"
err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644)
if err != nil {
t.Fatalf("Error creating test file: %v", err)
}
result := utils.GetSecret(conf, file)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get secret with an empty file and a valid config")
result = utils.GetSecret(expected, "")
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get secret with both a valid config and file")
result = utils.GetSecret(expected, file)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Cleaning up test file")
err = os.Remove(file)
if err != nil {
t.Fatalf("Error cleaning up test file: %v", err)
}
}
func TestGetUsers(t *testing.T) {
t.Log("Testing get users with a config and no file")
conf := "user1:pass1,user2:pass2"
file := ""
expected := types.Users{
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
result, err := utils.GetUsers(conf, file)
if err != nil {
t.Fatalf("Error getting users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get users with a file and no config")
conf = ""
file = "/tmp/test.txt"
expected = types.Users{
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644)
if err != nil {
t.Fatalf("Error creating test file: %v", err)
}
result, err = utils.GetUsers(conf, file)
if err != nil {
t.Fatalf("Error getting users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get users with both a config and file")
conf = "user3:pass3"
expected = types.Users{
{
Username: "user3",
Password: "pass3",
},
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
result, err = utils.GetUsers(conf, file)
if err != nil {
t.Fatalf("Error getting users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Cleaning up test file")
err = os.Remove(file)
if err != nil {
t.Fatalf("Error cleaning up test file: %v", err)
}
}
func TestGetLabels(t *testing.T) {
t.Log("Testing get labels with a valid map")
labels := map[string]string{
"tinyauth.users": "user1,user2",
"tinyauth.oauth.whitelist": "/regex/",
"tinyauth.allowed": "random",
"tinyauth.headers": "X-Header=value",
"tinyauth.oauth.groups": "group1,group2",
}
expected := types.Labels{
Users: "user1,user2",
Allowed: "random",
Headers: []string{"X-Header=value"},
OAuth: types.OAuthLabels{
Whitelist: "/regex/",
Groups: "group1,group2",
},
}
result, err := utils.GetLabels(labels)
if err != nil {
t.Fatalf("Error getting labels: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestParseUser(t *testing.T) {
t.Log("Testing parse user with a valid user")
user := "user:pass:secret"
expected := types.User{
Username: "user",
Password: "pass",
TotpSecret: "secret",
}
result, err := utils.ParseUser(user)
if err != nil {
t.Fatalf("Error parsing user: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing parse user with an escaped user")
user = "user:p$$ass$$:secret"
expected = types.User{
Username: "user",
Password: "p$ass$",
TotpSecret: "secret",
}
result, err = utils.ParseUser(user)
if err != nil {
t.Fatalf("Error parsing user: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing parse user with an invalid user")
user = "user::pass"
_, err = utils.ParseUser(user)
if err == nil {
t.Fatalf("Expected error parsing user")
}
}
func TestCheckFilter(t *testing.T) {
t.Log("Testing check filter with a comma separated list")
filter := "user1,user2,user3"
str := "user1"
expected := true
result := utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with a regex filter")
filter = "/^user[0-9]+$/"
str = "user1"
expected = true
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with an empty filter")
filter = ""
str = "user1"
expected = true
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with an invalid regex filter")
filter = "/^user[0-9+$/"
str = "user1"
expected = false
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with a non matching list")
filter = "user1,user2,user3"
str = "user4"
expected = false
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestSanitizeHeader(t *testing.T) {
t.Log("Testing sanitize header with a valid string")
str := "X-Header=value"
expected := "X-Header=value"
result := utils.SanitizeHeader(str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing sanitize header with an invalid string")
str = "X-Header=val\nue"
expected = "X-Header=value"
result = utils.SanitizeHeader(str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestParseHeaders(t *testing.T) {
t.Log("Testing parse headers with a valid string")
headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"}
expected := map[string]string{
"X-Header1": "value1",
"X-Header2": "value2",
}
result := utils.ParseHeaders(headers)
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing parse headers with an invalid string")
headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"}
expected = map[string]string{"X-Header3": "value3"}
result = utils.ParseHeaders(headers)
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestParseSecretFile(t *testing.T) {
t.Log("Testing parse secret file with a valid file")
content := "\n\n \n\n\n secret \n\n \n "
expected := "secret"
result := utils.ParseSecretFile(content)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestFilterIP(t *testing.T) {
t.Log("Testing filter IP with an IP and a valid CIDR")
ip := "10.10.10.10"
filter := "10.10.10.0/24"
expected := true
result, err := utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with an IP and a valid IP")
filter = "10.10.10.10"
expected = true
result, err = utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with an IP and an non matching CIDR")
filter = "10.10.15.0/24"
expected = false
result, err = utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with a non matching IP and a valid CIDR")
filter = "10.10.10.11"
expected = false
result, err = utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with an IP and an invalid CIDR")
filter = "10.../83"
_, err = utils.FilterIP(filter, ip)
if err == nil {
t.Fatalf("Expected error filtering IP")
}
}
func TestDeriveKey(t *testing.T) {
t.Log("Testing the derive key function")
master := "master"
info := "info"
expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl"
result, err := utils.DeriveKey(master, info)
if err != nil {
t.Fatalf("Error deriving key: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestCoalesceToString(t *testing.T) {
t.Log("Testing coalesce to string with a string")
value := any("test")
expected := "test"
result := utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing coalesce to string with a slice of strings")
value = []any{any("test1"), any("test2"), any(123)}
expected = "test1,test2"
result = utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing coalesce to string with an unsupported type")
value = 12345
expected = ""
result = utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}