wip: add middlewares

This commit is contained in:
Stavros
2025-08-25 13:28:30 +03:00
parent 4979121395
commit ace22acdb2
6 changed files with 276 additions and 225 deletions

View File

@@ -1,144 +0,0 @@
package hooks
import (
"fmt"
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/oauth"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth
Providers *providers.Providers
}
func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks {
return &Hooks{
Config: config,
Auth: auth,
Providers: providers,
}
}
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
cookie, err := hooks.Auth.GetSessionCookie(c)
var provider *oauth.OAuth
if err != nil {
log.Error().Err(err).Msg("Failed to get session cookie")
goto basic
}
if cookie.TotpPending {
log.Debug().Msg("Totp pending")
return types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
TotpPending: true,
}
}
if cookie.Provider == "username" {
log.Debug().Msg("Provider is username")
userSearch := hooks.Auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" {
log.Warn().Str("username", cookie.Username).Msg("User does not exist")
goto basic
}
log.Debug().Str("type", userSearch.Type).Msg("User exists")
return types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
IsLoggedIn: true,
Provider: "username",
}
}
log.Debug().Msg("Provider is not username")
provider = hooks.Providers.GetProvider(cookie.Provider)
if provider != nil {
log.Debug().Msg("Provider exists")
if !hooks.Auth.EmailWhitelisted(cookie.Email) {
log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted")
hooks.Auth.DeleteSessionCookie(c)
goto basic
}
log.Debug().Msg("Email is whitelisted")
return types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
IsLoggedIn: true,
OAuth: true,
Provider: cookie.Provider,
OAuthGroups: cookie.OAuthGroups,
}
}
basic:
log.Debug().Msg("Trying basic auth")
basic := hooks.Auth.GetBasicAuth(c)
if basic != nil {
log.Debug().Msg("Got basic auth")
userSearch := hooks.Auth.SearchUser(basic.Username)
if userSearch.Type == "unkown" {
log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist")
return types.UserContext{}
}
if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect")
return types.UserContext{}
}
if userSearch.Type == "ldap" {
log.Debug().Msg("User is LDAP")
return types.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain),
IsLoggedIn: true,
Provider: "basic",
TotpEnabled: false,
}
}
user := hooks.Auth.GetLocalUser(basic.Username)
return types.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain),
IsLoggedIn: true,
Provider: "basic",
TotpEnabled: user.TotpSecret != "",
}
}
return types.UserContext{}
}

View File

@@ -0,0 +1,143 @@
package middlewares
import (
"fmt"
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
)
type ContextMiddlewareConfig struct {
Domain string
}
type ContextMiddleware struct {
Config ContextMiddlewareConfig
Auth *auth.Auth
Providers *providers.Providers
}
func NewContextMiddleware(config ContextMiddlewareConfig, auth *auth.Auth, providers *providers.Providers) *ContextMiddleware {
return &ContextMiddleware{
Config: config,
Auth: auth,
Providers: providers,
}
}
func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
cookie, err := m.Auth.GetSessionCookie(c)
if err != nil {
goto basic
}
if cookie.TotpPending {
c.Set("context", &types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
TotpPending: true,
TotpEnabled: true,
})
c.Next()
return
}
switch cookie.Provider {
case "username":
userSearch := m.Auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" {
goto basic
}
c.Set("context", &types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
IsLoggedIn: true,
})
c.Next()
return
default:
provider := m.Providers.GetProvider(cookie.Provider)
if provider == nil {
goto basic
}
if !m.Auth.EmailWhitelisted(cookie.Email) {
m.Auth.DeleteSessionCookie(c)
goto basic
}
c.Set("context", &types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
OAuthGroups: cookie.OAuthGroups,
IsLoggedIn: true,
OAuth: true,
})
c.Next()
return
}
basic:
basic := m.Auth.GetBasicAuth(c)
if basic == nil {
c.Next()
return
}
userSearch := m.Auth.SearchUser(basic.Username)
if userSearch.Type == "unknown" {
c.Next()
return
}
if !m.Auth.VerifyUser(userSearch, basic.Password) {
c.Next()
return
}
switch userSearch.Type {
case "local":
user := m.Auth.GetLocalUser(basic.Username)
c.Set("context", &types.UserContext{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain),
Provider: "basic",
IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "",
})
c.Next()
return
case "ldap":
c.Set("context", &types.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain),
Provider: "basic",
IsLoggedIn: true,
})
c.Next()
return
}
c.Next()
}
}

View File

