diff --git a/cmd/root.go b/cmd/root.go index db35e82..6b52983 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -105,19 +105,18 @@ var rootCmd = &cobra.Command{ // Create api config apiConfig := types.APIConfig{ - Port: config.Port, - Address: config.Address, - Secret: config.Secret, - CookieSecure: config.CookieSecure, - SessionExpiry: config.SessionExpiry, - Domain: domain, + Port: config.Port, + Address: config.Address, } // Create auth config authConfig := types.AuthConfig{ Users: users, OauthWhitelist: oauthWhitelist, + Secret: config.Secret, + CookieSecure: config.CookieSecure, SessionExpiry: config.SessionExpiry, + Domain: domain, LoginTimeout: config.LoginTimeout, LoginMaxRetries: config.LoginMaxRetries, } diff --git a/go.mod b/go.mod index 0b6e313..6fa1cd2 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module tinyauth go 1.23.2 require ( - github.com/gin-contrib/sessions v1.0.2 github.com/gin-gonic/gin v1.10.0 github.com/go-playground/validator/v10 v10.24.0 github.com/google/go-querystring v1.1.0 @@ -58,9 +57,8 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/goccy/go-json v0.10.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/gorilla/context v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect - github.com/gorilla/sessions v1.2.2 // indirect + github.com/gorilla/sessions v1.2.2 github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect diff --git a/go.sum b/go.sum index 7887980..bd7304d 100644 --- a/go.sum +++ b/go.sum @@ -65,8 +65,6 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= -github.com/gin-contrib/sessions v1.0.2 h1:UaIjUvTH1cMeOdj3in6dl+Xb6It8RiKRF9Z1anbUyCA= -github.com/gin-contrib/sessions v1.0.2/go.mod h1:KxKxWqWP5LJVDCInulOl4WbLzK2KSPlLesfZ66wRvMs= github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E= github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= @@ -99,8 +97,6 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= -github.com/gorilla/context v1.1.2/go.mod h1:KDPwT9i/MeWHiLl90fuTgrt4/wPcv75vFAZLaOOcbxM= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= diff --git a/internal/api/api.go b/internal/api/api.go index ba3e4c3..1b6aacf 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -11,8 +11,6 @@ import ( "tinyauth/internal/handlers" "tinyauth/internal/types" - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" ) @@ -51,21 +49,6 @@ func (api *API) Init() { log.Debug().Msg("Setting up file server") fileServer := http.FileServer(http.FS(dist)) - // Setup cookie store - log.Debug().Msg("Setting up cookie store") - store := cookie.NewStore([]byte(api.Config.Secret)) - - // Use session middleware - store.Options(sessions.Options{ - Domain: api.Config.Domain, - Path: "/", - HttpOnly: true, - Secure: api.Config.CookieSecure, - MaxAge: api.Config.SessionExpiry, - }) - - router.Use(sessions.Sessions("tinyauth", store)) - // UI middleware router.Use(func(c *gin.Context) { // If not an API request, serve the UI diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 4358b4a..c4477c2 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -21,11 +21,8 @@ import ( // Simple API config for tests var apiConfig = types.APIConfig{ - Port: 8080, - Address: "0.0.0.0", - Secret: "super-secret-api-thing-for-tests", // It is 32 chars long - CookieSecure: false, - SessionExpiry: 3600, + Port: 8080, + Address: "0.0.0.0", } // Simple handlers config for tests @@ -42,6 +39,8 @@ var handlersConfig = types.HandlersConfig{ var authConfig = types.AuthConfig{ Users: types.Users{}, OauthWhitelist: []string{}, + Secret: "super-secret-api-thing-for-tests", // It is 32 chars long + CookieSecure: false, SessionExpiry: 3600, LoginTimeout: 0, LoginMaxRetries: 0, diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 33a3a39..100a248 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,6 +1,8 @@ package auth import ( + "fmt" + "net/http" "regexp" "slices" "strings" @@ -9,8 +11,8 @@ import ( "tinyauth/internal/docker" "tinyauth/internal/types" - "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/gorilla/sessions" "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" ) @@ -30,6 +32,30 @@ type Auth struct { LoginMutex sync.RWMutex } +func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { + // Create cookie store + store := sessions.NewCookieStore([]byte(auth.Config.Secret)) + + // Configure cookie store + store.Options = &sessions.Options{ + Path: "/", + MaxAge: auth.Config.SessionExpiry, + Secure: auth.Config.CookieSecure, + HttpOnly: true, + SameSite: http.SameSiteDefaultMode, + Domain: fmt.Sprintf(".%s", auth.Config.Domain), + } + + // Get session + session, err := store.Get(c.Request, "tinyauth") + if err != nil { + log.Error().Err(err).Msg("Failed to get session") + return nil, err + } + + return session, nil +} + func (auth *Auth) GetUser(username string) *types.User { // Loop through users and return the user if the username matches for _, user := range auth.Config.Users { @@ -126,11 +152,15 @@ func (auth *Auth) EmailWhitelisted(emailSrc string) bool { return false } -func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) { +func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { log.Debug().Msg("Creating session cookie") // Get session - sessions := sessions.Default(c) + session, err := auth.GetSession(c) + if err != nil { + log.Error().Err(err).Msg("Failed to get session") + return err + } log.Debug().Msg("Setting session cookie") @@ -144,39 +174,63 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) } // Set data - sessions.Set("username", data.Username) - sessions.Set("provider", data.Provider) - sessions.Set("expiry", time.Now().Add(time.Duration(sessionExpiry)*time.Second).Unix()) - sessions.Set("totpPending", data.TotpPending) + session.Values["username"] = data.Username + session.Values["provider"] = data.Provider + session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() + session.Values["totpPending"] = data.TotpPending // Save session - sessions.Save() + err = session.Save(c.Request, c.Writer) + if err != nil { + log.Error().Err(err).Msg("Failed to save session") + return err + } + + // Return nil + return nil } -func (auth *Auth) DeleteSessionCookie(c *gin.Context) { +func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { log.Debug().Msg("Deleting session cookie") // Get session - sessions := sessions.Default(c) + session, err := auth.GetSession(c) + if err != nil { + log.Error().Err(err).Msg("Failed to get session") + return err + } - // Clear session - sessions.Clear() + // Delete all values in the session + for key := range session.Values { + delete(session.Values, key) + } // Save session - sessions.Save() + err = session.Save(c.Request, c.Writer) + if err != nil { + log.Error().Err(err).Msg("Failed to save session") + return err + } + + // Return nil + return nil } -func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { +func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { log.Debug().Msg("Getting session cookie") // Get session - sessions := sessions.Default(c) + session, err := auth.GetSession(c) + if err != nil { + log.Error().Err(err).Msg("Failed to get session") + return types.SessionCookie{}, err + } // Get data - cookieUsername := sessions.Get("username") - cookieProvider := sessions.Get("provider") - cookieExpiry := sessions.Get("expiry") - cookieTotpPending := sessions.Get("totpPending") + cookieUsername := session.Values["username"] + cookieProvider := session.Values["provider"] + cookieExpiry := session.Values["expiry"] + cookieTotpPending := session.Values["totpPending"] // Convert interfaces to correct types username, usernameOk := cookieUsername.(string) @@ -187,7 +241,7 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { // Check if the cookie is invalid if !usernameOk || !providerOk || !expiryOk || !totpPendingOk { log.Warn().Msg("Session cookie invalid") - return types.SessionCookie{} + return types.SessionCookie{}, nil } // Check if the cookie has expired @@ -198,7 +252,7 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { auth.DeleteSessionCookie(c) // Return empty cookie - return types.SessionCookie{} + return types.SessionCookie{}, nil } log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") @@ -208,7 +262,7 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { Username: username, Provider: provider, TotpPending: totpPending, - } + }, nil } func (auth *Auth) UserAuthConfigured() bool { diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 6921372..5e9a689 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -23,7 +23,7 @@ type Hooks struct { func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { // Get session cookie and basic auth - cookie := hooks.Auth.GetSessionCookie(c) + cookie, err := hooks.Auth.GetSessionCookie(c) basic := hooks.Auth.GetBasicAuth(c) // Check if basic auth is set @@ -46,6 +46,19 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { } + // Check cookie error after basic auth + if err != nil { + log.Error().Err(err).Msg("Failed to get session cookie") + // Return empty context + return types.UserContext{ + Username: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + TotpPending: false, + } + } + // Check if session cookie has totp pending if cookie.TotpPending { log.Debug().Msg("Totp pending") diff --git a/internal/types/config.go b/internal/types/config.go index dce9657..ddb1a5d 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -66,12 +66,8 @@ type OAuthConfig struct { // APIConfig is the configuration for the API type APIConfig struct { - Port int - Address string - Secret string - CookieSecure bool - SessionExpiry int - Domain string + Port int + Address string } // AuthConfig is the configuration for the auth service @@ -79,6 +75,9 @@ type AuthConfig struct { Users Users OauthWhitelist []string SessionExpiry int + Secret string + CookieSecure bool + Domain string LoginTimeout int LoginMaxRetries int }