diff --git a/cmd/root.go b/cmd/root.go index 7ed86e4..ef5733e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,7 +8,7 @@ import ( "tinyauth/internal/config" "tinyauth/internal/utils" - "github.com/go-playground/validator" + "github.com/go-playground/validator/v10" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" @@ -34,9 +34,9 @@ var rootCmd = &cobra.Command{ conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) // Validate config - validator := validator.New() + v := validator.New() - err = validator.Struct(conf) + err = v.Struct(conf) if err != nil { log.Fatal().Err(err).Msg("Invalid config") } diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 9802ea1..aa3289b 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -74,13 +74,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { state := service.GenerateState() authURL := service.GetAuthURL(state) - c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) redirectURI := c.Query("redirect_uri") if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { log.Debug().Msg("Setting redirect URI cookie") - c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) } c.JSON(200, gin.H{ @@ -112,7 +112,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) code := c.Query("code") service, exists := controller.Broker.GetService(req.Provider) @@ -195,6 +195,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 72e22d8..f7f7c9e 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -49,7 +49,7 @@ func (controller *UserController) SetupRoutes() { func (controller *UserController) loginHandler(c *gin.Context) { var req LoginRequest - err := c.BindJSON(&req) + err := c.ShouldBindJSON(&req) if err != nil { log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ @@ -174,7 +174,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) totpHandler(c *gin.Context) { var req TotpRequest - err := c.BindJSON(&req) + err := c.ShouldBindJSON(&req) if err != nil { log.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 29f2dd1..10d49e7 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -266,6 +266,9 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { return err } + // Clear the cookie in the browser + c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) + return nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index a09fd93..c16384d 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "time" "tinyauth/internal/config" "golang.org/x/oauth2" @@ -64,8 +65,11 @@ func (generic *GenericOAuthService) Init() error { func (generic *GenericOAuthService) GenerateState() string { b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) return state } diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 4df4444..7f8466b 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "time" "tinyauth/internal/config" "golang.org/x/oauth2" @@ -59,8 +60,11 @@ func (github *GithubOAuthService) Init() error { func (github *GithubOAuthService) GenerateState() string { b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) return state } diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 4f738e7..1605a85 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "time" "tinyauth/internal/config" "golang.org/x/oauth2" @@ -54,8 +55,11 @@ func (google *GoogleOAuthService) Init() error { func (oauth *GoogleOAuthService) GenerateState() string { b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) return state } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 503432f..8576c4d 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -140,7 +140,7 @@ func (ldap *LdapService) reconnect() error { ldap.Conn.Close() conn, err := ldap.connect() if err != nil { - return nil, nil + return nil, err } return conn, nil } diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 4e9e187..a031900 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -101,7 +101,7 @@ func CheckFilter(filter string, str string) bool { return false } - if re.MatchString(str) { + if re.MatchString(strings.TrimSpace(str)) { return true } } @@ -109,7 +109,7 @@ func CheckFilter(filter string, str string) bool { filterSplit := strings.Split(filter, ",") for _, item := range filterSplit { - if strings.TrimSpace(item) == str { + if strings.TrimSpace(item) == strings.TrimSpace(str) { return true } }