@@ -0,0 +1,66 @@
package middlewares
import (
"io/fs"
"net/http"
"os"
"strings"
"tinyauth/internal/assets"
"github.com/gin-gonic/gin"
)
type UIMiddleware struct {
UIFS fs.FS
UIFileServer http.Handler
ResourcesFileServer http.Handler
}
func NewUIMiddleware() (*UIMiddleware, error) {
ui, err := fs.Sub(assets.Assets, "dist")
if err != nil {
return nil, err
}
uiFileServer := http.FileServer(http.FS(ui))
resourcesFileServer := http.FileServer(http.Dir("/data/resources"))
return &UIMiddleware{
UIFS: ui,
UIFileServer: uiFileServer,
ResourcesFileServer: resourcesFileServer,
}, nil
}
func (m UIMiddleware) Middlware() gin.HandlerFunc {
return func(c *gin.Context) {
switch strings.Split(c.Request.URL.Path, "/")[1] {
case "api":
c.Next()
return
case "resources":
_, err := os.Stat("/data/resources/" + strings.TrimPrefix(c.Request.URL.Path, "/resources/"))
if os.IsNotExist(err) {
c.Status(404)
c.Abort()
return
}
m.ResourcesFileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
return
default:
_, err := fs.Stat(m.UIFS, strings.TrimPrefix(c.Request.URL.Path, "/"))
if os.IsNotExist(err) {
c.Request.URL.Path = "/"
}
m.UIFileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
return
}
}
}

View File

@@ -0,0 +1,62 @@
package middlewares
import (
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
var (
loggerSkipPathsPrefix = []string{
"GET /api/healthcheck",
"HEAD /api/healthcheck",
"GET /favicon.ico",
}
)
type ZerologMiddleware struct{}
func NewZerologMiddleware() *ZerologMiddleware {
return &ZerologMiddleware{}
}
func (m ZerologMiddleware) logPath(path string) bool {
for _, prefix := range loggerSkipPathsPrefix {
if strings.HasPrefix(path, prefix) {
return false
}
}
return true
}
func (m ZerologMiddleware) Middlware() gin.HandlerFunc {
return func(c *gin.Context) {
tStart := time.Now()
c.Next()
code := c.Writer.Status()
address := c.Request.RemoteAddr
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
latency := time.Since(tStart).String()
// logPath check if the path should be logged normally or with debug
if m.logPath(method + " " + path) {
switch {
case code >= 200 && code < 300:
log.Info().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request")
case code >= 300 && code < 400:
log.Warn().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request")
case code >= 400:
log.Error().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request")
}
} else {
log.Debug().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
}
}
}

View File

@@ -2,12 +2,6 @@ package server
import (
"fmt"
"io/fs"
"net/http"
"os"
"strings"
"time"
"tinyauth/internal/assets"
"tinyauth/internal/handlers"
"tinyauth/internal/types"
@@ -21,52 +15,17 @@ type Server struct {
Router *gin.Engine
}
var (
loggerSkipPathsPrefix = []string{
"GET /api/healthcheck",
"HEAD /api/healthcheck",
"GET /favicon.ico",
}
)
func logPath(path string) bool {
for _, prefix := range loggerSkipPathsPrefix {
if strings.HasPrefix(path, prefix) {
return false
}
}
return true
type Middlware interface {
Middlware() gin.HandlerFunc
}
func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) {
gin.SetMode(gin.ReleaseMode)
log.Debug().Msg("Setting up router")
func NewServer(config types.ServerConfig, handlers *handlers.Handlers, middlewares []Middlware) (*Server, error) {
router := gin.New()
router.Use(zerolog())
log.Debug().Msg("Setting up assets")
dist, err := fs.Sub(assets.Assets, "dist")
if err != nil {
return nil, err
for _, middleware := range middlewares {
router.Use(middleware.Middlware())
}
log.Debug().Msg("Setting up file server")
fileServer := http.FileServer(http.FS(dist))
// UI middleware
router.Use(func(c *gin.Context) {
// If not an API request, serve the UI
if !strings.HasPrefix(c.Request.URL.Path, "/api") {
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
if os.IsNotExist(err) {
c.Request.URL.Path = "/"
}
fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
}
})
// Proxy routes
router.GET("/api/auth/:proxy", handlers.ProxyHandler)
@@ -98,33 +57,3 @@ func (s *Server) Start() error {
log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server")
return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port))
}
// zerolog is a middleware for gin that logs requests using zerolog
func zerolog() gin.HandlerFunc {
return func(c *gin.Context) {
tStart := time.Now()
c.Next()
code := c.Writer.Status()
address := c.Request.RemoteAddr
method := c.Request.Method
path := c.Request.URL.Path
latency := time.Since(tStart).String()
// logPath check if the path should be logged normally or with debug
if logPath(method + " " + path) {
switch {
case code >= 200 && code < 300:
log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
case code >= 300 && code < 400:
log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
case code >= 400:
log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
}
} else {
log.Debug().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
}
}
}

View File

@@ -95,11 +95,6 @@ type AuthConfig struct {
EncryptionSecret string
}
// HooksConfig is the configuration for the hooks service
type HooksConfig struct {
Domain string
}
// OAuthLabels is a list of labels that can be used in a tinyauth protected container
type OAuthLabels struct {
Whitelist string