mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 04:35:40 +00:00
refactor: use dependency injection
This commit is contained in:
@@ -20,8 +20,25 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Run(config types.Config, users types.UserList) {
|
||||
func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) {
|
||||
return &API{
|
||||
Config: config,
|
||||
Hooks: hooks,
|
||||
Auth: auth,
|
||||
Router: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type API struct {
|
||||
Config types.APIConfig
|
||||
Router *gin.Engine
|
||||
Hooks *hooks.Hooks
|
||||
Auth *auth.Auth
|
||||
}
|
||||
|
||||
func (api *API) Init() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(zerolog())
|
||||
dist, distErr := fs.Sub(assets.Assets, "dist")
|
||||
@@ -32,9 +49,9 @@ func Run(config types.Config, users types.UserList) {
|
||||
}
|
||||
|
||||
fileServer := http.FileServer(http.FS(dist))
|
||||
store := cookie.NewStore([]byte(config.Secret))
|
||||
store := cookie.NewStore([]byte(api.Config.Secret))
|
||||
|
||||
domain, domainErr := utils.GetRootURL(config.AppURL)
|
||||
domain, domainErr := utils.GetRootURL(api.Config.AppURL)
|
||||
|
||||
log.Info().Str("domain", domain).Msg("Using domain for cookies")
|
||||
|
||||
@@ -45,7 +62,7 @@ func Run(config types.Config, users types.UserList) {
|
||||
|
||||
var isSecure bool
|
||||
|
||||
if config.CookieSecure {
|
||||
if api.Config.CookieSecure {
|
||||
isSecure = true
|
||||
} else {
|
||||
isSecure = false
|
||||
@@ -60,7 +77,7 @@ func Run(config types.Config, users types.UserList) {
|
||||
|
||||
router.Use(sessions.Sessions("tinyauth", store))
|
||||
|
||||
router.Use(func(c *gin.Context) {
|
||||
router.Use(func(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/api") {
|
||||
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
|
||||
if os.IsNotExist(err) {
|
||||
@@ -71,8 +88,12 @@ func Run(config types.Config, users types.UserList) {
|
||||
}
|
||||
})
|
||||
|
||||
router.GET("/api/auth", func (c *gin.Context) {
|
||||
userContext := hooks.UseUserContext(c, users)
|
||||
api.Router = router
|
||||
}
|
||||
|
||||
func (api *API) SetupRoutes() {
|
||||
api.Router.GET("/api/auth", func (c *gin.Context) {
|
||||
userContext := api.Hooks.UseUserContext(c)
|
||||
|
||||
if userContext.IsLoggedIn {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -97,10 +118,10 @@ func Run(config types.Config, users types.UserList) {
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
|
||||
})
|
||||
|
||||
router.POST("/api/login", func (c *gin.Context) {
|
||||
api.Router.POST("/api/login", func (c *gin.Context) {
|
||||
var login types.LoginRequest
|
||||
|
||||
err := c.BindJSON(&login)
|
||||
@@ -113,7 +134,7 @@ func Run(config types.Config, users types.UserList) {
|
||||
return
|
||||
}
|
||||
|
||||
user := auth.FindUser(users, login.Username)
|
||||
user := api.Auth.GetUser(login.Username)
|
||||
|
||||
if user == nil {
|
||||
c.JSON(401, gin.H{
|
||||
@@ -123,7 +144,7 @@ func Run(config types.Config, users types.UserList) {
|
||||
return
|
||||
}
|
||||
|
||||
if !auth.CheckPassword(*user, login.Password) {
|
||||
if !api.Auth.CheckPassword(*user, login.Password) {
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
@@ -141,7 +162,7 @@ func Run(config types.Config, users types.UserList) {
|
||||
})
|
||||
})
|
||||
|
||||
router.POST("/api/logout", func (c *gin.Context) {
|
||||
api.Router.POST("/api/logout", func (c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Delete("tinyauth")
|
||||
session.Save()
|
||||
@@ -152,8 +173,8 @@ func Run(config types.Config, users types.UserList) {
|
||||
})
|
||||
})
|
||||
|
||||
router.GET("/api/status", func (c *gin.Context) {
|
||||
userContext := hooks.UseUserContext(c, users)
|
||||
api.Router.GET("/api/status", func (c *gin.Context) {
|
||||
userContext := api.Hooks.UseUserContext(c)
|
||||
|
||||
if !userContext.IsLoggedIn {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -173,14 +194,18 @@ func Run(config types.Config, users types.UserList) {
|
||||
})
|
||||
})
|
||||
|
||||
router.GET("/api/healthcheck", func (c *gin.Context) {
|
||||
api.Router.GET("/api/healthcheck", func (c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"message": "OK",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
router.Run(fmt.Sprintf("%s:%d", config.Address, config.Port))
|
||||
|
||||
func (api *API) Run() {
|
||||
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
|
||||
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
|
||||
}
|
||||
|
||||
func zerolog() gin.HandlerFunc {
|
||||
|
||||
@@ -6,8 +6,18 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func FindUser(userList types.UserList, username string) (*types.User) {
|
||||
for _, user := range userList.Users {
|
||||
func NewAuth(userList types.Users) *Auth {
|
||||
return &Auth{
|
||||
Users: userList,
|
||||
}
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
Users types.Users
|
||||
}
|
||||
|
||||
func (auth *Auth) GetUser(username string) *types.User {
|
||||
for _, user := range auth.Users {
|
||||
if user.Username == username {
|
||||
return &user
|
||||
}
|
||||
@@ -15,7 +25,7 @@ func FindUser(userList types.UserList, username string) (*types.User) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckPassword(user types.User, password string) bool {
|
||||
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
||||
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||
return hashedPasswordErr == nil
|
||||
}
|
||||
@@ -8,7 +8,17 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext) {
|
||||
func NewHooks(auth *auth.Auth) *Hooks {
|
||||
return &Hooks{
|
||||
Auth: auth,
|
||||
}
|
||||
}
|
||||
|
||||
type Hooks struct {
|
||||
Auth *auth.Auth
|
||||
}
|
||||
|
||||
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) {
|
||||
session := sessions.Default(c)
|
||||
cookie := session.Get("tinyauth")
|
||||
|
||||
@@ -28,7 +38,7 @@ func UseUserContext(c *gin.Context, userList types.UserList) (types.UserContext)
|
||||
}
|
||||
}
|
||||
|
||||
user := auth.FindUser(userList, username)
|
||||
user := hooks.Auth.GetUser(username)
|
||||
|
||||
if user == nil {
|
||||
return types.UserContext{
|
||||
|
||||
@@ -14,9 +14,7 @@ type User struct {
|
||||
Password string
|
||||
}
|
||||
|
||||
type UserList struct {
|
||||
Users []User
|
||||
}
|
||||
type Users []User
|
||||
|
||||
type Config struct {
|
||||
Port int `validate:"number" mapstructure:"port"`
|
||||
@@ -31,4 +29,12 @@ type Config struct {
|
||||
type UserContext struct {
|
||||
Username string
|
||||
IsLoggedIn bool
|
||||
}
|
||||
|
||||
type APIConfig struct {
|
||||
Port int
|
||||
Address string
|
||||
Secret string
|
||||
AppURL string
|
||||
CookieSecure bool
|
||||
}
|
||||
@@ -8,26 +8,26 @@ import (
|
||||
"tinyauth/internal/types"
|
||||
)
|
||||
|
||||
func ParseUsers(users string) (types.UserList, error) {
|
||||
var userList types.UserList
|
||||
userListString := strings.Split(users, ",")
|
||||
func ParseUsers(users string) (types.Users, error) {
|
||||
var usersParsed types.Users
|
||||
userList := strings.Split(users, ",")
|
||||
|
||||
if len(userListString) == 0 {
|
||||
return types.UserList{}, errors.New("invalid user format")
|
||||
if len(userList) == 0 {
|
||||
return types.Users{}, errors.New("invalid user format")
|
||||
}
|
||||
|
||||
for _, user := range userListString {
|
||||
for _, user := range userList {
|
||||
userSplit := strings.Split(user, ":")
|
||||
if len(userSplit) != 2 {
|
||||
return types.UserList{}, errors.New("invalid user format")
|
||||
return types.Users{}, errors.New("invalid user format")
|
||||
}
|
||||
userList.Users = append(userList.Users, types.User{
|
||||
usersParsed = append(usersParsed, types.User{
|
||||
Username: userSplit[0],
|
||||
Password: userSplit[1],
|
||||
})
|
||||
}
|
||||
|
||||
return userList, nil
|
||||
return usersParsed, nil
|
||||
}
|
||||
|
||||
func GetRootURL(urlSrc string) (string, error) {
|
||||
|
||||
Reference in New Issue
Block a user