diff --git a/cmd/root.go b/cmd/root.go index 2b0c172..7ed86e4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -82,7 +82,7 @@ func init() { {"app-url", "", "The Tinyauth URL."}, {"users", "", "Comma separated list of users in the format username:hash."}, {"users-file", "", "Path to a file containing users in the format username:hash."}, - {"cookie-secure", false, "Send cookie over secure connection only."}, + {"secure-cookie", false, "Send cookie over secure connection only."}, {"github-client-id", "", "Github OAuth client ID."}, {"github-client-secret", "", "Github OAuth client secret."}, {"github-client-secret-file", "", "Github OAuth client secret file."}, diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 0b572e0..df6e61f 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -7,4 +7,4 @@ import ( // Frontend assets // //go:embed dist -var FontendAssets embed.FS +var FrontendAssets embed.FS diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 4401172..594c575 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -164,13 +164,13 @@ func (app *BootstrapApp) Setup() error { log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware") err := middleware.Init() if err != nil { - return fmt.Errorf("failed to initialize %s middleware: %T", middleware, err) + return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err) } engine.Use(middleware.Middleware()) } // Create routers - mainRouter := engine.Group("/") + mainRouter := engine.Group("") apiRouter := engine.Group("/api") // Create controllers @@ -190,6 +190,7 @@ func (app *BootstrapApp) Setup() error { SecureCookie: app.Config.SecureCookie, CSRFCookieName: csrfCookieName, RedirectCookieName: redirectCookieName, + Domain: domain, }, apiRouter, authService, oauthBrokerService) proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ diff --git a/internal/config/config.go b/internal/config/config.go index 48961d6..5d4dba8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,10 +65,10 @@ type OAuthLabels struct { type BasicLabels struct { Username string - Password PassowrdLabels + Password PasswordLabels } -type PassowrdLabels struct { +type PasswordLabels struct { Plain string File string } diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 025db1b..9802ea1 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -23,6 +23,7 @@ type OAuthControllerConfig struct { RedirectCookieName string SecureCookie bool AppURL string + Domain string } type OAuthController struct { @@ -77,7 +78,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { redirectURI := c.Query("redirect_uri") - if redirectURI != "" { + if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { log.Debug().Msg("Setting redirect URI cookie") c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) } @@ -178,7 +179,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) - if err != nil { + if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { log.Debug().Msg("No redirect URI cookie found, redirecting to app root") c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) return @@ -195,5 +196,5 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) } diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index ae7a101..348be65 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -128,6 +128,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return } c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) @@ -212,9 +213,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if userContext.OAuth { - queries.Set("username", userContext.Username) - } else { queries.Set("username", userContext.Email) + } else { + queries.Set("username", userContext.Username) } if err != nil { @@ -247,9 +248,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if userContext.OAuth { - queries.Set("username", userContext.Username) - } else { queries.Set("username", userContext.Email) + } else { + queries.Set("username", userContext.Username) } if err != nil { diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index f0c2009..56bae87 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -11,14 +11,18 @@ type ResourcesControllerConfig struct { } type ResourcesController struct { - Config ResourcesControllerConfig - Router *gin.RouterGroup + Config ResourcesControllerConfig + Router *gin.RouterGroup + FileServer http.Handler } func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir))) + return &ResourcesController{ - Config: config, - Router: router, + Config: config, + Router: router, + FileServer: fileServer, } } @@ -27,6 +31,12 @@ func (controller *ResourcesController) SetupRoutes() { } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { - fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(controller.Config.ResourcesDir))) - fileServer.ServeHTTP(c.Writer, c.Request) + if controller.Config.ResourcesDir == "" { + c.JSON(404, gin.H{ + "status": 404, + "message": "Resources not found", + }) + return + } + controller.FileServer.ServeHTTP(c.Writer, c.Request) } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 7f307e3..72e22d8 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -112,7 +112,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { if user.TotpSecret != "" { log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), @@ -120,6 +120,15 @@ func (controller *UserController) loginHandler(c *gin.Context) { TotpPending: true, }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + c.JSON(200, gin.H{ "status": 200, "message": "TOTP required", @@ -129,13 +138,22 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: req.Username, Name: utils.Capitalize(req.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), Provider: "username", }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", @@ -144,7 +162,9 @@ func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) { log.Debug().Msg("Logout request received") + controller.Auth.DeleteSessionCookie(c) + c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", @@ -175,8 +195,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - if !context.IsLoggedIn { - log.Warn().Msg("TOTP attempt without being logged in") + if !context.TotpPending { + log.Warn().Msg("TOTP attempt without a pending TOTP session") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -223,13 +243,22 @@ func (controller *UserController) totpHandler(c *gin.Context) { controller.Auth.RecordLoginAttempt(rateIdentifier, true) - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), Provider: "username", }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index e11f80c..58e53e1 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -79,6 +79,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { if !exists { log.Debug().Msg("OAuth provider from session cookie not found") + m.Auth.DeleteSessionCookie(c) goto basic } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 6c03e4f..dcfaa35 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -20,10 +20,10 @@ func NewUIMiddleware() *UIMiddleware { } func (m *UIMiddleware) Init() error { - ui, err := fs.Sub(assets.FontendAssets, "dist") + ui, err := fs.Sub(assets.FrontendAssets, "dist") if err != nil { - return nil + return err } m.UIFS = ui diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index 95f5821..877ad4c 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -10,8 +10,8 @@ import ( var ( loggerSkipPathsPrefix = []string{ - "GET /api/healthcheck", - "HEAD /api/healthcheck", + "GET /api/health", + "HEAD /api/health", "GET /favicon.ico", } ) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 8c91e79..29f2dd1 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -71,9 +71,9 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { - log.Debug().Err(err).Msg("Error getting session, clearing cookie and retrying") + log.Debug().Err(err).Msg("Error getting session, creating a new one") c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) - session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) + session, err = auth.Store.New(c.Request, auth.Config.SessionCookieName) if err != nil { return nil, err } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index c68d150..a09fd93 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "tinyauth/internal/config" @@ -76,7 +77,7 @@ func (generic *GenericOAuthService) VerifyCode(code string) error { token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) if err != nil { - return nil + return err } generic.Token = token @@ -94,6 +95,10 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err := io.ReadAll(res.Body) if err != nil { return user, err diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 2f9e27f..4df4444 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "errors" + "fmt" "io" "net/http" "tinyauth/internal/config" @@ -71,7 +72,7 @@ func (github *GithubOAuthService) VerifyCode(code string) error { token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) if err != nil { - return nil + return err } github.Token = token @@ -83,12 +84,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { client := github.Config.Client(github.Context, github.Token) - res, err := client.Get("https://api.github.com/user") + req, err := http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + return user, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err := client.Do(req) if err != nil { return user, err } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err := io.ReadAll(res.Body) if err != nil { return user, err @@ -101,12 +113,23 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { return user, err } - res, err = client.Get("https://api.github.com/user/emails") + req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil) + if err != nil { + return user, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err = client.Do(req) if err != nil { return user, err } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err = io.ReadAll(res.Body) if err != nil { return user, err diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 776aeca..4f738e7 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "fmt" "io" "net/http" "strings" @@ -66,7 +67,7 @@ func (google *GoogleOAuthService) VerifyCode(code string) error { token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) if err != nil { - return nil + return err } google.Token = token @@ -84,6 +85,10 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { } defer res.Body.Close() + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + body, err := io.ReadAll(res.Body) if err != nil { return config.Claims{}, err diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 1ed8d4c..85a8754 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -2,6 +2,7 @@ package utils import ( "errors" + "net" "net/url" "strings" "tinyauth/internal/config" @@ -12,16 +13,25 @@ import ( ) // Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetUpperDomain(urlSrc string) (string, error) { - urlParsed, err := url.Parse(urlSrc) +func GetUpperDomain(appUrl string) (string, error) { + appUrlParsed, err := url.Parse(appUrl) if err != nil { return "", err } - urlSplitted := strings.Split(urlParsed.Hostname(), ".") - urlFinal := strings.Join(urlSplitted[1:], ".") + host := appUrlParsed.Hostname() - return urlFinal, nil + if netIP := net.ParseIP(host); netIP != nil { + return "", errors.New("IP addresses are not allowed") + } + + urlParts := strings.Split(host, ".") + + if len(urlParts) < 2 { + return "", errors.New("invalid domain, must be at least second level domain") + } + + return strings.Join(urlParts[1:], "."), nil } func ParseFileToLine(content string) string { @@ -63,8 +73,38 @@ func GetContext(c *gin.Context) (config.UserContext, error) { return *userContext, nil } +func IsRedirectSafe(redirectURL string, domain string) bool { + if redirectURL == "" { + return false + } + + parsedURL, err := url.Parse(redirectURL) + + if err != nil { + return false + } + + if !parsedURL.IsAbs() { + return false + } + + upper, err := GetUpperDomain(redirectURL) + + if err != nil { + return false + } + + if upper != domain { + return false + } + + return true +} + func GetLogLevel(level string) zerolog.Level { switch strings.ToLower(level) { + case "trace": + return zerolog.TraceLevel case "debug": return zerolog.DebugLevel case "info": diff --git a/internal/utils/label_utils.go b/internal/utils/label_utils.go index a01685b..f10092d 100644 --- a/internal/utils/label_utils.go +++ b/internal/utils/label_utils.go @@ -1,6 +1,7 @@ package utils import ( + "net/http" "strings" "tinyauth/internal/config" @@ -26,6 +27,10 @@ func ParseHeaders(headers []string) map[string]string { continue } key := SanitizeHeader(strings.TrimSpace(split[0])) + if strings.ContainsAny(key, " \t") { + continue + } + key = http.CanonicalHeaderKey(key) value := SanitizeHeader(strings.TrimSpace(split[1])) headerMap[key] = value } diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go index bfcec49..0044db4 100644 --- a/internal/utils/user_utils.go +++ b/internal/utils/user_utils.go @@ -9,6 +9,12 @@ import ( func ParseUsers(users string) ([]config.User, error) { var usersParsed []config.User + users = strings.TrimSpace(users) + + if users == "" { + return []config.User{}, nil + } + userList := strings.Split(users, ",") if len(userList) == 0 { @@ -16,7 +22,10 @@ func ParseUsers(users string) ([]config.User, error) { } for _, user := range userList { - parsed, err := ParseUser(user) + if strings.TrimSpace(user) == "" { + continue + } + parsed, err := ParseUser(strings.TrimSpace(user)) if err != nil { return []config.User{}, err } @@ -39,12 +48,13 @@ func GetUsers(conf string, file string) ([]config.User, error) { if file != "" { contents, err := ReadFile(file) - if err == nil { - if users != "" { - users += "," - } - users += ParseFileToLine(contents) + if err != nil { + return []config.User{}, err } + if users != "" { + users += "," + } + users += ParseFileToLine(contents) } return ParseUsers(users) diff --git a/main.go b/main.go index eac789e..8126e9e 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,6 @@ import ( ) func main() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger().Level(zerolog.FatalLevel) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger() cmd.Execute() }