feat: create oauth broker service

This commit is contained in:
Stavros
2025-08-25 19:33:52 +03:00
parent 44f35af3bf
commit dbadb096b4
12 changed files with 184 additions and 683 deletions

View File

@@ -1,7 +1,5 @@
package config
import "time"
type Claims struct {
Name string `json:"name"`
Email string `json:"email"`
@@ -100,7 +98,6 @@ type OAuthServiceConfig struct {
TokenURL string
UserinfoURL string
InsecureSkipVerify bool
Name string
}
type User struct {
@@ -114,8 +111,6 @@ type UserSearch struct {
Type string // local, ldap or unknown
}
type Users []User
type SessionCookie struct {
Username string
Name string
@@ -137,12 +132,6 @@ type UserContext struct {
TotpEnabled bool
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`

View File

@@ -5,9 +5,8 @@ import (
"net/http"
"strings"
"time"
"tinyauth/internal/auth"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/service"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
@@ -26,18 +25,18 @@ type OAuthControllerConfig struct {
}
type OAuthController struct {
Config OAuthControllerConfig
Router *gin.RouterGroup
Auth *auth.Auth
Providers *providers.Providers
Config OAuthControllerConfig
Router *gin.RouterGroup
Auth *service.AuthService
Broker *service.OAuthBrokerService
}
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController {
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
return &OAuthController{
Config: config,
Router: router,
Auth: auth,
Providers: providers,
Config: config,
Router: router,
Auth: auth,
Broker: broker,
}
}
@@ -59,9 +58,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
provider := controller.Providers.GetProvider(req.Provider)
service, exists := controller.Broker.GetService(req.Provider)
if provider == nil {
if !exists {
c.JSON(404, gin.H{
"status": 404,
"message": "Not Found",
@@ -69,8 +68,8 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
state := provider.GenerateState()
authURL := provider.GetAuthURL(state)
state := service.GenerateState()
authURL := service.GetAuthURL(state)
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
redirectURI := c.Query("redirect_uri")
@@ -109,20 +108,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
code := c.Query("code")
provider := controller.Providers.GetProvider(req.Provider)
service, exists := controller.Broker.GetService(req.Provider)
if provider == nil {
if !exists {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return
}
_, err = provider.ExchangeToken(code)
err = service.VerifyCode(code)
if err != nil {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return
}
user, err := controller.Providers.GetUser(req.Provider)
user, err := controller.Broker.GetUser(req.Provider)
if err != nil {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
@@ -135,7 +134,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}
if !controller.Auth.EmailWhitelisted(user.Email) {
queries, err := query.Values(types.UnauthorizedQuery{
queries, err := query.Values(config.UnauthorizedQuery{
Username: user.Email,
})
@@ -156,7 +155,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Email,
Name: name,
Email: user.Email,
@@ -171,7 +170,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
queries, err := query.Values(types.RedirectQuery{
queries, err := query.Values(config.RedirectQuery{
RedirectURI: redirectURI,
})

View File

@@ -4,9 +4,8 @@ import (
"fmt"
"net/http"
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/docker"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/service"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
@@ -24,11 +23,11 @@ type ProxyControllerConfig struct {
type ProxyController struct {
Config ProxyControllerConfig
Router *gin.RouterGroup
Docker *docker.Docker
Auth *auth.Auth
Docker *service.DockerService
Auth *service.AuthService
}
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController {
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController {
return &ProxyController{
Config: config,
Router: router,
@@ -109,7 +108,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(types.UnauthorizedQuery{
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
IP: clientIP,
})
@@ -157,12 +156,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
var userContext types.UserContext
var userContext config.UserContext
context, err := utils.GetContext(c)
if err != nil {
userContext = types.UserContext{
userContext = config.UserContext{
IsLoggedIn: false,
}
} else {
@@ -185,7 +184,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(types.UnauthorizedQuery{
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
})
@@ -216,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(types.UnauthorizedQuery{
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
GroupErr: true,
})
@@ -268,7 +267,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(types.RedirectQuery{
queries, err := query.Values(config.RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
})

View File

@@ -3,8 +3,8 @@ package controller
import (
"fmt"
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/service"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
@@ -27,10 +27,10 @@ type UserControllerConfig struct {
type UserController struct {
Config UserControllerConfig
Router *gin.RouterGroup
Auth *auth.Auth
Auth *service.AuthService
}
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController {
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
return &UserController{
Config: config,
Router: router,
@@ -101,7 +101,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
user := controller.Auth.GetLocalUser(userSearch.Username)
if user.TotpSecret != "" {
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
@@ -118,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
@@ -202,7 +202,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),

View File

@@ -3,9 +3,8 @@ package middleware
import (
"fmt"
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/service"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
@@ -16,16 +15,16 @@ type ContextMiddlewareConfig struct {
}
type ContextMiddleware struct {
Config ContextMiddlewareConfig
Auth *auth.Auth
Providers *providers.Providers
Config ContextMiddlewareConfig
Auth *service.AuthService
Broker *service.OAuthBrokerService
}
func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware {
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
return &ContextMiddleware{
Config: config,
Auth: auth,
Providers: providers,
Config: config,
Auth: auth,
Broker: broker,
}
}
@@ -46,7 +45,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
if cookie.TotpPending {
c.Set("context", &types.UserContext{
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
@@ -66,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
goto basic
}
c.Set("context", &types.UserContext{
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
@@ -76,9 +75,9 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Next()
return
default:
provider := m.Providers.GetProvider(cookie.Provider)
_, exists := m.Broker.GetService(cookie.Provider)
if provider == nil {
if !exists {
goto basic
}
@@ -87,7 +86,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
goto basic
}
c.Set("context", &types.UserContext{
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
@@ -124,7 +123,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
case "local":
user := m.Auth.GetLocalUser(basic.Username)
c.Set("context", &types.UserContext{
c.Set("context", &config.UserContext{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain),
@@ -135,7 +134,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
c.Next()
return
case "ldap":
c.Set("context", &types.UserContext{
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain),

View File

@@ -6,7 +6,7 @@ import (
"strings"
"sync"
"time"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
@@ -15,8 +15,14 @@ import (
"golang.org/x/crypto/bcrypt"
)
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type AuthServiceConfig struct {
Users types.Users
Users []config.User
OauthWhitelist string
SessionExpiry int
CookieSecure bool
@@ -31,7 +37,7 @@ type AuthServiceConfig struct {
type AuthService struct {
Config AuthServiceConfig
Docker *DockerService
LoginAttempts map[string]*types.LoginAttempt
LoginAttempts map[string]*LoginAttempt
LoginMutex sync.RWMutex
Store *sessions.CookieStore
LDAP *LdapService
@@ -41,7 +47,7 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
return &AuthService{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
LoginAttempts: make(map[string]*LoginAttempt),
LDAP: ldap,
}
}
@@ -75,13 +81,13 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
return session, nil
}
func (auth *AuthService) SearchUser(username string) types.UserSearch {
func (auth *AuthService) SearchUser(username string) config.UserSearch {
log.Debug().Str("username", username).Msg("Searching for user")
// Check local users first
if auth.GetLocalUser(username).Username != "" {
log.Debug().Str("username", username).Msg("Found local user")
return types.UserSearch{
return config.UserSearch{
Username: username,
Type: "local",
}
@@ -93,20 +99,20 @@ func (auth *AuthService) SearchUser(username string) types.UserSearch {
userDN, err := auth.LDAP.Search(username)
if err != nil {
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
return types.UserSearch{}
return config.UserSearch{}
}
return types.UserSearch{
return config.UserSearch{
Username: userDN,
Type: "ldap",
}
}
return types.UserSearch{
return config.UserSearch{
Type: "unknown",
}
}
func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool {
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
// Authenticate the user based on the type
switch search.Type {
case "local":
@@ -144,7 +150,7 @@ func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bo
return false
}
func (auth *AuthService) GetLocalUser(username string) types.User {
func (auth *AuthService) GetLocalUser(username string) config.User {
// Loop through users and return the user if the username matches
log.Debug().Str("username", username).Msg("Searching for local user")
@@ -156,10 +162,10 @@ func (auth *AuthService) GetLocalUser(username string) types.User {
// If no user found, return an empty user
log.Warn().Str("username", username).Msg("Local user not found")
return types.User{}
return config.User{}
}
func (auth *AuthService) CheckPassword(user types.User, password string) bool {
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
@@ -201,7 +207,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
// Get current attempt record or create a new one
attempt, exists := auth.LoginAttempts[identifier]
if !exists {
attempt = &types.LoginAttempt{}
attempt = &LoginAttempt{}
auth.LoginAttempts[identifier] = attempt
}
@@ -229,7 +235,7 @@ func (auth *AuthService) EmailWhitelisted(email string) bool {
return utils.CheckFilter(auth.Config.OauthWhitelist, email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
log.Debug().Msg("Creating session cookie")
session, err := auth.GetSession(c)
@@ -288,13 +294,13 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
log.Debug().Msg("Getting session cookie")
session, err := auth.GetSession(c)
if err != nil {
log.Error().Err(err).Msg("Failed to get session")
return types.SessionCookie{}, err
return config.SessionCookie{}, err
}
log.Debug().Msg("Got session")
@@ -311,18 +317,18 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie,
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
log.Warn().Msg("Session cookie is invalid")
auth.DeleteSessionCookie(c)
return types.SessionCookie{}, nil
return config.SessionCookie{}, nil
}
// If the session cookie has expired, delete it
if time.Now().Unix() > expiry {
log.Warn().Msg("Session cookie expired")
auth.DeleteSessionCookie(c)
return types.SessionCookie{}, nil
return config.SessionCookie{}, nil
}
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
return types.SessionCookie{
return config.SessionCookie{
Username: username,
Name: name,
Email: email,
@@ -337,7 +343,7 @@ func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.Config.Users) > 0 || auth.LDAP != nil
}
func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool {
if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
@@ -347,7 +353,7 @@ func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserConte
return utils.CheckFilter(labels.Users, context.Username)
}
func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool {
if labels.OAuth.Groups == "" {
return true
}
@@ -374,7 +380,7 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, l
return false
}
func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) {
func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) {
// If the label is empty, auth is enabled
if labels.Allowed == "" {
return true, nil
@@ -398,18 +404,18 @@ func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, err
return true, nil
}
func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User {
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
username, password, ok := c.Request.BasicAuth()
if !ok {
return nil
}
return &types.User{
return &config.User{
Username: username,
Password: password,
}
}
func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
// Check if the IP is in block list
for _, blocked := range labels.IP.Block {
res, err := utils.FilterIP(blocked, ip)
@@ -446,7 +452,7 @@ func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
return true
}
func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool {
func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool {
// For every IP in the bypass list, check if the IP matches
for _, bypassed := range labels.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)

View File

@@ -19,7 +19,6 @@ type GenericOAuthService struct {
Token *oauth2.Token
Verifier string
InsecureSkipVerify bool
ServiceName string
UserinfoURL string
}
@@ -36,7 +35,6 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi
},
},
InsecureSkipVerify: config.InsecureSkipVerify,
ServiceName: config.Name,
UserinfoURL: config.UserinfoURL,
}
}
@@ -63,10 +61,6 @@ func (generic *GenericOAuthService) Init() error {
return nil
}
func (generic *GenericOAuthService) Name() string {
return generic.ServiceName
}
func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)

View File

@@ -54,10 +54,6 @@ func (github *GithubOAuthService) Init() error {
return nil
}
func (github *GithubOAuthService) Name() string {
return "github"
}
func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)

View File

@@ -49,10 +49,6 @@ func (google *GoogleOAuthService) Init() error {
return nil
}
func (google *GoogleOAuthService) Name() string {
return "google"
}
func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)

View File

@@ -0,0 +1,76 @@
package service
import (
"errors"
"tinyauth/internal/config"
"github.com/rs/zerolog/log"
)
type OAuthService interface {
Init() error
GenerateState() string
GetAuthURL(state string) string
VerifyCode(code string) error
Userinfo() (config.Claims, error)
}
type OAuthBrokerService struct {
Services map[string]OAuthService
Configs map[string]config.OAuthServiceConfig
}
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{
Services: make(map[string]OAuthService),
Configs: configs,
}
}
func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.Configs {
switch name {
case "github":
service := NewGithubOAuthService(cfg)
broker.Services[name] = service
case "google":
service := NewGoogleOAuthService(cfg)
broker.Services[name] = service
default:
service := NewGenericOAuthService(cfg)
broker.Services[name] = service
}
}
for name, service := range broker.Services {
err := service.Init()
if err != nil {
log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
return err
}
log.Info().Msgf("Initialized OAuth service: %s", name)
}
return nil
}
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
services := make([]string, 0, len(broker.Services))
for name := range broker.Services {
services = append(services, name)
}
return services
}
func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
service, exists := broker.Services[name]
return service, exists
}
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
oauthService, exists := broker.Services[service]
if !exists {
return config.Claims{}, errors.New("oauth service not found")
}
return oauthService.Userinfo()
}

View File

@@ -11,7 +11,7 @@ import (
"os"
"regexp"
"strings"
"tinyauth/internal/types"
"tinyauth/internal/config"
"github.com/gin-gonic/gin"
"github.com/traefik/paerser/parser"
@@ -22,21 +22,21 @@ import (
)
// Parses a list of comma separated users in a struct
func ParseUsers(users string) (types.Users, error) {
func ParseUsers(users string) ([]config.User, error) {
log.Debug().Msg("Parsing users")
var usersParsed types.Users
var usersParsed []config.User
userList := strings.Split(users, ",")
if len(userList) == 0 {
return types.Users{}, errors.New("invalid user format")
return []config.User{}, errors.New("invalid user format")
}
for _, user := range userList {
parsed, err := ParseUser(user)
if err != nil {
return types.Users{}, err
return []config.User{}, err
}
usersParsed = append(usersParsed, parsed)
}
@@ -107,11 +107,11 @@ func GetSecret(conf string, file string) string {
}
// Get the users from the config or file
func GetUsers(conf string, file string) (types.Users, error) {
func GetUsers(conf string, file string) ([]config.User, error) {
var users string
if conf == "" && file == "" {
return types.Users{}, nil
return []config.User{}, nil
}
if conf != "" {
@@ -152,23 +152,18 @@ func ParseHeaders(headers []string) map[string]string {
}
// Get labels parses a map of labels into a struct with only the needed labels
func GetLabels(labels map[string]string) (types.Labels, error) {
var labelsParsed types.Labels
func GetLabels(labels map[string]string) (config.Labels, error) {
var labelsParsed config.Labels
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
if err != nil {
log.Error().Err(err).Msg("Error parsing labels")
return types.Labels{}, err
return config.Labels{}, err
}
return labelsParsed, nil
}
// Check if any of the OAuth providers are configured based on the client id and secret
func OAuthConfigured(config types.Config) bool {
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "")
}
// Filter helper function
func Filter[T any](slice []T, test func(T) bool) (res []T) {
for _, value := range slice {
@@ -180,7 +175,7 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
}
// Parse user
func ParseUser(user string) (types.User, error) {
func ParseUser(user string) (config.User, error) {
if strings.Contains(user, "$$") {
user = strings.ReplaceAll(user, "$$", "$")
}
@@ -188,23 +183,23 @@ func ParseUser(user string) (types.User, error) {
userSplit := strings.Split(user, ":")
if len(userSplit) < 2 || len(userSplit) > 3 {
return types.User{}, errors.New("invalid user format")
return config.User{}, errors.New("invalid user format")
}
for _, userPart := range userSplit {
if strings.TrimSpace(userPart) == "" {
return types.User{}, errors.New("invalid user format")
return config.User{}, errors.New("invalid user format")
}
}
if len(userSplit) == 2 {
return types.User{
return config.User{
Username: strings.TrimSpace(userSplit[0]),
Password: strings.TrimSpace(userSplit[1]),
}, nil
}
return types.User{
return config.User{
Username: strings.TrimSpace(userSplit[0]),
Password: strings.TrimSpace(userSplit[1]),
TotpSecret: strings.TrimSpace(userSplit[2]),
@@ -350,17 +345,17 @@ func CoalesceToString(value any) string {
}
}
func GetContext(c *gin.Context) (types.UserContext, error) {
func GetContext(c *gin.Context) (config.UserContext, error) {
userContextValue, exists := c.Get("context")
if !exists {
return types.UserContext{}, errors.New("no user context in request")
return config.UserContext{}, errors.New("no user context in request")
}
userContext, ok := userContextValue.(*types.UserContext)
userContext, ok := userContextValue.(*config.UserContext)
if !ok {
return types.UserContext{}, errors.New("invalid user context in request")
return config.UserContext{}, errors.New("invalid user context in request")
}
return *userContext, nil

View File

@@ -1,548 +0,0 @@
package utils_test
import (
"fmt"
"os"
"reflect"
"testing"
"tinyauth/internal/types"
"tinyauth/internal/utils"
)
func TestParseUsers(t *testing.T) {
t.Log("Testing parse users with a valid string")
users := "user1:pass1,user2:pass2"
expected := types.Users{
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
result, err := utils.ParseUsers(users)
if err != nil {
t.Fatalf("Error parsing users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestGetUpperDomain(t *testing.T) {
t.Log("Testing get upper domain with a valid url")
url := "https://sub1.sub2.domain.com:8080"
expected := "sub2.domain.com"
result, err := utils.GetUpperDomain(url)
if err != nil {
t.Fatalf("Error getting root url: %v", err)
}
if expected != result {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestReadFile(t *testing.T) {
t.Log("Creating a test file")
err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644)
if err != nil {
t.Fatalf("Error creating test file: %v", err)
}
t.Log("Testing read file with a valid file")
data, err := utils.ReadFile("/tmp/test.txt")
if err != nil {
t.Fatalf("Error reading file: %v", err)
}
if data != "test" {
t.Fatalf("Expected test, got %v", data)
}
t.Log("Cleaning up test file")
err = os.Remove("/tmp/test.txt")
if err != nil {
t.Fatalf("Error cleaning up test file: %v", err)
}
}
func TestParseFileToLine(t *testing.T) {
t.Log("Testing parse file to line with a valid string")
content := "\nuser1:pass1\nuser2:pass2\n"
expected := "user1:pass1,user2:pass2"
result := utils.ParseFileToLine(content)
if expected != result {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestGetSecret(t *testing.T) {
t.Log("Testing get secret with an empty config and file")
conf := ""
file := "/tmp/test.txt"
expected := "test"
err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644)
if err != nil {
t.Fatalf("Error creating test file: %v", err)
}
result := utils.GetSecret(conf, file)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get secret with an empty file and a valid config")
result = utils.GetSecret(expected, "")
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get secret with both a valid config and file")
result = utils.GetSecret(expected, file)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Cleaning up test file")
err = os.Remove(file)
if err != nil {
t.Fatalf("Error cleaning up test file: %v", err)
}
}
func TestGetUsers(t *testing.T) {
t.Log("Testing get users with a config and no file")
conf := "user1:pass1,user2:pass2"
file := ""
expected := types.Users{
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
result, err := utils.GetUsers(conf, file)
if err != nil {
t.Fatalf("Error getting users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get users with a file and no config")
conf = ""
file = "/tmp/test.txt"
expected = types.Users{
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644)
if err != nil {
t.Fatalf("Error creating test file: %v", err)
}
result, err = utils.GetUsers(conf, file)
if err != nil {
t.Fatalf("Error getting users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing get users with both a config and file")
conf = "user3:pass3"
expected = types.Users{
{
Username: "user3",
Password: "pass3",
},
{
Username: "user1",
Password: "pass1",
},
{
Username: "user2",
Password: "pass2",
},
}
result, err = utils.GetUsers(conf, file)
if err != nil {
t.Fatalf("Error getting users: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Cleaning up test file")
err = os.Remove(file)
if err != nil {
t.Fatalf("Error cleaning up test file: %v", err)
}
}
func TestGetLabels(t *testing.T) {
t.Log("Testing get labels with a valid map")
labels := map[string]string{
"tinyauth.users": "user1,user2",
"tinyauth.oauth.whitelist": "/regex/",
"tinyauth.allowed": "random",
"tinyauth.headers": "X-Header=value",
"tinyauth.oauth.groups": "group1,group2",
}
expected := types.Labels{
Users: "user1,user2",
Allowed: "random",
Headers: []string{"X-Header=value"},
OAuth: types.OAuthLabels{
Whitelist: "/regex/",
Groups: "group1,group2",
},
}
result, err := utils.GetLabels(labels)
if err != nil {
t.Fatalf("Error getting labels: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestParseUser(t *testing.T) {
t.Log("Testing parse user with a valid user")
user := "user:pass:secret"
expected := types.User{
Username: "user",
Password: "pass",
TotpSecret: "secret",
}
result, err := utils.ParseUser(user)
if err != nil {
t.Fatalf("Error parsing user: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing parse user with an escaped user")
user = "user:p$$ass$$:secret"
expected = types.User{
Username: "user",
Password: "p$ass$",
TotpSecret: "secret",
}
result, err = utils.ParseUser(user)
if err != nil {
t.Fatalf("Error parsing user: %v", err)
}
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing parse user with an invalid user")
user = "user::pass"
_, err = utils.ParseUser(user)
if err == nil {
t.Fatalf("Expected error parsing user")
}
}
func TestCheckFilter(t *testing.T) {
t.Log("Testing check filter with a comma separated list")
filter := "user1,user2,user3"
str := "user1"
expected := true
result := utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with a regex filter")
filter = "/^user[0-9]+$/"
str = "user1"
expected = true
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with an empty filter")
filter = ""
str = "user1"
expected = true
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with an invalid regex filter")
filter = "/^user[0-9+$/"
str = "user1"
expected = false
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing check filter with a non matching list")
filter = "user1,user2,user3"
str = "user4"
expected = false
result = utils.CheckFilter(filter, str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestSanitizeHeader(t *testing.T) {
t.Log("Testing sanitize header with a valid string")
str := "X-Header=value"
expected := "X-Header=value"
result := utils.SanitizeHeader(str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing sanitize header with an invalid string")
str = "X-Header=val\nue"
expected = "X-Header=value"
result = utils.SanitizeHeader(str)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestParseHeaders(t *testing.T) {
t.Log("Testing parse headers with a valid string")
headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"}
expected := map[string]string{
"X-Header1": "value1",
"X-Header2": "value2",
}
result := utils.ParseHeaders(headers)
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing parse headers with an invalid string")
headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"}
expected = map[string]string{"X-Header3": "value3"}
result = utils.ParseHeaders(headers)
if !reflect.DeepEqual(expected, result) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestParseSecretFile(t *testing.T) {
t.Log("Testing parse secret file with a valid file")
content := "\n\n \n\n\n secret \n\n \n "
expected := "secret"
result := utils.ParseSecretFile(content)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestFilterIP(t *testing.T) {
t.Log("Testing filter IP with an IP and a valid CIDR")
ip := "10.10.10.10"
filter := "10.10.10.0/24"
expected := true
result, err := utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with an IP and a valid IP")
filter = "10.10.10.10"
expected = true
result, err = utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with an IP and an non matching CIDR")
filter = "10.10.15.0/24"
expected = false
result, err = utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with a non matching IP and a valid CIDR")
filter = "10.10.10.11"
expected = false
result, err = utils.FilterIP(filter, ip)
if err != nil {
t.Fatalf("Error filtering IP: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing filter IP with an IP and an invalid CIDR")
filter = "10.../83"
_, err = utils.FilterIP(filter, ip)
if err == nil {
t.Fatalf("Expected error filtering IP")
}
}
func TestDeriveKey(t *testing.T) {
t.Log("Testing the derive key function")
master := "master"
info := "info"
expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl"
result, err := utils.DeriveKey(master, info)
if err != nil {
t.Fatalf("Error deriving key: %v", err)
}
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestCoalesceToString(t *testing.T) {
t.Log("Testing coalesce to string with a string")
value := any("test")
expected := "test"
result := utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing coalesce to string with a slice of strings")
value = []any{any("test1"), any("test2"), any(123)}
expected = "test1,test2"
result = utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing coalesce to string with an unsupported type")
value = 12345
expected = ""
result = utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}