mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 20:55:42 +00:00
fix: coderabbit suggestions
This commit is contained in:
@@ -82,7 +82,7 @@ func init() {
|
||||
{"app-url", "", "The Tinyauth URL."},
|
||||
{"users", "", "Comma separated list of users in the format username:hash."},
|
||||
{"users-file", "", "Path to a file containing users in the format username:hash."},
|
||||
{"cookie-secure", false, "Send cookie over secure connection only."},
|
||||
{"secure-cookie", false, "Send cookie over secure connection only."},
|
||||
{"github-client-id", "", "Github OAuth client ID."},
|
||||
{"github-client-secret", "", "Github OAuth client secret."},
|
||||
{"github-client-secret-file", "", "Github OAuth client secret file."},
|
||||
|
||||
@@ -7,4 +7,4 @@ import (
|
||||
// Frontend assets
|
||||
//
|
||||
//go:embed dist
|
||||
var FontendAssets embed.FS
|
||||
var FrontendAssets embed.FS
|
||||
|
||||
@@ -164,13 +164,13 @@ func (app *BootstrapApp) Setup() error {
|
||||
log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware")
|
||||
err := middleware.Init()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err)
|
||||
return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err)
|
||||
}
|
||||
engine.Use(middleware.Middleware())
|
||||
}
|
||||
|
||||
// Create routers
|
||||
mainRouter := engine.Group("/")
|
||||
mainRouter := engine.Group("")
|
||||
apiRouter := engine.Group("/api")
|
||||
|
||||
// Create controllers
|
||||
@@ -190,6 +190,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
SecureCookie: app.Config.SecureCookie,
|
||||
CSRFCookieName: csrfCookieName,
|
||||
RedirectCookieName: redirectCookieName,
|
||||
Domain: domain,
|
||||
}, apiRouter, authService, oauthBrokerService)
|
||||
|
||||
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{
|
||||
|
||||
@@ -65,10 +65,10 @@ type OAuthLabels struct {
|
||||
|
||||
type BasicLabels struct {
|
||||
Username string
|
||||
Password PassowrdLabels
|
||||
Password PasswordLabels
|
||||
}
|
||||
|
||||
type PassowrdLabels struct {
|
||||
type PasswordLabels struct {
|
||||
Plain string
|
||||
File string
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type OAuthControllerConfig struct {
|
||||
RedirectCookieName string
|
||||
SecureCookie bool
|
||||
AppURL string
|
||||
Domain string
|
||||
}
|
||||
|
||||
type OAuthController struct {
|
||||
@@ -77,7 +78,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
||||
|
||||
redirectURI := c.Query("redirect_uri")
|
||||
|
||||
if redirectURI != "" {
|
||||
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
|
||||
log.Debug().Msg("Setting redirect URI cookie")
|
||||
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
redirectURI, err := c.Cookie(controller.Config.RedirectCookieName)
|
||||
|
||||
if err != nil {
|
||||
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
|
||||
log.Debug().Msg("No redirect URI cookie found, redirecting to app root")
|
||||
c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL)
|
||||
return
|
||||
@@ -195,5 +196,5 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
|
||||
}
|
||||
|
||||
@@ -128,6 +128,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
|
||||
@@ -212,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
})
|
||||
|
||||
if userContext.OAuth {
|
||||
queries.Set("username", userContext.Username)
|
||||
} else {
|
||||
queries.Set("username", userContext.Email)
|
||||
} else {
|
||||
queries.Set("username", userContext.Username)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -247,9 +248,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
})
|
||||
|
||||
if userContext.OAuth {
|
||||
queries.Set("username", userContext.Username)
|
||||
} else {
|
||||
queries.Set("username", userContext.Email)
|
||||
} else {
|
||||
queries.Set("username", userContext.Username)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -13,12 +13,16 @@ type ResourcesControllerConfig struct {
|
||||
type ResourcesController struct {
|
||||
Config ResourcesControllerConfig
|
||||
Router *gin.RouterGroup
|
||||
FileServer http.Handler
|
||||
}
|
||||
|
||||
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
|
||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir)))
|
||||
|
||||
return &ResourcesController{
|
||||
Config: config,
|
||||
Router: router,
|
||||
FileServer: fileServer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +31,12 @@ func (controller *ResourcesController) SetupRoutes() {
|
||||
}
|
||||
|
||||
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
|
||||
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(controller.Config.ResourcesDir)))
|
||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||
if controller.Config.ResourcesDir == "" {
|
||||
c.JSON(404, gin.H{
|
||||
"status": 404,
|
||||
"message": "Resources not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.FileServer.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
if user.TotpSecret != "" {
|
||||
log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||
|
||||
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||
err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||
Username: user.Username,
|
||||
Name: utils.Capitalize(req.Username),
|
||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
||||
@@ -120,6 +120,15 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
TotpPending: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create session cookie")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "TOTP required",
|
||||
@@ -129,13 +138,22 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||
err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||
Username: req.Username,
|
||||
Name: utils.Capitalize(req.Username),
|
||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
|
||||
Provider: "username",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create session cookie")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Login successful",
|
||||
@@ -144,7 +162,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
|
||||
func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
log.Debug().Msg("Logout request received")
|
||||
|
||||
controller.Auth.DeleteSessionCookie(c)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Logout successful",
|
||||
@@ -175,8 +195,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !context.IsLoggedIn {
|
||||
log.Warn().Msg("TOTP attempt without being logged in")
|
||||
if !context.TotpPending {
|
||||
log.Warn().Msg("TOTP attempt without a pending TOTP session")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
@@ -223,13 +243,22 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
|
||||
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
|
||||
|
||||
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||
err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||
Username: user.Username,
|
||||
Name: utils.Capitalize(user.Username),
|
||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),
|
||||
Provider: "username",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create session cookie")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "Login successful",
|
||||
|
||||
@@ -79,6 +79,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
|
||||
if !exists {
|
||||
log.Debug().Msg("OAuth provider from session cookie not found")
|
||||
m.Auth.DeleteSessionCookie(c)
|
||||
goto basic
|
||||
}
|
||||
|
||||
|
||||
@@ -20,10 +20,10 @@ func NewUIMiddleware() *UIMiddleware {
|
||||
}
|
||||
|
||||
func (m *UIMiddleware) Init() error {
|
||||
ui, err := fs.Sub(assets.FontendAssets, "dist")
|
||||
ui, err := fs.Sub(assets.FrontendAssets, "dist")
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
m.UIFS = ui
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
|
||||
var (
|
||||
loggerSkipPathsPrefix = []string{
|
||||
"GET /api/healthcheck",
|
||||
"HEAD /api/healthcheck",
|
||||
"GET /api/health",
|
||||
"HEAD /api/health",
|
||||
"GET /favicon.ico",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -71,9 +71,9 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
|
||||
|
||||
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Error getting session, clearing cookie and retrying")
|
||||
log.Debug().Err(err).Msg("Error getting session, creating a new one")
|
||||
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true)
|
||||
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
||||
session, err = auth.Store.New(c.Request, auth.Config.SessionCookieName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"tinyauth/internal/config"
|
||||
@@ -76,7 +77,7 @@ func (generic *GenericOAuthService) VerifyCode(code string) error {
|
||||
token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
generic.Token = token
|
||||
@@ -94,6 +95,10 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"tinyauth/internal/config"
|
||||
@@ -71,7 +72,7 @@ func (github *GithubOAuthService) VerifyCode(code string) error {
|
||||
token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
github.Token = token
|
||||
@@ -83,12 +84,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
|
||||
|
||||
client := github.Config.Client(github.Context, github.Token)
|
||||
|
||||
res, err := client.Get("https://api.github.com/user")
|
||||
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
@@ -101,12 +113,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
|
||||
return user, err
|
||||
}
|
||||
|
||||
res, err = client.Get("https://api.github.com/user/emails")
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
|
||||
res, err = client.Do(req)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
}
|
||||
|
||||
body, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -66,7 +67,7 @@ func (google *GoogleOAuthService) VerifyCode(code string) error {
|
||||
token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
google.Token = token
|
||||
@@ -84,6 +85,10 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return user, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
|
||||
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
@@ -12,16 +13,25 @@ import (
|
||||
)
|
||||
|
||||
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
|
||||
func GetUpperDomain(urlSrc string) (string, error) {
|
||||
urlParsed, err := url.Parse(urlSrc)
|
||||
func GetUpperDomain(appUrl string) (string, error) {
|
||||
appUrlParsed, err := url.Parse(appUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
urlSplitted := strings.Split(urlParsed.Hostname(), ".")
|
||||
urlFinal := strings.Join(urlSplitted[1:], ".")
|
||||
host := appUrlParsed.Hostname()
|
||||
|
||||
return urlFinal, nil
|
||||
if netIP := net.ParseIP(host); netIP != nil {
|
||||
return "", errors.New("IP addresses are not allowed")
|
||||
}
|
||||
|
||||
urlParts := strings.Split(host, ".")
|
||||
|
||||
if len(urlParts) < 2 {
|
||||
return "", errors.New("invalid domain, must be at least second level domain")
|
||||
}
|
||||
|
||||
return strings.Join(urlParts[1:], "."), nil
|
||||
}
|
||||
|
||||
func ParseFileToLine(content string) string {
|
||||
@@ -63,8 +73,38 @@ func GetContext(c *gin.Context) (config.UserContext, error) {
|
||||
return *userContext, nil
|
||||
}
|
||||
|
||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||
if redirectURL == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !parsedURL.IsAbs() {
|
||||
return false
|
||||
}
|
||||
|
||||
upper, err := GetUpperDomain(redirectURL)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if upper != domain {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func GetLogLevel(level string) zerolog.Level {
|
||||
switch strings.ToLower(level) {
|
||||
case "trace":
|
||||
return zerolog.TraceLevel
|
||||
case "debug":
|
||||
return zerolog.DebugLevel
|
||||
case "info":
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
@@ -26,6 +27,10 @@ func ParseHeaders(headers []string) map[string]string {
|
||||
continue
|
||||
}
|
||||
key := SanitizeHeader(strings.TrimSpace(split[0]))
|
||||
if strings.ContainsAny(key, " \t") {
|
||||
continue
|
||||
}
|
||||
key = http.CanonicalHeaderKey(key)
|
||||
value := SanitizeHeader(strings.TrimSpace(split[1]))
|
||||
headerMap[key] = value
|
||||
}
|
||||
|
||||
@@ -9,6 +9,12 @@ import (
|
||||
func ParseUsers(users string) ([]config.User, error) {
|
||||
var usersParsed []config.User
|
||||
|
||||
users = strings.TrimSpace(users)
|
||||
|
||||
if users == "" {
|
||||
return []config.User{}, nil
|
||||
}
|
||||
|
||||
userList := strings.Split(users, ",")
|
||||
|
||||
if len(userList) == 0 {
|
||||
@@ -16,7 +22,10 @@ func ParseUsers(users string) ([]config.User, error) {
|
||||
}
|
||||
|
||||
for _, user := range userList {
|
||||
parsed, err := ParseUser(user)
|
||||
if strings.TrimSpace(user) == "" {
|
||||
continue
|
||||
}
|
||||
parsed, err := ParseUser(strings.TrimSpace(user))
|
||||
if err != nil {
|
||||
return []config.User{}, err
|
||||
}
|
||||
@@ -39,13 +48,14 @@ func GetUsers(conf string, file string) ([]config.User, error) {
|
||||
|
||||
if file != "" {
|
||||
contents, err := ReadFile(file)
|
||||
if err == nil {
|
||||
if err != nil {
|
||||
return []config.User{}, err
|
||||
}
|
||||
if users != "" {
|
||||
users += ","
|
||||
}
|
||||
users += ParseFileToLine(contents)
|
||||
}
|
||||
}
|
||||
|
||||
return ParseUsers(users)
|
||||
}
|
||||
|
||||
2
main.go
2
main.go
@@ -10,6 +10,6 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger().Level(zerolog.FatalLevel)
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger()
|
||||
cmd.Execute()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user