mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-07-01 07:40:14 +00:00
Compare commits
3 Commits
f3965a7470
...
04b2290d73
| Author | SHA1 | Date | |
|---|---|---|---|
| 04b2290d73 | |||
| e04980468f | |||
| d47e4d3d79 |
@@ -29,7 +29,7 @@ type BootstrapApp struct {
|
||||
csrfCookieName string
|
||||
redirectCookieName string
|
||||
oauthSessionCookieName string
|
||||
localUsers []model.LocalUser
|
||||
localUsers *[]model.LocalUser
|
||||
oauthProviders map[string]model.OAuthServiceConfig
|
||||
configuredProviders []controller.Provider
|
||||
oidcClients []model.OIDCClientConfig
|
||||
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
return err
|
||||
}
|
||||
|
||||
app.context.localUsers = *users
|
||||
app.context.localUsers = users
|
||||
|
||||
// Setup OAuth providers
|
||||
app.context.oauthProviders = app.config.OAuth.Providers
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: []model.LocalUser{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
|
||||
@@ -226,17 +226,6 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := controller.auth.DeleteSession(c, uuid)
|
||||
|
||||
if err != nil {
|
||||
@@ -248,7 +237,14 @@ func (controller *UserController) logoutHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||
context, err := new(model.UserContext).NewFromGin(c)
|
||||
|
||||
if err == nil {
|
||||
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
|
||||
} else {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
|
||||
tlog.AuditLogout(c, "unknown", "unknown")
|
||||
}
|
||||
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
|
||||
@@ -308,6 +304,15 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
||||
|
||||
user := controller.auth.GetLocalUser(context.GetUsername())
|
||||
|
||||
if user == nil {
|
||||
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
|
||||
c.JSON(401, gin.H{
|
||||
"status": 401,
|
||||
"message": "Unauthorized",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ok := totp.Validate(req.Code, user.TOTPSecret)
|
||||
|
||||
if !ok {
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestUserController(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: []model.LocalUser{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
@@ -67,7 +67,7 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
@@ -83,7 +83,7 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
totpAttrCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
Authenticated: true,
|
||||
Authenticated: false,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
@@ -141,7 +141,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.Equal(t, 10, cookie.MaxAge)
|
||||
assert.Equal(t, 9, cookie.MaxAge)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -230,7 +230,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", cookie.Name)
|
||||
assert.True(t, cookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", cookie.Domain)
|
||||
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||
assert.Equal(t, 3599, cookie.MaxAge) // 1 hour, default for totp pending sessions
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -306,7 +306,7 @@ func TestUserController(t *testing.T) {
|
||||
assert.Equal(t, "tinyauth-session", totpCookie.Name)
|
||||
assert.True(t, totpCookie.HttpOnly)
|
||||
assert.Equal(t, "example.com", totpCookie.Domain)
|
||||
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
|
||||
assert.Equal(t, 9, totpCookie.MaxAge) // should use the regular session expiry time
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -70,20 +70,18 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
if err == nil {
|
||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
if err == nil {
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
} else {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
}
|
||||
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: []model.LocalUser{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
@@ -253,6 +253,18 @@ func TestContextMiddleware(t *testing.T) {
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure fallback to basic auth when cookie is missing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
|
||||
+57
-12
@@ -56,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool {
|
||||
}
|
||||
|
||||
func (c *UserContext) IsLocal() bool {
|
||||
return c.Provider == ProviderLocal
|
||||
return c.Provider == ProviderLocal && c.Local != nil
|
||||
}
|
||||
|
||||
func (c *UserContext) IsOAuth() bool {
|
||||
return c.Provider == ProviderOAuth
|
||||
return c.Provider == ProviderOAuth && c.OAuth != nil
|
||||
}
|
||||
|
||||
func (c *UserContext) IsLDAP() bool {
|
||||
return c.Provider == ProviderLDAP
|
||||
return c.Provider == ProviderLDAP && c.LDAP != nil
|
||||
}
|
||||
|
||||
func (c *UserContext) IsBasicAuth() bool {
|
||||
return c.Provider == ProviderBasicAuth
|
||||
return c.Provider == ProviderBasicAuth && c.Local != nil
|
||||
}
|
||||
|
||||
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
@@ -80,16 +80,24 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||
|
||||
userContext, ok := userContextValue.(*UserContext)
|
||||
|
||||
if !ok {
|
||||
if !ok || userContext == nil {
|
||||
return nil, errors.New("invalid user context type")
|
||||
}
|
||||
|
||||
if userContext.LDAP == nil && userContext.Local == nil && userContext.OAuth == nil {
|
||||
return nil, errors.New("incomplete user context")
|
||||
}
|
||||
|
||||
*c = *userContext
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Compatability layer until we get an excuse to drop in database migrations
|
||||
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||
*c = UserContext{
|
||||
Authenticated: !session.TotpPending,
|
||||
}
|
||||
|
||||
switch session.Provider {
|
||||
case "local":
|
||||
c.Provider = ProviderLocal
|
||||
@@ -119,29 +127,42 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
||||
Name: session.Name,
|
||||
Email: session.Email,
|
||||
},
|
||||
Groups: strings.Split(session.OAuthGroups, ","),
|
||||
Groups: func() []string {
|
||||
if session.OAuthGroups == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(session.OAuthGroups, ",")
|
||||
}(),
|
||||
Sub: session.OAuthSub,
|
||||
DisplayName: session.OAuthName,
|
||||
ID: session.Provider,
|
||||
}
|
||||
}
|
||||
|
||||
if !session.TotpPending {
|
||||
c.Authenticated = true
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *UserContext) GetUsername() string {
|
||||
switch c.Provider {
|
||||
case ProviderLocal:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Username
|
||||
case ProviderLDAP:
|
||||
if c.LDAP == nil {
|
||||
return ""
|
||||
}
|
||||
return c.LDAP.Username
|
||||
case ProviderBasicAuth:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Username
|
||||
case ProviderOAuth:
|
||||
if c.OAuth == nil {
|
||||
return ""
|
||||
}
|
||||
return c.OAuth.Username
|
||||
default:
|
||||
return ""
|
||||
@@ -151,12 +172,24 @@ func (c *UserContext) GetUsername() string {
|
||||
func (c *UserContext) GetEmail() string {
|
||||
switch c.Provider {
|
||||
case ProviderLocal:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Email
|
||||
case ProviderLDAP:
|
||||
if c.LDAP == nil {
|
||||
return ""
|
||||
}
|
||||
return c.LDAP.Email
|
||||
case ProviderBasicAuth:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Email
|
||||
case ProviderOAuth:
|
||||
if c.OAuth == nil {
|
||||
return ""
|
||||
}
|
||||
return c.OAuth.Email
|
||||
default:
|
||||
return ""
|
||||
@@ -166,12 +199,24 @@ func (c *UserContext) GetEmail() string {
|
||||
func (c *UserContext) GetName() string {
|
||||
switch c.Provider {
|
||||
case ProviderLocal:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Name
|
||||
case ProviderLDAP:
|
||||
if c.LDAP == nil {
|
||||
return ""
|
||||
}
|
||||
return c.LDAP.Name
|
||||
case ProviderBasicAuth:
|
||||
if c.Local == nil {
|
||||
return ""
|
||||
}
|
||||
return c.Local.Name
|
||||
case ProviderOAuth:
|
||||
if c.OAuth == nil {
|
||||
return ""
|
||||
}
|
||||
return c.OAuth.Name
|
||||
default:
|
||||
return ""
|
||||
@@ -192,14 +237,14 @@ func (c *UserContext) ProviderName() string {
|
||||
}
|
||||
|
||||
func (c *UserContext) TOTPPending() bool {
|
||||
if c.Provider == ProviderLocal {
|
||||
if c.Provider == ProviderLocal && c.Local != nil {
|
||||
return c.Local.TOTPPending
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *UserContext) OAuthName() string {
|
||||
if c.Provider == ProviderOAuth {
|
||||
if c.Provider == ProviderOAuth && c.OAuth != nil {
|
||||
return c.OAuth.DisplayName
|
||||
}
|
||||
return ""
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
@@ -22,47 +23,48 @@ func TestContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
context *model.UserContext
|
||||
run func(*model.UserContext) any
|
||||
run func(*testing.T, *model.UserContext) any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
description: "IsAuthenticated reflects Authenticated field",
|
||||
context: &model.UserContext{Authenticated: true},
|
||||
run: func(c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsAuthenticated() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLocal returns true for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.IsLocal() },
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsOAuth returns true for ProviderOAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsOAuth() },
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsLDAP returns true for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(c *model.UserContext) any { return c.IsLDAP() },
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
description: "NewFromSession local session is authenticated and ProviderLocal",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "alice", Email: "alice@example.com", Name: "Alice",
|
||||
Provider: "local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Provider, got.Authenticated}
|
||||
},
|
||||
expected: [2]any{model.ProviderLocal, true},
|
||||
@@ -70,10 +72,11 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromSession local session with TotpPending is not authenticated",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "bob", Provider: "local", TotpPending: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Authenticated
|
||||
},
|
||||
expected: false,
|
||||
@@ -81,10 +84,11 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromSession ldap session is ProviderLDAP",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "carol", Provider: "ldap",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return got.Provider
|
||||
},
|
||||
expected: model.ProviderLDAP,
|
||||
@@ -92,11 +96,12 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromSession unknown provider defaults to OAuth and populates oauth fields",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
got, _ := c.NewFromSession(&repository.Session{
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
got, err := c.NewFromSession(&repository.Session{
|
||||
Username: "dave", Provider: "github",
|
||||
OAuthGroups: "devs,admins", OAuthSub: "sub-123", OAuthName: "GitHub",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return [5]any{got.Provider, got.OAuth.ID, got.OAuth.Sub, got.OAuth.DisplayName, got.OAuth.Groups}
|
||||
},
|
||||
expected: [5]any{model.ProviderOAuth, "github", "sub-123", "GitHub", []string{"devs", "admins"}},
|
||||
@@ -107,7 +112,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice", Email: "alice@example.com", Name: "Alice"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"alice", "alice@example.com", "Alice"},
|
||||
@@ -118,7 +123,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderBasicAuth,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "bob", Email: "bob@example.com", Name: "Bob"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"bob", "bob@example.com", "Bob"},
|
||||
@@ -129,7 +134,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLDAP,
|
||||
LDAP: &model.LDAPContext{BaseContext: model.BaseContext{Username: "carol", Email: "carol@example.com", Name: "Carol"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"carol", "carol@example.com", "Carol"},
|
||||
@@ -140,7 +145,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{BaseContext: model.BaseContext{Username: "dave", Email: "dave@example.com", Name: "Dave"}},
|
||||
},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"dave", "dave@example.com", "Dave"},
|
||||
@@ -148,19 +153,19 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderLocal",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'local' for ProviderBasicAuth",
|
||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "local",
|
||||
},
|
||||
{
|
||||
description: "ProviderName returns 'ldap' for ProviderLDAP",
|
||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "ldap",
|
||||
},
|
||||
{
|
||||
@@ -169,7 +174,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{DisplayName: "GitHub"},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.ProviderName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.ProviderName() },
|
||||
expected: "GitHub",
|
||||
},
|
||||
{
|
||||
@@ -178,7 +183,7 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: true},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
@@ -187,13 +192,13 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{TOTPPending: false},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
description: "TOTPPending returns false for non-local providers",
|
||||
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||
run: func(c *model.UserContext) any { return c.TOTPPending() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.TOTPPending() },
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
@@ -202,28 +207,26 @@ func TestContext(t *testing.T) {
|
||||
Provider: model.ProviderOAuth,
|
||||
OAuth: &model.OAuthContext{DisplayName: "Google"},
|
||||
},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "Google",
|
||||
},
|
||||
{
|
||||
description: "OAuthName returns empty string for non-oauth providers",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||
run: func(c *model.UserContext) any { return c.OAuthName() },
|
||||
run: func(t *testing.T, c *model.UserContext) any { return c.OAuthName() },
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin populates context from gin value",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
stored := &model.UserContext{
|
||||
Authenticated: true,
|
||||
Provider: model.ProviderLocal,
|
||||
Local: &model.LocalContext{BaseContext: model.BaseContext{Username: "alice"}},
|
||||
}
|
||||
got, err := c.NewFromGin(newGinCtx(stored, true))
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
return [2]any{got.Authenticated, got.GetUsername()}
|
||||
},
|
||||
expected: [2]any{true, "alice"},
|
||||
@@ -231,7 +234,7 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromGin returns error when context value is missing",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(nil, false))
|
||||
return err.Error()
|
||||
},
|
||||
@@ -240,17 +243,34 @@ func TestContext(t *testing.T) {
|
||||
{
|
||||
description: "NewFromGin returns error when context value has wrong type",
|
||||
context: &model.UserContext{},
|
||||
run: func(c *model.UserContext) any {
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx("not a user context", true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "invalid user context type",
|
||||
},
|
||||
{
|
||||
description: "NewFromGin returns an error when context doesn't include user information",
|
||||
context: &model.UserContext{},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
_, err := c.NewFromGin(newGinCtx(&model.UserContext{Provider: model.ProviderLocal}, true))
|
||||
return err.Error()
|
||||
},
|
||||
expected: "incomplete user context",
|
||||
},
|
||||
{
|
||||
description: "Getters should not panic if provider context is empty",
|
||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||
run: func(t *testing.T, c *model.UserContext) any {
|
||||
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||
},
|
||||
expected: [3]string{"", "", ""},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
assert.Equal(t, test.expected, test.run(test.context))
|
||||
assert.Equal(t, test.expected, test.run(t, test.context))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,18 +28,21 @@ func (acls *AccessControlsService) Init() error {
|
||||
}
|
||||
|
||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||
var appAcls *model.App
|
||||
for app, config := range acls.static {
|
||||
if config.Config.Domain == domain {
|
||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||
return &config
|
||||
appAcls = &config
|
||||
break // If we find a match by domain, we can stop searching
|
||||
}
|
||||
|
||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||
return &config
|
||||
appAcls = &config
|
||||
break // If we find a match by app name, we can stop searching
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return appAcls
|
||||
}
|
||||
|
||||
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||
|
||||
@@ -73,7 +73,7 @@ type Lockdown struct {
|
||||
}
|
||||
|
||||
type AuthServiceConfig struct {
|
||||
LocalUsers []model.LocalUser
|
||||
LocalUsers *[]model.LocalUser
|
||||
OauthWhitelist []string
|
||||
SessionExpiry int
|
||||
SessionMaxLifetime int
|
||||
@@ -120,7 +120,7 @@ func (auth *AuthService) Init() error {
|
||||
}
|
||||
|
||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||
if auth.GetLocalUser(username).Username != "" {
|
||||
if auth.GetLocalUser(username) != nil {
|
||||
return &model.UserSearch{
|
||||
Username: username,
|
||||
Type: model.UserLocal,
|
||||
@@ -147,6 +147,9 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
||||
switch search.Type {
|
||||
case model.UserLocal:
|
||||
user := auth.GetLocalUser(search.Username)
|
||||
if user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||
case model.UserLDAP:
|
||||
if auth.ldap.IsConfigured() {
|
||||
@@ -169,7 +172,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||
for _, user := range auth.config.LocalUsers {
|
||||
if auth.config.LocalUsers == nil {
|
||||
return nil
|
||||
}
|
||||
for _, user := range *auth.config.LocalUsers {
|
||||
if user.Username == username {
|
||||
return &user
|
||||
}
|
||||
@@ -295,6 +301,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
expiry = auth.config.SessionExpiry
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||
|
||||
session := repository.CreateSessionParams{
|
||||
UUID: uuid.String(),
|
||||
Username: data.Username,
|
||||
@@ -303,7 +311,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
Provider: data.Provider,
|
||||
TotpPending: data.TotpPending,
|
||||
OAuthGroups: data.OAuthGroups,
|
||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
||||
Expiry: expiresAt.Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
OAuthName: data.OAuthName,
|
||||
OAuthSub: data.OAuthSub,
|
||||
@@ -320,8 +328,8 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
|
||||
MaxAge: expiry,
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
@@ -374,7 +382,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: auth.config.SessionExpiry,
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
@@ -436,7 +444,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
}
|
||||
|
||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||
return len(auth.config.LocalUsers) > 0
|
||||
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||
}
|
||||
|
||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
|
||||
@@ -95,7 +95,8 @@ func (k *KubernetesService) getByDomain(domain string) *model.App {
|
||||
|
||||
if appKey, ok := k.domainIndex[domain]; ok {
|
||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||
for _, app := range apps {
|
||||
for i := range apps {
|
||||
app := &apps[i]
|
||||
if app.domain == domain && app.appName == appKey.appName {
|
||||
return &app.app
|
||||
}
|
||||
@@ -111,7 +112,8 @@ func (k *KubernetesService) getByAppName(appName string) *model.App {
|
||||
|
||||
if appKey, ok := k.appNameIndex[appName]; ok {
|
||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||
for _, app := range apps {
|
||||
for i := range apps {
|
||||
app := &apps[i]
|
||||
if app.appName == appName {
|
||||
return &app.app
|
||||
}
|
||||
|
||||
@@ -12,8 +12,9 @@ import (
|
||||
)
|
||||
|
||||
type GithubEmailResponse []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
type GithubUserInfoResponse struct {
|
||||
@@ -56,7 +57,16 @@ func githubExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||
|
||||
// Use first available email if no primary email was found
|
||||
if user.Email == "" {
|
||||
user.Email = (*userEmails)[0].Email
|
||||
for _, email := range *userEmails {
|
||||
if email.Verified {
|
||||
user.Email = email.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if user.Email == "" {
|
||||
return nil, errors.New("no verified email found")
|
||||
}
|
||||
|
||||
user.PreferredUsername = userInfo.Login
|
||||
|
||||
@@ -5,18 +5,19 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadFile(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.WriteString("file content\n")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_file")
|
||||
|
||||
// Normal case
|
||||
|
||||
@@ -5,19 +5,20 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetSecret(t *testing.T) {
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_test_secret")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.WriteString(" secret \n")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_secret")
|
||||
|
||||
// Get from config
|
||||
|
||||
@@ -5,28 +5,31 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
)
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
hash := "$2a$10$Mz5xhkfSJUtPWkzCd/TdaePh9CaXc5QcGII5wIMPLSR46eTwma30G"
|
||||
|
||||
// Setup
|
||||
file, err := os.Create("/tmp/tinyauth_users_test.txt")
|
||||
assert.NoError(t, err)
|
||||
file, err := os.Create(tmpDir + "/tinyauth_users_test.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = file.WriteString(" user1:" + hash + " \n user2:" + hash + " ") // Spacing is on purpose
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_users_test.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpDir + "/tinyauth_users_test.txt")
|
||||
|
||||
noAttrs := map[string]model.UserAttributes{}
|
||||
|
||||
// Test file only
|
||||
users, err := utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
||||
users, err := utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, users)
|
||||
@@ -47,7 +50,7 @@ func TestGetUsers(t *testing.T) {
|
||||
assert.Equal(t, "user4", (*users)[1].Username)
|
||||
|
||||
// Test both
|
||||
users, err = utils.GetUsers([]string{"user5:" + hash}, "/tmp/tinyauth_users_test.txt", noAttrs)
|
||||
users, err = utils.GetUsers([]string{"user5:" + hash}, tmpDir+"/tinyauth_users_test.txt", noAttrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -65,7 +68,7 @@ func TestGetUsers(t *testing.T) {
|
||||
attrs := map[string]model.UserAttributes{
|
||||
"user1": {Name: "User One", Email: "user1@example.com"},
|
||||
}
|
||||
users, err = utils.GetUsers([]string{}, "/tmp/tinyauth_users_test.txt", attrs)
|
||||
users, err = utils.GetUsers([]string{}, tmpDir+"/tinyauth_users_test.txt", attrs)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, *users, 2)
|
||||
@@ -87,7 +90,7 @@ func TestGetUsers(t *testing.T) {
|
||||
assert.Nil(t, users)
|
||||
|
||||
// Test non-existent file
|
||||
users, err = utils.GetUsers([]string{}, "/tmp/non_existent_file.txt", noAttrs)
|
||||
users, err = utils.GetUsers([]string{}, tmpDir+"/non_existent_file.txt", noAttrs)
|
||||
|
||||
assert.ErrorContains(t, err, "no such file or directory")
|
||||
assert.Nil(t, users)
|
||||
|
||||
Reference in New Issue
Block a user