refactor: simplify user parsing (#571)

This commit is contained in:
Stavros
2026-01-08 16:03:37 +02:00
committed by GitHub
parent 454612226b
commit e3f92ce4fc
7 changed files with 59 additions and 61 deletions

View File

@@ -2,7 +2,6 @@ package bootstrap
import ( import (
"fmt" "fmt"
"strings"
"github.com/steveiliop56/tinyauth/internal/controller" "github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/middleware" "github.com/steveiliop56/tinyauth/internal/middleware"
@@ -15,7 +14,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
engine.Use(gin.Recovery()) engine.Use(gin.Recovery())
if len(app.config.Server.TrustedProxies) > 0 { if len(app.config.Server.TrustedProxies) > 0 {
err := engine.SetTrustedProxies(strings.Split(app.config.Server.TrustedProxies, ",")) err := engine.SetTrustedProxies(app.config.Server.TrustedProxies)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set trusted proxies: %w", err) return nil, fmt.Errorf("failed to set trusted proxies: %w", err)

View File

@@ -33,15 +33,15 @@ type Config struct {
} }
type ServerConfig struct { type ServerConfig struct {
Port int `description:"The port on which the server listens." yaml:"port"` Port int `description:"The port on which the server listens." yaml:"port"`
Address string `description:"The address on which the server listens." yaml:"address"` Address string `description:"The address on which the server listens." yaml:"address"`
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
TrustedProxies string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"` TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
} }
type AuthConfig struct { type AuthConfig struct {
IP IPConfig `description:"IP whitelisting config options." yaml:"ip"` IP IPConfig `description:"IP whitelisting config options." yaml:"ip"`
Users string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"` Users []string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
UsersFile string `description:"Path to the users file." yaml:"usersFile"` UsersFile string `description:"Path to the users file." yaml:"usersFile"`
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"` SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
SessionExpiry int `description:"Session expiry time in seconds." yaml:"sessionExpiry"` SessionExpiry int `description:"Session expiry time in seconds." yaml:"sessionExpiry"`
@@ -56,7 +56,7 @@ type IPConfig struct {
} }
type OAuthConfig struct { type OAuthConfig struct {
Whitelist string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
} }

View File

@@ -57,7 +57,7 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
}, },
}, },
OauthWhitelist: "", OauthWhitelist: []string{},
SessionExpiry: 3600, SessionExpiry: 3600,
SessionMaxLifetime: 0, SessionMaxLifetime: 0,
SecureCookie: false, SecureCookie: false,

View File

@@ -60,7 +60,7 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng
TotpSecret: totpSecret, TotpSecret: totpSecret,
}, },
}, },
OauthWhitelist: "", OauthWhitelist: []string{},
SessionExpiry: 3600, SessionExpiry: 3600,
SessionMaxLifetime: 0, SessionMaxLifetime: 0,
SecureCookie: false, SecureCookie: false,

View File

