mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-05 20:08:12 +00:00
fix: review comments batch 1
This commit is contained in:
@@ -226,17 +226,6 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||
|
||||
if err != nil {
|
||||
@@ -248,7 +237,14 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err == nil {
|
||||
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||
} else {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||
tlog.AuditLogout(c, "unknown", "unknown")
|
||||
}
|
||||
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
|
||||
@@ -308,6 +304,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
|
||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||
|
||||
if user == nil {
|
||||
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||
|
||||
if !ok {
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
@@ -83,7 +83,7 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
totpAttrCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
@@ -141,7 +141,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.Equal(t, 10, cookie.MaxAge)
|
||||
assert.Equal(t, 9, cookie.MaxAge)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -230,7 +230,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -306,7 +306,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
||||
assert.True(t, totpCookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", totpCookie.Domain)
|
||||
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
|
||||
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -70,20 +70,18 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
if err == nil {
|
||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
if err == nil {
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
} else {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
}
|
||||
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
|
||||
@@ -253,6 +253,18 @@ func TestContextMiddleware(t *testing.T) {
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure fallback to basic auth when cookie is missing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
|
||||
@@ -80,16 +80,24 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
|
||||
userContext, ok := userContextValue.(*UserContext)
|
||||
|
||||
if !ok {
|
||||
if !ok || userContext == nil {
|
||||
return nil, errors.New("invalid user context type")
|
||||
}
|
||||
|
||||
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
|
||||
return nil, errors.New("incomplete user context")
|
||||
}
|
||||
|
||||
*c = *userContext
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Compatability layer until we get an excuse to drop in database migrations
|
||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||
*c = UserContext{
|
||||
Authenticated: !session.TotpPending,
|
||||
}
|
||||
|
||||
switch session.Provider {
|
||||
case "local":
|
||||
c.Provider = ProviderLocal
|
||||
@@ -119,17 +127,18 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
||||
Name: session.Name,
|
||||
Email: session.Email,
|
||||
},
|
||||
Groups: strings.Split(session.OAuthGroups, ","),
|
||||
Groups: func() []string {
|
||||
if session.OAuthGroups == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(session.OAuthGroups, ",")
|
||||
}(),
|
||||
Sub: session.OAuthSub,
|
||||
DisplayName: session.OAuthName,
|
||||
ID: session.Provider,
|
||||
}
|
||||
}
|
||||
|
||||
if !session.TotpPending {
|
||||
c.Authenticated = true
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
@@ -22,47 +23,48 @@ func TestContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
context *model.UserContext
|
||||
run func(*model.UserContext) any
|
||||
run func(*testing.T, *model.UserContext) any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
description: "IsAuthenticated reflects Authenticated field",
|
||||
context: &model.UserContext{Authenticated: true},
|
||||
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLocal returns true for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.IsLocal() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsOAuth returns true for ProviderOAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsOAuth() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLDAP returns true for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(c *model.UserContext) any { return c.IsLDAP() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||
Provider: "local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Provider, got.Authenticated}
|
||||
},
|
||||
expected: [2]any{model.ProviderLocal, true},
|
||||
@@ -70,10 +72,11 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "bob", Provider: "local", TotpPending: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Authenticated
|
||||
},
|
||||
expected: false,
|
||||
@@ -81,10 +84,11 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromSession ldap session is ProviderLDAP",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "carol", Provider: "ldap",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Provider
|
||||
},
|
||||
expected: model.ProviderLDAP,
|
||||
@@ -92,11 +96,12 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "dave", Provider: "github",
|
||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||
},
|
||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||
@@ -107,7 +112,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||
@@ -118,7 +123,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||
@@ -129,7 +134,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||
@@ -140,7 +145,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||
@@ -148,19 +153,19 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "ldap",
|
||||
},
|
||||
{
|
||||
@@ -169,7 +174,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "GitHub",
|
||||
},
|
||||
{
|
||||
@@ -178,7 +183,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: true},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
@@ -187,13 +192,13 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: false},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for non-local providers",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@@ -202,28 +207,26 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "Google",
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns empty string for non-oauth providers",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin populates context from gin value",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
stored := &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||
}
|
||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Authenticated, got.GetUsername()}
|
||||
},
|
||||
expected: [2]any{true, "alice"},
|
||||
@@ -231,7 +234,7 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromGin returns error when context value is missing",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||
return err.Error()
|
||||
},
|
||||
@@ -240,17 +243,26 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromGin returns error when context value has wrong type",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "invalid user context type",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns an error when context doesn't include user information",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "incomplete user context",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
assert.Equal(t, test.expected, test.run(test.context))
|
||||
assert.Equal(t, test.expected, test.run(t, test.context))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ func (auth *AuthService) Init() error {
|
||||
}
|
||||
|
||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||
if auth.GetLocalUser(username).Username != "" {
|
||||
if auth.GetLocalUser(username) != nil {
|
||||
return &model.UserSearch{
|
||||
Username: username,
|
||||
Type: model.UserLocal,
|
||||
@@ -295,6 +295,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
expiry = auth.config.SessionExpiry
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||
|
||||
session := repository.CreateSessionParams{
|
||||
UUID: uuid.String(),
|
||||
Username: data.Username,
|
||||
@@ -303,7 +305,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
Provider: data.Provider,
|
||||
TotpPending: data.TotpPending,
|
||||
OAuthGroups: data.OAuthGroups,
|
||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
||||
Expiry: expiresAt.Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
OAuthName: data.OAuthName,
|
||||
OAuthSub: data.OAuthSub,
|
||||
@@ -320,8 +322,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
|
||||
MaxAge: expiry,
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
@@ -374,7 +376,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: auth.config.SessionExpiry,
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -5,18 +5,19 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadFile(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.WriteString("file content\n")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_file")
|
||||
|
||||
// Normal case
|
||||
|
||||
@@ -5,19 +5,20 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetSecret(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_secret")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.WriteString(" secret \n")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||
|
||||
// Get from config
|
||||
|
||||
@@ -5,28 +5,31 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
||||
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_users_test.txt")
|
||||
assert.NoError(t, err)
|
||||
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_users_test.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
|
||||
|
||||
noAttrs := map[string]model.UserAttributes{}
|
||||
|
||||
// Test file only
|
||||
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
||||
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, users)
|
||||
@@ -47,7 +50,7 @@ func TestGetUsers(t *testing.T) {
|
||||
assert.Equal(t, "user4", (*users)[1].Username)
|
||||
|
||||
// Test both
|
||||
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
||||
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -65,7 +68,7 @@ func TestGetUsers(t *testing.T) {
|
||||
attrs := map[string]model.UserAttributes{
|
||||
"user1": {Name: "User One", Email: "user1@example.com"},
|
||||
}
|
||||
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
|
||||
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, *users, 2)
|
||||
@@ -87,7 +90,7 @@ func TestGetUsers(t *testing.T) {
|
||||
assert.Nil(t, users)
|
||||
|
||||
// Test non-existent file
|
||||
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
|
||||
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
|
||||
|
||||
assert.ErrorContains(t, err, "no such file or directory")
|
||||
assert.Nil(t, users)
|
||||
|
||||
Reference in New Issue
Block a user