refactor: use dependency injection

This commit is contained in:
Stavros
2025-01-21 18:41:06 +02:00
parent 2988b5f22f
commit 5e73d06fcc
6 changed files with 113 additions and 42 deletions

View File

@@ -6,6 +6,8 @@ import (
"time" "time"
"tinyauth/internal/api" "tinyauth/internal/api"
"tinyauth/internal/assets" "tinyauth/internal/assets"
"tinyauth/internal/auth"
"tinyauth/internal/hooks"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
@@ -45,26 +47,44 @@ var rootCmd = &cobra.Command{
os.Exit(1) os.Exit(1)
} }
users := config.Users usersString := config.Users
if config.UsersFile != "" { if config.UsersFile != "" {
log.Info().Msg("Reading users from file") log.Info().Msg("Reading users from file")
usersFromFile, readErr := utils.GetUsersFromFile(config.UsersFile) usersFromFile, readErr := utils.GetUsersFromFile(config.UsersFile)
HandleError(readErr, "Failed to read users from file") HandleError(readErr, "Failed to read users from file")
usersFromFileParsed := strings.Join(strings.Split(usersFromFile, "\n"), ",") usersFromFileParsed := strings.Join(strings.Split(usersFromFile, "\n"), ",")
if users != "" { if usersString != "" {
users = users + "," + usersFromFileParsed usersString = usersString + "," + usersFromFileParsed
} else { } else {
users = usersFromFileParsed usersString = usersFromFileParsed
} }
} }
userList, createErr := utils.ParseUsers(users) users, parseErr := utils.ParseUsers(usersString)
HandleError(createErr, "Failed to parse users") HandleError(parseErr, "Failed to parse users")
// Start server // Create auth service
log.Info().Msg("Starting server") auth := auth.NewAuth(users)
api.Run(config, userList)
// Create hooks service
hooks := hooks.NewHooks(auth)
// Create API
api := api.NewAPI(types.APIConfig{
Port: config.Port,
Address: config.Address,
Secret: config.Secret,
AppURL: config.AppURL,
CookieSecure: config.CookieSecure,
}, hooks, auth)
// Setup routes
api.Init()
api.SetupRoutes()
// Start
api.Run()
}, },
} }

View File

