refactor: use dependency injection

This commit is contained in:
Stavros
2025-01-21 18:41:06 +02:00
parent 2988b5f22f
commit 5e73d06fcc
6 changed files with 113 additions and 42 deletions

View File

@@ -6,6 +6,8 @@ import (
"time"
"tinyauth/internal/api"
"tinyauth/internal/assets"
"tinyauth/internal/auth"
"tinyauth/internal/hooks"
"tinyauth/internal/types"
"tinyauth/internal/utils"
@@ -45,26 +47,44 @@ var rootCmd = &cobra.Command{
os.Exit(1)
}
users := config.Users
usersString := config.Users
if config.UsersFile != "" {
log.Info().Msg("Reading users from file")
usersFromFile, readErr := utils.GetUsersFromFile(config.UsersFile)
HandleError(readErr, "Failed to read users from file")
usersFromFileParsed := strings.Join(strings.Split(usersFromFile, "\n"), ",")
if users != "" {
users = users + "," + usersFromFileParsed
if usersString != "" {
usersString = usersString + "," + usersFromFileParsed
} else {
users = usersFromFileParsed
usersString = usersFromFileParsed
}
}
userList, createErr := utils.ParseUsers(users)
HandleError(createErr, "Failed to parse users")
users, parseErr := utils.ParseUsers(usersString)
HandleError(parseErr, "Failed to parse users")
// Start server
log.Info().Msg("Starting server")
api.Run(config, userList)
// Create auth service
auth := auth.NewAuth(users)
// Create hooks service
hooks := hooks.NewHooks(auth)
// Create API
api := api.NewAPI(types.APIConfig{
Port: config.Port,
Address: config.Address,
Secret: config.Secret,
AppURL: config.AppURL,
CookieSecure: config.CookieSecure,
}, hooks, auth)
// Setup routes
api.Init()
api.SetupRoutes()
// Start
api.Run()
},
}

View File

