package middleware import ( "context" "fmt" "net/http" "strings" "time" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" ) // Gin won't let us set a middleware on a specific route (at least it doesn't work, // see https://github.com/gin-gonic/gin/issues/531) so we have to do some hackery var ( contextSkipPathsPrefix = []string{ "GET /api/context/app", "GET /api/healthz", "HEAD /api/healthz", "GET /api/oauth/url", "GET /api/oauth/callback", "GET /api/oidc/clients", "POST /api/oidc/token", "GET /api/oidc/userinfo", "POST /api/oidc/userinfo", "GET /resources", "POST /api/user/login", "GET /.well-known/openid-configuration", "GET /.well-known/jwks.json", } ) type ContextMiddlewareConfig struct { CookieDomain string SessionCookieName string } type ContextMiddleware struct { config ContextMiddlewareConfig auth *service.AuthService broker *service.OAuthBrokerService } func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { return &ContextMiddleware{ config: config, auth: auth, broker: broker, } } func (m *ContextMiddleware) Init() error { return nil } func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { c.Next() return } uuid, err := c.Cookie(m.config.SessionCookieName) if err == nil { userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) if err != nil { tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) c.Next() return } if cookie != nil { http.SetCookie(c.Writer, cookie) } c.Set("context", userContext) c.Next() return } basic, err := m.auth.GetBasicAuth(c.Request) if err == nil { userContext, headers, err := m.basicAuth(c.Request.Context(), basic) if err != nil { tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) c.Next() return } for k, v := range headers { c.Header(k, v) } c.Set("context", userContext) c.Next() return } c.Next() } } func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) { session, err := m.auth.GetSession(ctx, uuid) if err != nil { return nil, nil, fmt.Errorf("error retrieving session: %w", err) } userContext, err := new(model.UserContext).NewFromSession(session) if err != nil { return nil, nil, fmt.Errorf("error creating user context from session: %w", err) } if userContext.Provider == model.ProviderLocal && userContext.Local.TOTPPending { userContext.Local.TOTPEnabled = true return userContext, nil, nil } switch userContext.Provider { case model.ProviderLocal: user := m.auth.GetLocalUser(userContext.Local.Username) if user == nil { return nil, nil, fmt.Errorf("local user not found") } userContext.Local.Attributes = user.Attributes if userContext.Local.Attributes.Name == "" { userContext.Local.Attributes.Name = utils.Capitalize(user.Username) } if userContext.Local.Attributes.Email == "" { userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain) } case model.ProviderLDAP: search, err := m.auth.SearchUser(userContext.LDAP.Username) if err != nil { return nil, nil, fmt.Errorf("error searching for ldap user: %w", err) } if search.Type != model.UserLDAP { return nil, nil, fmt.Errorf("user from session cookie is not ldap") } user, err := m.auth.GetLDAPUser(search.Username) if err != nil { return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) } userContext.LDAP.Groups = user.Groups userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain) case model.ProviderOAuth: _, exists := m.broker.GetService(userContext.OAuth.ID) if !exists { return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) } if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { m.auth.DeleteSession(ctx, uuid) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) } } cookie, err := m.auth.RefreshSession(ctx, uuid) if err != nil { return nil, nil, fmt.Errorf("error refreshing session: %w", err) } return userContext, cookie, nil } func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) { headers := make(map[string]string) userContext := new(model.UserContext) locked, remaining := m.auth.IsAccountLocked(basic.Username) if locked { tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) headers["x-tinyauth-lock-locked"] = "true" headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) return nil, headers, nil } search, err := m.auth.SearchUser(basic.Username) if err != nil { return nil, nil, fmt.Errorf("error searching for user: %w", err) } err = m.auth.CheckUserPassword(*search, basic.Password) if err != nil { m.auth.RecordLoginAttempt(basic.Username, false) return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) } m.auth.RecordLoginAttempt(basic.Username, true) switch search.Type { case model.UserLocal: user := m.auth.GetLocalUser(basic.Username) if user.TOTPSecret != "" { return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username) } userContext.Local = &model.LocalContext{ BaseContext: model.BaseContext{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), }, Attributes: user.Attributes, } userContext.Provider = model.ProviderLocal case model.UserLDAP: user, err := m.auth.GetLDAPUser(basic.Username) if err != nil { return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) } userContext.LDAP = &model.LDAPContext{ BaseContext: model.BaseContext{ Username: basic.Username, Name: utils.Capitalize(basic.Username), Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), }, Groups: user.Groups, } userContext.Provider = model.ProviderLDAP } userContext.Authenticated = true return userContext, nil, nil } func (m *ContextMiddleware) isIgnorePath(path string) bool { for _, prefix := range contextSkipPathsPrefix { if strings.HasPrefix(path, prefix) { return true } } return false }