@@ -20,8 +20,25 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func Run(config types.Config, users types.UserList) { func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) {
return &API{
Config: config,
Hooks: hooks,
Auth: auth,
Router: nil,
}
}
type API struct {
Config types.APIConfig
Router *gin.Engine
Hooks *hooks.Hooks
Auth *auth.Auth
}
func (api *API) Init() {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.New() router := gin.New()
router.Use(zerolog()) router.Use(zerolog())
dist, distErr := fs.Sub(assets.Assets, "dist") dist, distErr := fs.Sub(assets.Assets, "dist")
@@ -32,9 +49,9 @@ func Run(config types.Config, users types.UserList) {
} }
fileServer := http.FileServer(http.FS(dist)) fileServer := http.FileServer(http.FS(dist))
store := cookie.NewStore([]byte(config.Secret)) store := cookie.NewStore([]byte(api.Config.Secret))
domain, domainErr := utils.GetRootURL(config.AppURL) domain, domainErr := utils.GetRootURL(api.Config.AppURL)
log.Info().Str("domain", domain).Msg("Using domain for cookies") log.Info().Str("domain", domain).Msg("Using domain for cookies")
@@ -45,7 +62,7 @@ func Run(config types.Config, users types.UserList) {
var isSecure bool var isSecure bool
if config.CookieSecure { if api.Config.CookieSecure {
isSecure = true isSecure = true
} else { } else {
isSecure = false isSecure = false
@@ -60,7 +77,7 @@ func Run(config types.Config, users types.UserList) {
router.Use(sessions.Sessions("tinyauth", store)) router.Use(sessions.Sessions("tinyauth", store))
router.Use(func(c *gin.Context) { router.Use(func(c *gin.Context) {
if !strings.HasPrefix(c.Request.URL.Path, "/api") { if !strings.HasPrefix(c.Request.URL.Path, "/api") {
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
if os.IsNotExist(err) { if os.IsNotExist(err) {
@@ -71,8 +88,12 @@ func Run(config types.Config, users types.UserList) {
} }
}) })
router.GET("/api/auth", func (c *gin.Context) { api.Router = router
userContext := hooks.UseUserContext(c, users) }
func (api *API) SetupRoutes() {
api.Router.GET("/api/auth", func (c *gin.Context) {
userContext := api.Hooks.UseUserContext(c)
if userContext.IsLoggedIn { if userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -97,10 +118,10 @@ func Run(config types.Config, users types.UserList) {
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
}) })
router.POST("/api/login", func (c *gin.Context) { api.Router.POST("/api/login", func (c *gin.Context) {
var login types.LoginRequest var login types.LoginRequest
err := c.BindJSON(&login) err := c.BindJSON(&login)
@@ -113,7 +134,7 @@ func Run(config types.Config, users types.UserList) {
return return
} }
user := auth.FindUser(users, login.Username) user := api.Auth.GetUser(login.Username)
if user == nil { if user == nil {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
@@ -123,7 +144,7 @@ func Run(config types.Config, users types.UserList) {
return return
} }
if !auth.CheckPassword(*user, login.Password) { if !api.Auth.CheckPassword(*user, login.Password) {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -141,7 +162,7 @@ func Run(config types.Config, users types.UserList) {
}) })
}) })
router.POST("/api/logout", func (c *gin.Context) { api.Router.POST("/api/logout", func (c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
session.Delete("tinyauth") session.Delete("tinyauth")
session.Save() session.Save()
@@ -152,8 +173,8 @@ func Run(config types.Config, users types.UserList) {
}) })
}) })
router.GET("/api/status", func (c *gin.Context) { api.Router.GET("/api/status", func (c *gin.Context) {
userContext := hooks.UseUserContext(c, users) userContext := api.Hooks.UseUserContext(c)
if !userContext.IsLoggedIn { if !userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -173,14 +194,18 @@ func Run(config types.Config, users types.UserList) {
}) })
}) })
router.GET("/api/healthcheck", func (c *gin.Context) { api.Router.GET("/api/healthcheck", func (c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "OK", "message": "OK",
}) })
}) })
}
router.Run(fmt.Sprintf("%s:%d", config.Address, config.Port))
func (api *API) Run() {
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
} }
func zerolog() gin.HandlerFunc { func zerolog() gin.HandlerFunc {

View File

@@ -6,8 +6,18 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func FindUser(userList types.UserList, username string) (*types.User) { func NewAuth(userList types.Users) *Auth {
for _, user := range userList.Users { return &Auth{
Users: userList,
}
}
type Auth struct {
Users types.Users
}
func (auth *Auth) GetUser(username string) *types.User {
for _, user := range auth.Users {
if user.Username == username { if user.Username == username {
return &user return &user
} }
@@ -15,7 +25,7 @@ func FindUser(userList types.UserList, username string) (*types.User) {
return nil return nil
} }
func CheckPassword(user types.User, password string) bool { func (auth *Auth) CheckPassword(user types.User, password string) bool {
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
return hashedPasswordErr == nil return hashedPasswordErr == nil
} }

View File

@@ -8,7 +8,17 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext) { func NewHooks(auth *auth.Auth) *Hooks {
return &Hooks{
Auth: auth,
}
}
type Hooks struct {
Auth *auth.Auth
}
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) {
session := sessions.Default(c) session := sessions.Default(c)
cookie := session.Get("tinyauth") cookie := session.Get("tinyauth")
@@ -28,7 +38,7 @@ func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext)
} }
} }
user := auth.FindUser(userList, username) user := hooks.Auth.GetUser(username)
if user == nil { if user == nil {
return types.UserContext{ return types.UserContext{

View File

@@ -14,9 +14,7 @@ type User struct {
Password string Password string
} }
type UserList struct { type Users []User
Users []User
}
type Config struct { type Config struct {
Port int `validate:"number" mapstructure:"port"` Port int `validate:"number" mapstructure:"port"`
@@ -32,3 +30,11 @@ type UserContext struct {
Username string Username string
IsLoggedIn bool IsLoggedIn bool
} }
type APIConfig struct {
Port int
Address string
Secret string
AppURL string
CookieSecure bool
}

View File

@@ -8,26 +8,26 @@ import (
"tinyauth/internal/types" "tinyauth/internal/types"
) )
func ParseUsers(users string) (types.UserList, error) { func ParseUsers(users string) (types.Users, error) {
var userList types.UserList var usersParsed types.Users
userListString := strings.Split(users, ",") userList := strings.Split(users, ",")
if len(userListString) == 0 { if len(userList) == 0 {
return types.UserList{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
for _, user := range userListString { for _, user := range userList {
userSplit := strings.Split(user, ":") userSplit := strings.Split(user, ":")
if len(userSplit) != 2 { if len(userSplit) != 2 {
return types.UserList{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
userList.Users = append(userList.Users, types.User{ usersParsed = append(usersParsed, types.User{
Username: userSplit[0], Username: userSplit[0],
Password: userSplit[1], Password: userSplit[1],
}) })
} }
return userList, nil return usersParsed, nil
} }
func GetRootURL(urlSrc string) (string, error) { func GetRootURL(urlSrc string) (string, error) {