@@ -20,8 +20,25 @@ import (
"github.com/rs/zerolog/log"
)
func Run(config types.Config, users types.UserList) {
func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) {
return &API{
Config: config,
Hooks: hooks,
Auth: auth,
Router: nil,
}
}
type API struct {
Config types.APIConfig
Router *gin.Engine
Hooks *hooks.Hooks
Auth *auth.Auth
}
func (api *API) Init() {
gin.SetMode(gin.ReleaseMode)
router := gin.New()
router.Use(zerolog())
dist, distErr := fs.Sub(assets.Assets, "dist")
@@ -32,9 +49,9 @@ func Run(config types.Config, users types.UserList) {
}
fileServer := http.FileServer(http.FS(dist))
store := cookie.NewStore([]byte(config.Secret))
store := cookie.NewStore([]byte(api.Config.Secret))
domain, domainErr := utils.GetRootURL(config.AppURL)
domain, domainErr := utils.GetRootURL(api.Config.AppURL)
log.Info().Str("domain", domain).Msg("Using domain for cookies")
@@ -45,7 +62,7 @@ func Run(config types.Config, users types.UserList) {
var isSecure bool
if config.CookieSecure {
if api.Config.CookieSecure {
isSecure = true
} else {
isSecure = false
@@ -71,8 +88,12 @@ func Run(config types.Config, users types.UserList) {
}
})
router.GET("/api/auth", func (c *gin.Context) {
userContext := hooks.UseUserContext(c, users)
api.Router = router
}
func (api *API) SetupRoutes() {
api.Router.GET("/api/auth", func (c *gin.Context) {
userContext := api.Hooks.UseUserContext(c)
if userContext.IsLoggedIn {
c.JSON(200, gin.H{
@@ -97,10 +118,10 @@ func Run(config types.Config, users types.UserList) {
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", config.AppURL, queries.Encode()))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
})
router.POST("/api/login", func (c *gin.Context) {
api.Router.POST("/api/login", func (c *gin.Context) {
var login types.LoginRequest
err := c.BindJSON(&login)
@@ -113,7 +134,7 @@ func Run(config types.Config, users types.UserList) {
return
}
user := auth.FindUser(users, login.Username)
user := api.Auth.GetUser(login.Username)
if user == nil {
c.JSON(401, gin.H{
@@ -123,7 +144,7 @@ func Run(config types.Config, users types.UserList) {
return
}
if !auth.CheckPassword(*user, login.Password) {
if !api.Auth.CheckPassword(*user, login.Password) {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -141,7 +162,7 @@ func Run(config types.Config, users types.UserList) {
})
})
router.POST("/api/logout", func (c *gin.Context) {
api.Router.POST("/api/logout", func (c *gin.Context) {
session := sessions.Default(c)
session.Delete("tinyauth")
session.Save()
@@ -152,8 +173,8 @@ func Run(config types.Config, users types.UserList) {
})
})
router.GET("/api/status", func (c *gin.Context) {
userContext := hooks.UseUserContext(c, users)
api.Router.GET("/api/status", func (c *gin.Context) {
userContext := api.Hooks.UseUserContext(c)
if !userContext.IsLoggedIn {
c.JSON(200, gin.H{
@@ -173,14 +194,18 @@ func Run(config types.Config, users types.UserList) {
})
})
router.GET("/api/healthcheck", func (c *gin.Context) {
api.Router.GET("/api/healthcheck", func (c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"message": "OK",
})
})
}
router.Run(fmt.Sprintf("%s:%d", config.Address, config.Port))
func (api *API) Run() {
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
}
func zerolog() gin.HandlerFunc {

View File

@@ -6,8 +6,18 @@ import (
"golang.org/x/crypto/bcrypt"
)
func FindUser(userList types.UserList, username string) (*types.User) {
for _, user := range userList.Users {
func NewAuth(userList types.Users) *Auth {
return &Auth{
Users: userList,
}
}
type Auth struct {
Users types.Users
}
func (auth *Auth) GetUser(username string) *types.User {
for _, user := range auth.Users {
if user.Username == username {
return &user
}
@@ -15,7 +25,7 @@ func FindUser(userList types.UserList, username string) (*types.User) {
return nil
}
func CheckPassword(user types.User, password string) bool {
func (auth *Auth) CheckPassword(user types.User, password string) bool {
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
return hashedPasswordErr == nil
}

View File

@@ -8,7 +8,17 @@ import (
"github.com/gin-gonic/gin"
)
func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext) {
func NewHooks(auth *auth.Auth) *Hooks {
return &Hooks{
Auth: auth,
}
}
type Hooks struct {
Auth *auth.Auth
}
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) {
session := sessions.Default(c)
cookie := session.Get("tinyauth")
@@ -28,7 +38,7 @@ func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext)
}
}
user := auth.FindUser(userList, username)
user := hooks.Auth.GetUser(username)
if user == nil {
return types.UserContext{

View File

@@ -14,9 +14,7 @@ type User struct {
Password string
}
type UserList struct {
Users []User
}
type Users []User
type Config struct {
Port int `validate:"number" mapstructure:"port"`
@@ -32,3 +30,11 @@ type UserContext struct {
Username string
IsLoggedIn bool
}
type APIConfig struct {
Port int
Address string
Secret string
AppURL string
CookieSecure bool
}

View File

@@ -8,26 +8,26 @@ import (
"tinyauth/internal/types"
)
func ParseUsers(users string) (types.UserList, error) {
var userList types.UserList
userListString := strings.Split(users, ",")
func ParseUsers(users string) (types.Users, error) {
var usersParsed types.Users
userList := strings.Split(users, ",")
if len(userListString) == 0 {
return types.UserList{}, errors.New("invalid user format")
if len(userList) == 0 {
return types.Users{}, errors.New("invalid user format")
}
for _, user := range userListString {
for _, user := range userList {
userSplit := strings.Split(user, ":")
if len(userSplit) != 2 {
return types.UserList{}, errors.New("invalid user format")
return types.Users{}, errors.New("invalid user format")
}
userList.Users = append(userList.Users, types.User{
usersParsed = append(usersParsed, types.User{
Username: userSplit[0],
Password: userSplit[1],
})
}
return userList, nil
return usersParsed, nil
}
func GetRootURL(urlSrc string) (string, error) {