mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-08 13:28:12 +00:00
fix: review comments batch 1
This commit is contained in:
@@ -226,17 +226,6 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
context, err := new(model.UserContext).NewFromGin(c)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -248,7 +237,14 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
context, err := new(model.UserContext).NewFromGin(c)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||||
|
} else {
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||||
|
tlog.AuditLogout(c, "unknown", "unknown")
|
||||||
|
}
|
||||||
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
|
||||||
@@ -308,6 +304,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"status": 401,
|
||||||
|
"message": "Unauthorized",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
totpCtx := func(c *gin.Context) {
|
totpCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: false,
|
||||||
Provider: model.ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &model.LocalContext{
|
Local: &model.LocalContext{
|
||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
@@ -83,7 +83,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
totpAttrCtx := func(c *gin.Context) {
|
totpAttrCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: false,
|
||||||
Provider: model.ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &model.LocalContext{
|
Local: &model.LocalContext{
|
||||||
BaseContext: model.BaseContext{
|
BaseContext: model.BaseContext{
|
||||||
@@ -141,7 +141,7 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.True(t, cookie.HttpOnly)
|
assert.True(t, cookie.HttpOnly)
|
||||||
assert.Equal(t, "example.com", cookie.Domain)
|
assert.Equal(t, "example.com", cookie.Domain)
|
||||||
assert.Equal(t, 10, cookie.MaxAge)
|
assert.Equal(t, 9, cookie.MaxAge)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -230,7 +230,7 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||||
assert.True(t, cookie.HttpOnly)
|
assert.True(t, cookie.HttpOnly)
|
||||||
assert.Equal(t, "example.com", cookie.Domain)
|
assert.Equal(t, "example.com", cookie.Domain)
|
||||||
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -306,7 +306,7 @@ func TestUserController(t *testing.T) {
|
|||||||
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
||||||
assert.True(t, totpCookie.HttpOnly)
|
assert.True(t, totpCookie.HttpOnly)
|
||||||
assert.Equal(t, "example.com", totpCookie.Domain)
|
assert.Equal(t, "example.com", totpCookie.Domain)
|
||||||
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
|
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -70,20 +70,18 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err == nil {
|
||||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
if cookie != nil {
|
||||||
|
http.SetCookie(c.Writer, cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||||
|
c.Set("context", userContext)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
|
} else {
|
||||||
|
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cookie != nil {
|
|
||||||
http.SetCookie(c.Writer, cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
|
||||||
c.Set("context", userContext)
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
username, password, ok := c.Request.BasicAuth()
|
username, password, ok := c.Request.BasicAuth()
|
||||||
|
|||||||
@@ -253,6 +253,18 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||||
userCtx, _ := args.do(req)
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
|
require.NotNil(t, userCtx)
|
||||||
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||||
|
assert.True(t, userCtx.Authenticated)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "Ensure fallback to basic auth when cookie is missing",
|
||||||
|
run: func(t *testing.T, args runArgs) {
|
||||||
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
|
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||||
|
userCtx, _ := args.do(req)
|
||||||
|
|
||||||
require.NotNil(t, userCtx)
|
require.NotNil(t, userCtx)
|
||||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||||
assert.True(t, userCtx.Authenticated)
|
assert.True(t, userCtx.Authenticated)
|
||||||
|
|||||||
@@ -80,16 +80,24 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
|||||||
|
|
||||||
userContext, ok := userContextValue.(*UserContext)
|
userContext, ok := userContextValue.(*UserContext)
|
||||||
|
|
||||||
if !ok {
|
if !ok || userContext == nil {
|
||||||
return nil, errors.New("invalid user context type")
|
return nil, errors.New("invalid user context type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
|
||||||
|
return nil, errors.New("incomplete user context")
|
||||||
|
}
|
||||||
|
|
||||||
*c = *userContext
|
*c = *userContext
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compatability layer until we get an excuse to drop in database migrations
|
// Compatability layer until we get an excuse to drop in database migrations
|
||||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
|
*c = UserContext{
|
||||||
|
Authenticated: !session.TotpPending,
|
||||||
|
}
|
||||||
|
|
||||||
switch session.Provider {
|
switch session.Provider {
|
||||||
case "local":
|
case "local":
|
||||||
c.Provider = ProviderLocal
|
c.Provider = ProviderLocal
|
||||||
@@ -119,17 +127,18 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
|||||||
Name: session.Name,
|
Name: session.Name,
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
},
|
},
|
||||||
Groups: strings.Split(session.OAuthGroups, ","),
|
Groups: func() []string {
|
||||||
|
if session.OAuthGroups == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return strings.Split(session.OAuthGroups, ",")
|
||||||
|
}(),
|
||||||
Sub: session.OAuthSub,
|
Sub: session.OAuthSub,
|
||||||
DisplayName: session.OAuthName,
|
DisplayName: session.OAuthName,
|
||||||
ID: session.Provider,
|
ID: session.Provider,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !session.TotpPending {
|
|
||||||
c.Authenticated = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
)
|
)
|
||||||
@@ -22,47 +23,48 @@ func TestContext(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
context *model.UserContext
|
context *model.UserContext
|
||||||
run func(*model.UserContext) any
|
run func(*testing.T, *model.UserContext) any
|
||||||
expected any
|
expected any
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
description: "IsAuthenticated reflects Authenticated field",
|
description: "IsAuthenticated reflects Authenticated field",
|
||||||
context: &model.UserContext{Authenticated: true},
|
context: &model.UserContext{Authenticated: true},
|
||||||
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
run: func(c *model.UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
context: &model.UserContext{Provider: model.ProviderOAuth},
|
||||||
run: func(c *model.UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||||
run: func(c *model.UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||||
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, _ := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||||
Provider: "local",
|
Provider: "local",
|
||||||
})
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
return [2]any{got.Provider, got.Authenticated}
|
return [2]any{got.Provider, got.Authenticated}
|
||||||
},
|
},
|
||||||
expected: [2]any{model.ProviderLocal, true},
|
expected: [2]any{model.ProviderLocal, true},
|
||||||
@@ -70,10 +72,11 @@ func TestContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, _ := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "bob", Provider: "local", TotpPending: true,
|
Username: "bob", Provider: "local", TotpPending: true,
|
||||||
})
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
return got.Authenticated
|
return got.Authenticated
|
||||||
},
|
},
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -81,10 +84,11 @@ func TestContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "NewFromSession ldap session is ProviderLDAP",
|
description: "NewFromSession ldap session is ProviderLDAP",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, _ := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "carol", Provider: "ldap",
|
Username: "carol", Provider: "ldap",
|
||||||
})
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
return got.Provider
|
return got.Provider
|
||||||
},
|
},
|
||||||
expected: model.ProviderLDAP,
|
expected: model.ProviderLDAP,
|
||||||
@@ -92,11 +96,12 @@ func TestContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
got, _ := c.NewFromSession(&repository.Session{
|
got, err := c.NewFromSession(&repository.Session{
|
||||||
Username: "dave", Provider: "github",
|
Username: "dave", Provider: "github",
|
||||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||||
})
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||||
},
|
},
|
||||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||||
@@ -107,7 +112,7 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||||
@@ -118,7 +123,7 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderBasicAuth,
|
Provider: model.ProviderBasicAuth,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||||
@@ -129,7 +134,7 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderLDAP,
|
Provider: model.ProviderLDAP,
|
||||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||||
@@ -140,7 +145,7 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderOAuth,
|
Provider: model.ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
},
|
},
|
||||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||||
@@ -148,19 +153,19 @@ func TestContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderLocal",
|
description: "ProviderName returns 'local' for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
expected: "local",
|
expected: "local",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
expected: "ldap",
|
expected: "ldap",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -169,7 +174,7 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderOAuth,
|
Provider: model.ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
|
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||||
expected: "GitHub",
|
expected: "GitHub",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -178,7 +183,7 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: true},
|
Local: &model.LocalContext{TOTPPending: true},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -187,13 +192,13 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &model.LocalContext{TOTPPending: false},
|
Local: &model.LocalContext{TOTPPending: false},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "TOTPPending returns false for non-local providers",
|
description: "TOTPPending returns false for non-local providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -202,28 +207,26 @@ func TestContext(t *testing.T) {
|
|||||||
Provider: model.ProviderOAuth,
|
Provider: model.ProviderOAuth,
|
||||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||||
},
|
},
|
||||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||||
expected: "Google",
|
expected: "Google",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "OAuthName returns empty string for non-oauth providers",
|
description: "OAuthName returns empty string for non-oauth providers",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "NewFromGin populates context from gin value",
|
description: "NewFromGin populates context from gin value",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
stored := &model.UserContext{
|
stored := &model.UserContext{
|
||||||
Authenticated: true,
|
Authenticated: true,
|
||||||
Provider: model.ProviderLocal,
|
Provider: model.ProviderLocal,
|
||||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||||
}
|
}
|
||||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
return err.Error()
|
|
||||||
}
|
|
||||||
return [2]any{got.Authenticated, got.GetUsername()}
|
return [2]any{got.Authenticated, got.GetUsername()}
|
||||||
},
|
},
|
||||||
expected: [2]any{true, "alice"},
|
expected: [2]any{true, "alice"},
|
||||||
@@ -231,7 +234,7 @@ func TestContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value is missing",
|
description: "NewFromGin returns error when context value is missing",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
@@ -240,17 +243,26 @@ func TestContext(t *testing.T) {
|
|||||||
{
|
{
|
||||||
description: "NewFromGin returns error when context value has wrong type",
|
description: "NewFromGin returns error when context value has wrong type",
|
||||||
context: &model.UserContext{},
|
context: &model.UserContext{},
|
||||||
run: func(c *model.UserContext) any {
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||||
return err.Error()
|
return err.Error()
|
||||||
},
|
},
|
||||||
expected: "invalid user context type",
|
expected: "invalid user context type",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "NewFromGin returns an error when context doesn't include user information",
|
||||||
|
context: &model.UserContext{},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||||
|
return err.Error()
|
||||||
|
},
|
||||||
|
expected: "incomplete user context",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
assert.Equal(t, test.expected, test.run(test.context))
|
assert.Equal(t, test.expected, test.run(t, test.context))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ func (auth *AuthService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||||
if auth.GetLocalUser(username).Username != "" {
|
if auth.GetLocalUser(username) != nil {
|
||||||
return &model.UserSearch{
|
return &model.UserSearch{
|
||||||
Username: username,
|
Username: username,
|
||||||
Type: model.UserLocal,
|
Type: model.UserLocal,
|
||||||
@@ -295,6 +295,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
expiry = auth.config.SessionExpiry
|
expiry = auth.config.SessionExpiry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||||
|
|
||||||
session := repository.CreateSessionParams{
|
session := repository.CreateSessionParams{
|
||||||
UUID: uuid.String(),
|
UUID: uuid.String(),
|
||||||
Username: data.Username,
|
Username: data.Username,
|
||||||
@@ -303,7 +305,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
Provider: data.Provider,
|
Provider: data.Provider,
|
||||||
TotpPending: data.TotpPending,
|
TotpPending: data.TotpPending,
|
||||||
OAuthGroups: data.OAuthGroups,
|
OAuthGroups: data.OAuthGroups,
|
||||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
Expiry: expiresAt.Unix(),
|
||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
OAuthName: data.OAuthName,
|
OAuthName: data.OAuthName,
|
||||||
OAuthSub: data.OAuthSub,
|
OAuthSub: data.OAuthSub,
|
||||||
@@ -320,8 +322,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
|||||||
Value: session.UUID,
|
Value: session.UUID,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
|
Expires: expiresAt,
|
||||||
MaxAge: expiry,
|
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||||
Secure: auth.config.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
@@ -374,7 +376,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
MaxAge: auth.config.SessionExpiry,
|
MaxAge: int(newExpiry - currentTime),
|
||||||
Secure: auth.config.SecureCookie,
|
Secure: auth.config.SecureCookie,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
|||||||
@@ -5,18 +5,19 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestReadFile(t *testing.T) {
|
func TestReadFile(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = file.WriteString("file content\n")
|
_, err = file.WriteString("file content\n")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove("/tmp/tinyauth_test_file")
|
defer os.Remove("/tmp/tinyauth_test_file")
|
||||||
|
|
||||||
// Normal case
|
// Normal case
|
||||||
|
|||||||
@@ -5,19 +5,20 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetSecret(t *testing.T) {
|
func TestGetSecret(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
file, err := os.Create("/tmp/tinyauth_test_secret")
|
file, err := os.Create("/tmp/tinyauth_test_secret")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = file.WriteString(" secret \n")
|
_, err = file.WriteString(" secret \n")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||||
|
|
||||||
// Get from config
|
// Get from config
|
||||||
|
|||||||
@@ -5,28 +5,31 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetUsers(t *testing.T) {
|
func TestGetUsers(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
||||||
|
|
||||||
// Setup
|
// Setup
|
||||||
file, err := os.Create("/tmp/tinyauth_users_test.txt")
|
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = file.Close()
|
err = file.Close()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer os.Remove("/tmp/tinyauth_users_test.txt")
|
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
|
||||||
|
|
||||||
noAttrs := map[string]model.UserAttributes{}
|
noAttrs := map[string]model.UserAttributes{}
|
||||||
|
|
||||||
// Test file only
|
// Test file only
|
||||||
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, users)
|
assert.NotNil(t, users)
|
||||||
@@ -47,7 +50,7 @@ func TestGetUsers(t *testing.T) {
|
|||||||
assert.Equal(t, "user4", (*users)[1].Username)
|
assert.Equal(t, "user4", (*users)[1].Username)
|
||||||
|
|
||||||
// Test both
|
// Test both
|
||||||
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
@@ -65,7 +68,7 @@ func TestGetUsers(t *testing.T) {
|
|||||||
attrs := map[string]model.UserAttributes{
|
attrs := map[string]model.UserAttributes{
|
||||||
"user1": {Name: "User One", Email: "user1@example.com"},
|
"user1": {Name: "User One", Email: "user1@example.com"},
|
||||||
}
|
}
|
||||||
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
|
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, *users, 2)
|
assert.Len(t, *users, 2)
|
||||||
@@ -87,7 +90,7 @@ func TestGetUsers(t *testing.T) {
|
|||||||
assert.Nil(t, users)
|
assert.Nil(t, users)
|
||||||
|
|
||||||
// Test non-existent file
|
// Test non-existent file
|
||||||
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
|
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
|
||||||
|
|
||||||
assert.ErrorContains(t, err, "no such file or directory")
|
assert.ErrorContains(t, err, "no such file or directory")
|
||||||
assert.Nil(t, users)
|
assert.Nil(t, users)
|
||||||
|
|||||||
Reference in New Issue
Block a user