fix: coderabbit suggestions

This commit is contained in:
Stavros
2025-08-26 14:31:09 +03:00
parent d3c40bb366
commit a5e1ae096b
19 changed files with 178 additions and 47 deletions

View File

@@ -82,7 +82,7 @@ func init() {
{"app-url", "", "The Tinyauth URL."},
{"users", "", "Comma separated list of users in the format username:hash."},
{"users-file", "", "Path to a file containing users in the format username:hash."},
{"cookie-secure", false, "Send cookie over secure connection only."},
{"secure-cookie", false, "Send cookie over secure connection only."},
{"github-client-id", "", "Github OAuth client ID."},
{"github-client-secret", "", "Github OAuth client secret."},
{"github-client-secret-file", "", "Github OAuth client secret file."},

View File

@@ -7,4 +7,4 @@ import (
// Frontend assets
//
//go:embed dist
var FontendAssets embed.FS
var FrontendAssets embed.FS

View File

@@ -164,13 +164,13 @@ func (app *BootstrapApp) Setup() error {
log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware")
err := middleware.Init()
if err != nil {
return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err)
return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err)
}
engine.Use(middleware.Middleware())
}
// Create routers
mainRouter := engine.Group("/")
mainRouter := engine.Group("")
apiRouter := engine.Group("/api")
// Create controllers
@@ -190,6 +190,7 @@ func (app *BootstrapApp) Setup() error {
SecureCookie: app.Config.SecureCookie,
CSRFCookieName: csrfCookieName,
RedirectCookieName: redirectCookieName,
Domain: domain,
}, apiRouter, authService, oauthBrokerService)
proxyController := controller.NewProxyController(controller.ProxyControllerConfig{

View File

@@ -65,10 +65,10 @@ type OAuthLabels struct {
type BasicLabels struct {
Username string
Password PassowrdLabels
Password PasswordLabels
}
type PassowrdLabels struct {
type PasswordLabels struct {
Plain string
File string
}

View File

@@ -23,6 +23,7 @@ type OAuthControllerConfig struct {
RedirectCookieName string
SecureCookie bool
AppURL string
Domain string
}
type OAuthController struct {
@@ -77,7 +78,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
redirectURI := c.Query("redirect_uri")
if redirectURI != "" {
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
log.Debug().Msg("Setting redirect URI cookie")
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
}
@@ -178,7 +179,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
redirectURI, err := c.Cookie(controller.Config.RedirectCookieName)
if err != nil {
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
log.Debug().Msg("No redirect URI cookie found, redirecting to app root")
c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL)
return
@@ -195,5 +196,5 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode()))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
}

View File

@@ -128,6 +128,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if err != nil {
log.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
@@ -212,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
})
if userContext.OAuth {
queries.Set("username", userContext.Username)
} else {
queries.Set("username", userContext.Email)
} else {
queries.Set("username", userContext.Username)
}
if err != nil {
@@ -247,9 +248,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
})
if userContext.OAuth {
queries.Set("username", userContext.Username)
} else {
queries.Set("username", userContext.Email)
} else {
queries.Set("username", userContext.Username)
}
if err != nil {

View File

@@ -11,14 +11,18 @@ type ResourcesControllerConfig struct {
}
type ResourcesController struct {
Config ResourcesControllerConfig
Router *gin.RouterGroup
Config ResourcesControllerConfig
Router *gin.RouterGroup
FileServer http.Handler
}
func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir)))
return &ResourcesController{
Config: config,
Router: router,
Config: config,
Router: router,
FileServer: fileServer,
}
}
@@ -27,6 +31,12 @@ func (controller *ResourcesController) SetupRoutes() {
}
func (controller *ResourcesController) resourcesHandler(c *gin.Context) {
fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(controller.Config.ResourcesDir)))
fileServer.ServeHTTP(c.Writer, c.Request)
if controller.Config.ResourcesDir == "" {
c.JSON(404, gin.H{
"status": 404,
"message": "Resources not found",
})
return
}
controller.FileServer.ServeHTTP(c.Writer, c.Request)
}

View File

