mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-08 13:28:12 +00:00
Merge branch 'main' into feat/oauth-whitelist-file
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
@@ -33,7 +36,8 @@ var (
|
||||
)
|
||||
|
||||
type ContextMiddlewareConfig struct {
|
||||
CookieDomain string
|
||||
CookieDomain string
|
||||
SessionCookieName string
|
||||
}
|
||||
|
||||
type ContextMiddleware struct {
|
||||
@@ -61,194 +65,41 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
cookie, err := m.auth.GetSessionCookie(c)
|
||||
uuid, err := c.Cookie(m.config.SessionCookieName)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Debug().Err(err).Msg("No valid session cookie found")
|
||||
goto basic
|
||||
}
|
||||
if err == nil {
|
||||
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||
|
||||
if cookie.TotpPending {
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: cookie.Username,
|
||||
Name: cookie.Name,
|
||||
Email: cookie.Email,
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
TotpEnabled: true,
|
||||
})
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
switch cookie.Provider {
|
||||
case "local", "ldap":
|
||||
userSearch := m.auth.SearchUser(cookie.Username)
|
||||
|
||||
if userSearch.Type == "unknown" {
|
||||
tlog.App.Debug().Msg("User from session cookie not found")
|
||||
m.auth.DeleteSessionCookie(c)
|
||||
goto basic
|
||||
}
|
||||
|
||||
if userSearch.Type != cookie.Provider {
|
||||
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
|
||||
m.auth.DeleteSessionCookie(c)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
var ldapGroups []string
|
||||
var localAttributes config.UserAttributes
|
||||
|
||||
if cookie.Provider == "ldap" {
|
||||
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
|
||||
c.Next()
|
||||
return
|
||||
if err == nil {
|
||||
if cookie != nil {
|
||||
http.SetCookie(c.Writer, cookie)
|
||||
}
|
||||
|
||||
ldapGroups = ldapUser.Groups
|
||||
}
|
||||
|
||||
if cookie.Provider == "local" {
|
||||
localUser := m.auth.GetLocalUser(cookie.Username)
|
||||
localAttributes = localUser.Attributes
|
||||
}
|
||||
|
||||
m.auth.RefreshSessionCookie(c)
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: cookie.Username,
|
||||
Name: cookie.Name,
|
||||
Email: cookie.Email,
|
||||
Provider: cookie.Provider,
|
||||
IsLoggedIn: true,
|
||||
LdapGroups: strings.Join(ldapGroups, ","),
|
||||
Attributes: localAttributes,
|
||||
})
|
||||
c.Next()
|
||||
return
|
||||
default:
|
||||
_, exists := m.broker.GetService(cookie.Provider)
|
||||
|
||||
if !exists {
|
||||
tlog.App.Debug().Msg("OAuth provider from session cookie not found")
|
||||
m.auth.DeleteSessionCookie(c)
|
||||
goto basic
|
||||
}
|
||||
|
||||
if !m.auth.IsEmailWhitelisted(cookie.Email) {
|
||||
tlog.App.Debug().Msg("Email from session cookie not whitelisted")
|
||||
m.auth.DeleteSessionCookie(c)
|
||||
goto basic
|
||||
}
|
||||
|
||||
m.auth.RefreshSessionCookie(c)
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: cookie.Username,
|
||||
Name: cookie.Name,
|
||||
Email: cookie.Email,
|
||||
Provider: cookie.Provider,
|
||||
OAuthGroups: cookie.OAuthGroups,
|
||||
OAuthName: cookie.OAuthName,
|
||||
OAuthSub: cookie.OAuthSub,
|
||||
IsLoggedIn: true,
|
||||
OAuth: true,
|
||||
})
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
basic:
|
||||
basic := m.auth.GetBasicAuth(c)
|
||||
|
||||
if basic == nil {
|
||||
tlog.App.Debug().Msg("No basic auth provided")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
locked, remaining := m.auth.IsAccountLocked(basic.Username)
|
||||
|
||||
if locked {
|
||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
|
||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
userSearch := m.auth.SearchUser(basic.Username)
|
||||
|
||||
if userSearch.Type == "unknown" || userSearch.Type == "error" {
|
||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
||||
tlog.App.Debug().Msg("User from basic auth not found")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if !m.auth.VerifyUser(userSearch, basic.Password) {
|
||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
||||
tlog.App.Debug().Msg("Invalid password for basic auth user")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
m.auth.RecordLoginAttempt(basic.Username, true)
|
||||
|
||||
switch userSearch.Type {
|
||||
case "local":
|
||||
tlog.App.Debug().Msg("Basic auth user is local")
|
||||
|
||||
user := m.auth.GetLocalUser(basic.Username)
|
||||
|
||||
if user.TotpSecret != "" {
|
||||
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
|
||||
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
} else {
|
||||
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := utils.Capitalize(user.Username)
|
||||
if user.Attributes.Name != "" {
|
||||
name = user.Attributes.Name
|
||||
}
|
||||
email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
||||
if user.Attributes.Email != "" {
|
||||
email = user.Attributes.Email
|
||||
}
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: user.Username,
|
||||
Name: name,
|
||||
Email: email,
|
||||
Provider: "local",
|
||||
IsLoggedIn: true,
|
||||
IsBasicAuth: true,
|
||||
Attributes: user.Attributes,
|
||||
})
|
||||
c.Next()
|
||||
return
|
||||
case "ldap":
|
||||
tlog.App.Debug().Msg("Basic auth user is LDAP")
|
||||
|
||||
ldapUser, err := m.auth.GetLdapUser(basic.Username)
|
||||
if ok {
|
||||
userContext, headers, err := m.basicAuth(username, password)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
|
||||
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: basic.Username,
|
||||
Name: utils.Capitalize(basic.Username),
|
||||
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
|
||||
Provider: "ldap",
|
||||
IsLoggedIn: true,
|
||||
LdapGroups: strings.Join(ldapUser.Groups, ","),
|
||||
IsBasicAuth: true,
|
||||
})
|
||||
for k, v := range headers {
|
||||
c.Header(k, v)
|
||||
}
|
||||
|
||||
c.Set("context", userContext)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -257,6 +108,149 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
|
||||
session, err := m.auth.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
|
||||
}
|
||||
|
||||
userContext, err := new(model.UserContext).NewFromSession(session)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
|
||||
}
|
||||
|
||||
if userContext.Provider == model.ProviderLocal &&
|
||||
userContext.Local.TOTPPending {
|
||||
return userContext, nil, nil
|
||||
}
|
||||
|
||||
switch userContext.Provider {
|
||||
case model.ProviderLocal:
|
||||
user := m.auth.GetLocalUser(userContext.Local.Username)
|
||||
|
||||
if user == nil {
|
||||
return nil, nil, fmt.Errorf("local user not found")
|
||||
}
|
||||
|
||||
userContext.Local.Attributes = user.Attributes
|
||||
|
||||
if userContext.Local.Attributes.Name == "" {
|
||||
userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
|
||||
}
|
||||
|
||||
if userContext.Local.Attributes.Email == "" {
|
||||
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
||||
}
|
||||
case model.ProviderLDAP:
|
||||
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
|
||||
}
|
||||
|
||||
if search.Type != model.UserLDAP {
|
||||
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
|
||||
}
|
||||
|
||||
user, err := m.auth.GetLDAPUser(search.Username)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||
}
|
||||
|
||||
userContext.LDAP.Groups = user.Groups
|
||||
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
||||
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
|
||||
case model.ProviderOAuth:
|
||||
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
||||
|
||||
if !exists {
|
||||
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
|
||||
}
|
||||
|
||||
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
|
||||
m.auth.DeleteSession(ctx, uuid)
|
||||
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||
}
|
||||
}
|
||||
|
||||
cookie, err := m.auth.RefreshSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
||||
}
|
||||
|
||||
return userContext, cookie, nil
|
||||
}
|
||||
|
||||
func (m *ContextMiddleware) basicAuth(username string, password string) (*model.UserContext, map[string]string, error) {
|
||||
headers := make(map[string]string)
|
||||
userContext := new(model.UserContext)
|
||||
locked, remaining := m.auth.IsAccountLocked(username)
|
||||
|
||||
if locked {
|
||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining)
|
||||
headers["x-tinyauth-lock-locked"] = "true"
|
||||
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||
return nil, headers, nil
|
||||
}
|
||||
|
||||
search, err := m.auth.SearchUser(username)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error searching for user: %w", err)
|
||||
}
|
||||
|
||||
err = m.auth.CheckUserPassword(*search, password)
|
||||
|
||||
if err != nil {
|
||||
m.auth.RecordLoginAttempt(username, false)
|
||||
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
|
||||
}
|
||||
|
||||
m.auth.RecordLoginAttempt(username, true)
|
||||
|
||||
switch search.Type {
|
||||
case model.UserLocal:
|
||||
user := m.auth.GetLocalUser(username)
|
||||
|
||||
if user.TOTPSecret != "" {
|
||||
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", username)
|
||||
}
|
||||
|
||||
userContext.Local = &model.LocalContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: user.Username,
|
||||
Name: utils.Capitalize(user.Username),
|
||||
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
||||
},
|
||||
Attributes: user.Attributes,
|
||||
}
|
||||
userContext.Provider = model.ProviderLocal
|
||||
case model.UserLDAP:
|
||||
user, err := m.auth.GetLDAPUser(username)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||
}
|
||||
|
||||
userContext.LDAP = &model.LDAPContext{
|
||||
BaseContext: model.BaseContext{
|
||||
Username: username,
|
||||
Name: utils.Capitalize(username),
|
||||
Email: utils.CompileUserEmail(username, m.config.CookieDomain),
|
||||
},
|
||||
Groups: user.Groups,
|
||||
}
|
||||
userContext.Provider = model.ProviderLDAP
|
||||
}
|
||||
|
||||
userContext.Authenticated = true
|
||||
return userContext, nil, nil
|
||||
}
|
||||
|
||||
func (m *ContextMiddleware) isIgnorePath(path string) bool {
|
||||
for _, prefix := range contextSkipPathsPrefix {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/middleware"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
func TestContextMiddleware(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
LocalUsers: &[]model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
|
||||
},
|
||||
},
|
||||
SessionExpiry: 10, // 10 seconds, useful for testing
|
||||
CookieDomain: "example.com",
|
||||
LoginTimeout: 10, // 10 seconds, useful for testing
|
||||
LoginMaxRetries: 3,
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}
|
||||
|
||||
middlewareCfg := middleware.ContextMiddlewareConfig{
|
||||
CookieDomain: "example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
}
|
||||
|
||||
basicAuthHeader := func(username, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) {
|
||||
t.Helper()
|
||||
_, err := queries.CreateSession(context.Background(), params)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runArgs struct {
|
||||
do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder)
|
||||
queries *repository.Queries
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
run func(t *testing.T, args runArgs)
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
description: "Skip path bypasses auth processing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/healthz", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "No credentials yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Valid session cookie sets authenticated local context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-valid-local"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
require.NotNil(t, userCtx.Local)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Session cookie with totp pending sets unauthenticated context with totp enabled",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-totp-pending"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "totpuser",
|
||||
Provider: "local",
|
||||
TotpPending: true,
|
||||
Expiry: time.Now().Add(60 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "totpuser", userCtx.GetUsername())
|
||||
assert.False(t, userCtx.Authenticated)
|
||||
require.NotNil(t, userCtx.Local)
|
||||
assert.True(t, userCtx.Local.TOTPPending)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Unknown session cookie yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: "does-not-exist"})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Session for missing local user yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-deleted-user"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "ghostuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Expired session cookie yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-expired"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(-1 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Add(-10 * time.Second).Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Valid basic auth sets authenticated local context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, model.ProviderLocal, userCtx.Provider)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Invalid basic auth password yields no context",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Basic auth is rejected for users with totp",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Locked account on basic auth sets lock headers",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
for range 3 {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "wrongpassword"))
|
||||
args.do(req)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, recorder := args.do(req)
|
||||
|
||||
assert.Nil(t, userCtx)
|
||||
assert.Equal(t, "true", recorder.Header().Get("x-tinyauth-lock-locked"))
|
||||
assert.NotEmpty(t, recorder.Header().Get("x-tinyauth-lock-reset"))
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Cookie auth takes precedence over basic auth",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
uuid := "session-precedence"
|
||||
seedSession(t, args.queries, repository.CreateSessionParams{
|
||||
UUID: uuid,
|
||||
Username: "testuser",
|
||||
Provider: "local",
|
||||
Expiry: time.Now().Add(10 * time.Second).Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "tinyauth-session", Value: uuid})
|
||||
req.Header.Set("Authorization", basicAuthHeader("totpuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "Ensure fallback to basic auth when cookie is missing",
|
||||
run: func(t *testing.T, args runArgs) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("Authorization", basicAuthHeader("testuser", "password"))
|
||||
userCtx, _ := args.do(req)
|
||||
|
||||
require.NotNil(t, userCtx)
|
||||
assert.Equal(t, "testuser", userCtx.GetUsername())
|
||||
assert.True(t, userCtx.Authenticated)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
|
||||
|
||||
app := bootstrap.NewBootstrapApp(model.Config{})
|
||||
|
||||
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
|
||||
queries := repository.New(db)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
err = ldap.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(oauthBrokerCfgs)
|
||||
err = broker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
authService := service.NewAuthService(authServiceCfg, ldap, queries, broker)
|
||||
err = authService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker)
|
||||
err = contextMiddleware.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
authService.ClearRateLimitsTestingOnly()
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
do := func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) {
|
||||
var captured *model.UserContext
|
||||
router := gin.New()
|
||||
router.Use(contextMiddleware.Middleware())
|
||||
handler := func(c *gin.Context) {
|
||||
if val, exists := c.Get("context"); exists {
|
||||
captured, _ = val.(*model.UserContext)
|
||||
}
|
||||
}
|
||||
router.GET("/api/test", handler)
|
||||
router.GET("/api/healthz", handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
return captured, recorder
|
||||
}
|
||||
|
||||
test.run(t, runArgs{do: do, queries: queries})
|
||||
})
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user