@@ -27,7 +27,7 @@ type LoginAttempt struct {
type AuthServiceConfig struct { type AuthServiceConfig struct {
Users []config.User Users []config.User
OauthWhitelist string OauthWhitelist []string
SessionExpiry int SessionExpiry int
SessionMaxLifetime int SessionMaxLifetime int
SecureCookie bool SecureCookie bool
@@ -187,7 +187,7 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
} }
func (auth *AuthService) IsEmailWhitelisted(email string) bool { func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(auth.config.OauthWhitelist, email) return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
} }
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error { func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {

View File

@@ -7,22 +7,14 @@ import (
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
) )
func ParseUsers(users string) ([]config.User, error) { func ParseUsers(usersStr []string) ([]config.User, error) {
var usersParsed []config.User var users []config.User
users = strings.TrimSpace(users) if len(usersStr) == 0 {
if users == "" {
return []config.User{}, nil return []config.User{}, nil
} }
userList := strings.Split(users, ",") for _, user := range usersStr {
if len(userList) == 0 {
return []config.User{}, errors.New("invalid user format")
}
for _, user := range userList {
if strings.TrimSpace(user) == "" { if strings.TrimSpace(user) == "" {
continue continue
} }
@@ -30,64 +22,71 @@ func ParseUsers(users string) ([]config.User, error) {
if err != nil { if err != nil {
return []config.User{}, err return []config.User{}, err
} }
usersParsed = append(usersParsed, parsed) users = append(users, parsed)
} }
return usersParsed, nil return users, nil
} }
func GetUsers(conf string, file string) ([]config.User, error) { func GetUsers(usersCfg []string, usersPath string) ([]config.User, error) {
var users string var usersStr []string
if conf == "" && file == "" { if len(usersCfg) == 0 && usersPath == "" {
return []config.User{}, nil return []config.User{}, nil
} }
if conf != "" { if len(usersCfg) > 0 {
users += conf usersStr = append(usersStr, usersCfg...)
} }
if file != "" { if usersPath != "" {
contents, err := ReadFile(file) contents, err := ReadFile(usersPath)
if err != nil { if err != nil {
return []config.User{}, err return []config.User{}, err
} }
if users != "" {
users += "," lines := strings.SplitSeq(contents, "\n")
for line := range lines {
lineTrimmed := strings.TrimSpace(line)
if lineTrimmed == "" {
continue
}
usersStr = append(usersStr, lineTrimmed)
} }
users += ParseFileToLine(contents)
} }
return ParseUsers(users) return ParseUsers(usersStr)
} }
func ParseUser(user string) (config.User, error) { func ParseUser(userStr string) (config.User, error) {
if strings.Contains(user, "$$") { if strings.Contains(userStr, "$$") {
user = strings.ReplaceAll(user, "$$", "$") userStr = strings.ReplaceAll(userStr, "$$", "$")
} }
userSplit := strings.Split(user, ":") parts := strings.SplitN(userStr, ":", 4)
if len(userSplit) < 2 || len(userSplit) > 3 { if len(parts) < 2 || len(parts) > 3 {
return config.User{}, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
for _, userPart := range userSplit { for i, part := range parts {
if strings.TrimSpace(userPart) == "" { trimmed := strings.TrimSpace(part)
if trimmed == "" {
return config.User{}, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
parts[i] = trimmed
} }
if len(userSplit) == 2 { user := config.User{
return config.User{ Username: parts[0],
Username: strings.TrimSpace(userSplit[0]), Password: parts[1],
Password: strings.TrimSpace(userSplit[1]),
}, nil
} }
return config.User{ if len(parts) == 3 {
Username: strings.TrimSpace(userSplit[0]), user.TotpSecret = parts[2]
Password: strings.TrimSpace(userSplit[1]), }
TotpSecret: strings.TrimSpace(userSplit[2]),
}, nil return user, nil
} }

View File

@@ -22,7 +22,7 @@ func TestGetUsers(t *testing.T) {
defer os.Remove("/tmp/tinyauth_users_test.txt") defer os.Remove("/tmp/tinyauth_users_test.txt")
// Test file // Test file
users, err := utils.GetUsers("", "/tmp/tinyauth_users_test.txt") users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt")
assert.NilError(t, err) assert.NilError(t, err)
@@ -34,7 +34,7 @@ func TestGetUsers(t *testing.T) {
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
// Test config // Test config
users, err = utils.GetUsers("user3:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G,user4:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "") users, err = utils.GetUsers([]string{"user3:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "user4:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"}, "")
assert.NilError(t, err) assert.NilError(t, err)
@@ -46,7 +46,7 @@ func TestGetUsers(t *testing.T) {
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password) assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[1].Password)
// Test both // Test both
users, err = utils.GetUsers("user5:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "/tmp/tinyauth_users_test.txt") users, err = utils.GetUsers([]string{"user5:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"}, "/tmp/tinyauth_users_test.txt")
assert.NilError(t, err) assert.NilError(t, err)
@@ -60,14 +60,14 @@ func TestGetUsers(t *testing.T) {
assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[2].Password) assert.Equal(t, "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", users[2].Password)
// Test empty // Test empty
users, err = utils.GetUsers("", "") users, err = utils.GetUsers([]string{}, "")
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, 0, len(users)) assert.Equal(t, 0, len(users))
// Test non-existent file // Test non-existent file
users, err = utils.GetUsers("", "/tmp/non_existent_file.txt") users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt")
assert.ErrorContains(t, err, "no such file or directory") assert.ErrorContains(t, err, "no such file or directory")
@@ -76,7 +76,7 @@ func TestGetUsers(t *testing.T) {
func TestParseUsers(t *testing.T) { func TestParseUsers(t *testing.T) {
// Valid users // Valid users
users, err := utils.ParseUsers("user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G,user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF") // user2 has TOTP users, err := utils.ParseUsers([]string{"user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G", "user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF"}) // user2 has TOTP
assert.NilError(t, err) assert.NilError(t, err)
@@ -90,7 +90,7 @@ func TestParseUsers(t *testing.T) {
assert.Equal(t, "ABCDEF", users[1].TotpSecret) assert.Equal(t, "ABCDEF", users[1].TotpSecret)
// Valid weirdly spaced users // Valid weirdly spaced users
users, err = utils.ParseUsers(" user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G , user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF ") // Spacing is on purpose users, err = utils.ParseUsers([]string{" user1:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G ", " user2:$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G:ABCDEF "}) // Spacing is on purpose
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, 2, len(users)) assert.Equal(t, 2, len(users))