@@ -112,7 +112,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
if user.TotpSecret != "" {
log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
@@ -120,6 +120,15 @@ func (controller *UserController) loginHandler(c *gin.Context) {
TotpPending: true,
})
if err != nil {
log.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"message": "TOTP required",
@@ -129,13 +138,22 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain),
Provider: "username",
})
if err != nil {
log.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
@@ -144,7 +162,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
func (controller *UserController) logoutHandler(c *gin.Context) {
log.Debug().Msg("Logout request received")
controller.Auth.DeleteSessionCookie(c)
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
@@ -175,8 +195,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
if !context.IsLoggedIn {
log.Warn().Msg("TOTP attempt without being logged in")
if !context.TotpPending {
log.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -223,13 +243,22 @@ func (controller *UserController) totpHandler(c *gin.Context) {
controller.Auth.RecordLoginAttempt(rateIdentifier, true)
controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain),
Provider: "username",
})
if err != nil {
log.Error().Err(err).Msg("Failed to create session cookie")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",

View File

@@ -79,6 +79,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
if !exists {
log.Debug().Msg("OAuth provider from session cookie not found")
m.Auth.DeleteSessionCookie(c)
goto basic
}

View File

@@ -20,10 +20,10 @@ func NewUIMiddleware() *UIMiddleware {
}
func (m *UIMiddleware) Init() error {
ui, err := fs.Sub(assets.FontendAssets, "dist")
ui, err := fs.Sub(assets.FrontendAssets, "dist")
if err != nil {
return nil
return err
}
m.UIFS = ui

View File

@@ -10,8 +10,8 @@ import (
var (
loggerSkipPathsPrefix = []string{
"GET /api/healthcheck",
"HEAD /api/healthcheck",
"GET /api/health",
"HEAD /api/health",
"GET /favicon.ico",
}
)

View File

@@ -71,9 +71,9 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
// If there was an error getting the session, it might be invalid so let's clear it and retry
if err != nil {
log.Debug().Err(err).Msg("Error getting session, clearing cookie and retrying")
log.Debug().Err(err).Msg("Error getting session, creating a new one")
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true)
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
session, err = auth.Store.New(c.Request, auth.Config.SessionCookieName)
if err != nil {
return nil, err
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"tinyauth/internal/config"
@@ -76,7 +77,7 @@ func (generic *GenericOAuthService) VerifyCode(code string) error {
token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier))
if err != nil {
return nil
return err
}
generic.Token = token
@@ -94,6 +95,10 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err

View File

@@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"tinyauth/internal/config"
@@ -71,7 +72,7 @@ func (github *GithubOAuthService) VerifyCode(code string) error {
token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier))
if err != nil {
return nil
return err
}
github.Token = token
@@ -83,12 +84,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
client := github.Config.Client(github.Context, github.Token)
res, err := client.Get("https://api.github.com/user")
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err := client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
@@ -101,12 +113,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
return user, err
}
res, err = client.Get("https://api.github.com/user/emails")
req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil)
if err != nil {
return user, err
}
req.Header.Set("Accept", "application/vnd.github+json")
res, err = client.Do(req)
if err != nil {
return user, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err = io.ReadAll(res.Body)
if err != nil {
return user, err

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
@@ -66,7 +67,7 @@ func (google *GoogleOAuthService) VerifyCode(code string) error {
token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier))
if err != nil {
return nil
return err
}
google.Token = token
@@ -84,6 +85,10 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return user, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err

View File

@@ -2,6 +2,7 @@ package utils
import (
"errors"
"net"
"net/url"
"strings"
"tinyauth/internal/config"
@@ -12,16 +13,25 @@ import (
)
// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com)
func GetUpperDomain(urlSrc string) (string, error) {
urlParsed, err := url.Parse(urlSrc)
func GetUpperDomain(appUrl string) (string, error) {
appUrlParsed, err := url.Parse(appUrl)
if err != nil {
return "", err
}
urlSplitted := strings.Split(urlParsed.Hostname(), ".")
urlFinal := strings.Join(urlSplitted[1:], ".")
host := appUrlParsed.Hostname()
return urlFinal, nil
if netIP := net.ParseIP(host); netIP != nil {
return "", errors.New("IP addresses are not allowed")
}
urlParts := strings.Split(host, ".")
if len(urlParts) < 2 {
return "", errors.New("invalid domain, must be at least second level domain")
}
return strings.Join(urlParts[1:], "."), nil
}
func ParseFileToLine(content string) string {
@@ -63,8 +73,38 @@ func GetContext(c *gin.Context) (config.UserContext, error) {
return *userContext, nil
}
func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" {
return false
}
parsedURL, err := url.Parse(redirectURL)
if err != nil {
return false
}
if !parsedURL.IsAbs() {
return false
}
upper, err := GetUpperDomain(redirectURL)
if err != nil {
return false
}
if upper != domain {
return false
}
return true
}
func GetLogLevel(level string) zerolog.Level {
switch strings.ToLower(level) {
case "trace":
return zerolog.TraceLevel
case "debug":
return zerolog.DebugLevel
case "info":

View File

@@ -1,6 +1,7 @@
package utils
import (
"net/http"
"strings"
"tinyauth/internal/config"
@@ -26,6 +27,10 @@ func ParseHeaders(headers []string) map[string]string {
continue
}
key := SanitizeHeader(strings.TrimSpace(split[0]))
if strings.ContainsAny(key, " \t") {
continue
}
key = http.CanonicalHeaderKey(key)
value := SanitizeHeader(strings.TrimSpace(split[1]))
headerMap[key] = value
}

View File

@@ -9,6 +9,12 @@ import (
func ParseUsers(users string) ([]config.User, error) {
var usersParsed []config.User
users = strings.TrimSpace(users)
if users == "" {
return []config.User{}, nil
}
userList := strings.Split(users, ",")
if len(userList) == 0 {
@@ -16,7 +22,10 @@ func ParseUsers(users string) ([]config.User, error) {
}
for _, user := range userList {
parsed, err := ParseUser(user)
if strings.TrimSpace(user) == "" {
continue
}
parsed, err := ParseUser(strings.TrimSpace(user))
if err != nil {
return []config.User{}, err
}
@@ -39,12 +48,13 @@ func GetUsers(conf string, file string) ([]config.User, error) {
if file != "" {
contents, err := ReadFile(file)
if err == nil {
if users != "" {
users += ","
}
users += ParseFileToLine(contents)
if err != nil {
return []config.User{}, err
}
if users != "" {
users += ","
}
users += ParseFileToLine(contents)
}
return ParseUsers(users)

View File

@@ -10,6 +10,6 @@ import (
)
func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger().Level(zerolog.FatalLevel)
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger()
cmd.Execute()
}