fix: review comments batch 1

This commit is contained in:
Stavros
2026-05-05 18:43:22 +03:00
parent f3965a7470
commit d47e4d3d79
10 changed files with 131 additions and 88 deletions
+17 -12
View File
@@ -226,17 +226,6 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return
}
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
@@ -248,7 +237,14 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
return
}
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
context, err := new(model.UserContext).NewFromGin(c)
if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
} else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
tlog.AuditLogout(c, "unknown", "unknown")
}
http.SetCookie(c.Writer, cookie)
@@ -308,6 +304,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
user := controller.auth.GetLocalUser(context.GetUsername())
if user == nil {
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok {
+5 -5
View File
@@ -67,7 +67,7 @@ func TestUserController(t *testing.T) {
totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
@@ -83,7 +83,7 @@ func TestUserController(t *testing.T) {
totpAttrCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
@@ -141,7 +141,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 10, cookie.MaxAge)
assert.Equal(t, 9, cookie.MaxAge)
},
},
{
@@ -230,7 +230,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions
},
},
{
@@ -306,7 +306,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", totpCookie.Name)
assert.True(t, totpCookie.HttpOnly)
assert.Equal(t, "example.com", totpCookie.Domain)
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time
},
},
{
+9 -11
View File
@@ -70,20 +70,18 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
if err != nil {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
if err == nil {
if cookie != nil {
http.SetCookie(c.Writer, cookie)
}
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
c.Set("context", userContext)
c.Next()
return
} else {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
}
if cookie != nil {
http.SetCookie(c.Writer, cookie)
}
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
c.Set("context", userContext)
c.Next()
return
}
username, password, ok := c.Request.BasicAuth()
@@ -253,6 +253,18 @@ func TestContextMiddleware(t *testing.T) {
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
},
},
{
description: "Ensure fallback to basic auth when cookie is missing",
run: func(t *testing.T, args runArgs) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
userCtx, _ := args.do(req)
require.NotNil(t, userCtx)
assert.Equal(t, "testuser", userCtx.GetUsername())
assert.True(t, userCtx.Authenticated)
+15 -6
View File
@@ -80,16 +80,24 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContext, ok := userContextValue.(*UserContext)
if !ok {
if !ok || userContext == nil {
return nil, errors.New("invalid user context type")
}
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
return nil, errors.New("incomplete user context")
}
*c = *userContext
return c, nil
}
// Compatability layer until we get an excuse to drop in database migrations
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
*c = UserContext{
Authenticated: !session.TotpPending,
}
switch session.Provider {
case "local":
c.Provider = ProviderLocal
@@ -119,17 +127,18 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
Name: session.Name,
Email: session.Email,
},
Groups: strings.Split(session.OAuthGroups, ","),
Groups: func() []string {
if session.OAuthGroups == "" {
return nil
}
return strings.Split(session.OAuthGroups, ",")
}(),
Sub: session.OAuthSub,
DisplayName: session.OAuthName,
ID: session.Provider,
}
}
if !session.TotpPending {
c.Authenticated = true
}
return c, nil
}
+46 -34
View File
@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
@@ -22,47 +23,48 @@ func TestContext(t *testing.T) {
tests := []struct {
description string
context *model.UserContext
run func(*model.UserContext) any
run func(*testing.T, *model.UserContext) any
expected any
}{
{
description: "IsAuthenticated reflects Authenticated field",
context: &model.UserContext{Authenticated: true},
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
expected: true,
},
{
description: "IsLocal returns true for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.IsLocal() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
expected: true,
},
{
description: "IsOAuth returns true for ProviderOAuth",
context: &model.UserContext{Provider: model.ProviderOAuth},
run: func(c *model.UserContext) any { return c.IsOAuth() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
expected: true,
},
{
description: "IsLDAP returns true for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(c *model.UserContext) any { return c.IsLDAP() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
expected: true,
},
{
description: "IsBasicAuth returns true for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
expected: true,
},
{
description: "NewFromSession local session is authenticated and ProviderLocal",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "alice", Email: "alice@example.com", Name: "Alice",
Provider: "local",
})
require.NoError(t, err)
return [2]any{got.Provider, got.Authenticated}
},
expected: [2]any{model.ProviderLocal, true},
@@ -70,10 +72,11 @@ func TestContext(t *testing.T) {
{
description: "NewFromSession local session with TotpPending is not authenticated",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "bob", Provider: "local", TotpPending: true,
})
require.NoError(t, err)
return got.Authenticated
},
expected: false,
@@ -81,10 +84,11 @@ func TestContext(t *testing.T) {
{
description: "NewFromSession ldap session is ProviderLDAP",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "carol", Provider: "ldap",
})
require.NoError(t, err)
return got.Provider
},
expected: model.ProviderLDAP,
@@ -92,11 +96,12 @@ func TestContext(t *testing.T) {
{
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
got, _ := c.NewFromSession(&repository.Session{
run: func(t *testing.T, c *model.UserContext) any {
got, err := c.NewFromSession(&repository.Session{
Username: "dave", Provider: "github",
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
})
require.NoError(t, err)
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
},
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
@@ -107,7 +112,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"alice", "alice@example.com", "Alice"},
@@ -118,7 +123,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderBasicAuth,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"bob", "bob@example.com", "Bob"},
@@ -129,7 +134,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLDAP,
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"carol", "carol@example.com", "Carol"},
@@ -140,7 +145,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
},
expected: [3]string{"dave", "dave@example.com", "Dave"},
@@ -148,19 +153,19 @@ func TestContext(t *testing.T) {
{
description: "ProviderName returns 'local' for ProviderLocal",
context: &model.UserContext{Provider: model.ProviderLocal},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "local",
},
{
description: "ProviderName returns 'local' for ProviderBasicAuth",
context: &model.UserContext{Provider: model.ProviderBasicAuth},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "local",
},
{
description: "ProviderName returns 'ldap' for ProviderLDAP",
context: &model.UserContext{Provider: model.ProviderLDAP},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "ldap",
},
{
@@ -169,7 +174,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
},
run: func(c *model.UserContext) any { return c.ProviderName() },
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
expected: "GitHub",
},
{
@@ -178,7 +183,7 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: true},
},
run: func(c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: true,
},
{
@@ -187,13 +192,13 @@ func TestContext(t *testing.T) {
Provider: model.ProviderLocal,
Local: &model.LocalContext{TOTPPending: false},
},
run: func(c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
description: "TOTPPending returns false for non-local providers",
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
run: func(c *model.UserContext) any { return c.TOTPPending() },
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
expected: false,
},
{
@@ -202,28 +207,26 @@ func TestContext(t *testing.T) {
Provider: model.ProviderOAuth,
OAuth: &model.OAuthContext{DisplayName: "Google"},
},
run: func(c *model.UserContext) any { return c.OAuthName() },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "Google",
},
{
description: "OAuthName returns empty string for non-oauth providers",
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
run: func(c *model.UserContext) any { return c.OAuthName() },
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
expected: "",
},
{
description: "NewFromGin populates context from gin value",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
stored := &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
}
got, err := c.NewFromGin(newGinCtx(stored, true))
if err != nil {
return err.Error()
}
require.NoError(t, err)
return [2]any{got.Authenticated, got.GetUsername()}
},
expected: [2]any{true, "alice"},
@@ -231,7 +234,7 @@ func TestContext(t *testing.T) {
{
description: "NewFromGin returns error when context value is missing",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(nil, false))
return err.Error()
},
@@ -240,17 +243,26 @@ func TestContext(t *testing.T) {
{
description: "NewFromGin returns error when context value has wrong type",
context: &model.UserContext{},
run: func(c *model.UserContext) any {
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx("not a user context", true))
return err.Error()
},
expected: "invalid user context type",
},
{
description: "NewFromGin returns an error when context doesn't include user information",
context: &model.UserContext{},
run: func(t *testing.T, c *model.UserContext) any {
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
return err.Error()
},
expected: "incomplete user context",
},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
assert.Equal(t, test.expected, test.run(test.context))
assert.Equal(t, test.expected, test.run(t, test.context))
})
}
}
+7 -5
View File
@@ -120,7 +120,7 @@ func (auth *AuthService) Init() error {
}
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
if auth.GetLocalUser(username).Username != "" {
if auth.GetLocalUser(username) != nil {
return &model.UserSearch{
Username: username,
Type: model.UserLocal,
@@ -295,6 +295,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
expiry = auth.config.SessionExpiry
}
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
session := repository.CreateSessionParams{
UUID: uuid.String(),
Username: data.Username,
@@ -303,7 +305,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
Provider: data.Provider,
TotpPending: data.TotpPending,
OAuthGroups: data.OAuthGroups,
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
Expiry: expiresAt.Unix(),
CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
@@ -320,8 +322,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
MaxAge: expiry,
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
@@ -374,7 +376,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: auth.config.SessionExpiry,
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
+4 -3
View File
@@ -5,18 +5,19 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadFile(t *testing.T) {
// Setup
file, err := os.Create("/tmp/tinyauth_test_file")
assert.NoError(t, err)
require.NoError(t, err)
_, err = file.WriteString("file content\n")
assert.NoError(t, err)
require.NoError(t, err)
err = file.Close()
assert.NoError(t, err)
require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_file")
// Normal case
+4 -3
View File
@@ -5,19 +5,20 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
func TestGetSecret(t *testing.T) {
// Setup
file, err := os.Create("/tmp/tinyauth_test_secret")
assert.NoError(t, err)
require.NoError(t, err)
_, err = file.WriteString(" secret \n")
assert.NoError(t, err)
require.NoError(t, err)
err = file.Close()
assert.NoError(t, err)
require.NoError(t, err)
defer os.Remove("/tmp/tinyauth_test_secret")
// Get from config
+12 -9
View File
@@ -5,28 +5,31 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
)
func TestGetUsers(t *testing.T) {
tmpDir := t.TempDir()
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
// Setup
file, err := os.Create("/tmp/tinyauth_users_test.txt")
assert.NoError(t, err)
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
require.NoError(t, err)
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
assert.NoError(t, err)
require.NoError(t, err)
err = file.Close()
assert.NoError(t, err)
defer os.Remove("/tmp/tinyauth_users_test.txt")
require.NoError(t, err)
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
noAttrs := map[string]model.UserAttributes{}
// Test file only
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
assert.NoError(t, err)
assert.NotNil(t, users)
@@ -47,7 +50,7 @@ func TestGetUsers(t *testing.T) {
assert.Equal(t, "user4", (*users)[1].Username)
// Test both
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
assert.NoError(t, err)
@@ -65,7 +68,7 @@ func TestGetUsers(t *testing.T) {
attrs := map[string]model.UserAttributes{
"user1": {Name: "User One", Email: "user1@example.com"},
}
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
assert.NoError(t, err)
assert.Len(t, *users, 2)
@@ -87,7 +90,7 @@ func TestGetUsers(t *testing.T) {
assert.Nil(t, users)
// Test non-existent file
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
assert.ErrorContains(t, err, "no such file or directory")
assert.Nil(t, users)