fix: further coderabbit suggestions

This commit is contained in:
Stavros
2025-08-26 14:49:55 +03:00
parent a5e1ae096b
commit a1b6ecdd5d
9 changed files with 33 additions and 18 deletions

View File

@@ -8,7 +8,7 @@ import (
"tinyauth/internal/config" "tinyauth/internal/config"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"github.com/go-playground/validator" "github.com/go-playground/validator/v10"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -34,9 +34,9 @@ var rootCmd = &cobra.Command{
conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile)
// Validate config // Validate config
validator := validator.New() v := validator.New()
err = validator.Struct(conf) err = v.Struct(conf)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Invalid config") log.Fatal().Err(err).Msg("Invalid config")
} }

View File

@@ -74,13 +74,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
state := service.GenerateState() state := service.GenerateState()
authURL := service.GetAuthURL(state) authURL := service.GetAuthURL(state)
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
log.Debug().Msg("Setting redirect URI cookie") log.Debug().Msg("Setting redirect URI cookie")
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -112,7 +112,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
code := c.Query("code") code := c.Query("code")
service, exists := controller.Broker.GetService(req.Provider) service, exists := controller.Broker.GetService(req.Provider)
@@ -195,6 +195,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
} }

View File

@@ -49,7 +49,7 @@ func (controller *UserController) SetupRoutes() {
func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) loginHandler(c *gin.Context) {
var req LoginRequest var req LoginRequest
err := c.BindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind JSON") log.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{
@@ -174,7 +174,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
func (controller *UserController) totpHandler(c *gin.Context) { func (controller *UserController) totpHandler(c *gin.Context) {
var req TotpRequest var req TotpRequest
err := c.BindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to bind JSON") log.Error().Err(err).Msg("Failed to bind JSON")
c.JSON(400, gin.H{ c.JSON(400, gin.H{

View File

@@ -266,6 +266,9 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
return err return err
} }
// Clear the cookie in the browser
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true)
return nil return nil
} }

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time"
"tinyauth/internal/config" "tinyauth/internal/config"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -64,8 +65,11 @@ func (generic *GenericOAuthService) Init() error {
func (generic *GenericOAuthService) GenerateState() string { func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128) b := make([]byte, 128)
rand.Read(b) _, err := rand.Read(b)
state := base64.URLEncoding.EncodeToString(b) if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state return state
} }

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time"
"tinyauth/internal/config" "tinyauth/internal/config"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -59,8 +60,11 @@ func (github *GithubOAuthService) Init() error {
func (github *GithubOAuthService) GenerateState() string { func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128) b := make([]byte, 128)
rand.Read(b) _, err := rand.Read(b)
state := base64.URLEncoding.EncodeToString(b) if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state return state
} }

View File

@@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"time"
"tinyauth/internal/config" "tinyauth/internal/config"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -54,8 +55,11 @@ func (google *GoogleOAuthService) Init() error {
func (oauth *GoogleOAuthService) GenerateState() string { func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128) b := make([]byte, 128)
rand.Read(b) _, err := rand.Read(b)
state := base64.URLEncoding.EncodeToString(b) if err != nil {
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
}
state := base64.RawURLEncoding.EncodeToString(b)
return state return state
} }

View File

@@ -140,7 +140,7 @@ func (ldap *LdapService) reconnect() error {
ldap.Conn.Close() ldap.Conn.Close()
conn, err := ldap.connect() conn, err := ldap.connect()
if err != nil { if err != nil {
return nil, nil return nil, err
} }
return conn, nil return conn, nil
} }

View File

@@ -101,7 +101,7 @@ func CheckFilter(filter string, str string) bool {
return false return false
} }
if re.MatchString(str) { if re.MatchString(strings.TrimSpace(str)) {
return true return true
} }
} }
@@ -109,7 +109,7 @@ func CheckFilter(filter string, str string) bool {
filterSplit := strings.Split(filter, ",") filterSplit := strings.Split(filter, ",")
for _, item := range filterSplit { for _, item := range filterSplit {
if strings.TrimSpace(item) == str { if strings.TrimSpace(item) == strings.TrimSpace(str) {
return true return true
} }
} }