mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-29 05:05:42 +00:00
feat: create oauth broker service
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
type Claims struct {
|
type Claims struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@@ -100,7 +98,6 @@ type OAuthServiceConfig struct {
|
|||||||
TokenURL string
|
TokenURL string
|
||||||
UserinfoURL string
|
UserinfoURL string
|
||||||
InsecureSkipVerify bool
|
InsecureSkipVerify bool
|
||||||
Name string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
@@ -114,8 +111,6 @@ type UserSearch struct {
|
|||||||
Type string // local, ldap or unknown
|
Type string // local, ldap or unknown
|
||||||
}
|
}
|
||||||
|
|
||||||
type Users []User
|
|
||||||
|
|
||||||
type SessionCookie struct {
|
type SessionCookie struct {
|
||||||
Username string
|
Username string
|
||||||
Name string
|
Name string
|
||||||
@@ -137,12 +132,6 @@ type UserContext struct {
|
|||||||
TotpEnabled bool
|
TotpEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type LoginAttempt struct {
|
|
||||||
FailedAttempts int
|
|
||||||
LastAttempt time.Time
|
|
||||||
LockedUntil time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type UnauthorizedQuery struct {
|
type UnauthorizedQuery struct {
|
||||||
Username string `url:"username"`
|
Username string `url:"username"`
|
||||||
Resource string `url:"resource"`
|
Resource string `url:"resource"`
|
||||||
|
|||||||
@@ -5,9 +5,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"tinyauth/internal/auth"
|
"tinyauth/internal/config"
|
||||||
"tinyauth/internal/providers"
|
"tinyauth/internal/service"
|
||||||
"tinyauth/internal/types"
|
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -26,18 +25,18 @@ type OAuthControllerConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthController struct {
|
type OAuthController struct {
|
||||||
Config OAuthControllerConfig
|
Config OAuthControllerConfig
|
||||||
Router *gin.RouterGroup
|
Router *gin.RouterGroup
|
||||||
Auth *auth.Auth
|
Auth *service.AuthService
|
||||||
Providers *providers.Providers
|
Broker *service.OAuthBrokerService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController {
|
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
|
||||||
return &OAuthController{
|
return &OAuthController{
|
||||||
Config: config,
|
Config: config,
|
||||||
Router: router,
|
Router: router,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
Providers: providers,
|
Broker: broker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,9 +58,9 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
provider := controller.Providers.GetProvider(req.Provider)
|
service, exists := controller.Broker.GetService(req.Provider)
|
||||||
|
|
||||||
if provider == nil {
|
if !exists {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
"message": "Not Found",
|
"message": "Not Found",
|
||||||
@@ -69,8 +68,8 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
state := provider.GenerateState()
|
state := service.GenerateState()
|
||||||
authURL := provider.GetAuthURL(state)
|
authURL := service.GetAuthURL(state)
|
||||||
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
||||||
|
|
||||||
redirectURI := c.Query("redirect_uri")
|
redirectURI := c.Query("redirect_uri")
|
||||||
@@ -109,20 +108,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
||||||
|
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
provider := controller.Providers.GetProvider(req.Provider)
|
service, exists := controller.Broker.GetService(req.Provider)
|
||||||
|
|
||||||
if provider == nil {
|
if !exists {
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = provider.ExchangeToken(code)
|
err = service.VerifyCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := controller.Providers.GetUser(req.Provider)
|
user, err := controller.Broker.GetUser(req.Provider)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
||||||
@@ -135,7 +134,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !controller.Auth.EmailWhitelisted(user.Email) {
|
if !controller.Auth.EmailWhitelisted(user.Email) {
|
||||||
queries, err := query.Values(types.UnauthorizedQuery{
|
queries, err := query.Values(config.UnauthorizedQuery{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -156,7 +155,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: user.Email,
|
Username: user.Email,
|
||||||
Name: name,
|
Name: name,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
@@ -171,7 +170,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(types.RedirectQuery{
|
queries, err := query.Values(config.RedirectQuery{
|
||||||
RedirectURI: redirectURI,
|
RedirectURI: redirectURI,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/auth"
|
"tinyauth/internal/config"
|
||||||
"tinyauth/internal/docker"
|
"tinyauth/internal/service"
|
||||||
"tinyauth/internal/types"
|
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -24,11 +23,11 @@ type ProxyControllerConfig struct {
|
|||||||
type ProxyController struct {
|
type ProxyController struct {
|
||||||
Config ProxyControllerConfig
|
Config ProxyControllerConfig
|
||||||
Router *gin.RouterGroup
|
Router *gin.RouterGroup
|
||||||
Docker *docker.Docker
|
Docker *service.DockerService
|
||||||
Auth *auth.Auth
|
Auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController {
|
func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController {
|
||||||
return &ProxyController{
|
return &ProxyController{
|
||||||
Config: config,
|
Config: config,
|
||||||
Router: router,
|
Router: router,
|
||||||
@@ -109,7 +108,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(types.UnauthorizedQuery{
|
queries, err := query.Values(config.UnauthorizedQuery{
|
||||||
Resource: strings.Split(host, ".")[0],
|
Resource: strings.Split(host, ".")[0],
|
||||||
IP: clientIP,
|
IP: clientIP,
|
||||||
})
|
})
|
||||||
@@ -157,12 +156,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var userContext types.UserContext
|
var userContext config.UserContext
|
||||||
|
|
||||||
context, err := utils.GetContext(c)
|
context, err := utils.GetContext(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
userContext = types.UserContext{
|
userContext = config.UserContext{
|
||||||
IsLoggedIn: false,
|
IsLoggedIn: false,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -185,7 +184,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(types.UnauthorizedQuery{
|
queries, err := query.Values(config.UnauthorizedQuery{
|
||||||
Resource: strings.Split(host, ".")[0],
|
Resource: strings.Split(host, ".")[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -216,7 +215,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(types.UnauthorizedQuery{
|
queries, err := query.Values(config.UnauthorizedQuery{
|
||||||
Resource: strings.Split(host, ".")[0],
|
Resource: strings.Split(host, ".")[0],
|
||||||
GroupErr: true,
|
GroupErr: true,
|
||||||
})
|
})
|
||||||
@@ -268,7 +267,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queries, err := query.Values(types.RedirectQuery{
|
queries, err := query.Values(config.RedirectQuery{
|
||||||
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
|
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/auth"
|
"tinyauth/internal/config"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/service"
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -27,10 +27,10 @@ type UserControllerConfig struct {
|
|||||||
type UserController struct {
|
type UserController struct {
|
||||||
Config UserControllerConfig
|
Config UserControllerConfig
|
||||||
Router *gin.RouterGroup
|
Router *gin.RouterGroup
|
||||||
Auth *auth.Auth
|
Auth *service.AuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController {
|
func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController {
|
||||||
return &UserController{
|
return &UserController{
|
||||||
Config: config,
|
Config: config,
|
||||||
Router: router,
|
Router: router,
|
||||||
@@ -101,7 +101,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
user := controller.Auth.GetLocalUser(userSearch.Username)
|
user := controller.Auth.GetLocalUser(userSearch.Username)
|
||||||
|
|
||||||
if user.TotpSecret != "" {
|
if user.TotpSecret != "" {
|
||||||
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
||||||
@@ -118,7 +118,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
||||||
@@ -202,7 +202,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
|
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
|
||||||
|
|
||||||
controller.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/auth"
|
"tinyauth/internal/config"
|
||||||
"tinyauth/internal/providers"
|
"tinyauth/internal/service"
|
||||||
"tinyauth/internal/types"
|
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -16,16 +15,16 @@ type ContextMiddlewareConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddleware struct {
|
||||||
Config ContextMiddlewareConfig
|
Config ContextMiddlewareConfig
|
||||||
Auth *auth.Auth
|
Auth *service.AuthService
|
||||||
Providers *providers.Providers
|
Broker *service.OAuthBrokerService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware {
|
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
|
||||||
return &ContextMiddleware{
|
return &ContextMiddleware{
|
||||||
Config: config,
|
Config: config,
|
||||||
Auth: auth,
|
Auth: auth,
|
||||||
Providers: providers,
|
Broker: broker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +45,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cookie.TotpPending {
|
if cookie.TotpPending {
|
||||||
c.Set("context", &types.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: cookie.Username,
|
Username: cookie.Username,
|
||||||
Name: cookie.Name,
|
Name: cookie.Name,
|
||||||
Email: cookie.Email,
|
Email: cookie.Email,
|
||||||
@@ -66,7 +65,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
goto basic
|
goto basic
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set("context", &types.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: cookie.Username,
|
Username: cookie.Username,
|
||||||
Name: cookie.Name,
|
Name: cookie.Name,
|
||||||
Email: cookie.Email,
|
Email: cookie.Email,
|
||||||
@@ -76,9 +75,9 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
provider := m.Providers.GetProvider(cookie.Provider)
|
_, exists := m.Broker.GetService(cookie.Provider)
|
||||||
|
|
||||||
if provider == nil {
|
if !exists {
|
||||||
goto basic
|
goto basic
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +86,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
goto basic
|
goto basic
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set("context", &types.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: cookie.Username,
|
Username: cookie.Username,
|
||||||
Name: cookie.Name,
|
Name: cookie.Name,
|
||||||
Email: cookie.Email,
|
Email: cookie.Email,
|
||||||
@@ -124,7 +123,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
case "local":
|
case "local":
|
||||||
user := m.Auth.GetLocalUser(basic.Username)
|
user := m.Auth.GetLocalUser(basic.Username)
|
||||||
|
|
||||||
c.Set("context", &types.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain),
|
||||||
@@ -135,7 +134,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
case "ldap":
|
case "ldap":
|
||||||
c.Set("context", &types.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: basic.Username,
|
Username: basic.Username,
|
||||||
Name: utils.Capitalize(basic.Username),
|
Name: utils.Capitalize(basic.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain),
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/config"
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -15,8 +15,14 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type LoginAttempt struct {
|
||||||
|
FailedAttempts int
|
||||||
|
LastAttempt time.Time
|
||||||
|
LockedUntil time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type AuthServiceConfig struct {
|
type AuthServiceConfig struct {
|
||||||
Users types.Users
|
Users []config.User
|
||||||
OauthWhitelist string
|
OauthWhitelist string
|
||||||
SessionExpiry int
|
SessionExpiry int
|
||||||
CookieSecure bool
|
CookieSecure bool
|
||||||
@@ -31,7 +37,7 @@ type AuthServiceConfig struct {
|
|||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
Config AuthServiceConfig
|
Config AuthServiceConfig
|
||||||
Docker *DockerService
|
Docker *DockerService
|
||||||
LoginAttempts map[string]*types.LoginAttempt
|
LoginAttempts map[string]*LoginAttempt
|
||||||
LoginMutex sync.RWMutex
|
LoginMutex sync.RWMutex
|
||||||
Store *sessions.CookieStore
|
Store *sessions.CookieStore
|
||||||
LDAP *LdapService
|
LDAP *LdapService
|
||||||
@@ -41,7 +47,7 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
|
|||||||
return &AuthService{
|
return &AuthService{
|
||||||
Config: config,
|
Config: config,
|
||||||
Docker: docker,
|
Docker: docker,
|
||||||
LoginAttempts: make(map[string]*types.LoginAttempt),
|
LoginAttempts: make(map[string]*LoginAttempt),
|
||||||
LDAP: ldap,
|
LDAP: ldap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,13 +81,13 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) SearchUser(username string) types.UserSearch {
|
func (auth *AuthService) SearchUser(username string) config.UserSearch {
|
||||||
log.Debug().Str("username", username).Msg("Searching for user")
|
log.Debug().Str("username", username).Msg("Searching for user")
|
||||||
|
|
||||||
// Check local users first
|
// Check local users first
|
||||||
if auth.GetLocalUser(username).Username != "" {
|
if auth.GetLocalUser(username).Username != "" {
|
||||||
log.Debug().Str("username", username).Msg("Found local user")
|
log.Debug().Str("username", username).Msg("Found local user")
|
||||||
return types.UserSearch{
|
return config.UserSearch{
|
||||||
Username: username,
|
Username: username,
|
||||||
Type: "local",
|
Type: "local",
|
||||||
}
|
}
|
||||||
@@ -93,20 +99,20 @@ func (auth *AuthService) SearchUser(username string) types.UserSearch {
|
|||||||
userDN, err := auth.LDAP.Search(username)
|
userDN, err := auth.LDAP.Search(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
|
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
|
||||||
return types.UserSearch{}
|
return config.UserSearch{}
|
||||||
}
|
}
|
||||||
return types.UserSearch{
|
return config.UserSearch{
|
||||||
Username: userDN,
|
Username: userDN,
|
||||||
Type: "ldap",
|
Type: "ldap",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.UserSearch{
|
return config.UserSearch{
|
||||||
Type: "unknown",
|
Type: "unknown",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool {
|
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
|
||||||
// Authenticate the user based on the type
|
// Authenticate the user based on the type
|
||||||
switch search.Type {
|
switch search.Type {
|
||||||
case "local":
|
case "local":
|
||||||
@@ -144,7 +150,7 @@ func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bo
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) types.User {
|
func (auth *AuthService) GetLocalUser(username string) config.User {
|
||||||
// Loop through users and return the user if the username matches
|
// Loop through users and return the user if the username matches
|
||||||
log.Debug().Str("username", username).Msg("Searching for local user")
|
log.Debug().Str("username", username).Msg("Searching for local user")
|
||||||
|
|
||||||
@@ -156,10 +162,10 @@ func (auth *AuthService) GetLocalUser(username string) types.User {
|
|||||||
|
|
||||||
// If no user found, return an empty user
|
// If no user found, return an empty user
|
||||||
log.Warn().Str("username", username).Msg("Local user not found")
|
log.Warn().Str("username", username).Msg("Local user not found")
|
||||||
return types.User{}
|
return config.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CheckPassword(user types.User, password string) bool {
|
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +207,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
// Get current attempt record or create a new one
|
// Get current attempt record or create a new one
|
||||||
attempt, exists := auth.LoginAttempts[identifier]
|
attempt, exists := auth.LoginAttempts[identifier]
|
||||||
if !exists {
|
if !exists {
|
||||||
attempt = &types.LoginAttempt{}
|
attempt = &LoginAttempt{}
|
||||||
auth.LoginAttempts[identifier] = attempt
|
auth.LoginAttempts[identifier] = attempt
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,7 +235,7 @@ func (auth *AuthService) EmailWhitelisted(email string) bool {
|
|||||||
return utils.CheckFilter(auth.Config.OauthWhitelist, email)
|
return utils.CheckFilter(auth.Config.OauthWhitelist, email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
|
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
|
||||||
log.Debug().Msg("Creating session cookie")
|
log.Debug().Msg("Creating session cookie")
|
||||||
|
|
||||||
session, err := auth.GetSession(c)
|
session, err := auth.GetSession(c)
|
||||||
@@ -288,13 +294,13 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
|
func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
|
||||||
log.Debug().Msg("Getting session cookie")
|
log.Debug().Msg("Getting session cookie")
|
||||||
|
|
||||||
session, err := auth.GetSession(c)
|
session, err := auth.GetSession(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to get session")
|
log.Error().Err(err).Msg("Failed to get session")
|
||||||
return types.SessionCookie{}, err
|
return config.SessionCookie{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got session")
|
log.Debug().Msg("Got session")
|
||||||
@@ -311,18 +317,18 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie,
|
|||||||
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
|
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
|
||||||
log.Warn().Msg("Session cookie is invalid")
|
log.Warn().Msg("Session cookie is invalid")
|
||||||
auth.DeleteSessionCookie(c)
|
auth.DeleteSessionCookie(c)
|
||||||
return types.SessionCookie{}, nil
|
return config.SessionCookie{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the session cookie has expired, delete it
|
// If the session cookie has expired, delete it
|
||||||
if time.Now().Unix() > expiry {
|
if time.Now().Unix() > expiry {
|
||||||
log.Warn().Msg("Session cookie expired")
|
log.Warn().Msg("Session cookie expired")
|
||||||
auth.DeleteSessionCookie(c)
|
auth.DeleteSessionCookie(c)
|
||||||
return types.SessionCookie{}, nil
|
return config.SessionCookie{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
|
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
|
||||||
return types.SessionCookie{
|
return config.SessionCookie{
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: name,
|
Name: name,
|
||||||
Email: email,
|
Email: email,
|
||||||
@@ -337,7 +343,7 @@ func (auth *AuthService) UserAuthConfigured() bool {
|
|||||||
return len(auth.Config.Users) > 0 || auth.LDAP != nil
|
return len(auth.Config.Users) > 0 || auth.LDAP != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
|
func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool {
|
||||||
if context.OAuth {
|
if context.OAuth {
|
||||||
log.Debug().Msg("Checking OAuth whitelist")
|
log.Debug().Msg("Checking OAuth whitelist")
|
||||||
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
|
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
|
||||||
@@ -347,7 +353,7 @@ func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserConte
|
|||||||
return utils.CheckFilter(labels.Users, context.Username)
|
return utils.CheckFilter(labels.Users, context.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
|
func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool {
|
||||||
if labels.OAuth.Groups == "" {
|
if labels.OAuth.Groups == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -374,7 +380,7 @@ func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, l
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) {
|
func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) {
|
||||||
// If the label is empty, auth is enabled
|
// If the label is empty, auth is enabled
|
||||||
if labels.Allowed == "" {
|
if labels.Allowed == "" {
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -398,18 +404,18 @@ func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, err
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User {
|
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
|
||||||
username, password, ok := c.Request.BasicAuth()
|
username, password, ok := c.Request.BasicAuth()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &types.User{
|
return &config.User{
|
||||||
Username: username,
|
Username: username,
|
||||||
Password: password,
|
Password: password,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
|
func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
|
||||||
// Check if the IP is in block list
|
// Check if the IP is in block list
|
||||||
for _, blocked := range labels.IP.Block {
|
for _, blocked := range labels.IP.Block {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
@@ -446,7 +452,7 @@ func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool {
|
func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool {
|
||||||
// For every IP in the bypass list, check if the IP matches
|
// For every IP in the bypass list, check if the IP matches
|
||||||
for _, bypassed := range labels.IP.Bypass {
|
for _, bypassed := range labels.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ type GenericOAuthService struct {
|
|||||||
Token *oauth2.Token
|
Token *oauth2.Token
|
||||||
Verifier string
|
Verifier string
|
||||||
InsecureSkipVerify bool
|
InsecureSkipVerify bool
|
||||||
ServiceName string
|
|
||||||
UserinfoURL string
|
UserinfoURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,7 +35,6 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||||
ServiceName: config.Name,
|
|
||||||
UserinfoURL: config.UserinfoURL,
|
UserinfoURL: config.UserinfoURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,10 +61,6 @@ func (generic *GenericOAuthService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (generic *GenericOAuthService) Name() string {
|
|
||||||
return generic.ServiceName
|
|
||||||
}
|
|
||||||
|
|
||||||
func (generic *GenericOAuthService) GenerateState() string {
|
func (generic *GenericOAuthService) GenerateState() string {
|
||||||
b := make([]byte, 128)
|
b := make([]byte, 128)
|
||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
|
|||||||
@@ -54,10 +54,6 @@ func (github *GithubOAuthService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (github *GithubOAuthService) Name() string {
|
|
||||||
return "github"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (github *GithubOAuthService) GenerateState() string {
|
func (github *GithubOAuthService) GenerateState() string {
|
||||||
b := make([]byte, 128)
|
b := make([]byte, 128)
|
||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
|
|||||||
@@ -49,10 +49,6 @@ func (google *GoogleOAuthService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (google *GoogleOAuthService) Name() string {
|
|
||||||
return "google"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (oauth *GoogleOAuthService) GenerateState() string {
|
func (oauth *GoogleOAuthService) GenerateState() string {
|
||||||
b := make([]byte, 128)
|
b := make([]byte, 128)
|
||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
|
|||||||
76
internal/service/oauth_broker_service.go
Normal file
76
internal/service/oauth_broker_service.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"tinyauth/internal/config"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuthService interface {
|
||||||
|
Init() error
|
||||||
|
GenerateState() string
|
||||||
|
GetAuthURL(state string) string
|
||||||
|
VerifyCode(code string) error
|
||||||
|
Userinfo() (config.Claims, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthBrokerService struct {
|
||||||
|
Services map[string]OAuthService
|
||||||
|
Configs map[string]config.OAuthServiceConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
|
||||||
|
return &OAuthBrokerService{
|
||||||
|
Services: make(map[string]OAuthService),
|
||||||
|
Configs: configs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (broker *OAuthBrokerService) Init() error {
|
||||||
|
for name, cfg := range broker.Configs {
|
||||||
|
switch name {
|
||||||
|
case "github":
|
||||||
|
service := NewGithubOAuthService(cfg)
|
||||||
|
broker.Services[name] = service
|
||||||
|
case "google":
|
||||||
|
service := NewGoogleOAuthService(cfg)
|
||||||
|
broker.Services[name] = service
|
||||||
|
default:
|
||||||
|
service := NewGenericOAuthService(cfg)
|
||||||
|
broker.Services[name] = service
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, service := range broker.Services {
|
||||||
|
err := service.Init()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Info().Msgf("Initialized OAuth service: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (broker *OAuthBrokerService) GetConfiguredServices() []string {
|
||||||
|
services := make([]string, 0, len(broker.Services))
|
||||||
|
for name := range broker.Services {
|
||||||
|
services = append(services, name)
|
||||||
|
}
|
||||||
|
return services
|
||||||
|
}
|
||||||
|
|
||||||
|
func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) {
|
||||||
|
service, exists := broker.Services[name]
|
||||||
|
return service, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) {
|
||||||
|
oauthService, exists := broker.Services[service]
|
||||||
|
if !exists {
|
||||||
|
return config.Claims{}, errors.New("oauth service not found")
|
||||||
|
}
|
||||||
|
return oauthService.Userinfo()
|
||||||
|
}
|
||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/config"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/traefik/paerser/parser"
|
"github.com/traefik/paerser/parser"
|
||||||
@@ -22,21 +22,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Parses a list of comma separated users in a struct
|
// Parses a list of comma separated users in a struct
|
||||||
func ParseUsers(users string) (types.Users, error) {
|
func ParseUsers(users string) ([]config.User, error) {
|
||||||
log.Debug().Msg("Parsing users")
|
log.Debug().Msg("Parsing users")
|
||||||
|
|
||||||
var usersParsed types.Users
|
var usersParsed []config.User
|
||||||
|
|
||||||
userList := strings.Split(users, ",")
|
userList := strings.Split(users, ",")
|
||||||
|
|
||||||
if len(userList) == 0 {
|
if len(userList) == 0 {
|
||||||
return types.Users{}, errors.New("invalid user format")
|
return []config.User{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range userList {
|
for _, user := range userList {
|
||||||
parsed, err := ParseUser(user)
|
parsed, err := ParseUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.Users{}, err
|
return []config.User{}, err
|
||||||
}
|
}
|
||||||
usersParsed = append(usersParsed, parsed)
|
usersParsed = append(usersParsed, parsed)
|
||||||
}
|
}
|
||||||
@@ -107,11 +107,11 @@ func GetSecret(conf string, file string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the users from the config or file
|
// Get the users from the config or file
|
||||||
func GetUsers(conf string, file string) (types.Users, error) {
|
func GetUsers(conf string, file string) ([]config.User, error) {
|
||||||
var users string
|
var users string
|
||||||
|
|
||||||
if conf == "" && file == "" {
|
if conf == "" && file == "" {
|
||||||
return types.Users{}, nil
|
return []config.User{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf != "" {
|
if conf != "" {
|
||||||
@@ -152,23 +152,18 @@ func ParseHeaders(headers []string) map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get labels parses a map of labels into a struct with only the needed labels
|
// Get labels parses a map of labels into a struct with only the needed labels
|
||||||
func GetLabels(labels map[string]string) (types.Labels, error) {
|
func GetLabels(labels map[string]string) (config.Labels, error) {
|
||||||
var labelsParsed types.Labels
|
var labelsParsed config.Labels
|
||||||
|
|
||||||
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
|
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error parsing labels")
|
log.Error().Err(err).Msg("Error parsing labels")
|
||||||
return types.Labels{}, err
|
return config.Labels{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return labelsParsed, nil
|
return labelsParsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if any of the OAuth providers are configured based on the client id and secret
|
|
||||||
func OAuthConfigured(config types.Config) bool {
|
|
||||||
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter helper function
|
// Filter helper function
|
||||||
func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
||||||
for _, value := range slice {
|
for _, value := range slice {
|
||||||
@@ -180,7 +175,7 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse user
|
// Parse user
|
||||||
func ParseUser(user string) (types.User, error) {
|
func ParseUser(user string) (config.User, error) {
|
||||||
if strings.Contains(user, "$$") {
|
if strings.Contains(user, "$$") {
|
||||||
user = strings.ReplaceAll(user, "$$", "$")
|
user = strings.ReplaceAll(user, "$$", "$")
|
||||||
}
|
}
|
||||||
@@ -188,23 +183,23 @@ func ParseUser(user string) (types.User, error) {
|
|||||||
userSplit := strings.Split(user, ":")
|
userSplit := strings.Split(user, ":")
|
||||||
|
|
||||||
if len(userSplit) < 2 || len(userSplit) > 3 {
|
if len(userSplit) < 2 || len(userSplit) > 3 {
|
||||||
return types.User{}, errors.New("invalid user format")
|
return config.User{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, userPart := range userSplit {
|
for _, userPart := range userSplit {
|
||||||
if strings.TrimSpace(userPart) == "" {
|
if strings.TrimSpace(userPart) == "" {
|
||||||
return types.User{}, errors.New("invalid user format")
|
return config.User{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(userSplit) == 2 {
|
if len(userSplit) == 2 {
|
||||||
return types.User{
|
return config.User{
|
||||||
Username: strings.TrimSpace(userSplit[0]),
|
Username: strings.TrimSpace(userSplit[0]),
|
||||||
Password: strings.TrimSpace(userSplit[1]),
|
Password: strings.TrimSpace(userSplit[1]),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.User{
|
return config.User{
|
||||||
Username: strings.TrimSpace(userSplit[0]),
|
Username: strings.TrimSpace(userSplit[0]),
|
||||||
Password: strings.TrimSpace(userSplit[1]),
|
Password: strings.TrimSpace(userSplit[1]),
|
||||||
TotpSecret: strings.TrimSpace(userSplit[2]),
|
TotpSecret: strings.TrimSpace(userSplit[2]),
|
||||||
@@ -350,17 +345,17 @@ func CoalesceToString(value any) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetContext(c *gin.Context) (types.UserContext, error) {
|
func GetContext(c *gin.Context) (config.UserContext, error) {
|
||||||
userContextValue, exists := c.Get("context")
|
userContextValue, exists := c.Get("context")
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return types.UserContext{}, errors.New("no user context in request")
|
return config.UserContext{}, errors.New("no user context in request")
|
||||||
}
|
}
|
||||||
|
|
||||||
userContext, ok := userContextValue.(*types.UserContext)
|
userContext, ok := userContextValue.(*config.UserContext)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return types.UserContext{}, errors.New("invalid user context in request")
|
return config.UserContext{}, errors.New("invalid user context in request")
|
||||||
}
|
}
|
||||||
|
|
||||||
return *userContext, nil
|
return *userContext, nil
|
||||||
|
|||||||
@@ -1,548 +0,0 @@
|
|||||||
package utils_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"tinyauth/internal/types"
|
|
||||||
"tinyauth/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseUsers(t *testing.T) {
|
|
||||||
t.Log("Testing parse users with a valid string")
|
|
||||||
|
|
||||||
users := "user1:pass1,user2:pass2"
|
|
||||||
expected := types.Users{
|
|
||||||
{
|
|
||||||
Username: "user1",
|
|
||||||
Password: "pass1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "user2",
|
|
||||||
Password: "pass2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := utils.ParseUsers(users)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error parsing users: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUpperDomain(t *testing.T) {
|
|
||||||
t.Log("Testing get upper domain with a valid url")
|
|
||||||
|
|
||||||
url := "https://sub1.sub2.domain.com:8080"
|
|
||||||
expected := "sub2.domain.com"
|
|
||||||
|
|
||||||
result, err := utils.GetUpperDomain(url)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error getting root url: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if expected != result {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadFile(t *testing.T) {
|
|
||||||
t.Log("Creating a test file")
|
|
||||||
|
|
||||||
err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing read file with a valid file")
|
|
||||||
|
|
||||||
data, err := utils.ReadFile("/tmp/test.txt")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error reading file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if data != "test" {
|
|
||||||
t.Fatalf("Expected test, got %v", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Cleaning up test file")
|
|
||||||
|
|
||||||
err = os.Remove("/tmp/test.txt")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error cleaning up test file: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseFileToLine(t *testing.T) {
|
|
||||||
t.Log("Testing parse file to line with a valid string")
|
|
||||||
|
|
||||||
content := "\nuser1:pass1\nuser2:pass2\n"
|
|
||||||
expected := "user1:pass1,user2:pass2"
|
|
||||||
|
|
||||||
result := utils.ParseFileToLine(content)
|
|
||||||
|
|
||||||
if expected != result {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetSecret(t *testing.T) {
|
|
||||||
t.Log("Testing get secret with an empty config and file")
|
|
||||||
|
|
||||||
conf := ""
|
|
||||||
file := "/tmp/test.txt"
|
|
||||||
expected := "test"
|
|
||||||
|
|
||||||
err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := utils.GetSecret(conf, file)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing get secret with an empty file and a valid config")
|
|
||||||
|
|
||||||
result = utils.GetSecret(expected, "")
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing get secret with both a valid config and file")
|
|
||||||
|
|
||||||
result = utils.GetSecret(expected, file)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Cleaning up test file")
|
|
||||||
|
|
||||||
err = os.Remove(file)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error cleaning up test file: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUsers(t *testing.T) {
|
|
||||||
t.Log("Testing get users with a config and no file")
|
|
||||||
|
|
||||||
conf := "user1:pass1,user2:pass2"
|
|
||||||
file := ""
|
|
||||||
expected := types.Users{
|
|
||||||
{
|
|
||||||
Username: "user1",
|
|
||||||
Password: "pass1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "user2",
|
|
||||||
Password: "pass2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := utils.GetUsers(conf, file)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error getting users: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing get users with a file and no config")
|
|
||||||
|
|
||||||
conf = ""
|
|
||||||
file = "/tmp/test.txt"
|
|
||||||
expected = types.Users{
|
|
||||||
{
|
|
||||||
Username: "user1",
|
|
||||||
Password: "pass1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "user2",
|
|
||||||
Password: "pass2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err = utils.GetUsers(conf, file)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error getting users: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing get users with both a config and file")
|
|
||||||
|
|
||||||
conf = "user3:pass3"
|
|
||||||
expected = types.Users{
|
|
||||||
{
|
|
||||||
Username: "user3",
|
|
||||||
Password: "pass3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "user1",
|
|
||||||
Password: "pass1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "user2",
|
|
||||||
Password: "pass2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err = utils.GetUsers(conf, file)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error getting users: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Cleaning up test file")
|
|
||||||
|
|
||||||
err = os.Remove(file)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error cleaning up test file: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetLabels(t *testing.T) {
|
|
||||||
t.Log("Testing get labels with a valid map")
|
|
||||||
|
|
||||||
labels := map[string]string{
|
|
||||||
"tinyauth.users": "user1,user2",
|
|
||||||
"tinyauth.oauth.whitelist": "/regex/",
|
|
||||||
"tinyauth.allowed": "random",
|
|
||||||
"tinyauth.headers": "X-Header=value",
|
|
||||||
"tinyauth.oauth.groups": "group1,group2",
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := types.Labels{
|
|
||||||
Users: "user1,user2",
|
|
||||||
Allowed: "random",
|
|
||||||
Headers: []string{"X-Header=value"},
|
|
||||||
OAuth: types.OAuthLabels{
|
|
||||||
Whitelist: "/regex/",
|
|
||||||
Groups: "group1,group2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := utils.GetLabels(labels)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error getting labels: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseUser(t *testing.T) {
|
|
||||||
t.Log("Testing parse user with a valid user")
|
|
||||||
|
|
||||||
user := "user:pass:secret"
|
|
||||||
expected := types.User{
|
|
||||||
Username: "user",
|
|
||||||
Password: "pass",
|
|
||||||
TotpSecret: "secret",
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := utils.ParseUser(user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error parsing user: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing parse user with an escaped user")
|
|
||||||
|
|
||||||
user = "user:p$$ass$$:secret"
|
|
||||||
expected = types.User{
|
|
||||||
Username: "user",
|
|
||||||
Password: "p$ass$",
|
|
||||||
TotpSecret: "secret",
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err = utils.ParseUser(user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error parsing user: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing parse user with an invalid user")
|
|
||||||
|
|
||||||
user = "user::pass"
|
|
||||||
|
|
||||||
_, err = utils.ParseUser(user)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Expected error parsing user")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckFilter(t *testing.T) {
|
|
||||||
t.Log("Testing check filter with a comma separated list")
|
|
||||||
|
|
||||||
filter := "user1,user2,user3"
|
|
||||||
str := "user1"
|
|
||||||
expected := true
|
|
||||||
|
|
||||||
result := utils.CheckFilter(filter, str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing check filter with a regex filter")
|
|
||||||
|
|
||||||
filter = "/^user[0-9]+$/"
|
|
||||||
str = "user1"
|
|
||||||
expected = true
|
|
||||||
|
|
||||||
result = utils.CheckFilter(filter, str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing check filter with an empty filter")
|
|
||||||
|
|
||||||
filter = ""
|
|
||||||
str = "user1"
|
|
||||||
expected = true
|
|
||||||
|
|
||||||
result = utils.CheckFilter(filter, str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing check filter with an invalid regex filter")
|
|
||||||
|
|
||||||
filter = "/^user[0-9+$/"
|
|
||||||
str = "user1"
|
|
||||||
expected = false
|
|
||||||
|
|
||||||
result = utils.CheckFilter(filter, str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing check filter with a non matching list")
|
|
||||||
|
|
||||||
filter = "user1,user2,user3"
|
|
||||||
str = "user4"
|
|
||||||
expected = false
|
|
||||||
|
|
||||||
result = utils.CheckFilter(filter, str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSanitizeHeader(t *testing.T) {
|
|
||||||
t.Log("Testing sanitize header with a valid string")
|
|
||||||
|
|
||||||
str := "X-Header=value"
|
|
||||||
expected := "X-Header=value"
|
|
||||||
|
|
||||||
result := utils.SanitizeHeader(str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing sanitize header with an invalid string")
|
|
||||||
|
|
||||||
str = "X-Header=val\nue"
|
|
||||||
expected = "X-Header=value"
|
|
||||||
|
|
||||||
result = utils.SanitizeHeader(str)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseHeaders(t *testing.T) {
|
|
||||||
t.Log("Testing parse headers with a valid string")
|
|
||||||
|
|
||||||
headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"}
|
|
||||||
expected := map[string]string{
|
|
||||||
"X-Header1": "value1",
|
|
||||||
"X-Header2": "value2",
|
|
||||||
}
|
|
||||||
|
|
||||||
result := utils.ParseHeaders(headers)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing parse headers with an invalid string")
|
|
||||||
|
|
||||||
headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"}
|
|
||||||
expected = map[string]string{"X-Header3": "value3"}
|
|
||||||
|
|
||||||
result = utils.ParseHeaders(headers)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, result) {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseSecretFile(t *testing.T) {
|
|
||||||
t.Log("Testing parse secret file with a valid file")
|
|
||||||
|
|
||||||
content := "\n\n \n\n\n secret \n\n \n "
|
|
||||||
expected := "secret"
|
|
||||||
|
|
||||||
result := utils.ParseSecretFile(content)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFilterIP(t *testing.T) {
|
|
||||||
t.Log("Testing filter IP with an IP and a valid CIDR")
|
|
||||||
|
|
||||||
ip := "10.10.10.10"
|
|
||||||
filter := "10.10.10.0/24"
|
|
||||||
expected := true
|
|
||||||
|
|
||||||
result, err := utils.FilterIP(filter, ip)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error filtering IP: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing filter IP with an IP and a valid IP")
|
|
||||||
|
|
||||||
filter = "10.10.10.10"
|
|
||||||
expected = true
|
|
||||||
|
|
||||||
result, err = utils.FilterIP(filter, ip)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error filtering IP: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing filter IP with an IP and an non matching CIDR")
|
|
||||||
|
|
||||||
filter = "10.10.15.0/24"
|
|
||||||
expected = false
|
|
||||||
|
|
||||||
result, err = utils.FilterIP(filter, ip)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error filtering IP: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing filter IP with a non matching IP and a valid CIDR")
|
|
||||||
|
|
||||||
filter = "10.10.10.11"
|
|
||||||
expected = false
|
|
||||||
|
|
||||||
result, err = utils.FilterIP(filter, ip)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error filtering IP: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing filter IP with an IP and an invalid CIDR")
|
|
||||||
|
|
||||||
filter = "10.../83"
|
|
||||||
|
|
||||||
_, err = utils.FilterIP(filter, ip)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Expected error filtering IP")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeriveKey(t *testing.T) {
|
|
||||||
t.Log("Testing the derive key function")
|
|
||||||
|
|
||||||
master := "master"
|
|
||||||
info := "info"
|
|
||||||
expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl"
|
|
||||||
|
|
||||||
result, err := utils.DeriveKey(master, info)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error deriving key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCoalesceToString(t *testing.T) {
|
|
||||||
t.Log("Testing coalesce to string with a string")
|
|
||||||
|
|
||||||
value := any("test")
|
|
||||||
expected := "test"
|
|
||||||
|
|
||||||
result := utils.CoalesceToString(value)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing coalesce to string with a slice of strings")
|
|
||||||
|
|
||||||
value = []any{any("test1"), any("test2"), any(123)}
|
|
||||||
expected = "test1,test2"
|
|
||||||
|
|
||||||
result = utils.CoalesceToString(value)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("Testing coalesce to string with an unsupported type")
|
|
||||||
|
|
||||||
value = 12345
|
|
||||||
expected = ""
|
|
||||||
|
|
||||||
result = utils.CoalesceToString(value)
|
|
||||||
|
|
||||||
if result != expected {
|
|
||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user