mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 12:45:47 +00:00
fix: further coderabbit suggestions
This commit is contained in:
@@ -8,7 +8,7 @@ import (
|
|||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/go-playground/validator"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -34,9 +34,9 @@ var rootCmd = &cobra.Command{
|
|||||||
conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile)
|
conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile)
|
||||||
|
|
||||||
// Validate config
|
// Validate config
|
||||||
validator := validator.New()
|
v := validator.New()
|
||||||
|
|
||||||
err = validator.Struct(conf)
|
err = v.Struct(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("Invalid config")
|
log.Fatal().Err(err).Msg("Invalid config")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,13 +74,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
|
|
||||||
state := service.GenerateState()
|
state := service.GenerateState()
|
||||||
authURL := service.GetAuthURL(state)
|
authURL := service.GetAuthURL(state)
|
||||||
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
|
||||||
|
|
||||||
redirectURI := c.Query("redirect_uri")
|
redirectURI := c.Query("redirect_uri")
|
||||||
|
|
||||||
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
|
if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) {
|
||||||
log.Debug().Msg("Setting redirect URI cookie")
|
log.Debug().Msg("Setting redirect URI cookie")
|
||||||
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -112,7 +112,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
|
||||||
|
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
service, exists := controller.Broker.GetService(req.Provider)
|
service, exists := controller.Broker.GetService(req.Provider)
|
||||||
@@ -195,6 +195,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true)
|
c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (controller *UserController) SetupRoutes() {
|
|||||||
func (controller *UserController) loginHandler(c *gin.Context) {
|
func (controller *UserController) loginHandler(c *gin.Context) {
|
||||||
var req LoginRequest
|
var req LoginRequest
|
||||||
|
|
||||||
err := c.BindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to bind JSON")
|
log.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
@@ -174,7 +174,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
func (controller *UserController) totpHandler(c *gin.Context) {
|
func (controller *UserController) totpHandler(c *gin.Context) {
|
||||||
var req TotpRequest
|
var req TotpRequest
|
||||||
|
|
||||||
err := c.BindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to bind JSON")
|
log.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
|
|||||||
@@ -266,6 +266,9 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear the cookie in the browser
|
||||||
|
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -64,8 +65,11 @@ func (generic *GenericOAuthService) Init() error {
|
|||||||
|
|
||||||
func (generic *GenericOAuthService) GenerateState() string {
|
func (generic *GenericOAuthService) GenerateState() string {
|
||||||
b := make([]byte, 128)
|
b := make([]byte, 128)
|
||||||
rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
state := base64.URLEncoding.EncodeToString(b)
|
if err != nil {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
|
||||||
|
}
|
||||||
|
state := base64.RawURLEncoding.EncodeToString(b)
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -59,8 +60,11 @@ func (github *GithubOAuthService) Init() error {
|
|||||||
|
|
||||||
func (github *GithubOAuthService) GenerateState() string {
|
func (github *GithubOAuthService) GenerateState() string {
|
||||||
b := make([]byte, 128)
|
b := make([]byte, 128)
|
||||||
rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
state := base64.URLEncoding.EncodeToString(b)
|
if err != nil {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
|
||||||
|
}
|
||||||
|
state := base64.RawURLEncoding.EncodeToString(b)
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
"tinyauth/internal/config"
|
"tinyauth/internal/config"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -54,8 +55,11 @@ func (google *GoogleOAuthService) Init() error {
|
|||||||
|
|
||||||
func (oauth *GoogleOAuthService) GenerateState() string {
|
func (oauth *GoogleOAuthService) GenerateState() string {
|
||||||
b := make([]byte, 128)
|
b := make([]byte, 128)
|
||||||
rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
state := base64.URLEncoding.EncodeToString(b)
|
if err != nil {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano()))
|
||||||
|
}
|
||||||
|
state := base64.RawURLEncoding.EncodeToString(b)
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func (ldap *LdapService) reconnect() error {
|
|||||||
ldap.Conn.Close()
|
ldap.Conn.Close()
|
||||||
conn, err := ldap.connect()
|
conn, err := ldap.connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func CheckFilter(filter string, str string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if re.MatchString(str) {
|
if re.MatchString(strings.TrimSpace(str)) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -109,7 +109,7 @@ func CheckFilter(filter string, str string) bool {
|
|||||||
filterSplit := strings.Split(filter, ",")
|
filterSplit := strings.Split(filter, ",")
|
||||||
|
|
||||||
for _, item := range filterSplit {
|
for _, item := range filterSplit {
|
||||||
if strings.TrimSpace(item) == str {
|
if strings.TrimSpace(item) == strings.TrimSpace(str) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user