mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 12:45:47 +00:00
refactor: use dependency injection
This commit is contained in:
38
cmd/root.go
38
cmd/root.go
@@ -6,6 +6,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"tinyauth/internal/api"
|
"tinyauth/internal/api"
|
||||||
"tinyauth/internal/assets"
|
"tinyauth/internal/assets"
|
||||||
|
"tinyauth/internal/auth"
|
||||||
|
"tinyauth/internal/hooks"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/types"
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
|
|
||||||
@@ -45,26 +47,44 @@ var rootCmd = &cobra.Command{
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
users := config.Users
|
usersString := config.Users
|
||||||
|
|
||||||
if config.UsersFile != "" {
|
if config.UsersFile != "" {
|
||||||
log.Info().Msg("Reading users from file")
|
log.Info().Msg("Reading users from file")
|
||||||
usersFromFile, readErr := utils.GetUsersFromFile(config.UsersFile)
|
usersFromFile, readErr := utils.GetUsersFromFile(config.UsersFile)
|
||||||
HandleError(readErr, "Failed to read users from file")
|
HandleError(readErr, "Failed to read users from file")
|
||||||
usersFromFileParsed := strings.Join(strings.Split(usersFromFile, "\n"), ",")
|
usersFromFileParsed := strings.Join(strings.Split(usersFromFile, "\n"), ",")
|
||||||
if users != "" {
|
if usersString != "" {
|
||||||
users = users + "," + usersFromFileParsed
|
usersString = usersString + "," + usersFromFileParsed
|
||||||
} else {
|
} else {
|
||||||
users = usersFromFileParsed
|
usersString = usersFromFileParsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userList, createErr := utils.ParseUsers(users)
|
users, parseErr := utils.ParseUsers(usersString)
|
||||||
HandleError(createErr, "Failed to parse users")
|
HandleError(parseErr, "Failed to parse users")
|
||||||
|
|
||||||
// Start server
|
// Create auth service
|
||||||
log.Info().Msg("Starting server")
|
auth := auth.NewAuth(users)
|
||||||
api.Run(config, userList)
|
|
||||||
|
// Create hooks service
|
||||||
|
hooks := hooks.NewHooks(auth)
|
||||||
|
|
||||||
|
// Create API
|
||||||
|
api := api.NewAPI(types.APIConfig{
|
||||||
|
Port: config.Port,
|
||||||
|
Address: config.Address,
|
||||||
|
Secret: config.Secret,
|
||||||
|
AppURL: config.AppURL,
|
||||||
|
CookieSecure: config.CookieSecure,
|
||||||
|
}, hooks, auth)
|
||||||
|
|
||||||
|
// Setup routes
|
||||||
|
api.Init()
|
||||||
|
api.SetupRoutes()
|
||||||
|
|
||||||
|
// Start
|
||||||
|
api.Run()
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,25 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Run(config types.Config, users types.UserList) {
|
func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) {
|
||||||
|
return &API{
|
||||||
|
Config: config,
|
||||||
|
Hooks: hooks,
|
||||||
|
Auth: auth,
|
||||||
|
Router: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type API struct {
|
||||||
|
Config types.APIConfig
|
||||||
|
Router *gin.Engine
|
||||||
|
Hooks *hooks.Hooks
|
||||||
|
Auth *auth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *API) Init() {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(zerolog())
|
router.Use(zerolog())
|
||||||
dist, distErr := fs.Sub(assets.Assets, "dist")
|
dist, distErr := fs.Sub(assets.Assets, "dist")
|
||||||
@@ -32,9 +49,9 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fileServer := http.FileServer(http.FS(dist))
|
fileServer := http.FileServer(http.FS(dist))
|
||||||
store := cookie.NewStore([]byte(config.Secret))
|
store := cookie.NewStore([]byte(api.Config.Secret))
|
||||||
|
|
||||||
domain, domainErr := utils.GetRootURL(config.AppURL)
|
domain, domainErr := utils.GetRootURL(api.Config.AppURL)
|
||||||
|
|
||||||
log.Info().Str("domain", domain).Msg("Using domain for cookies")
|
log.Info().Str("domain", domain).Msg("Using domain for cookies")
|
||||||
|
|
||||||
@@ -45,7 +62,7 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
|
|
||||||
var isSecure bool
|
var isSecure bool
|
||||||
|
|
||||||
if config.CookieSecure {
|
if api.Config.CookieSecure {
|
||||||
isSecure = true
|
isSecure = true
|
||||||
} else {
|
} else {
|
||||||
isSecure = false
|
isSecure = false
|
||||||
@@ -60,7 +77,7 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
|
|
||||||
router.Use(sessions.Sessions("tinyauth", store))
|
router.Use(sessions.Sessions("tinyauth", store))
|
||||||
|
|
||||||
router.Use(func(c *gin.Context) {
|
router.Use(func(c *gin.Context) {
|
||||||
if !strings.HasPrefix(c.Request.URL.Path, "/api") {
|
if !strings.HasPrefix(c.Request.URL.Path, "/api") {
|
||||||
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
|
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
@@ -71,8 +88,12 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
router.GET("/api/auth", func (c *gin.Context) {
|
api.Router = router
|
||||||
userContext := hooks.UseUserContext(c, users)
|
}
|
||||||
|
|
||||||
|
func (api *API) SetupRoutes() {
|
||||||
|
api.Router.GET("/api/auth", func (c *gin.Context) {
|
||||||
|
userContext := api.Hooks.UseUserContext(c)
|
||||||
|
|
||||||
if userContext.IsLoggedIn {
|
if userContext.IsLoggedIn {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -97,10 +118,10 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
|
||||||
})
|
})
|
||||||
|
|
||||||
router.POST("/api/login", func (c *gin.Context) {
|
api.Router.POST("/api/login", func (c *gin.Context) {
|
||||||
var login types.LoginRequest
|
var login types.LoginRequest
|
||||||
|
|
||||||
err := c.BindJSON(&login)
|
err := c.BindJSON(&login)
|
||||||
@@ -113,7 +134,7 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user := auth.FindUser(users, login.Username)
|
user := api.Auth.GetUser(login.Username)
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
@@ -123,7 +144,7 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !auth.CheckPassword(*user, login.Password) {
|
if !api.Auth.CheckPassword(*user, login.Password) {
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
@@ -141,7 +162,7 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
router.POST("/api/logout", func (c *gin.Context) {
|
api.Router.POST("/api/logout", func (c *gin.Context) {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
session.Delete("tinyauth")
|
session.Delete("tinyauth")
|
||||||
session.Save()
|
session.Save()
|
||||||
@@ -152,8 +173,8 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
router.GET("/api/status", func (c *gin.Context) {
|
api.Router.GET("/api/status", func (c *gin.Context) {
|
||||||
userContext := hooks.UseUserContext(c, users)
|
userContext := api.Hooks.UseUserContext(c)
|
||||||
|
|
||||||
if !userContext.IsLoggedIn {
|
if !userContext.IsLoggedIn {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -173,14 +194,18 @@ func Run(config types.Config, users types.UserList) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
router.GET("/api/healthcheck", func (c *gin.Context) {
|
api.Router.GET("/api/healthcheck", func (c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "OK",
|
"message": "OK",
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
router.Run(fmt.Sprintf("%s:%d", config.Address, config.Port))
|
|
||||||
|
func (api *API) Run() {
|
||||||
|
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
|
||||||
|
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
func zerolog() gin.HandlerFunc {
|
func zerolog() gin.HandlerFunc {
|
||||||
|
|||||||
@@ -6,8 +6,18 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func FindUser(userList types.UserList, username string) (*types.User) {
|
func NewAuth(userList types.Users) *Auth {
|
||||||
for _, user := range userList.Users {
|
return &Auth{
|
||||||
|
Users: userList,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Auth struct {
|
||||||
|
Users types.Users
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *Auth) GetUser(username string) *types.User {
|
||||||
|
for _, user := range auth.Users {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
@@ -15,7 +25,7 @@ func FindUser(userList types.UserList, username string) (*types.User) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckPassword(user types.User, password string) bool {
|
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
||||||
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
return hashedPasswordErr == nil
|
return hashedPasswordErr == nil
|
||||||
}
|
}
|
||||||
@@ -8,7 +8,17 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext) {
|
func NewHooks(auth *auth.Auth) *Hooks {
|
||||||
|
return &Hooks{
|
||||||
|
Auth: auth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Hooks struct {
|
||||||
|
Auth *auth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) {
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
cookie := session.Get("tinyauth")
|
cookie := session.Get("tinyauth")
|
||||||
|
|
||||||
@@ -28,7 +38,7 @@ func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user := auth.FindUser(userList, username)
|
user := hooks.Auth.GetUser(username)
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ type User struct {
|
|||||||
Password string
|
Password string
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserList struct {
|
type Users []User
|
||||||
Users []User
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Port int `validate:"number" mapstructure:"port"`
|
Port int `validate:"number" mapstructure:"port"`
|
||||||
@@ -31,4 +29,12 @@ type Config struct {
|
|||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Username string
|
Username string
|
||||||
IsLoggedIn bool
|
IsLoggedIn bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type APIConfig struct {
|
||||||
|
Port int
|
||||||
|
Address string
|
||||||
|
Secret string
|
||||||
|
AppURL string
|
||||||
|
CookieSecure bool
|
||||||
}
|
}
|
||||||
@@ -8,26 +8,26 @@ import (
|
|||||||
"tinyauth/internal/types"
|
"tinyauth/internal/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ParseUsers(users string) (types.UserList, error) {
|
func ParseUsers(users string) (types.Users, error) {
|
||||||
var userList types.UserList
|
var usersParsed types.Users
|
||||||
userListString := strings.Split(users, ",")
|
userList := strings.Split(users, ",")
|
||||||
|
|
||||||
if len(userListString) == 0 {
|
if len(userList) == 0 {
|
||||||
return types.UserList{}, errors.New("invalid user format")
|
return types.Users{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range userListString {
|
for _, user := range userList {
|
||||||
userSplit := strings.Split(user, ":")
|
userSplit := strings.Split(user, ":")
|
||||||
if len(userSplit) != 2 {
|
if len(userSplit) != 2 {
|
||||||
return types.UserList{}, errors.New("invalid user format")
|
return types.Users{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
userList.Users = append(userList.Users, types.User{
|
usersParsed = append(usersParsed, types.User{
|
||||||
Username: userSplit[0],
|
Username: userSplit[0],
|
||||||
Password: userSplit[1],
|
Password: userSplit[1],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return userList, nil
|
return usersParsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRootURL(urlSrc string) (string, error) {
|
func GetRootURL(urlSrc string) (string, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user