Files
tinyauth/internal/middleware/context_middleware.go
T
2026-04-29 20:11:43 +03:00

265 lines
6.9 KiB
Go

package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
)
// Gin won't let us set a middleware on a specific route (at least it doesn't work,
// see https://github.com/gin-gonic/gin/issues/531) so we have to do some hackery
var (
contextSkipPathsPrefix = []string{
"GET /api/context/app",
"GET /api/healthz",
"HEAD /api/healthz",
"GET /api/oauth/url",
"GET /api/oauth/callback",
"GET /api/oidc/clients",
"POST /api/oidc/token",
"GET /api/oidc/userinfo",
"POST /api/oidc/userinfo",
"GET /resources",
"POST /api/user/login",
"GET /.well-known/openid-configuration",
"GET /.well-known/jwks.json",
}
)
type ContextMiddlewareConfig struct {
CookieDomain string
SessionCookieName string
}
type ContextMiddleware struct {
config ContextMiddlewareConfig
auth *service.AuthService
broker *service.OAuthBrokerService
}
func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware {
return &ContextMiddleware{
config: config,
auth: auth,
broker: broker,
}
}
func (m *ContextMiddleware) Init() error {
return nil
}
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) {
c.Next()
return
}
uuid, err := c.Cookie(m.config.SessionCookieName)
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
if err != nil {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
c.Next()
return
}
if cookie != nil {
http.SetCookie(c.Writer, cookie)
}
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername())
c.Set("context", userContext)
c.Next()
return
}
basic, err := m.auth.GetBasicAuth(c.Request)
if err == nil {
userContext, headers, err := m.basicAuth(c.Request.Context(), basic)
if err != nil {
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
c.Next()
return
}
for k, v := range headers {
c.Header(k, v)
}
c.Set("context", userContext)
c.Next()
return
}
c.Next()
}
}
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
session, err := m.auth.GetSession(ctx, uuid)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
}
userContext, err := new(model.UserContext).NewFromSession(session)
if err != nil {
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
}
if userContext.Provider == model.ProviderLocal &&
userContext.Local.TOTPPending {
userContext.Local.TOTPEnabled = true
return userContext, nil, nil
}
switch userContext.Provider {
case model.ProviderLocal:
user := m.auth.GetLocalUser(userContext.Local.Username)
if user == nil {
return nil, nil, fmt.Errorf("local user not found")
}
userContext.Local.Attributes = user.Attributes
if userContext.Local.Attributes.Name == "" {
userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
}
if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
}
case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username)
if err != nil {
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
}
if search.Type != model.UserLDAP {
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
}
user, err := m.auth.GetLDAPUser(search.Username)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
}
userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID)
if !exists {
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
}
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
}
}
cookie, err := m.auth.RefreshSession(ctx, uuid)
if err != nil {
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
}
return userContext, cookie, nil
}
func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) {
headers := make(map[string]string)
userContext := new(model.UserContext)
locked, remaining := m.auth.IsAccountLocked(basic.Username)
if locked {
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
headers["x-tinyauth-lock-locked"] = "true"
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
return nil, headers, nil
}
search, err := m.auth.SearchUser(basic.Username)
if err != nil {
return nil, nil, fmt.Errorf("error searching for user: %w", err)
}
err = m.auth.CheckUserPassword(*search, basic.Password)
if err != nil {
m.auth.RecordLoginAttempt(basic.Username, false)
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
}
m.auth.RecordLoginAttempt(basic.Username, true)
switch search.Type {
case model.UserLocal:
user := m.auth.GetLocalUser(basic.Username)
if user.TOTPSecret != "" {
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username)
}
userContext.Local = &model.LocalContext{
BaseContext: model.BaseContext{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
},
Attributes: user.Attributes,
}
userContext.Provider = model.ProviderLocal
case model.UserLDAP:
user, err := m.auth.GetLDAPUser(basic.Username)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
}
userContext.LDAP = &model.LDAPContext{
BaseContext: model.BaseContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
},
Groups: user.Groups,
}
userContext.Provider = model.ProviderLDAP
}
userContext.Authenticated = true
return userContext, nil, nil
}
func (m *ContextMiddleware) isIgnorePath(path string) bool {
for _, prefix := range contextSkipPathsPrefix {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}