mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 20:55:42 +00:00
fix: coderabbit suggestions
This commit is contained in:
@@ -82,7 +82,7 @@ func init() {
|
|||||||
{"app-url", "", "The Tinyauth URL."},
|
{"app-url", "", "The Tinyauth URL."},
|
||||||
{"users", "", "Comma separated list of users in the format username:hash."},
|
{"users", "", "Comma separated list of users in the format username:hash."},
|
||||||
{"users-file", "", "Path to a file containing users in the format username:hash."},
|
{"users-file", "", "Path to a file containing users in the format username:hash."},
|
||||||
{"cookie-secure", false, "Send cookie over secure connection only."},
|
{"secure-cookie", false, "Send cookie over secure connection only."},
|
||||||
{"github-client-id", "", "Github OAuth client ID."},
|
{"github-client-id", "", "Github OAuth client ID."},
|
||||||
{"github-client-secret", "", "Github OAuth client secret."},
|
{"github-client-secret", "", "Github OAuth client secret."},
|
||||||
{"github-client-secret-file", "", "Github OAuth client secret file."},
|
{"github-client-secret-file", "", "Github OAuth client secret file."},
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ import (
|
|||||||
// Frontend assets
|
// Frontend assets
|
||||||
//
|
//
|
||||||
//go:embed dist
|
//go:embed dist
|
||||||
var FontendAssets embed.FS
|
var FrontendAssets embed.FS
|
||||||
|
|||||||
@@ -164,13 +164,13 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware")
|
log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware")
|
||||||
err := middleware.Init()
|
err := middleware.Init()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err)
|
return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err)
|
||||||
}
|
}
|
||||||
engine.Use(middleware.Middleware())
|
engine.Use(middleware.Middleware())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create routers
|
// Create routers
|
||||||
mainRouter := engine.Group("/")
|
mainRouter := engine.Group("")
|
||||||
apiRouter := engine.Group("/api")
|
apiRouter := engine.Group("/api")
|
||||||
|
|
||||||
// Create controllers
|
// Create controllers
|
||||||
@@ -190,6 +190,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
SecureCookie: app.Config.SecureCookie,
|
SecureCookie: app.Config.SecureCookie,
|
||||||
CSRFCookieName: csrfCookieName,
|
CSRFCookieName: csrfCookieName,
|
||||||
RedirectCookieName: redirectCookieName,
|
RedirectCookieName: redirectCookieName,
|
||||||
|
Domain: domain,
|
||||||
}, apiRouter, authService, oauthBrokerService)
|
}, apiRouter, authService, oauthBrokerService)
|
||||||
|
|
||||||
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
|
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
|
||||||
|
|||||||
@@ -65,10 +65,10 @@ type OAuthLabels struct {
|
|||||||
|
|
||||||
type BasicLabels struct {
|
type BasicLabels struct {
|
||||||
Username string
|
Username string
|
||||||
Password PassowrdLabels
|
Password PasswordLabels
|
||||||
}
|
}
|
||||||
|
|
||||||
type PassowrdLabels struct {
|
type PasswordLabels struct {
|
||||||
Plain string
|
Plain string
|
||||||
File string
|
File string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type OAuthControllerConfig struct {
|
|||||||
RedirectCookieName string
|
RedirectCookieName string
|
||||||
SecureCookie bool
|
SecureCookie bool
|
||||||
AppURL string
|
AppURL string
|
||||||
|
Domain string
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthController struct {
|
type OAuthController struct {
|
||||||
@@ -77,7 +78,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
|
|
||||||
redirectURI := c.Query("redirect_uri")
|
redirectURI := c.Query("redirect_uri")
|
||||||
|
|
||||||
if redirectURI != "" {
|
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
|
||||||
log.Debug().Msg("Setting redirect URI cookie")
|
log.Debug().Msg("Setting redirect URI cookie")
|
||||||
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
||||||
}
|
}
|
||||||
@@ -178,7 +179,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
redirectURI, err := c.Cookie(controller.Config.RedirectCookieName)
|
redirectURI, err := c.Cookie(controller.Config.RedirectCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
|
||||||
log.Debug().Msg("No redirect URI cookie found, redirecting to app root")
|
log.Debug().Msg("No redirect URI cookie found, redirecting to app root")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL)
|
c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL)
|
||||||
return
|
return
|
||||||
@@ -195,5 +196,5 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to encode unauthorized query")
|
log.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
|
||||||
@@ -212,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if userContext.OAuth {
|
if userContext.OAuth {
|
||||||
queries.Set("username", userContext.Username)
|
|
||||||
} else {
|
|
||||||
queries.Set("username", userContext.Email)
|
queries.Set("username", userContext.Email)
|
||||||
|
} else {
|
||||||
|
queries.Set("username", userContext.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -247,9 +248,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if userContext.OAuth {
|
if userContext.OAuth {
|
||||||
queries.Set("username", userContext.Username)
|
|
||||||
} else {
|
|
||||||
queries.Set("username", userContext.Email)
|
queries.Set("username", userContext.Email)
|
||||||
|
} else {
|
||||||
|
queries.Set("username", userContext.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -11,14 +11,18 @@ type ResourcesControllerConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ResourcesController struct {
|
type ResourcesController struct {
|
||||||
Config ResourcesControllerConfig
|
Config ResourcesControllerConfig
|
||||||
Router *gin.RouterGroup
|
Router *gin.RouterGroup
|
||||||
|
FileServer http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
|
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
|
||||||
|
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir)))
|
||||||
|
|
||||||
return &ResourcesController{
|
return &ResourcesController{
|
||||||
Config: config,
|
Config: config,
|
||||||
Router: router,
|
Router: router,
|
||||||
|
FileServer: fileServer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,6 +31,12 @@ func (controller *ResourcesController) SetupRoutes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(controller.Config.ResourcesDir)))
|
if controller.Config.ResourcesDir == "" {
|
||||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
c.JSON(404, gin.H{
|
||||||
|
"status": 404,
|
||||||
|
"message": "Resources not found",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
controller.FileServer.ServeHTTP(c.Writer, c.Request)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
if user.TotpSecret != "" {
|
if user.TotpSecret != "" {
|
||||||
log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||||
|
|
||||||
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
||||||
@@ -120,6 +120,15 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
TotpPending: true,
|
TotpPending: true,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "TOTP required",
|
"message": "TOTP required",
|
||||||
@@ -129,13 +138,22 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
||||||
Provider: "username",
|
Provider: "username",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
@@ -144,7 +162,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func (controller *UserController) logoutHandler(c *gin.Context) {
|
func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||||
log.Debug().Msg("Logout request received")
|
log.Debug().Msg("Logout request received")
|
||||||
|
|
||||||
controller.Auth.DeleteSessionCookie(c)
|
controller.Auth.DeleteSessionCookie(c)
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Logout successful",
|
"message": "Logout successful",
|
||||||
@@ -175,8 +195,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !context.IsLoggedIn {
|
if !context.TotpPending {
|
||||||
log.Warn().Msg("TOTP attempt without being logged in")
|
log.Warn().Msg("TOTP attempt without a pending TOTP session")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -223,13 +243,22 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
|
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
|
||||||
|
|
||||||
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),
|
||||||
Provider: "username",
|
Provider: "username",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debug().Msg("OAuth provider from session cookie not found")
|
log.Debug().Msg("OAuth provider from session cookie not found")
|
||||||
|
m.Auth.DeleteSessionCookie(c)
|
||||||
goto basic
|
goto basic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,10 @@ func NewUIMiddleware() *UIMiddleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *UIMiddleware) Init() error {
|
func (m *UIMiddleware) Init() error {
|
||||||
ui, err := fs.Sub(assets.FontendAssets, "dist")
|
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.UIFS = ui
|
m.UIFS = ui
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
loggerSkipPathsPrefix = []string{
|
loggerSkipPathsPrefix = []string{
|
||||||
"GET /api/healthcheck",
|
"GET /api/health",
|
||||||
"HEAD /api/healthcheck",
|
"HEAD /api/health",
|
||||||
"GET /favicon.ico",
|
"GET /favicon.ico",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -71,9 +71,9 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
|
|||||||
|
|
||||||
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Msg("Error getting session, clearing cookie and retrying")
|
log.Debug().Err(err).Msg("Error getting session, creating a new one")
|
||||||
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true)
|
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true)
|
||||||
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
session, err = auth.Store.New(c.Request, auth.Config.SessionCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
@@ -76,7 +77,7 @@ func (generic *GenericOAuthService) VerifyCode(code string) error {
|
|||||||
token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier))
|
token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
generic.Token = token
|
generic.Token = token
|
||||||
@@ -94,6 +95,10 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
|
|||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
|
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
@@ -71,7 +72,7 @@ func (github *GithubOAuthService) VerifyCode(code string) error {
|
|||||||
token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier))
|
token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
github.Token = token
|
github.Token = token
|
||||||
@@ -83,12 +84,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
|
|||||||
|
|
||||||
client := github.Config.Client(github.Context, github.Token)
|
client := github.Config.Client(github.Context, github.Token)
|
||||||
|
|
||||||
res, err := client.Get("https://api.github.com/user")
|
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "application/vnd.github+json")
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
|
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
@@ -101,12 +113,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
|
|||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err = client.Get("https://api.github.com/user/emails")
|
req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "application/vnd.github+json")
|
||||||
|
|
||||||
|
res, err = client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
|
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
|
}
|
||||||
|
|
||||||
body, err = io.ReadAll(res.Body)
|
body, err = io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return user, err
|
return user, err
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -66,7 +67,7 @@ func (google *GoogleOAuthService) VerifyCode(code string) error {
|
|||||||
token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier))
|
token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
google.Token = token
|
google.Token = token
|
||||||
@@ -84,6 +85,10 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
|
|||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
|
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return config.Claims{}, err
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
@@ -12,16 +13,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||||
func GetUpperDomain(urlSrc string) (string, error) {
|
func GetUpperDomain(appUrl string) (string, error) {
|
||||||
urlParsed, err := url.Parse(urlSrc)
|
appUrlParsed, err := url.Parse(appUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
urlSplitted := strings.Split(urlParsed.Hostname(), ".")
|
host := appUrlParsed.Hostname()
|
||||||
urlFinal := strings.Join(urlSplitted[1:], ".")
|
|
||||||
|
|
||||||
return urlFinal, nil
|
if netIP := net.ParseIP(host); netIP != nil {
|
||||||
|
return "", errors.New("IP addresses are not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
urlParts := strings.Split(host, ".")
|
||||||
|
|
||||||
|
if len(urlParts) < 2 {
|
||||||
|
return "", errors.New("invalid domain, must be at least second level domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(urlParts[1:], "."), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseFileToLine(content string) string {
|
func ParseFileToLine(content string) string {
|
||||||
@@ -63,8 +73,38 @@ func GetContext(c *gin.Context) (config.UserContext, error) {
|
|||||||
return *userContext, nil
|
return *userContext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||||
|
if redirectURL == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(redirectURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !parsedURL.IsAbs() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
upper, err := GetUpperDomain(redirectURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if upper != domain {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func GetLogLevel(level string) zerolog.Level {
|
func GetLogLevel(level string) zerolog.Level {
|
||||||
switch strings.ToLower(level) {
|
switch strings.ToLower(level) {
|
||||||
|
case "trace":
|
||||||
|
return zerolog.TraceLevel
|
||||||
case "debug":
|
case "debug":
|
||||||
return zerolog.DebugLevel
|
return zerolog.DebugLevel
|
||||||
case "info":
|
case "info":
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
|
|
||||||
@@ -26,6 +27,10 @@ func ParseHeaders(headers []string) map[string]string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
key := SanitizeHeader(strings.TrimSpace(split[0]))
|
key := SanitizeHeader(strings.TrimSpace(split[0]))
|
||||||
|
if strings.ContainsAny(key, " \t") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key = http.CanonicalHeaderKey(key)
|
||||||
value := SanitizeHeader(strings.TrimSpace(split[1]))
|
value := SanitizeHeader(strings.TrimSpace(split[1]))
|
||||||
headerMap[key] = value
|
headerMap[key] = value
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ import (
|
|||||||
func ParseUsers(users string) ([]config.User, error) {
|
func ParseUsers(users string) ([]config.User, error) {
|
||||||
var usersParsed []config.User
|
var usersParsed []config.User
|
||||||
|
|
||||||
|
users = strings.TrimSpace(users)
|
||||||
|
|
||||||
|
if users == "" {
|
||||||
|
return []config.User{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
userList := strings.Split(users, ",")
|
userList := strings.Split(users, ",")
|
||||||
|
|
||||||
if len(userList) == 0 {
|
if len(userList) == 0 {
|
||||||
@@ -16,7 +22,10 @@ func ParseUsers(users string) ([]config.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range userList {
|
for _, user := range userList {
|
||||||
parsed, err := ParseUser(user)
|
if strings.TrimSpace(user) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parsed, err := ParseUser(strings.TrimSpace(user))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []config.User{}, err
|
return []config.User{}, err
|
||||||
}
|
}
|
||||||
@@ -39,12 +48,13 @@ func GetUsers(conf string, file string) ([]config.User, error) {
|
|||||||
|
|
||||||
if file != "" {
|
if file != "" {
|
||||||
contents, err := ReadFile(file)
|
contents, err := ReadFile(file)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
if users != "" {
|
return []config.User{}, err
|
||||||
users += ","
|
|
||||||
}
|
|
||||||
users += ParseFileToLine(contents)
|
|
||||||
}
|
}
|
||||||
|
if users != "" {
|
||||||
|
users += ","
|
||||||
|
}
|
||||||
|
users += ParseFileToLine(contents)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ParseUsers(users)
|
return ParseUsers(users)
|
||||||
|
|||||||
2
main.go
2
main.go
@@ -10,6 +10,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger().Level(zerolog.FatalLevel)
|
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger()
|
||||||
cmd.Execute()
|
cmd.Execute()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user