From 5e73d06fccf73d2bb0b3c4b4a8470f65baab6578 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 21 Jan 2025 18:41:06 +0200 Subject: [PATCH] refactor: use dependency injection --- cmd/root.go | 38 ++++++++++++++++++++------- internal/api/api.go | 57 +++++++++++++++++++++++++++++------------ internal/auth/auth.go | 16 +++++++++--- internal/hooks/hooks.go | 14 ++++++++-- internal/types/types.go | 12 ++++++--- internal/utils/utils.go | 18 ++++++------- 6 files changed, 113 insertions(+), 42 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 06a6232..996bb9a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,6 +6,8 @@ import ( "time" "tinyauth/internal/api" "tinyauth/internal/assets" + "tinyauth/internal/auth" + "tinyauth/internal/hooks" "tinyauth/internal/types" "tinyauth/internal/utils" @@ -45,26 +47,44 @@ var rootCmd = &cobra.Command{ os.Exit(1) } - users := config.Users + usersString := config.Users if config.UsersFile != "" { log.Info().Msg("Reading users from file") usersFromFile, readErr := utils.GetUsersFromFile(config.UsersFile) HandleError(readErr, "Failed to read users from file") usersFromFileParsed := strings.Join(strings.Split(usersFromFile, "\n"), ",") - if users != "" { - users = users + "," + usersFromFileParsed + if usersString != "" { + usersString = usersString + "," + usersFromFileParsed } else { - users = usersFromFileParsed + usersString = usersFromFileParsed } } - userList, createErr := utils.ParseUsers(users) - HandleError(createErr, "Failed to parse users") + users, parseErr := utils.ParseUsers(usersString) + HandleError(parseErr, "Failed to parse users") - // Start server - log.Info().Msg("Starting server") - api.Run(config, userList) + // Create auth service + auth := auth.NewAuth(users) + + // Create hooks service + hooks := hooks.NewHooks(auth) + + // Create API + api := api.NewAPI(types.APIConfig{ + Port: config.Port, + Address: config.Address, + Secret: config.Secret, + AppURL: config.AppURL, + CookieSecure: config.CookieSecure, + }, hooks, auth) + + // Setup routes + api.Init() + api.SetupRoutes() + + // Start + api.Run() }, } diff --git a/internal/api/api.go b/internal/api/api.go index 780e309..360aa3c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -20,8 +20,25 @@ import ( "github.com/rs/zerolog/log" ) -func Run(config types.Config, users types.UserList) { +func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) { + return &API{ + Config: config, + Hooks: hooks, + Auth: auth, + Router: nil, + } +} + +type API struct { + Config types.APIConfig + Router *gin.Engine + Hooks *hooks.Hooks + Auth *auth.Auth +} + +func (api *API) Init() { gin.SetMode(gin.ReleaseMode) + router := gin.New() router.Use(zerolog()) dist, distErr := fs.Sub(assets.Assets, "dist") @@ -32,9 +49,9 @@ func Run(config types.Config, users types.UserList) { } fileServer := http.FileServer(http.FS(dist)) - store := cookie.NewStore([]byte(config.Secret)) + store := cookie.NewStore([]byte(api.Config.Secret)) - domain, domainErr := utils.GetRootURL(config.AppURL) + domain, domainErr := utils.GetRootURL(api.Config.AppURL) log.Info().Str("domain", domain).Msg("Using domain for cookies") @@ -45,7 +62,7 @@ func Run(config types.Config, users types.UserList) { var isSecure bool - if config.CookieSecure { + if api.Config.CookieSecure { isSecure = true } else { isSecure = false @@ -60,7 +77,7 @@ func Run(config types.Config, users types.UserList) { router.Use(sessions.Sessions("tinyauth", store)) - router.Use(func(c *gin.Context) { + router.Use(func(c *gin.Context) { if !strings.HasPrefix(c.Request.URL.Path, "/api") { _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) if os.IsNotExist(err) { @@ -71,8 +88,12 @@ func Run(config types.Config, users types.UserList) { } }) - router.GET("/api/auth", func (c *gin.Context) { - userContext := hooks.UseUserContext(c, users) + api.Router = router +} + +func (api *API) SetupRoutes() { + api.Router.GET("/api/auth", func (c *gin.Context) { + userContext := api.Hooks.UseUserContext(c) if userContext.IsLoggedIn { c.JSON(200, gin.H{ @@ -97,10 +118,10 @@ func Run(config types.Config, users types.UserList) { return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) }) - router.POST("/api/login", func (c *gin.Context) { + api.Router.POST("/api/login", func (c *gin.Context) { var login types.LoginRequest err := c.BindJSON(&login) @@ -113,7 +134,7 @@ func Run(config types.Config, users types.UserList) { return } - user := auth.FindUser(users, login.Username) + user := api.Auth.GetUser(login.Username) if user == nil { c.JSON(401, gin.H{ @@ -123,7 +144,7 @@ func Run(config types.Config, users types.UserList) { return } - if !auth.CheckPassword(*user, login.Password) { + if !api.Auth.CheckPassword(*user, login.Password) { c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -141,7 +162,7 @@ func Run(config types.Config, users types.UserList) { }) }) - router.POST("/api/logout", func (c *gin.Context) { + api.Router.POST("/api/logout", func (c *gin.Context) { session := sessions.Default(c) session.Delete("tinyauth") session.Save() @@ -152,8 +173,8 @@ func Run(config types.Config, users types.UserList) { }) }) - router.GET("/api/status", func (c *gin.Context) { - userContext := hooks.UseUserContext(c, users) + api.Router.GET("/api/status", func (c *gin.Context) { + userContext := api.Hooks.UseUserContext(c) if !userContext.IsLoggedIn { c.JSON(200, gin.H{ @@ -173,14 +194,18 @@ func Run(config types.Config, users types.UserList) { }) }) - router.GET("/api/healthcheck", func (c *gin.Context) { + api.Router.GET("/api/healthcheck", func (c *gin.Context) { c.JSON(200, gin.H{ "status": 200, "message": "OK", }) }) +} - router.Run(fmt.Sprintf("%s:%d", config.Address, config.Port)) + +func (api *API) Run() { + log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") + api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) } func zerolog() gin.HandlerFunc { diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 24f168b..a5f0cef 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,8 +6,18 @@ import ( "golang.org/x/crypto/bcrypt" ) -func FindUser(userList types.UserList, username string) (*types.User) { - for _, user := range userList.Users { +func NewAuth(userList types.Users) *Auth { + return &Auth{ + Users: userList, + } +} + +type Auth struct { + Users types.Users +} + +func (auth *Auth) GetUser(username string) *types.User { + for _, user := range auth.Users { if user.Username == username { return &user } @@ -15,7 +25,7 @@ func FindUser(userList types.UserList, username string) (*types.User) { return nil } -func CheckPassword(user types.User, password string) bool { +func (auth *Auth) CheckPassword(user types.User, password string) bool { hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) return hashedPasswordErr == nil } \ No newline at end of file diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 0790eb2..d45d84f 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -8,7 +8,17 @@ import ( "github.com/gin-gonic/gin" ) -func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext) { +func NewHooks(auth *auth.Auth) *Hooks { + return &Hooks{ + Auth: auth, + } +} + +type Hooks struct { + Auth *auth.Auth +} + +func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) { session := sessions.Default(c) cookie := session.Get("tinyauth") @@ -28,7 +38,7 @@ func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext) } } - user := auth.FindUser(userList, username) + user := hooks.Auth.GetUser(username) if user == nil { return types.UserContext{ diff --git a/internal/types/types.go b/internal/types/types.go index 2fa9aac..e1e3034 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -14,9 +14,7 @@ type User struct { Password string } -type UserList struct { - Users []User -} +type Users []User type Config struct { Port int `validate:"number" mapstructure:"port"` @@ -31,4 +29,12 @@ type Config struct { type UserContext struct { Username string IsLoggedIn bool +} + +type APIConfig struct { + Port int + Address string + Secret string + AppURL string + CookieSecure bool } \ No newline at end of file diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 8d48094..e237301 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -8,26 +8,26 @@ import ( "tinyauth/internal/types" ) -func ParseUsers(users string) (types.UserList, error) { - var userList types.UserList - userListString := strings.Split(users, ",") +func ParseUsers(users string) (types.Users, error) { + var usersParsed types.Users + userList := strings.Split(users, ",") - if len(userListString) == 0 { - return types.UserList{}, errors.New("invalid user format") + if len(userList) == 0 { + return types.Users{}, errors.New("invalid user format") } - for _, user := range userListString { + for _, user := range userList { userSplit := strings.Split(user, ":") if len(userSplit) != 2 { - return types.UserList{}, errors.New("invalid user format") + return types.Users{}, errors.New("invalid user format") } - userList.Users = append(userList.Users, types.User{ + usersParsed = append(usersParsed, types.User{ Username: userSplit[0], Password: userSplit[1], }) } - return userList, nil + return usersParsed, nil } func GetRootURL(urlSrc string) (string, error) {