From 504a3b87b4dbf8bd5529e68f9040d7a72fce611a Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 Aug 2025 15:05:03 +0300 Subject: [PATCH] refactor: rework file structure (#325) * wip: add middlewares * refactor: use context fom middleware in handlers * refactor: use controller approach in handlers * refactor: move oauth providers into services (non-working) * feat: create oauth broker service * refactor: use a boostrap service to bootstrap the app * refactor: split utils into smaller files * refactor: use more clear name for frontend assets * feat: allow customizability of resources dir * fix: fix typo in ui middleware * fix: validate resource file paths in ui middleware * refactor: move resource handling to a controller * feat: add some logging * fix: configure middlewares before groups * fix: use correct api path in login mutation * fix: coderabbit suggestions * fix: further coderabbit suggestions --- .env.example | 4 +- .gitignore | 5 +- air.toml | 2 +- cmd/root.go | 305 +++------- cmd/version.go | 8 +- frontend/src/context/app-context.tsx | 2 +- frontend/src/context/user-context.tsx | 2 +- frontend/src/pages/login-page.tsx | 2 +- frontend/src/pages/logout-page.tsx | 2 +- frontend/src/pages/totp-page.tsx | 2 +- frontend/vite.config.ts | 5 + go.mod | 1 + go.sum | 2 + internal/assets/assets.go | 4 +- internal/auth/auth_test.go | 146 ----- internal/bootstrap/app_bootstrap.go | 260 +++++++++ internal/{types => config}/config.go | 149 +++-- internal/constants/constants.go | 19 - internal/controller/context_controller.go | 104 ++++ internal/controller/health_controller.go | 25 + internal/controller/oauth_controller.go | 200 +++++++ internal/controller/proxy_controller.go | 311 ++++++++++ internal/controller/resources_controller.go | 42 ++ internal/controller/user_controller.go | 266 +++++++++ internal/handlers/context.go | 64 -- internal/handlers/handlers.go | 36 -- internal/handlers/handlers_test.go | 394 ------------- internal/handlers/oauth.go | 223 ------- internal/handlers/proxy.go | 282 --------- internal/handlers/user.go | 197 ------- internal/hooks/hooks.go | 144 ----- internal/middleware/context_middleware.go | 159 +++++ internal/middleware/ui_middleware.go | 56 ++ internal/middleware/zerolog_middleware.go | 66 +++ internal/oauth/oauth.go | 71 --- internal/providers/generic.go | 37 -- internal/providers/github.go | 102 ---- internal/providers/google.go | 56 -- internal/providers/providers.go | 154 ----- internal/server/server.go | 130 ----- .../{auth/auth.go => service/auth_service.go} | 193 +++--- .../docker.go => service/docker_service.go} | 50 +- internal/service/generic_oauth_service.go | 117 ++++ internal/service/github_oauth_service.go | 169 ++++++ internal/service/google_oauth_service.go | 113 ++++ .../{ldap/ldap.go => service/ldap_service.go} | 64 +- internal/service/oauth_broker_service.go | 76 +++ internal/types/api.go | 62 -- internal/types/types.go | 59 -- internal/utils/app_utils.go | 123 ++++ internal/utils/fs_utils.go | 17 + internal/utils/label_utils.go | 48 ++ internal/utils/security_utils.go | 124 ++++ internal/utils/string_utils.go | 30 + internal/utils/user_utils.go | 92 +++ internal/utils/utils.go | 350 ----------- internal/utils/utils_test.go | 548 ------------------ main.go | 2 +- 58 files changed, 2737 insertions(+), 3539 deletions(-) delete mode 100644 internal/auth/auth_test.go create mode 100644 internal/bootstrap/app_bootstrap.go rename internal/{types => config}/config.go (56%) delete mode 100644 internal/constants/constants.go create mode 100644 internal/controller/context_controller.go create mode 100644 internal/controller/health_controller.go create mode 100644 internal/controller/oauth_controller.go create mode 100644 internal/controller/proxy_controller.go create mode 100644 internal/controller/resources_controller.go create mode 100644 internal/controller/user_controller.go delete mode 100644 internal/handlers/context.go delete mode 100644 internal/handlers/handlers.go delete mode 100644 internal/handlers/handlers_test.go delete mode 100644 internal/handlers/oauth.go delete mode 100644 internal/handlers/proxy.go delete mode 100644 internal/handlers/user.go delete mode 100644 internal/hooks/hooks.go create mode 100644 internal/middleware/context_middleware.go create mode 100644 internal/middleware/ui_middleware.go create mode 100644 internal/middleware/zerolog_middleware.go delete mode 100644 internal/oauth/oauth.go delete mode 100644 internal/providers/generic.go delete mode 100644 internal/providers/github.go delete mode 100644 internal/providers/google.go delete mode 100644 internal/providers/providers.go delete mode 100644 internal/server/server.go rename internal/{auth/auth.go => service/auth_service.go} (62%) rename internal/{docker/docker.go => service/docker_service.go} (64%) create mode 100644 internal/service/generic_oauth_service.go create mode 100644 internal/service/github_oauth_service.go create mode 100644 internal/service/google_oauth_service.go rename internal/{ldap/ldap.go => service/ldap_service.go} (61%) create mode 100644 internal/service/oauth_broker_service.go delete mode 100644 internal/types/api.go delete mode 100644 internal/types/types.go create mode 100644 internal/utils/app_utils.go create mode 100644 internal/utils/fs_utils.go create mode 100644 internal/utils/label_utils.go create mode 100644 internal/utils/security_utils.go create mode 100644 internal/utils/string_utils.go create mode 100644 internal/utils/user_utils.go delete mode 100644 internal/utils/utils.go delete mode 100644 internal/utils/utils_test.go diff --git a/.env.example b/.env.example index 8edde7b..0f43bf0 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,7 @@ SECRET_FILE=app_secret_file APP_URL=http://localhost:3000 USERS=your_user_password_hash USERS_FILE=users_file -COOKIE_SECURE=false +SECURE_COOKIE=false GITHUB_CLIENT_ID=github_client_id GITHUB_CLIENT_SECRET=github_client_secret GITHUB_CLIENT_SECRET_FILE=github_client_secret_file @@ -25,7 +25,7 @@ GENERIC_NAME=My OAuth SESSION_EXPIRY=7200 LOGIN_TIMEOUT=300 LOGIN_MAX_RETRIES=5 -LOG_LEVEL=0 +LOG_LEVEL=debug APP_TITLE=Tinyauth SSO FORGOT_PASSWORD_MESSAGE=Some message about resetting the password OAUTH_AUTO_REDIRECT=none diff --git a/.gitignore b/.gitignore index 0100a13..cb79b93 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,7 @@ secret* tmp # version files -internal/assets/version \ No newline at end of file +internal/assets/version + +# data directory +data \ No newline at end of file diff --git a/air.toml b/air.toml index 7505b79..f84163b 100644 --- a/air.toml +++ b/air.toml @@ -4,7 +4,7 @@ tmp_dir = "tmp" [build] pre_cmd = ["mkdir -p internal/assets/dist", "echo 'backend running' > internal/assets/dist/index.html", "go install github.com/go-delve/delve/cmd/dlv@v1.25.0"] cmd = "CGO_ENABLED=0 go build -gcflags=\"all=-N -l\" -o tmp/tinyauth ." -bin = "/go/bin/dlv --listen :4000 --headless=true --api-version=2 --accept-multiclient --log=true exec tmp/tinyauth --continue" +bin = "/go/bin/dlv --listen :4000 --headless=true --api-version=2 --accept-multiclient --log=true exec tmp/tinyauth --continue --check-go-version=false" include_ext = ["go"] exclude_dir = ["internal/assets/dist"] exclude_regex = [".*_test\\.go"] diff --git a/cmd/root.go b/cmd/root.go index f96ec6b..ef5733e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,20 +1,11 @@ package cmd import ( - "errors" - "fmt" "strings" totpCmd "tinyauth/cmd/totp" userCmd "tinyauth/cmd/user" - "tinyauth/internal/auth" - "tinyauth/internal/constants" - "tinyauth/internal/docker" - "tinyauth/internal/handlers" - "tinyauth/internal/hooks" - "tinyauth/internal/ldap" - "tinyauth/internal/providers" - "tinyauth/internal/server" - "tinyauth/internal/types" + "tinyauth/internal/bootstrap" + "tinyauth/internal/config" "tinyauth/internal/utils" "github.com/go-playground/validator/v10" @@ -29,147 +20,47 @@ var rootCmd = &cobra.Command{ Short: "The simplest way to protect your apps with a login screen.", Long: `Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.`, Run: func(cmd *cobra.Command, args []string) { - var config types.Config - err := viper.Unmarshal(&config) - HandleError(err, "Failed to parse config") + var conf config.Config + + err := viper.Unmarshal(&conf) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse config") + } // Check if secrets have a file associated with them - config.Secret = utils.GetSecret(config.Secret, config.SecretFile) - config.GithubClientSecret = utils.GetSecret(config.GithubClientSecret, config.GithubClientSecretFile) - config.GoogleClientSecret = utils.GetSecret(config.GoogleClientSecret, config.GoogleClientSecretFile) - config.GenericClientSecret = utils.GetSecret(config.GenericClientSecret, config.GenericClientSecretFile) + conf.Secret = utils.GetSecret(conf.Secret, conf.SecretFile) + conf.GithubClientSecret = utils.GetSecret(conf.GithubClientSecret, conf.GithubClientSecretFile) + conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) + conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) - validator := validator.New() - err = validator.Struct(config) - HandleError(err, "Failed to validate config") + // Validate config + v := validator.New() - log.Logger = log.Level(zerolog.Level(config.LogLevel)) - log.Info().Str("version", strings.TrimSpace(constants.Version)).Msg("Starting tinyauth") - - log.Info().Msg("Parsing users") - users, err := utils.GetUsers(config.Users, config.UsersFile) - HandleError(err, "Failed to parse users") - - log.Debug().Msg("Getting domain") - domain, err := utils.GetUpperDomain(config.AppURL) - HandleError(err, "Failed to get upper domain") - log.Info().Str("domain", domain).Msg("Using domain for cookie store") - - cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) - sessionCookieName := fmt.Sprintf("%s-%s", constants.SessionCookieName, cookieId) - csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId) - redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId) - - log.Debug().Msg("Deriving HMAC and encryption secrets") - - hmacSecret, err := utils.DeriveKey(config.Secret, "hmac") - HandleError(err, "Failed to derive HMAC secret") - - encryptionSecret, err := utils.DeriveKey(config.Secret, "encryption") - HandleError(err, "Failed to derive encryption secret") - - // Split the config into service-specific sub-configs - oauthConfig := types.OAuthConfig{ - GithubClientId: config.GithubClientId, - GithubClientSecret: config.GithubClientSecret, - GoogleClientId: config.GoogleClientId, - GoogleClientSecret: config.GoogleClientSecret, - GenericClientId: config.GenericClientId, - GenericClientSecret: config.GenericClientSecret, - GenericScopes: strings.Split(config.GenericScopes, ","), - GenericAuthURL: config.GenericAuthURL, - GenericTokenURL: config.GenericTokenURL, - GenericUserURL: config.GenericUserURL, - GenericSkipSSL: config.GenericSkipSSL, - AppURL: config.AppURL, + err = v.Struct(conf) + if err != nil { + log.Fatal().Err(err).Msg("Invalid config") } - handlersConfig := types.HandlersConfig{ - AppURL: config.AppURL, - DisableContinue: config.DisableContinue, - Title: config.Title, - GenericName: config.GenericName, - CookieSecure: config.CookieSecure, - Domain: domain, - ForgotPasswordMessage: config.FogotPasswordMessage, - BackgroundImage: config.BackgroundImage, - OAuthAutoRedirect: config.OAuthAutoRedirect, - CsrfCookieName: csrfCookieName, - RedirectCookieName: redirectCookieName, + log.Logger = log.Level(zerolog.Level(utils.GetLogLevel(conf.LogLevel))) + log.Info().Str("version", strings.TrimSpace(config.Version)).Msg("Starting tinyauth") + + // Create bootstrap app + app := bootstrap.NewBootstrapApp(conf) + + // Run + err = app.Setup() + + if err != nil { + log.Fatal().Err(err).Msg("Failed to setup app") } - serverConfig := types.ServerConfig{ - Port: config.Port, - Address: config.Address, - } - - authConfig := types.AuthConfig{ - Users: users, - OauthWhitelist: config.OAuthWhitelist, - CookieSecure: config.CookieSecure, - SessionExpiry: config.SessionExpiry, - Domain: domain, - LoginTimeout: config.LoginTimeout, - LoginMaxRetries: config.LoginMaxRetries, - SessionCookieName: sessionCookieName, - HMACSecret: hmacSecret, - EncryptionSecret: encryptionSecret, - } - - hooksConfig := types.HooksConfig{ - Domain: domain, - } - - var ldapService *ldap.LDAP - - if config.LdapAddress != "" { - log.Info().Msg("Using LDAP for authentication") - ldapConfig := types.LdapConfig{ - Address: config.LdapAddress, - BindDN: config.LdapBindDN, - BindPassword: config.LdapBindPassword, - BaseDN: config.LdapBaseDN, - Insecure: config.LdapInsecure, - SearchFilter: config.LdapSearchFilter, - } - ldapService, err = ldap.NewLDAP(ldapConfig) - if err != nil { - log.Error().Err(err).Msg("Failed to initialize LDAP service, disabling LDAP authentication") - ldapService = nil - } - } else { - log.Info().Msg("LDAP not configured, using local users or OAuth") - } - - // Check if we have a source of users - if len(users) == 0 && !utils.OAuthConfigured(config) && ldapService == nil { - HandleError(errors.New("err no users"), "Unable to find a source of users") - } - - // Setup the services - docker, err := docker.NewDocker() - HandleError(err, "Failed to initialize docker") - auth := auth.NewAuth(authConfig, docker, ldapService) - providers := providers.NewProviders(oauthConfig) - hooks := hooks.NewHooks(hooksConfig, auth, providers) - handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) - srv, err := server.NewServer(serverConfig, handlers) - HandleError(err, "Failed to create server") - - // Start up - err = srv.Start() - HandleError(err, "Failed to start server") }, } func Execute() { err := rootCmd.Execute() - HandleError(err, "Failed to execute root command") -} - -func HandleError(err error, msg string) { if err != nil { - log.Fatal().Err(err).Msg(msg) + log.Fatal().Err(err).Msg("Failed to execute command") } } @@ -179,85 +70,67 @@ func init() { viper.AutomaticEnv() - rootCmd.Flags().Int("port", 3000, "Port to run the server on.") - rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") - rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") - rootCmd.Flags().String("secret-file", "", "Path to a file containing the secret.") - rootCmd.Flags().String("app-url", "", "The tinyauth URL.") - rootCmd.Flags().String("users", "", "Comma separated list of users in the format username:hash.") - rootCmd.Flags().String("users-file", "", "Path to a file containing users in the format username:hash.") - rootCmd.Flags().Bool("cookie-secure", false, "Send cookie over secure connection only.") - rootCmd.Flags().String("github-client-id", "", "Github OAuth client ID.") - rootCmd.Flags().String("github-client-secret", "", "Github OAuth client secret.") - rootCmd.Flags().String("github-client-secret-file", "", "Github OAuth client secret file.") - rootCmd.Flags().String("google-client-id", "", "Google OAuth client ID.") - rootCmd.Flags().String("google-client-secret", "", "Google OAuth client secret.") - rootCmd.Flags().String("google-client-secret-file", "", "Google OAuth client secret file.") - rootCmd.Flags().String("generic-client-id", "", "Generic OAuth client ID.") - rootCmd.Flags().String("generic-client-secret", "", "Generic OAuth client secret.") - rootCmd.Flags().String("generic-client-secret-file", "", "Generic OAuth client secret file.") - rootCmd.Flags().String("generic-scopes", "", "Generic OAuth scopes.") - rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") - rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") - rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") - rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") - rootCmd.Flags().Bool("generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider.") - rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") - rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") - rootCmd.Flags().String("oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)") - rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") - rootCmd.Flags().Int("login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable).") - rootCmd.Flags().Int("login-max-retries", 5, "Maximum login attempts before timeout (0 to disable).") - rootCmd.Flags().Int("log-level", 1, "Log level.") - rootCmd.Flags().String("app-title", "Tinyauth", "Title of the app.") - rootCmd.Flags().String("forgot-password-message", "", "Message to show on the forgot password page.") - rootCmd.Flags().String("background-image", "/background.jpg", "Background image URL for the login page.") - rootCmd.Flags().String("ldap-address", "", "LDAP server address (e.g. ldap://localhost:389).") - rootCmd.Flags().String("ldap-bind-dn", "", "LDAP bind DN (e.g. uid=user,dc=example,dc=com).") - rootCmd.Flags().String("ldap-bind-password", "", "LDAP bind password.") - rootCmd.Flags().String("ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com).") - rootCmd.Flags().Bool("ldap-insecure", false, "Skip certificate verification for the LDAP server.") - rootCmd.Flags().String("ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup.") + configOptions := []struct { + name string + defaultVal any + description string + }{ + {"port", 3000, "Port to run the server on."}, + {"address", "0.0.0.0", "Address to bind the server to."}, + {"secret", "", "Secret to use for the cookie."}, + {"secret-file", "", "Path to a file containing the secret."}, + {"app-url", "", "The Tinyauth URL."}, + {"users", "", "Comma separated list of users in the format username:hash."}, + {"users-file", "", "Path to a file containing users in the format username:hash."}, + {"secure-cookie", false, "Send cookie over secure connection only."}, + {"github-client-id", "", "Github OAuth client ID."}, + {"github-client-secret", "", "Github OAuth client secret."}, + {"github-client-secret-file", "", "Github OAuth client secret file."}, + {"google-client-id", "", "Google OAuth client ID."}, + {"google-client-secret", "", "Google OAuth client secret."}, + {"google-client-secret-file", "", "Google OAuth client secret file."}, + {"generic-client-id", "", "Generic OAuth client ID."}, + {"generic-client-secret", "", "Generic OAuth client secret."}, + {"generic-client-secret-file", "", "Generic OAuth client secret file."}, + {"generic-scopes", "", "Generic OAuth scopes."}, + {"generic-auth-url", "", "Generic OAuth auth URL."}, + {"generic-token-url", "", "Generic OAuth token URL."}, + {"generic-user-url", "", "Generic OAuth user info URL."}, + {"generic-name", "Generic", "Generic OAuth provider name."}, + {"generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider."}, + {"disable-continue", false, "Disable continue screen and redirect to app directly."}, + {"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, + {"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, + {"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, + {"login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable)."}, + {"login-max-retries", 5, "Maximum login attempts before timeout (0 to disable)."}, + {"log-level", "info", "Log level."}, + {"app-title", "Tinyauth", "Title of the app."}, + {"forgot-password-message", "", "Message to show on the forgot password page."}, + {"background-image", "/background.jpg", "Background image URL for the login page."}, + {"ldap-address", "", "LDAP server address (e.g. ldap://localhost:389)."}, + {"ldap-bind-dn", "", "LDAP bind DN (e.g. uid=user,dc=example,dc=com)."}, + {"ldap-bind-password", "", "LDAP bind password."}, + {"ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com)."}, + {"ldap-insecure", false, "Skip certificate verification for the LDAP server."}, + {"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, + {"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."}, + } - viper.BindEnv("port", "PORT") - viper.BindEnv("address", "ADDRESS") - viper.BindEnv("secret", "SECRET") - viper.BindEnv("secret-file", "SECRET_FILE") - viper.BindEnv("app-url", "APP_URL") - viper.BindEnv("users", "USERS") - viper.BindEnv("users-file", "USERS_FILE") - viper.BindEnv("cookie-secure", "COOKIE_SECURE") - viper.BindEnv("github-client-id", "GITHUB_CLIENT_ID") - viper.BindEnv("github-client-secret", "GITHUB_CLIENT_SECRET") - viper.BindEnv("github-client-secret-file", "GITHUB_CLIENT_SECRET_FILE") - viper.BindEnv("google-client-id", "GOOGLE_CLIENT_ID") - viper.BindEnv("google-client-secret", "GOOGLE_CLIENT_SECRET") - viper.BindEnv("google-client-secret-file", "GOOGLE_CLIENT_SECRET_FILE") - viper.BindEnv("generic-client-id", "GENERIC_CLIENT_ID") - viper.BindEnv("generic-client-secret", "GENERIC_CLIENT_SECRET") - viper.BindEnv("generic-client-secret-file", "GENERIC_CLIENT_SECRET_FILE") - viper.BindEnv("generic-scopes", "GENERIC_SCOPES") - viper.BindEnv("generic-auth-url", "GENERIC_AUTH_URL") - viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") - viper.BindEnv("generic-user-url", "GENERIC_USER_URL") - viper.BindEnv("generic-name", "GENERIC_NAME") - viper.BindEnv("generic-skip-ssl", "GENERIC_SKIP_SSL") - viper.BindEnv("disable-continue", "DISABLE_CONTINUE") - viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") - viper.BindEnv("oauth-auto-redirect", "OAUTH_AUTO_REDIRECT") - viper.BindEnv("session-expiry", "SESSION_EXPIRY") - viper.BindEnv("log-level", "LOG_LEVEL") - viper.BindEnv("app-title", "APP_TITLE") - viper.BindEnv("login-timeout", "LOGIN_TIMEOUT") - viper.BindEnv("login-max-retries", "LOGIN_MAX_RETRIES") - viper.BindEnv("forgot-password-message", "FORGOT_PASSWORD_MESSAGE") - viper.BindEnv("background-image", "BACKGROUND_IMAGE") - viper.BindEnv("ldap-address", "LDAP_ADDRESS") - viper.BindEnv("ldap-bind-dn", "LDAP_BIND_DN") - viper.BindEnv("ldap-bind-password", "LDAP_BIND_PASSWORD") - viper.BindEnv("ldap-base-dn", "LDAP_BASE_DN") - viper.BindEnv("ldap-insecure", "LDAP_INSECURE") - viper.BindEnv("ldap-search-filter", "LDAP_SEARCH_FILTER") + for _, opt := range configOptions { + switch v := opt.defaultVal.(type) { + case bool: + rootCmd.Flags().Bool(opt.name, v, opt.description) + case int: + rootCmd.Flags().Int(opt.name, v, opt.description) + case string: + rootCmd.Flags().String(opt.name, v, opt.description) + } + + // Create uppercase env var name + envVar := strings.ReplaceAll(strings.ToUpper(opt.name), "-", "_") + viper.BindEnv(opt.name, envVar) + } viper.BindPFlags(rootCmd.Flags()) } diff --git a/cmd/version.go b/cmd/version.go index ffbd6fc..2a1827b 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -2,7 +2,7 @@ package cmd import ( "fmt" - "tinyauth/internal/constants" + "tinyauth/internal/config" "github.com/spf13/cobra" ) @@ -12,9 +12,9 @@ var versionCmd = &cobra.Command{ Short: "Print the version number of Tinyauth", Long: `All software has versions. This is Tinyauth's`, Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("Version: %s\n", constants.Version) - fmt.Printf("Commit Hash: %s\n", constants.CommitHash) - fmt.Printf("Build Timestamp: %s\n", constants.BuildTimestamp) + fmt.Printf("Version: %s\n", config.Version) + fmt.Printf("Commit Hash: %s\n", config.CommitHash) + fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp) }, } diff --git a/frontend/src/context/app-context.tsx b/frontend/src/context/app-context.tsx index 13abf50..8f76c11 100644 --- a/frontend/src/context/app-context.tsx +++ b/frontend/src/context/app-context.tsx @@ -15,7 +15,7 @@ export const AppContextProvider = ({ }) => { const { isFetching, data, error } = useSuspenseQuery({ queryKey: ["app"], - queryFn: () => axios.get("/api/app").then((res) => res.data), + queryFn: () => axios.get("/api/context/app").then((res) => res.data), }); if (error && !isFetching) { diff --git a/frontend/src/context/user-context.tsx b/frontend/src/context/user-context.tsx index 43b3c00..a3cfeaa 100644 --- a/frontend/src/context/user-context.tsx +++ b/frontend/src/context/user-context.tsx @@ -15,7 +15,7 @@ export const UserContextProvider = ({ }) => { const { isFetching, data, error } = useSuspenseQuery({ queryKey: ["user"], - queryFn: () => axios.get("/api/user").then((res) => res.data), + queryFn: () => axios.get("/api/context/user").then((res) => res.data), }); if (error && !isFetching) { diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index 4828b38..53f183f 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -65,7 +65,7 @@ export const LoginPage = () => { }); const loginMutation = useMutation({ - mutationFn: (values: LoginSchema) => axios.post("/api/login", values), + mutationFn: (values: LoginSchema) => axios.post("/api/user/login", values), mutationKey: ["login"], onSuccess: (data) => { if (data.data.totpPending) { diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 8c28500..30b2af8 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -26,7 +26,7 @@ export const LogoutPage = () => { const { t } = useTranslation(); const logoutMutation = useMutation({ - mutationFn: () => axios.post("/api/logout"), + mutationFn: () => axios.post("/api/user/logout"), mutationKey: ["logout"], onSuccess: () => { toast.success(t("logoutSuccessTitle"), { diff --git a/frontend/src/pages/totp-page.tsx b/frontend/src/pages/totp-page.tsx index e04fb2f..7d4ebad 100644 --- a/frontend/src/pages/totp-page.tsx +++ b/frontend/src/pages/totp-page.tsx @@ -32,7 +32,7 @@ export const TotpPage = () => { const redirectUri = searchParams.get("redirect_uri"); const totpMutation = useMutation({ - mutationFn: (values: TotpSchema) => axios.post("/api/totp", values), + mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationKey: ["totp"], onSuccess: () => { toast.success(t("totpSuccessTitle"), { diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 07e6e7e..f391a49 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -19,6 +19,11 @@ export default defineConfig({ changeOrigin: true, rewrite: (path) => path.replace(/^\/api/, ""), }, + "/resources": { + target: "http://tinyauth-backend:3000/resources", + changeOrigin: true, + rewrite: (path) => path.replace(/^\/resources/, ""), + }, }, allowedHosts: true, }, diff --git a/go.mod b/go.mod index 0a6f885..8388b2a 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator v9.31.0+incompatible github.com/goccy/go-json v0.10.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect diff --git a/go.sum b/go.sum index dabff47..b43990c 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= +github.com/go-playground/validator v9.31.0+incompatible/go.mod h1:yrEkQXlcI+PugkyDjY2bRrL/UBU4f3rvrgkN3V8JEig= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 6918867..df6e61f 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -4,7 +4,7 @@ import ( "embed" ) -// UI assets +// Frontend assets // //go:embed dist -var Assets embed.FS +var FrontendAssets embed.FS diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go deleted file mode 100644 index 1ab7329..0000000 --- a/internal/auth/auth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -package auth_test - -import ( - "testing" - "time" - "tinyauth/internal/auth" - "tinyauth/internal/types" -) - -var config = types.AuthConfig{ - Users: types.Users{}, - OauthWhitelist: "", - SessionExpiry: 3600, -} - -func TestLoginRateLimiting(t *testing.T) { - // Initialize a new auth service with 3 max retries and 5 seconds timeout - config.LoginMaxRetries = 3 - config.LoginTimeout = 5 - authService := auth.NewAuth(config, nil, nil) - - // Test identifier - identifier := "test_user" - - // Test successful login - should not lock account - t.Log("Testing successful login") - - authService.RecordLoginAttempt(identifier, true) - locked, _ := authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should not be locked after successful login") - } - - // Test 2 failed attempts - should not lock account yet - t.Log("Testing 2 failed login attempts") - - authService.RecordLoginAttempt(identifier, false) - authService.RecordLoginAttempt(identifier, false) - locked, _ = authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should not be locked after only 2 failed attempts") - } - - // Add one more failed attempt (total 3) - should lock account with maxRetries=3 - t.Log("Testing 3 failed login attempts") - authService.RecordLoginAttempt(identifier, false) - locked, remainingTime := authService.IsAccountLocked(identifier) - - if !locked { - t.Fatalf("Account should be locked after reaching max retries") - } - if remainingTime <= 0 || remainingTime > 5 { - t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime) - } - - // Test reset after waiting for timeout - use 1 second timeout for fast testing - t.Log("Testing unlocking after timeout") - - // Reinitialize auth service with a shorter timeout for testing - config.LoginTimeout = 1 - config.LoginMaxRetries = 3 - authService = auth.NewAuth(config, nil, nil) - - // Add enough failed attempts to lock the account - for i := 0; i < 3; i++ { - authService.RecordLoginAttempt(identifier, false) - } - - // Verify it's locked - locked, _ = authService.IsAccountLocked(identifier) - if !locked { - t.Fatalf("Account should be locked initially") - } - - // Wait a bit and verify it gets unlocked after timeout - time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout - locked, _ = authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should be unlocked after timeout period") - } - - // Test disabled rate limiting - t.Log("Testing disabled rate limiting") - config.LoginMaxRetries = 0 - config.LoginTimeout = 0 - authService = auth.NewAuth(config, nil, nil) - - for i := 0; i < 10; i++ { - authService.RecordLoginAttempt(identifier, false) - } - - locked, _ = authService.IsAccountLocked(identifier) - if locked { - t.Fatalf("Account should not be locked when rate limiting is disabled") - } -} - -func TestConcurrentLoginAttempts(t *testing.T) { - // Initialize a new auth service with 2 max retries and 5 seconds timeout - config.LoginMaxRetries = 2 - config.LoginTimeout = 5 - authService := auth.NewAuth(config, nil, nil) - - // Test multiple identifiers - identifiers := []string{"user1", "user2", "user3"} - - // Test that locking one identifier doesn't affect others - t.Log("Testing multiple identifiers") - - // Add enough failed attempts to lock first user (2 attempts with maxRetries=2) - authService.RecordLoginAttempt(identifiers[0], false) - authService.RecordLoginAttempt(identifiers[0], false) - - // Check if first user is locked - locked, _ := authService.IsAccountLocked(identifiers[0]) - if !locked { - t.Fatalf("User1 should be locked after reaching max retries") - } - - // Check that other users are not affected - for i := 1; i < len(identifiers); i++ { - locked, _ := authService.IsAccountLocked(identifiers[i]) - if locked { - t.Fatalf("User%d should not be locked", i+1) - } - } - - // Test successful login after failed attempts (but before lock) - t.Log("Testing successful login after failed attempts but before lock") - - // One failed attempt for user2 - authService.RecordLoginAttempt(identifiers[1], false) - - // Successful login should reset the counter - authService.RecordLoginAttempt(identifiers[1], true) - - // Now try a failed login again - should not be locked as counter was reset - authService.RecordLoginAttempt(identifiers[1], false) - locked, _ = authService.IsAccountLocked(identifiers[1]) - if locked { - t.Fatalf("User2 should not be locked after successful login reset") - } -} diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go new file mode 100644 index 0000000..594c575 --- /dev/null +++ b/internal/bootstrap/app_bootstrap.go @@ -0,0 +1,260 @@ +package bootstrap + +import ( + "fmt" + "strings" + "tinyauth/internal/config" + "tinyauth/internal/controller" + "tinyauth/internal/middleware" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +type Controller interface { + SetupRoutes() +} + +type Middleware interface { + Middleware() gin.HandlerFunc + Init() error +} + +type Service interface { + Init() error +} + +type BootstrapApp struct { + Config config.Config +} + +func NewBootstrapApp(config config.Config) *BootstrapApp { + return &BootstrapApp{ + Config: config, + } +} + +func (app *BootstrapApp) Setup() error { + // Parse users + users, err := utils.GetUsers(app.Config.Users, app.Config.UsersFile) + + if err != nil { + return err + } + + // Get domain + domain, err := utils.GetUpperDomain(app.Config.AppURL) + + if err != nil { + return err + } + + // Cookie names + cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) + sessionCookieName := fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) + csrfCookieName := fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) + redirectCookieName := fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) + + // Secrets + encryptionSecret, err := utils.DeriveKey(app.Config.Secret, "encryption") + + if err != nil { + return err + } + + hmacSecret, err := utils.DeriveKey(app.Config.Secret, "hmac") + + if err != nil { + return err + } + + // Create configs + authConfig := service.AuthServiceConfig{ + Users: users, + OauthWhitelist: app.Config.OAuthWhitelist, + SessionExpiry: app.Config.SessionExpiry, + SecureCookie: app.Config.SecureCookie, + Domain: domain, + LoginTimeout: app.Config.LoginTimeout, + LoginMaxRetries: app.Config.LoginMaxRetries, + SessionCookieName: sessionCookieName, + HMACSecret: hmacSecret, + EncryptionSecret: encryptionSecret, + } + + // Setup services + var ldapService *service.LdapService + + if app.Config.LdapAddress != "" { + ldapConfig := service.LdapServiceConfig{ + Address: app.Config.LdapAddress, + BindDN: app.Config.LdapBindDN, + BindPassword: app.Config.LdapBindPassword, + BaseDN: app.Config.LdapBaseDN, + Insecure: app.Config.LdapInsecure, + SearchFilter: app.Config.LdapSearchFilter, + } + + ldapService = service.NewLdapService(ldapConfig) + + err := ldapService.Init() + + if err != nil { + log.Warn().Err(err).Msg("Failed to initialize LDAP service, continuing without LDAP") + ldapService = nil + } + } + + dockerService := service.NewDockerService() + authService := service.NewAuthService(authConfig, dockerService, ldapService) + oauthBrokerService := service.NewOAuthBrokerService(app.getOAuthBrokerConfig()) + + // Initialize services + services := []Service{ + dockerService, + authService, + oauthBrokerService, + } + + for _, svc := range services { + if svc != nil { + log.Debug().Str("service", fmt.Sprintf("%T", svc)).Msg("Initializing service") + err := svc.Init() + if err != nil { + return err + } + } + } + + // Configured providers + var configuredProviders []string + + if authService.UserAuthConfigured() || ldapService != nil { + configuredProviders = append(configuredProviders, "username") + } + + configuredProviders = append(configuredProviders, oauthBrokerService.GetConfiguredServices()...) + + if len(configuredProviders) == 0 { + return fmt.Errorf("no authentication providers configured") + } + + // Create engine + engine := gin.New() + + if config.Version != "development" { + gin.SetMode(gin.ReleaseMode) + } + + // Create middlewares + var middlewares []Middleware + + contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ + Domain: domain, + }, authService, oauthBrokerService) + + uiMiddleware := middleware.NewUIMiddleware() + zerologMiddleware := middleware.NewZerologMiddleware() + + middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) + + for _, middleware := range middlewares { + log.Debug().Str("middleware", fmt.Sprintf("%T", middleware)).Msg("Initializing middleware") + err := middleware.Init() + if err != nil { + return fmt.Errorf("failed to initialize middleware %T: %w", middleware, err) + } + engine.Use(middleware.Middleware()) + } + + // Create routers + mainRouter := engine.Group("") + apiRouter := engine.Group("/api") + + // Create controllers + contextController := controller.NewContextController(controller.ContextControllerConfig{ + ConfiguredProviders: configuredProviders, + DisableContinue: app.Config.DisableContinue, + Title: app.Config.Title, + GenericName: app.Config.GenericName, + Domain: domain, + ForgotPasswordMessage: app.Config.FogotPasswordMessage, + BackgroundImage: app.Config.BackgroundImage, + OAuthAutoRedirect: app.Config.OAuthAutoRedirect, + }, apiRouter) + + oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ + AppURL: app.Config.AppURL, + SecureCookie: app.Config.SecureCookie, + CSRFCookieName: csrfCookieName, + RedirectCookieName: redirectCookieName, + Domain: domain, + }, apiRouter, authService, oauthBrokerService) + + proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ + AppURL: app.Config.AppURL, + }, apiRouter, dockerService, authService) + + userController := controller.NewUserController(controller.UserControllerConfig{ + Domain: domain, + }, apiRouter, authService) + + resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ + ResourcesDir: app.Config.ResourcesDir, + }, mainRouter) + + healthController := controller.NewHealthController(apiRouter) + + // Setup routes + controller := []Controller{ + contextController, + oauthController, + proxyController, + userController, + healthController, + resourcesController, + } + + for _, ctrl := range controller { + log.Debug().Msgf("Setting up %T controller", ctrl) + ctrl.SetupRoutes() + } + + // Start server + address := fmt.Sprintf("%s:%d", app.Config.Address, app.Config.Port) + log.Info().Msgf("Starting server on %s", address) + if err := engine.Run(address); err != nil { + log.Fatal().Err(err).Msg("Failed to start server") + } + + return nil +} + +// Temporary +func (app *BootstrapApp) getOAuthBrokerConfig() map[string]config.OAuthServiceConfig { + return map[string]config.OAuthServiceConfig{ + "google": { + ClientID: app.Config.GoogleClientId, + ClientSecret: app.Config.GoogleClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", app.Config.AppURL), + }, + "github": { + ClientID: app.Config.GithubClientId, + ClientSecret: app.Config.GithubClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", app.Config.AppURL), + }, + "generic": { + ClientID: app.Config.GenericClientId, + ClientSecret: app.Config.GenericClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", app.Config.AppURL), + Scopes: strings.Split(app.Config.GenericScopes, ","), + AuthURL: app.Config.GenericAuthURL, + TokenURL: app.Config.GenericTokenURL, + UserinfoURL: app.Config.GenericUserURL, + InsecureSkipVerify: app.Config.GenericSkipSSL, + }, + } + +} diff --git a/internal/types/config.go b/internal/config/config.go similarity index 56% rename from internal/types/config.go rename to internal/config/config.go index b53e053..5d4dba8 100644 --- a/internal/types/config.go +++ b/internal/config/config.go @@ -1,6 +1,20 @@ -package types +package config + +type Claims struct { + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` +} + +var Version = "development" +var CommitHash = "n/a" +var BuildTimestamp = "n/a" + +var SessionCookieName = "tinyauth-session" +var CSRFCookieName = "tinyauth-csrf" +var RedirectCookieName = "tinyauth-redirect" -// Config is the configuration for the tinyauth server type Config struct { Port int `mapstructure:"port" validate:"required"` Address string `validate:"required,ip4_addr" mapstructure:"address"` @@ -9,7 +23,7 @@ type Config struct { AppURL string `validate:"required,url" mapstructure:"app-url"` Users string `mapstructure:"users"` UsersFile string `mapstructure:"users-file"` - CookieSecure bool `mapstructure:"cookie-secure"` + SecureCookie bool `mapstructure:"secure-cookie"` GithubClientId string `mapstructure:"github-client-id"` GithubClientSecret string `mapstructure:"github-client-secret"` GithubClientSecretFile string `mapstructure:"github-client-secret-file"` @@ -29,9 +43,8 @@ type Config struct { OAuthWhitelist string `mapstructure:"oauth-whitelist"` OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` SessionExpiry int `mapstructure:"session-expiry"` - LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"` + LogLevel string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` Title string `mapstructure:"app-title"` - EnvFile string `mapstructure:"env-file"` LoginTimeout int `mapstructure:"login-timeout"` LoginMaxRetries int `mapstructure:"login-max-retries"` FogotPasswordMessage string `mapstructure:"forgot-password-message"` @@ -42,90 +55,30 @@ type Config struct { LdapBaseDN string `mapstructure:"ldap-base-dn"` LdapInsecure bool `mapstructure:"ldap-insecure"` LdapSearchFilter string `mapstructure:"ldap-search-filter"` + ResourcesDir string `mapstructure:"resources-dir"` } -// Server configuration -type HandlersConfig struct { - AppURL string - Domain string - CookieSecure bool - DisableContinue bool - GenericName string - Title string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - CsrfCookieName string - RedirectCookieName string -} - -// OAuthConfig is the configuration for the providers -type OAuthConfig struct { - GithubClientId string - GithubClientSecret string - GoogleClientId string - GoogleClientSecret string - GenericClientId string - GenericClientSecret string - GenericScopes []string - GenericAuthURL string - GenericTokenURL string - GenericUserURL string - GenericSkipSSL bool - AppURL string -} - -// ServerConfig is the configuration for the server -type ServerConfig struct { - Port int - Address string -} - -// AuthConfig is the configuration for the auth service -type AuthConfig struct { - Users Users - OauthWhitelist string - SessionExpiry int - CookieSecure bool - Domain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - HMACSecret string - EncryptionSecret string -} - -// HooksConfig is the configuration for the hooks service -type HooksConfig struct { - Domain string -} - -// OAuthLabels is a list of labels that can be used in a tinyauth protected container type OAuthLabels struct { Whitelist string Groups string } -// Basic auth labels for a tinyauth protected container type BasicLabels struct { Username string - Password PassowrdLabels + Password PasswordLabels } -// PassowrdLabels is a struct that contains the password labels for a tinyauth protected container -type PassowrdLabels struct { +type PasswordLabels struct { Plain string File string } -// IP labels for a tinyauth protected container type IPLabels struct { Allow []string Block []string Bypass []string } -// Labels is a struct that contains the labels for a tinyauth protected container type Labels struct { Users string Allowed string @@ -136,12 +89,56 @@ type Labels struct { IP IPLabels } -// Ldap config is a struct that contains the configuration for the LDAP service -type LdapConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string +type OAuthServiceConfig struct { + ClientID string + ClientSecret string + Scopes []string + RedirectURL string + AuthURL string + TokenURL string + UserinfoURL string + InsecureSkipVerify bool +} + +type User struct { + Username string + Password string + TotpSecret string +} + +type UserSearch struct { + Username string + Type string // local, ldap or unknown +} + +type SessionCookie struct { + Username string + Name string + Email string + Provider string + TotpPending bool + OAuthGroups string +} + +type UserContext struct { + Username string + Name string + Email string + IsLoggedIn bool + OAuth bool + Provider string + TotpPending bool + OAuthGroups string + TotpEnabled bool +} + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` } diff --git a/internal/constants/constants.go b/internal/constants/constants.go deleted file mode 100644 index d6f64fa..0000000 --- a/internal/constants/constants.go +++ /dev/null @@ -1,19 +0,0 @@ -package constants - -// Claims are the OIDC supported claims (prefered username is included for convinience) -type Claims struct { - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Groups any `json:"groups"` -} - -// Version information -var Version = "development" -var CommitHash = "n/a" -var BuildTimestamp = "n/a" - -// Base cookie names -var SessionCookieName = "tinyauth-session" -var CsrfCookieName = "tinyauth-csrf" -var RedirectCookieName = "tinyauth-redirect" diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go new file mode 100644 index 0000000..c7570f0 --- /dev/null +++ b/internal/controller/context_controller.go @@ -0,0 +1,104 @@ +package controller + +import ( + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +type UserContextResponse struct { + Status int `json:"status"` + Message string `json:"message"` + IsLoggedIn bool `json:"isLoggedIn"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` + Provider string `json:"provider"` + Oauth bool `json:"oauth"` + TotpPending bool `json:"totpPending"` +} + +type AppContextResponse struct { + Status int `json:"status"` + Message string `json:"message"` + ConfiguredProviders []string `json:"configuredProviders"` + DisableContinue bool `json:"disableContinue"` + Title string `json:"title"` + GenericName string `json:"genericName"` + Domain string `json:"domain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` +} + +type ContextControllerConfig struct { + ConfiguredProviders []string + DisableContinue bool + Title string + GenericName string + Domain string + ForgotPasswordMessage string + BackgroundImage string + OAuthAutoRedirect string +} + +type ContextController struct { + Config ContextControllerConfig + Router *gin.RouterGroup +} + +func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { + return &ContextController{ + Config: config, + Router: router, + } +} + +func (controller *ContextController) SetupRoutes() { + contextGroup := controller.Router.Group("/context") + contextGroup.GET("/user", controller.userContextHandler) + contextGroup.GET("/app", controller.appContextHandler) +} + +func (controller *ContextController) userContextHandler(c *gin.Context) { + context, err := utils.GetContext(c) + + userContext := UserContextResponse{ + Status: 200, + Message: "Success", + IsLoggedIn: context.IsLoggedIn, + Username: context.Username, + Name: context.Name, + Email: context.Email, + Provider: context.Provider, + Oauth: context.OAuth, + TotpPending: context.TotpPending, + } + + if err != nil { + log.Debug().Err(err).Msg("No user context found in request") + userContext.Status = 401 + userContext.Message = "Unauthorized" + userContext.IsLoggedIn = false + c.JSON(200, userContext) + return + } + + c.JSON(200, userContext) +} + +func (controller *ContextController) appContextHandler(c *gin.Context) { + c.JSON(200, AppContextResponse{ + Status: 200, + Message: "Success", + ConfiguredProviders: controller.Config.ConfiguredProviders, + DisableContinue: controller.Config.DisableContinue, + Title: controller.Config.Title, + GenericName: controller.Config.GenericName, + Domain: controller.Config.Domain, + ForgotPasswordMessage: controller.Config.ForgotPasswordMessage, + BackgroundImage: controller.Config.BackgroundImage, + OAuthAutoRedirect: controller.Config.OAuthAutoRedirect, + }) +} diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go new file mode 100644 index 0000000..842b3d3 --- /dev/null +++ b/internal/controller/health_controller.go @@ -0,0 +1,25 @@ +package controller + +import "github.com/gin-gonic/gin" + +type HealthController struct { + Router *gin.RouterGroup +} + +func NewHealthController(router *gin.RouterGroup) *HealthController { + return &HealthController{ + Router: router, + } +} + +func (controller *HealthController) SetupRoutes() { + controller.Router.GET("/health", controller.healthHandler) + controller.Router.HEAD("/health", controller.healthHandler) +} + +func (controller *HealthController) healthHandler(c *gin.Context) { + c.JSON(200, gin.H{ + "status": "ok", + "message": "Healthy", + }) +} diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go new file mode 100644 index 0000000..aa3289b --- /dev/null +++ b/internal/controller/oauth_controller.go @@ -0,0 +1,200 @@ +package controller + +import ( + "fmt" + "net/http" + "strings" + "time" + "tinyauth/internal/config" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/rs/zerolog/log" +) + +type OAuthRequest struct { + Provider string `uri:"provider" binding:"required"` +} + +type OAuthControllerConfig struct { + CSRFCookieName string + RedirectCookieName string + SecureCookie bool + AppURL string + Domain string +} + +type OAuthController struct { + Config OAuthControllerConfig + Router *gin.RouterGroup + Auth *service.AuthService + Broker *service.OAuthBrokerService +} + +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { + return &OAuthController{ + Config: config, + Router: router, + Auth: auth, + Broker: broker, + } +} + +func (controller *OAuthController) SetupRoutes() { + oauthGroup := controller.Router.Group("/oauth") + oauthGroup.GET("/url/:provider", controller.oauthURLHandler) + oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) +} + +func (controller *OAuthController) oauthURLHandler(c *gin.Context) { + var req OAuthRequest + + err := c.BindUri(&req) + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + service, exists := controller.Broker.GetService(req.Provider) + + if !exists { + log.Warn().Msgf("OAuth provider not found: %s", req.Provider) + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + state := service.GenerateState() + authURL := service.GetAuthURL(state) + c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) + + redirectURI := c.Query("redirect_uri") + + if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { + log.Debug().Msg("Setting redirect URI cookie") + c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + "url": authURL, + }) +} + +func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { + var req OAuthRequest + + err := c.BindUri(&req) + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + state := c.Query("state") + csrfCookie, err := c.Cookie(controller.Config.CSRFCookieName) + + if err != nil || state != csrfCookie { + log.Warn().Err(err).Msg("CSRF token mismatch or cookie missing") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) + + code := c.Query("code") + service, exists := controller.Broker.GetService(req.Provider) + + if !exists { + log.Warn().Msgf("OAuth provider not found: %s", req.Provider) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + err = service.VerifyCode(code) + if err != nil { + log.Error().Err(err).Msg("Failed to verify OAuth code") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + user, err := controller.Broker.GetUser(req.Provider) + + if err != nil { + log.Error().Err(err).Msg("Failed to get user from OAuth provider") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if user.Email == "" { + log.Error().Msg("OAuth provider did not return an email") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if !controller.Auth.EmailWhitelisted(user.Email) { + queries, err := query.Values(config.UnauthorizedQuery{ + Username: user.Email, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + var name string + + if user.Name != "" { + log.Debug().Msg("Using name from OAuth provider") + name = user.Name + } else { + log.Debug().Msg("No name from OAuth provider, using pseudo name") + name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) + } + + controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + Username: user.Email, + Name: name, + Email: user.Email, + Provider: req.Provider, + OAuthGroups: utils.CoalesceToString(user.Groups), + }) + + redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) + + if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.Domain) { + log.Debug().Msg("No redirect URI cookie found, redirecting to app root") + c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) + return + } + + queries, err := query.Values(config.RedirectQuery{ + RedirectURI: redirectURI, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to encode redirect URI query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.Domain), controller.Config.SecureCookie, true) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) +} diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go new file mode 100644 index 0000000..348be65 --- /dev/null +++ b/internal/controller/proxy_controller.go @@ -0,0 +1,311 @@ +package controller + +import ( + "fmt" + "net/http" + "strings" + "tinyauth/internal/config" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/rs/zerolog/log" +) + +type Proxy struct { + Proxy string `uri:"proxy" binding:"required"` +} + +type ProxyControllerConfig struct { + AppURL string +} + +type ProxyController struct { + Config ProxyControllerConfig + Router *gin.RouterGroup + Docker *service.DockerService + Auth *service.AuthService +} + +func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController { + return &ProxyController{ + Config: config, + Router: router, + Docker: docker, + Auth: auth, + } +} + +func (controller *ProxyController) SetupRoutes() { + proxyGroup := controller.Router.Group("/auth") + proxyGroup.GET("/:proxy", controller.proxyHandler) +} + +func (controller *ProxyController) proxyHandler(c *gin.Context) { + var req Proxy + + err := c.BindUri(&req) + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + + if isBrowser { + log.Debug().Msg("Request identified as (most likely) coming from a browser") + } else { + log.Debug().Msg("Request identified as (most likely) coming from a non-browser client") + } + + uri := c.Request.Header.Get("X-Forwarded-Uri") + proto := c.Request.Header.Get("X-Forwarded-Proto") + host := c.Request.Header.Get("X-Forwarded-Host") + + hostWithoutPort := strings.Split(host, ":")[0] + id := strings.Split(hostWithoutPort, ".")[0] + + labels, err := controller.Docker.GetLabels(id, hostWithoutPort) + + if err != nil { + log.Error().Err(err).Msg("Failed to get labels from Docker") + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + clientIP := c.ClientIP() + + if controller.Auth.BypassedIP(labels, clientIP) { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + if !controller.Auth.CheckIP(labels, clientIP) { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(config.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + IP: clientIP, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + authEnabled, err := controller.Auth.AuthEnabled(uri, labels) + + if err != nil { + log.Error().Err(err).Msg("Failed to check if auth is enabled for resource") + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if !authEnabled { + log.Debug().Msg("Authentication disabled for resource, allowing access") + + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + var userContext config.UserContext + + context, err := utils.GetContext(c) + + if err != nil { + log.Debug().Msg("No user context found in request, treating as not logged in") + userContext = config.UserContext{ + IsLoggedIn: false, + } + } else { + userContext = context + } + + if userContext.Provider == "basic" && userContext.TotpEnabled { + log.Debug().Msg("User has TOTP enabled, denying basic auth access") + userContext.IsLoggedIn = false + } + + if userContext.IsLoggedIn { + appAllowed := controller.Auth.ResourceAllowed(c, userContext, labels) + + if !appAllowed { + log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User not allowed to access resource") + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + queries, err := query.Values(config.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + }) + + if userContext.OAuth { + queries.Set("username", userContext.Email) + } else { + queries.Set("username", userContext.Username) + } + + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + if userContext.OAuth { + groupOK := controller.Auth.OAuthGroup(c, userContext, labels) + + if !groupOK { + log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements") + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + queries, err := query.Values(config.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + GroupErr: true, + }) + + if userContext.OAuth { + queries.Set("username", userContext.Email) + } else { + queries.Set("username", userContext.Username) + } + + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + } + + c.Header("Authorization", c.Request.Header.Get("Authorization")) + c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) + c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) + c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) + c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(config.RedirectQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to encode redirect URI query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) +} diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go new file mode 100644 index 0000000..56bae87 --- /dev/null +++ b/internal/controller/resources_controller.go @@ -0,0 +1,42 @@ +package controller + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +type ResourcesControllerConfig struct { + ResourcesDir string +} + +type ResourcesController struct { + Config ResourcesControllerConfig + Router *gin.RouterGroup + FileServer http.Handler +} + +func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir))) + + return &ResourcesController{ + Config: config, + Router: router, + FileServer: fileServer, + } +} + +func (controller *ResourcesController) SetupRoutes() { + controller.Router.GET("/resources/*resource", controller.resourcesHandler) +} + +func (controller *ResourcesController) resourcesHandler(c *gin.Context) { + if controller.Config.ResourcesDir == "" { + c.JSON(404, gin.H{ + "status": 404, + "message": "Resources not found", + }) + return + } + controller.FileServer.ServeHTTP(c.Writer, c.Request) +} diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go new file mode 100644 index 0000000..f7f7c9e --- /dev/null +++ b/internal/controller/user_controller.go @@ -0,0 +1,266 @@ +package controller + +import ( + "fmt" + "strings" + "tinyauth/internal/config" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" + "github.com/rs/zerolog/log" +) + +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type TotpRequest struct { + Code string `json:"code"` +} + +type UserControllerConfig struct { + Domain string +} + +type UserController struct { + Config UserControllerConfig + Router *gin.RouterGroup + Auth *service.AuthService +} + +func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { + return &UserController{ + Config: config, + Router: router, + Auth: auth, + } +} + +func (controller *UserController) SetupRoutes() { + userGroup := controller.Router.Group("/user") + userGroup.POST("/login", controller.loginHandler) + userGroup.POST("/logout", controller.logoutHandler) + userGroup.POST("/totp", controller.totpHandler) +} + +func (controller *UserController) loginHandler(c *gin.Context) { + var req LoginRequest + + err := c.ShouldBindJSON(&req) + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + clientIP := c.ClientIP() + + rateIdentifier := req.Username + + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + log.Debug().Str("username", req.Username).Str("ip", clientIP).Msg("Login attempt") + + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + + if isLocked { + log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed login attempts") + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + userSearch := controller.Auth.SearchUser(req.Username) + + if userSearch.Type == "" { + log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("User not found") + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + if !controller.Auth.VerifyUser(userSearch, req.Password) { + log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("Invalid password") + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Info().Str("username", req.Username).Str("ip", clientIP).Msg("Login successful") + + controller.Auth.RecordLoginAttempt(rateIdentifier, true) + + if userSearch.Type == "local" { + user := controller.Auth.GetLocalUser(userSearch.Username) + + if user.TotpSecret != "" { + log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") + + err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(req.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), + Provider: "username", + TotpPending: true, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "TOTP required", + "totpPending": true, + }) + return + } + } + + err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + Username: req.Username, + Name: utils.Capitalize(req.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), + Provider: "username", + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Login successful", + }) +} + +func (controller *UserController) logoutHandler(c *gin.Context) { + log.Debug().Msg("Logout request received") + + controller.Auth.DeleteSessionCookie(c) + + c.JSON(200, gin.H{ + "status": 200, + "message": "Logout successful", + }) +} + +func (controller *UserController) totpHandler(c *gin.Context) { + var req TotpRequest + + err := c.ShouldBindJSON(&req) + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + context, err := utils.GetContext(c) + + if err != nil { + log.Error().Err(err).Msg("Failed to get user context") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + if !context.TotpPending { + log.Warn().Msg("TOTP attempt without a pending TOTP session") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + clientIP := c.ClientIP() + + rateIdentifier := context.Username + + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + log.Debug().Str("username", context.Username).Str("ip", clientIP).Msg("TOTP verification attempt") + + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + + if isLocked { + log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed TOTP attempts") + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + user := controller.Auth.GetLocalUser(context.Username) + + ok := totp.Validate(req.Code, user.TotpSecret) + + if !ok { + log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Invalid TOTP code") + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Info().Str("username", context.Username).Str("ip", clientIP).Msg("TOTP verification successful") + + controller.Auth.RecordLoginAttempt(rateIdentifier, true) + + err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), + Provider: "username", + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Login successful", + }) +} diff --git a/internal/handlers/context.go b/internal/handlers/context.go deleted file mode 100644 index d0fff5e..0000000 --- a/internal/handlers/context.go +++ /dev/null @@ -1,64 +0,0 @@ -package handlers - -import ( - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) AppContextHandler(c *gin.Context) { - log.Debug().Msg("Getting app context") - - // Get configured providers - configuredProviders := h.Providers.GetConfiguredProviders() - - // We have username/password configured so add it to our providers - if h.Auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Return app context - appContext := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: configuredProviders, - DisableContinue: h.Config.DisableContinue, - Title: h.Config.Title, - GenericName: h.Config.GenericName, - Domain: h.Config.Domain, - ForgotPasswordMessage: h.Config.ForgotPasswordMessage, - BackgroundImage: h.Config.BackgroundImage, - OAuthAutoRedirect: h.Config.OAuthAutoRedirect, - } - c.JSON(200, appContext) -} - -func (h *Handlers) UserContextHandler(c *gin.Context) { - log.Debug().Msg("Getting user context") - - // Create user context using hooks - userContext := h.Hooks.UseUserContext(c) - - userContextResponse := types.UserContextResponse{ - Status: 200, - IsLoggedIn: userContext.IsLoggedIn, - Username: userContext.Username, - Name: userContext.Name, - Email: userContext.Email, - Provider: userContext.Provider, - Oauth: userContext.OAuth, - TotpPending: userContext.TotpPending, - } - - // If we are not logged in we set the status to 401 else we set it to 200 - if !userContext.IsLoggedIn { - log.Debug().Msg("Unauthorized") - userContextResponse.Message = "Unauthorized" - } else { - log.Debug().Interface("userContext", userContext).Msg("Authenticated") - userContextResponse.Message = "Authenticated" - } - - c.JSON(200, userContextResponse) -} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go deleted file mode 100644 index 0e8ebe2..0000000 --- a/internal/handlers/handlers.go +++ /dev/null @@ -1,36 +0,0 @@ -package handlers - -import ( - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/hooks" - "tinyauth/internal/providers" - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" -) - -type Handlers struct { - Config types.HandlersConfig - Auth *auth.Auth - Hooks *hooks.Hooks - Providers *providers.Providers - Docker *docker.Docker -} - -func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers, docker *docker.Docker) *Handlers { - return &Handlers{ - Config: config, - Auth: auth, - Hooks: hooks, - Providers: providers, - Docker: docker, - } -} - -func (h *Handlers) HealthcheckHandler(c *gin.Context) { - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - }) -} diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go deleted file mode 100644 index 279534d..0000000 --- a/internal/handlers/handlers_test.go +++ /dev/null @@ -1,394 +0,0 @@ -package handlers_test - -import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "reflect" - "strings" - "testing" - "time" - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/handlers" - "tinyauth/internal/hooks" - "tinyauth/internal/providers" - "tinyauth/internal/server" - "tinyauth/internal/types" - - "github.com/magiconair/properties/assert" - "github.com/pquerna/otp/totp" -) - -// Simple server config -var serverConfig = types.ServerConfig{ - Port: 8080, - Address: "0.0.0.0", -} - -// Simple handlers config -var handlersConfig = types.HandlersConfig{ - AppURL: "http://localhost:8080", - Domain: "localhost", - DisableContinue: false, - CookieSecure: false, - Title: "Tinyauth", - GenericName: "Generic", - ForgotPasswordMessage: "Message", - CsrfCookieName: "tinyauth-csrf", - RedirectCookieName: "tinyauth-redirect", - BackgroundImage: "https://example.com/image.png", - OAuthAutoRedirect: "none", -} - -// Simple auth config -var authConfig = types.AuthConfig{ - Users: types.Users{}, - OauthWhitelist: "", - HMACSecret: "4bZ9K.*:;zH=,9zG!meUxu.B5-S[7.V.", // Complex on purpose - EncryptionSecret: "\\:!R(u[Sbv6ZLm.7es)H|OqH4y}0u\\rj", - CookieSecure: false, - SessionExpiry: 3600, - LoginTimeout: 0, - LoginMaxRetries: 0, - SessionCookieName: "tinyauth-session", - Domain: "localhost", -} - -// Simple hooks config -var hooksConfig = types.HooksConfig{ - Domain: "localhost", -} - -// Cookie -var cookie string - -// User -var user = types.User{ - Username: "user", - Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass -} - -// Initialize the server for tests -func getServer(t *testing.T) *server.Server { - // Create services - authConfig.Users = types.Users{ - { - Username: user.Username, - Password: user.Password, - TotpSecret: user.TotpSecret, - }, - } - docker, err := docker.NewDocker() - if err != nil { - t.Fatalf("Failed to create docker client: %v", err) - } - auth := auth.NewAuth(authConfig, nil, nil) - providers := providers.NewProviders(types.OAuthConfig{}) - hooks := hooks.NewHooks(hooksConfig, auth, providers) - handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) - - // Create server - srv, err := server.NewServer(serverConfig, handlers) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - return srv -} - -func TestLogin(t *testing.T) { - t.Log("Testing login") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - user := types.LoginRequest{ - Username: "user", - Password: "pass", - } - - json, err := json.Marshal(user) - if err != nil { - t.Fatalf("Error marshalling json: %v", err) - } - - req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(json))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - cookies := recorder.Result().Cookies() - - if len(cookies) == 0 { - t.Fatalf("Cookie not set") - } - - // Set the cookie for further tests - cookie = cookies[0].Value -} - -func TestAppContext(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing app context") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/app", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - // Set the cookie from the previous test - req.AddCookie(&http.Cookie{ - Name: "tinyauth", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - body, err := io.ReadAll(recorder.Body) - if err != nil { - t.Fatalf("Error getting body: %v", err) - } - - var app types.AppContext - - err = json.Unmarshal(body, &app) - if err != nil { - t.Fatalf("Error unmarshalling body: %v", err) - } - - expected := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: []string{"username"}, - DisableContinue: false, - Title: "Tinyauth", - GenericName: "Generic", - ForgotPasswordMessage: "Message", - BackgroundImage: "https://example.com/image.png", - OAuthAutoRedirect: "none", - Domain: "localhost", - } - - // We should get the username back - if !reflect.DeepEqual(app, expected) { - t.Fatalf("Expected %v, got %v", expected, app) - } -} - -func TestUserContext(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing user context") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/user", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - body, err := io.ReadAll(recorder.Body) - if err != nil { - t.Fatalf("Error getting body: %v", err) - } - - type User struct { - Username string `json:"username"` - } - - var user User - - err = json.Unmarshal(body, &user) - if err != nil { - t.Fatalf("Error unmarshalling body: %v", err) - } - - // We should get the user back - if user.Username != "user" { - t.Fatalf("Expected user, got %s", user.Username) - } -} - -func TestLogout(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing logout") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("POST", "/api/logout", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Check if the cookie is different (means the cookie is gone) - if recorder.Result().Cookies()[0].Value == cookie { - t.Fatalf("Cookie not flushed") - } -} - -func TestAuth(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing auth endpoint") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/auth/traefik", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.Header.Set("Accept", "text/html") - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusTemporaryRedirect) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/traefik", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) -} - -func TestTOTP(t *testing.T) { - t.Log("Testing TOTP") - - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: "Tinyauth", - AccountName: user.Username, - }) - if err != nil { - t.Fatalf("Failed to generate TOTP secret: %v", err) - } - - secret := key.Secret() - - user.TotpSecret = secret - - srv := getServer(t) - - user := types.LoginRequest{ - Username: "user", - Password: "pass", - } - - loginJson, err := json.Marshal(user) - if err != nil { - t.Fatalf("Error marshalling json: %v", err) - } - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(loginJson))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Set the cookie for next test - cookie = recorder.Result().Cookies()[0].Value - - code, err := totp.GenerateCode(secret, time.Now()) - if err != nil { - t.Fatalf("Failed to generate TOTP code: %v", err) - } - - totpRequest := types.TotpRequest{ - Code: code, - } - - totpJson, err := json.Marshal(totpRequest) - if err != nil { - t.Fatalf("Error marshalling TOTP request: %v", err) - } - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("POST", "/api/totp", strings.NewReader(string(totpJson))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) -} diff --git a/internal/handlers/oauth.go b/internal/handlers/oauth.go deleted file mode 100644 index 13c3a47..0000000 --- a/internal/handlers/oauth.go +++ /dev/null @@ -1,223 +0,0 @@ -package handlers - -import ( - "fmt" - "net/http" - "strings" - "time" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) OAuthURLHandler(c *gin.Context) { - var request types.OAuthRequest - - err := c.BindUri(&request) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got OAuth request") - - // Check if provider exists - provider := h.Providers.GetProvider(request.Provider) - - if provider == nil { - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", - }) - return - } - - log.Debug().Str("provider", request.Provider).Msg("Got provider") - - // Create state - state := provider.GenerateState() - - // Get auth URL - authURL := provider.GetAuthURL(state) - - log.Debug().Msg("Got auth URL") - - // Set CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - - // Get redirect URI - redirectURI := c.Query("redirect_uri") - - // Set redirect cookie if redirect URI is provided - if redirectURI != "" { - log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") - c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - } - - // Return auth URL - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - "url": authURL, - }) -} - -func (h *Handlers) OAuthCallbackHandler(c *gin.Context) { - var providerName types.OAuthRequest - - err := c.BindUri(&providerName) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") - - // Get state - state := c.Query("state") - - // Get CSRF cookie - csrfCookie, err := c.Cookie(h.Config.CsrfCookieName) - - if err != nil { - log.Debug().Msg("No CSRF cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie") - - // Check if CSRF cookie is valid - if csrfCookie != state { - log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Clean up CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Get code - code := c.Query("code") - - log.Debug().Msg("Got code") - - // Get provider - provider := h.Providers.GetProvider(providerName.Provider) - - if provider == nil { - c.Redirect(http.StatusTemporaryRedirect, "/not-found") - return - } - - log.Debug().Str("provider", providerName.Provider).Msg("Got provider") - - // Exchange token (authenticates user) - _, err = provider.ExchangeToken(code) - if err != nil { - log.Error().Err(err).Msg("Failed to exchange token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got token") - - // Get user - user, err := h.Providers.GetUser(providerName.Provider) - if err != nil { - log.Error().Err(err).Msg("Failed to get user") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("user", user).Msg("Got user") - - // Check that email is not empty - if user.Email == "" { - log.Error().Msg("Email is empty") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Email is not whitelisted - if !h.Auth.EmailWhitelisted(user.Email) { - log.Warn().Str("email", user.Email).Msg("Email not whitelisted") - queries, err := query.Values(types.UnauthorizedQuery{ - Username: user.Email, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - } - - log.Debug().Msg("Email whitelisted") - - // Get username - var username string - - if user.PreferredUsername != "" { - username = user.PreferredUsername - } else { - username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) - } - - // Get name - var name string - - if user.Name != "" { - name = user.Name - } else { - name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) - } - - // Create session cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: username, - Name: name, - Email: user.Email, - Provider: providerName.Provider, - OAuthGroups: utils.CoalesceToString(user.Groups), - }) - - // Check if we have a redirect URI - redirectCookie, err := c.Cookie(h.Config.RedirectCookieName) - - if err != nil { - log.Debug().Msg("No redirect cookie") - c.Redirect(http.StatusTemporaryRedirect, h.Config.AppURL) - return - } - - log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: redirectCookie, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got redirect query") - - // Clean up redirect cookie - c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Redirect to continue with the redirect URI - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) -} diff --git a/internal/handlers/proxy.go b/internal/handlers/proxy.go deleted file mode 100644 index fd87fd1..0000000 --- a/internal/handlers/proxy.go +++ /dev/null @@ -1,282 +0,0 @@ -package handlers - -import ( - "fmt" - "net/http" - "strings" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) ProxyHandler(c *gin.Context) { - var proxy types.Proxy - - err := c.BindUri(&proxy) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) - isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") - - if isBrowser { - log.Debug().Msg("Request is most likely coming from a browser") - } else { - log.Debug().Msg("Request is most likely not coming from a browser") - } - - log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") - - uri := c.Request.Header.Get("X-Forwarded-Uri") - proto := c.Request.Header.Get("X-Forwarded-Proto") - host := c.Request.Header.Get("X-Forwarded-Host") - - hostPortless := strings.Split(host, ":")[0] // *lol* - id := strings.Split(hostPortless, ".")[0] - - labels, err := h.Docker.GetLabels(id, hostPortless) - if err != nil { - log.Error().Err(err).Msg("Failed to get container labels") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("labels", labels).Msg("Got labels") - - ip := c.ClientIP() - - if h.Auth.BypassedIP(labels, ip) { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - if !h.Auth.CheckIP(labels, ip) { - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(403, gin.H{ - "status": 403, - "message": "Forbidden", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - IP: ip, - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - authEnabled, err := h.Auth.AuthEnabled(uri, labels) - if err != nil { - log.Error().Err(err).Msg("Failed to check if app is allowed") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - if !authEnabled { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - - return - } - - userContext := h.Hooks.UseUserContext(c) - - // If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth - if userContext.Provider == "basic" && userContext.TotpEnabled { - log.Warn().Str("username", userContext.Username).Msg("User has totp enabled, disabling basic auth") - userContext.IsLoggedIn = false - } - - if userContext.IsLoggedIn { - log.Debug().Msg("Authenticated") - - // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx - appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) - - log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") - - if !appAllowed { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - if userContext.OAuth { - groupOk := h.Auth.OAuthGroup(c, userContext, labels) - - log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") - - if !groupOk { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - GroupErr: true, - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - } - - c.Header("Authorization", c.Request.Header.Get("Authorization")) - c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) - c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) - c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - - parsedHeaders := utils.ParseHeaders(labels.Headers) - for key, value := range parsedHeaders { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - // The user is not logged in - log.Debug().Msg("Unauthorized") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode())) -} diff --git a/internal/handlers/user.go b/internal/handlers/user.go deleted file mode 100644 index 91d0fef..0000000 --- a/internal/handlers/user.go +++ /dev/null @@ -1,197 +0,0 @@ -package handlers - -import ( - "fmt" - "strings" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) LoginHandler(c *gin.Context) { - var login types.LoginRequest - - err := c.BindJSON(&login) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got login request") - - clientIP := c.ClientIP() - - // Create an identifier for rate limiting (username or IP if username doesn't exist yet) - rateIdentifier := login.Username - if rateIdentifier == "" { - rateIdentifier = clientIP - } - - // Check if the account is locked due to too many failed attempts - locked, remainingTime := h.Auth.IsAccountLocked(rateIdentifier) - if locked { - log.Warn().Str("identifier", rateIdentifier).Int("remaining_seconds", remainingTime).Msg("Account is locked due to too many failed login attempts") - c.JSON(429, gin.H{ - "status": 429, - "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), - }) - return - } - - // Search for a user based on username - log.Debug().Interface("username", login.Username).Msg("Searching for user") - - userSearch := h.Auth.SearchUser(login.Username) - - // User does not exist - if userSearch.Type == "" { - log.Debug().Str("username", login.Username).Msg("User not found") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Got user") - - // Check if password is correct - if !h.Auth.VerifyUser(userSearch, login.Password) { - log.Debug().Str("username", login.Username).Msg("Password incorrect") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Password correct, checking totp") - - // Record successful login attempt (will reset failed attempt counter) - h.Auth.RecordLoginAttempt(rateIdentifier, true) - - // Check if user is using TOTP - if userSearch.Type == "local" { - // Get local user - localUser := h.Auth.GetLocalUser(login.Username) - - // Check if TOTP is enabled - if localUser.TotpSecret != "" { - log.Debug().Msg("Totp enabled") - - // Set totp pending cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - TotpPending: true, - }) - - // Return totp required - c.JSON(200, gin.H{ - "status": 200, - "message": "Waiting for totp", - "totpPending": true, - }) - return - } - } - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - "totpPending": false, - }) -} - -func (h *Handlers) TOTPHandler(c *gin.Context) { - var totpReq types.TotpRequest - - err := c.BindJSON(&totpReq) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Checking totp") - - // Get user context - userContext := h.Hooks.UseUserContext(c) - - // Check if we have a user - if userContext.Username == "" { - log.Debug().Msg("No user context") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Get user - user := h.Auth.GetLocalUser(userContext.Username) - - // Check if totp is correct - ok := totp.Validate(totpReq.Code, user.TotpSecret) - - if !ok { - log.Debug().Msg("Totp incorrect") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Totp correct") - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: user.Username, - Name: utils.Capitalize(user.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - }) -} - -func (h *Handlers) LogoutHandler(c *gin.Context) { - log.Debug().Msg("Cleaning up redirect cookie") - - h.Auth.DeleteSessionCookie(c) - - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged out", - }) -} diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go deleted file mode 100644 index 3083b98..0000000 --- a/internal/hooks/hooks.go +++ /dev/null @@ -1,144 +0,0 @@ -package hooks - -import ( - "fmt" - "strings" - "tinyauth/internal/auth" - "tinyauth/internal/oauth" - "tinyauth/internal/providers" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -type Hooks struct { - Config types.HooksConfig - Auth *auth.Auth - Providers *providers.Providers -} - -func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { - return &Hooks{ - Config: config, - Auth: auth, - Providers: providers, - } -} - -func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { - cookie, err := hooks.Auth.GetSessionCookie(c) - var provider *oauth.OAuth - - if err != nil { - log.Error().Err(err).Msg("Failed to get session cookie") - goto basic - } - - if cookie.TotpPending { - log.Debug().Msg("Totp pending") - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - TotpPending: true, - } - } - - if cookie.Provider == "username" { - log.Debug().Msg("Provider is username") - - userSearch := hooks.Auth.SearchUser(cookie.Username) - - if userSearch.Type == "unknown" { - log.Warn().Str("username", cookie.Username).Msg("User does not exist") - goto basic - } - - log.Debug().Str("type", userSearch.Type).Msg("User exists") - - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - IsLoggedIn: true, - Provider: "username", - } - } - - log.Debug().Msg("Provider is not username") - - provider = hooks.Providers.GetProvider(cookie.Provider) - - if provider != nil { - log.Debug().Msg("Provider exists") - - if !hooks.Auth.EmailWhitelisted(cookie.Email) { - log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted") - hooks.Auth.DeleteSessionCookie(c) - goto basic - } - - log.Debug().Msg("Email is whitelisted") - - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - IsLoggedIn: true, - OAuth: true, - Provider: cookie.Provider, - OAuthGroups: cookie.OAuthGroups, - } - } - -basic: - log.Debug().Msg("Trying basic auth") - - basic := hooks.Auth.GetBasicAuth(c) - - if basic != nil { - log.Debug().Msg("Got basic auth") - - userSearch := hooks.Auth.SearchUser(basic.Username) - - if userSearch.Type == "unkown" { - log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist") - return types.UserContext{} - } - - if !hooks.Auth.VerifyUser(userSearch, basic.Password) { - log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect") - return types.UserContext{} - } - - if userSearch.Type == "ldap" { - log.Debug().Msg("User is LDAP") - - return types.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), - IsLoggedIn: true, - Provider: "basic", - TotpEnabled: false, - } - } - - user := hooks.Auth.GetLocalUser(basic.Username) - - return types.UserContext{ - Username: basic.Username, - Name: utils.Capitalize(basic.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), - IsLoggedIn: true, - Provider: "basic", - TotpEnabled: user.TotpSecret != "", - } - - } - - return types.UserContext{} -} diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go new file mode 100644 index 0000000..58e53e1 --- /dev/null +++ b/internal/middleware/context_middleware.go @@ -0,0 +1,159 @@ +package middleware + +import ( + "fmt" + "strings" + "tinyauth/internal/config" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +type ContextMiddlewareConfig struct { + Domain string +} + +type ContextMiddleware struct { + Config ContextMiddlewareConfig + Auth *service.AuthService + Broker *service.OAuthBrokerService +} + +func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { + return &ContextMiddleware{ + Config: config, + Auth: auth, + Broker: broker, + } +} + +func (m *ContextMiddleware) Init() error { + return nil +} + +func (m *ContextMiddleware) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + cookie, err := m.Auth.GetSessionCookie(c) + + if err != nil { + log.Debug().Err(err).Msg("No valid session cookie found") + goto basic + } + + if cookie.TotpPending { + c.Set("context", &config.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: "username", + TotpPending: true, + TotpEnabled: true, + }) + c.Next() + return + } + + switch cookie.Provider { + case "username": + userSearch := m.Auth.SearchUser(cookie.Username) + + if userSearch.Type == "unknown" { + log.Debug().Msg("User from session cookie not found") + m.Auth.DeleteSessionCookie(c) + goto basic + } + + c.Set("context", &config.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: "username", + IsLoggedIn: true, + }) + c.Next() + return + default: + _, exists := m.Broker.GetService(cookie.Provider) + + if !exists { + log.Debug().Msg("OAuth provider from session cookie not found") + m.Auth.DeleteSessionCookie(c) + goto basic + } + + if !m.Auth.EmailWhitelisted(cookie.Email) { + log.Debug().Msg("Email from session cookie not whitelisted") + m.Auth.DeleteSessionCookie(c) + goto basic + } + + c.Set("context", &config.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: cookie.Provider, + OAuthGroups: cookie.OAuthGroups, + IsLoggedIn: true, + OAuth: true, + }) + c.Next() + return + } + + basic: + basic := m.Auth.GetBasicAuth(c) + + if basic == nil { + log.Debug().Msg("No basic auth provided") + c.Next() + return + } + + userSearch := m.Auth.SearchUser(basic.Username) + + if userSearch.Type == "unknown" { + log.Debug().Msg("User from basic auth not found") + c.Next() + return + } + + if !m.Auth.VerifyUser(userSearch, basic.Password) { + log.Debug().Msg("Invalid password for basic auth user") + c.Next() + return + } + + switch userSearch.Type { + case "local": + log.Debug().Msg("Basic auth user is local") + + user := m.Auth.GetLocalUser(basic.Username) + + c.Set("context", &config.UserContext{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.Domain), + Provider: "basic", + IsLoggedIn: true, + TotpEnabled: user.TotpSecret != "", + }) + c.Next() + return + case "ldap": + log.Debug().Msg("Basic auth user is LDAP") + c.Set("context", &config.UserContext{ + Username: basic.Username, + Name: utils.Capitalize(basic.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.Domain), + Provider: "basic", + IsLoggedIn: true, + }) + c.Next() + return + } + + c.Next() + } +} diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go new file mode 100644 index 0000000..dcfaa35 --- /dev/null +++ b/internal/middleware/ui_middleware.go @@ -0,0 +1,56 @@ +package middleware + +import ( + "io/fs" + "net/http" + "os" + "strings" + "tinyauth/internal/assets" + + "github.com/gin-gonic/gin" +) + +type UIMiddleware struct { + UIFS fs.FS + UIFileServer http.Handler +} + +func NewUIMiddleware() *UIMiddleware { + return &UIMiddleware{} +} + +func (m *UIMiddleware) Init() error { + ui, err := fs.Sub(assets.FrontendAssets, "dist") + + if err != nil { + return err + } + + m.UIFS = ui + m.UIFileServer = http.FileServer(http.FS(ui)) + + return nil +} + +func (m *UIMiddleware) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + switch strings.Split(c.Request.URL.Path, "/")[1] { + case "api": + c.Next() + return + case "resources": + c.Next() + return + default: + _, err := fs.Stat(m.UIFS, strings.TrimPrefix(c.Request.URL.Path, "/")) + + if os.IsNotExist(err) { + c.Request.URL.Path = "/" + } + + m.UIFileServer.ServeHTTP(c.Writer, c.Request) + c.Abort() + return + } + } +} diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go new file mode 100644 index 0000000..877ad4c --- /dev/null +++ b/internal/middleware/zerolog_middleware.go @@ -0,0 +1,66 @@ +package middleware + +import ( + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +var ( + loggerSkipPathsPrefix = []string{ + "GET /api/health", + "HEAD /api/health", + "GET /favicon.ico", + } +) + +type ZerologMiddleware struct{} + +func NewZerologMiddleware() *ZerologMiddleware { + return &ZerologMiddleware{} +} + +func (m *ZerologMiddleware) Init() error { + return nil +} + +func (m *ZerologMiddleware) logPath(path string) bool { + for _, prefix := range loggerSkipPathsPrefix { + if strings.HasPrefix(path, prefix) { + return false + } + } + return true +} + +func (m *ZerologMiddleware) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + tStart := time.Now() + + c.Next() + + code := c.Writer.Status() + address := c.Request.RemoteAddr + clientIP := c.ClientIP() + method := c.Request.Method + path := c.Request.URL.Path + + latency := time.Since(tStart).String() + + // logPath check if the path should be logged normally or with debug + if m.logPath(method + " " + path) { + switch { + case code >= 200 && code < 300: + log.Info().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 300 && code < 400: + log.Warn().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 400: + log.Error().Str("method", method).Str("path", path).Str("address", address).Str("clientIp", clientIP).Int("status", code).Str("latency", latency).Msg("Request") + } + } else { + log.Debug().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") + } + } +} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go deleted file mode 100644 index 9529fce..0000000 --- a/internal/oauth/oauth.go +++ /dev/null @@ -1,71 +0,0 @@ -package oauth - -import ( - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "net/http" - - "golang.org/x/oauth2" -) - -type OAuth struct { - Config oauth2.Config - Context context.Context - Token *oauth2.Token - Verifier string -} - -func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecureSkipVerify, - MinVersion: tls.VersionTLS12, - }, - } - - httpClient := &http.Client{ - Transport: transport, - } - - ctx := context.Background() - - // Set the HTTP client in the context - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - verifier := oauth2.GenerateVerifier() - - return &OAuth{ - Config: config, - Context: ctx, - Verifier: verifier, - } -} - -func (oauth *OAuth) GetAuthURL(state string) string { - return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) -} - -func (oauth *OAuth) ExchangeToken(code string) (string, error) { - token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) - - if err != nil { - return "", err - } - - // Set and return the token - oauth.Token = token - return oauth.Token.AccessToken, nil -} - -func (oauth *OAuth) GetClient() *http.Client { - return oauth.Config.Client(oauth.Context, oauth.Token) -} - -func (oauth *OAuth) GenerateState() string { - b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) - return state -} diff --git a/internal/providers/generic.go b/internal/providers/generic.go deleted file mode 100644 index 200f7c4..0000000 --- a/internal/providers/generic.go +++ /dev/null @@ -1,37 +0,0 @@ -package providers - -import ( - "encoding/json" - "io" - "net/http" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get(url) - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got response from generic provider") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read body from generic provider") - - err = json.Unmarshal(body, &user) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed user from generic provider") - return user, nil -} diff --git a/internal/providers/github.go b/internal/providers/github.go deleted file mode 100644 index 67f8510..0000000 --- a/internal/providers/github.go +++ /dev/null @@ -1,102 +0,0 @@ -package providers - -import ( - "encoding/json" - "errors" - "io" - "net/http" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -// Response for the github email endpoint -type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` -} - -// Response for the github user endpoint -type GithubUserInfoResponse struct { - Login string `json:"login"` - Name string `json:"name"` -} - -// The scopes required for the github provider -func GithubScopes() []string { - return []string{"user:email", "read:user"} -} - -func GetGithubUser(client *http.Client) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get("https://api.github.com/user") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got user response from github") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read user body from github") - - var userInfo GithubUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - res, err = client.Get("https://api.github.com/user/emails") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got email response from github") - - body, err = io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read email body from github") - - var emails GithubEmailResponse - - err = json.Unmarshal(body, &emails) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed emails from github") - - // Find and return the primary email - for _, email := range emails { - if email.Primary { - log.Debug().Str("email", email.Email).Msg("Found primary email") - user.Email = email.Email - break - } - } - - if len(emails) == 0 { - return user, errors.New("no emails found") - } - - // Use first available email if no primary email was found - if user.Email == "" { - log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email") - user.Email = emails[0].Email - } - - user.PreferredUsername = userInfo.Login - user.Name = userInfo.Name - - return user, nil -} diff --git a/internal/providers/google.go b/internal/providers/google.go deleted file mode 100644 index e794bee..0000000 --- a/internal/providers/google.go +++ /dev/null @@ -1,56 +0,0 @@ -package providers - -import ( - "encoding/json" - "io" - "net/http" - "strings" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -// Response for the google user endpoint -type GoogleUserInfoResponse struct { - Email string `json:"email"` - Name string `json:"name"` -} - -// The scopes required for the google provider -func GoogleScopes() []string { - return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} -} - -func GetGoogleUser(client *http.Client) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got response from google") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read body from google") - - var userInfo GoogleUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed user from google") - - user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] - user.Name = userInfo.Name - user.Email = userInfo.Email - - return user, nil -} diff --git a/internal/providers/providers.go b/internal/providers/providers.go deleted file mode 100644 index 7af127e..0000000 --- a/internal/providers/providers.go +++ /dev/null @@ -1,154 +0,0 @@ -package providers - -import ( - "fmt" - "tinyauth/internal/constants" - "tinyauth/internal/oauth" - "tinyauth/internal/types" - - "github.com/rs/zerolog/log" - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -type Providers struct { - Config types.OAuthConfig - Github *oauth.OAuth - Google *oauth.OAuth - Generic *oauth.OAuth -} - -func NewProviders(config types.OAuthConfig) *Providers { - providers := &Providers{ - Config: config, - } - - if config.GithubClientId != "" && config.GithubClientSecret != "" { - log.Info().Msg("Initializing Github OAuth") - providers.Github = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GithubClientId, - ClientSecret: config.GithubClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL), - Scopes: GithubScopes(), - Endpoint: endpoints.GitHub, - }, false) - } - - if config.GoogleClientId != "" && config.GoogleClientSecret != "" { - log.Info().Msg("Initializing Google OAuth") - providers.Google = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GoogleClientId, - ClientSecret: config.GoogleClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL), - Scopes: GoogleScopes(), - Endpoint: endpoints.Google, - }, false) - } - - if config.GenericClientId != "" && config.GenericClientSecret != "" { - log.Info().Msg("Initializing Generic OAuth") - providers.Generic = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GenericClientId, - ClientSecret: config.GenericClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL), - Scopes: config.GenericScopes, - Endpoint: oauth2.Endpoint{ - AuthURL: config.GenericAuthURL, - TokenURL: config.GenericTokenURL, - }, - }, config.GenericSkipSSL) - } - - return providers -} - -func (providers *Providers) GetProvider(provider string) *oauth.OAuth { - switch provider { - case "github": - return providers.Github - case "google": - return providers.Google - case "generic": - return providers.Generic - default: - return nil - } -} - -func (providers *Providers) GetUser(provider string) (constants.Claims, error) { - var user constants.Claims - - // Get the user from the provider - switch provider { - case "github": - if providers.Github == nil { - log.Debug().Msg("Github provider not configured") - return user, nil - } - - client := providers.Github.GetClient() - - log.Debug().Msg("Got client from github") - - user, err := GetGithubUser(client) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from github") - - return user, nil - case "google": - if providers.Google == nil { - log.Debug().Msg("Google provider not configured") - return user, nil - } - - client := providers.Google.GetClient() - - log.Debug().Msg("Got client from google") - - user, err := GetGoogleUser(client) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from google") - - return user, nil - case "generic": - if providers.Generic == nil { - log.Debug().Msg("Generic provider not configured") - return user, nil - } - - client := providers.Generic.GetClient() - - log.Debug().Msg("Got client from generic") - - user, err := GetGenericUser(client, providers.Config.GenericUserURL) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from generic") - - return user, nil - default: - return user, nil - } -} - -func (provider *Providers) GetConfiguredProviders() []string { - providers := []string{} - if provider.Github != nil { - providers = append(providers, "github") - } - if provider.Google != nil { - providers = append(providers, "google") - } - if provider.Generic != nil { - providers = append(providers, "generic") - } - return providers -} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 8826032..0000000 --- a/internal/server/server.go +++ /dev/null @@ -1,130 +0,0 @@ -package server - -import ( - "fmt" - "io/fs" - "net/http" - "os" - "strings" - "time" - "tinyauth/internal/assets" - "tinyauth/internal/handlers" - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -type Server struct { - Config types.ServerConfig - Handlers *handlers.Handlers - Router *gin.Engine -} - -var ( - loggerSkipPathsPrefix = []string{ - "GET /api/healthcheck", - "HEAD /api/healthcheck", - "GET /favicon.ico", - } -) - -func logPath(path string) bool { - for _, prefix := range loggerSkipPathsPrefix { - if strings.HasPrefix(path, prefix) { - return false - } - } - return true -} - -func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) { - gin.SetMode(gin.ReleaseMode) - - log.Debug().Msg("Setting up router") - router := gin.New() - router.Use(zerolog()) - - log.Debug().Msg("Setting up assets") - dist, err := fs.Sub(assets.Assets, "dist") - if err != nil { - return nil, err - } - - log.Debug().Msg("Setting up file server") - fileServer := http.FileServer(http.FS(dist)) - - // UI middleware - router.Use(func(c *gin.Context) { - // If not an API request, serve the UI - if !strings.HasPrefix(c.Request.URL.Path, "/api") { - _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) - if os.IsNotExist(err) { - c.Request.URL.Path = "/" - } - fileServer.ServeHTTP(c.Writer, c.Request) - c.Abort() - } - }) - - // Proxy routes - router.GET("/api/auth/:proxy", handlers.ProxyHandler) - - // Auth routes - router.POST("/api/login", handlers.LoginHandler) - router.POST("/api/totp", handlers.TOTPHandler) - router.POST("/api/logout", handlers.LogoutHandler) - - // Context routes - router.GET("/api/app", handlers.AppContextHandler) - router.GET("/api/user", handlers.UserContextHandler) - - // OAuth routes - router.GET("/api/oauth/url/:provider", handlers.OAuthURLHandler) - router.GET("/api/oauth/callback/:provider", handlers.OAuthCallbackHandler) - - // App routes - router.GET("/api/healthcheck", handlers.HealthcheckHandler) - router.HEAD("/api/healthcheck", handlers.HealthcheckHandler) - - return &Server{ - Config: config, - Handlers: handlers, - Router: router, - }, nil -} - -func (s *Server) Start() error { - log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server") - return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port)) -} - -// zerolog is a middleware for gin that logs requests using zerolog -func zerolog() gin.HandlerFunc { - return func(c *gin.Context) { - tStart := time.Now() - - c.Next() - - code := c.Writer.Status() - address := c.Request.RemoteAddr - method := c.Request.Method - path := c.Request.URL.Path - - latency := time.Since(tStart).String() - - // logPath check if the path should be logged normally or with debug - if logPath(method + " " + path) { - switch { - case code >= 200 && code < 300: - log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - case code >= 300 && code < 400: - log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - case code >= 400: - log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - } - } else { - log.Debug().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - } - } -} diff --git a/internal/auth/auth.go b/internal/service/auth_service.go similarity index 62% rename from internal/auth/auth.go rename to internal/service/auth_service.go index 3f18419..10d49e7 100644 --- a/internal/auth/auth.go +++ b/internal/service/auth_service.go @@ -1,4 +1,4 @@ -package auth +package service import ( "fmt" @@ -6,9 +6,7 @@ import ( "strings" "sync" "time" - "tinyauth/internal/docker" - "tinyauth/internal/ldap" - "tinyauth/internal/types" + "tinyauth/internal/config" "tinyauth/internal/utils" "github.com/gin-gonic/gin" @@ -17,44 +15,66 @@ import ( "golang.org/x/crypto/bcrypt" ) -type Auth struct { - Config types.AuthConfig - Docker *docker.Docker - LoginAttempts map[string]*types.LoginAttempt - LoginMutex sync.RWMutex - Store *sessions.CookieStore - LDAP *ldap.LDAP +type LoginAttempt struct { + FailedAttempts int + LastAttempt time.Time + LockedUntil time.Time } -func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth { - // Setup cookie store and create the auth service - store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret)) - store.Options = &sessions.Options{ - Path: "/", - MaxAge: config.SessionExpiry, - Secure: config.CookieSecure, - HttpOnly: true, - Domain: fmt.Sprintf(".%s", config.Domain), - } - return &Auth{ +type AuthServiceConfig struct { + Users []config.User + OauthWhitelist string + SessionExpiry int + SecureCookie bool + Domain string + LoginTimeout int + LoginMaxRetries int + SessionCookieName string + HMACSecret string + EncryptionSecret string +} + +type AuthService struct { + Config AuthServiceConfig + Docker *DockerService + LoginAttempts map[string]*LoginAttempt + LoginMutex sync.RWMutex + Store *sessions.CookieStore + LDAP *LdapService +} + +func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService { + return &AuthService{ Config: config, Docker: docker, - LoginAttempts: make(map[string]*types.LoginAttempt), - Store: store, + LoginAttempts: make(map[string]*LoginAttempt), LDAP: ldap, } } -func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { +func (auth *AuthService) Init() error { + store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret)) + store.Options = &sessions.Options{ + Path: "/", + MaxAge: auth.Config.SessionExpiry, + Secure: auth.Config.SecureCookie, + HttpOnly: true, + Domain: fmt.Sprintf(".%s", auth.Config.Domain), + } + + auth.Store = store + return nil +} + +func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { - log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying") - c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) - session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) + log.Debug().Err(err).Msg("Error getting session, creating a new one") + c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) + session, err = auth.Store.New(c.Request, auth.Config.SessionCookieName) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return nil, err } } @@ -62,95 +82,79 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { return session, nil } -func (auth *Auth) SearchUser(username string) types.UserSearch { - log.Debug().Str("username", username).Msg("Searching for user") - - // Check local users first +func (auth *AuthService) SearchUser(username string) config.UserSearch { if auth.GetLocalUser(username).Username != "" { - log.Debug().Str("username", username).Msg("Found local user") - return types.UserSearch{ + return config.UserSearch{ Username: username, Type: "local", } } - // If no user found, check LDAP if auth.LDAP != nil { - log.Debug().Str("username", username).Msg("Checking LDAP for user") userDN, err := auth.LDAP.Search(username) + if err != nil { - log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") - return types.UserSearch{} + log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") + return config.UserSearch{} } - return types.UserSearch{ + + return config.UserSearch{ Username: userDN, Type: "ldap", } } - return types.UserSearch{ + return config.UserSearch{ Type: "unknown", } } -func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { - // Authenticate the user based on the type +func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool { switch search.Type { case "local": - // If local user, get the user and check the password user := auth.GetLocalUser(search.Username) return auth.CheckPassword(user, password) case "ldap": - // If LDAP is configured, bind to the LDAP server with the user DN and password if auth.LDAP != nil { - log.Debug().Str("username", search.Username).Msg("Binding to LDAP for user authentication") - err := auth.LDAP.Bind(search.Username, password) if err != nil { log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") return false } - // Rebind with the service account to reset the connection err = auth.LDAP.Bind(auth.LDAP.Config.BindDN, auth.LDAP.Config.BindPassword) if err != nil { log.Error().Err(err).Msg("Failed to rebind with service account after user authentication") return false } - log.Debug().Str("username", search.Username).Msg("LDAP authentication successful") return true } default: - log.Warn().Str("type", search.Type).Msg("Unknown user type for authentication") + log.Debug().Str("type", search.Type).Msg("Unknown user type for authentication") return false } - // If no user found or authentication failed, return false log.Warn().Str("username", search.Username).Msg("User authentication failed") return false } -func (auth *Auth) GetLocalUser(username string) types.User { - // Loop through users and return the user if the username matches - log.Debug().Str("username", username).Msg("Searching for local user") - +func (auth *AuthService) GetLocalUser(username string) config.User { for _, user := range auth.Config.Users { if user.Username == username { return user } } - // If no user found, return an empty user log.Warn().Str("username", username).Msg("Local user not found") - return types.User{} + return config.User{} } -func (auth *Auth) CheckPassword(user types.User, password string) bool { +func (auth *AuthService) CheckPassword(user config.User, password string) bool { return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } -func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { +func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { auth.LoginMutex.RLock() defer auth.LoginMutex.RUnlock() @@ -176,7 +180,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { return false, 0 } -func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { +func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { // Skip if rate limiting is not configured if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { return @@ -188,7 +192,7 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { // Get current attempt record or create a new one attempt, exists := auth.LoginAttempts[identifier] if !exists { - attempt = &types.LoginAttempt{} + attempt = &LoginAttempt{} auth.LoginAttempts[identifier] = attempt } @@ -212,21 +216,16 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { } } -func (auth *Auth) EmailWhitelisted(email string) bool { +func (auth *AuthService) EmailWhitelisted(email string) bool { return utils.CheckFilter(auth.Config.OauthWhitelist, email) } -func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { - log.Debug().Msg("Creating session cookie") - +func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error { session, err := auth.GetSession(c) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return err } - log.Debug().Msg("Setting session cookie") - var sessionExpiry int if data.TotpPending { @@ -245,19 +244,15 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) err = session.Save(c.Request, c.Writer) if err != nil { - log.Error().Err(err).Msg("Failed to save session") return err } return nil } -func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { - log.Debug().Msg("Deleting session cookie") - +func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { session, err := auth.GetSession(c) if err != nil { - log.Error().Err(err).Msg("Failed to get session") return err } @@ -268,24 +263,21 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { err = session.Save(c.Request, c.Writer) if err != nil { - log.Error().Err(err).Msg("Failed to save session") return err } + // Clear the cookie in the browser + c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) + return nil } -func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { - log.Debug().Msg("Getting session cookie") - +func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) { session, err := auth.GetSession(c) if err != nil { - log.Error().Err(err).Msg("Failed to get session") - return types.SessionCookie{}, err + return config.SessionCookie{}, err } - log.Debug().Msg("Got session") - username, usernameOk := session.Values["username"].(string) email, emailOk := session.Values["email"].(string) name, nameOk := session.Values["name"].(string) @@ -298,18 +290,17 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { log.Warn().Msg("Session cookie is invalid") auth.DeleteSessionCookie(c) - return types.SessionCookie{}, nil + return config.SessionCookie{}, nil } // If the session cookie has expired, delete it if time.Now().Unix() > expiry { log.Warn().Msg("Session cookie expired") auth.DeleteSessionCookie(c) - return types.SessionCookie{}, nil + return config.SessionCookie{}, nil } - log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") - return types.SessionCookie{ + return config.SessionCookie{ Username: username, Name: name, Email: email, @@ -319,12 +310,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) }, nil } -func (auth *Auth) UserAuthConfigured() bool { +func (auth *AuthService) UserAuthConfigured() bool { // If there are users or LDAP is configured, return true return len(auth.Config.Users) > 0 || auth.LDAP != nil } -func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) ResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool { if context.OAuth { log.Debug().Msg("Checking OAuth whitelist") return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) @@ -334,12 +325,11 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab return utils.CheckFilter(labels.Users, context.Username) } -func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) OAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool { if labels.OAuth.Groups == "" { return true } - // Check if we are using the generic oauth provider if context.Provider != "generic" { log.Debug().Msg("Not using generic provider, skipping group check") return true @@ -351,7 +341,6 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t // For every group check if it is in the required groups for _, group := range oauthGroups { if utils.CheckFilter(labels.OAuth.Groups, group) { - log.Debug().Str("group", group).Msg("Group is in required groups") return true } } @@ -361,18 +350,15 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t return false } -func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { +func (auth *AuthService) AuthEnabled(uri string, labels config.Labels) (bool, error) { // If the label is empty, auth is enabled if labels.Allowed == "" { return true, nil } - // Compile regex regex, err := regexp.Compile(labels.Allowed) - // If there is an error, invalid regex, auth enabled if err != nil { - log.Error().Err(err).Msg("Invalid regex") return true, err } @@ -385,27 +371,28 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { return true, nil } -func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { +func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { username, password, ok := c.Request.BasicAuth() if !ok { + log.Debug().Msg("No basic auth provided") return nil } - return &types.User{ + return &config.User{ Username: username, Password: password, } } -func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { +func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { // Check if the IP is in block list for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) if err != nil { - log.Error().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { - log.Warn().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") + log.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") return false } } @@ -414,7 +401,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { for _, allowed := range labels.IP.Allow { res, err := utils.FilterIP(allowed, ip) if err != nil { - log.Error().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { @@ -425,7 +412,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { // If not in allowed range and allowed range is not empty, deny access if len(labels.IP.Allow) > 0 { - log.Warn().Str("ip", ip).Msg("IP not in allow list, denying access") + log.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } @@ -433,12 +420,12 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { return true } -func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { +func (auth *AuthService) BypassedIP(labels config.Labels, ip string) bool { // For every IP in the bypass list, check if the IP matches for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - log.Error().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { diff --git a/internal/docker/docker.go b/internal/service/docker_service.go similarity index 64% rename from internal/docker/docker.go rename to internal/service/docker_service.go index f5a0468..41eb07c 100644 --- a/internal/docker/docker.go +++ b/internal/service/docker_service.go @@ -1,37 +1,42 @@ -package docker +package service import ( "context" "strings" - "tinyauth/internal/types" + "tinyauth/internal/config" "tinyauth/internal/utils" + "slices" + container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" "github.com/rs/zerolog/log" ) -type Docker struct { +type DockerService struct { Client *client.Client Context context.Context } -func NewDocker() (*Docker, error) { +func NewDockerService() *DockerService { + return &DockerService{} +} + +func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { - return nil, err + return err } ctx := context.Background() client.NegotiateAPIVersion(ctx) - return &Docker{ - Client: client, - Context: ctx, - }, nil + docker.Client = client + docker.Context = ctx + return nil } -func (docker *Docker) GetContainers() ([]container.Summary, error) { +func (docker *DockerService) GetContainers() ([]container.Summary, error) { containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) if err != nil { return nil, err @@ -39,7 +44,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) { return containers, nil } -func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { +func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) { inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) if err != nil { return container.InspectResponse{}, err @@ -47,25 +52,22 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes return inspect, nil } -func (docker *Docker) DockerConnected() bool { +func (docker *DockerService) DockerConnected() bool { _, err := docker.Client.Ping(docker.Context) return err == nil } -func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) { +func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { isConnected := docker.DockerConnected() if !isConnected { log.Debug().Msg("Docker not connected, returning empty labels") - return types.Labels{}, nil + return config.Labels{}, nil } - log.Debug().Msg("Getting containers") - containers, err := docker.GetContainers() if err != nil { - log.Error().Err(err).Msg("Error getting containers") - return types.Labels{}, err + return config.Labels{}, err } for _, container := range containers { @@ -75,8 +77,6 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) continue } - log.Debug().Str("id", inspect.ID).Msg("Getting labels for container") - labels, err := utils.GetLabels(inspect.Config.Labels) if err != nil { log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping") @@ -84,11 +84,9 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) } // Check if the container matches the ID or domain - for _, lDomain := range labels.Domain { - if lDomain == domain { - log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") - return labels, nil - } + if slices.Contains(labels.Domain, domain) { + log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") + return labels, nil } if strings.TrimPrefix(inspect.Name, "/") == app { @@ -98,5 +96,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) } log.Debug().Msg("No matching container found, returning empty labels") - return types.Labels{}, nil + return config.Labels{}, nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go new file mode 100644 index 0000000..c16384d --- /dev/null +++ b/internal/service/generic_oauth_service.go @@ -0,0 +1,117 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +type GenericOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string + InsecureSkipVerify bool + UserinfoURL string +} + +func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { + return &GenericOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + }, + }, + InsecureSkipVerify: config.InsecureSkipVerify, + UserinfoURL: config.UserinfoURL, + } +} + +func (generic *GenericOAuthService) Init() error { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: generic.InsecureSkipVerify, + MinVersion: tls.VersionTLS12, + }, + } + + httpClient := &http.Client{ + Transport: transport, + } + + ctx := context.Background() + + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + generic.Context = ctx + generic.Verifier = verifier + return nil +} + +func (generic *GenericOAuthService) GenerateState() string { + b := make([]byte, 128) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) + return state +} + +func (generic *GenericOAuthService) GetAuthURL(state string) string { + return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier)) +} + +func (generic *GenericOAuthService) VerifyCode(code string) error { + token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) + + if err != nil { + return err + } + + generic.Token = token + return nil +} + +func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := generic.Config.Client(generic.Context, generic.Token) + + res, err := client.Get(generic.UserinfoURL) + if err != nil { + return user, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return user, err + } + + err = json.Unmarshal(body, &user) + if err != nil { + return user, err + } + + return user, nil +} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go new file mode 100644 index 0000000..7f8466b --- /dev/null +++ b/internal/service/github_oauth_service.go @@ -0,0 +1,169 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" + "tinyauth/internal/config" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" +) + +var GithubOAuthScopes = []string{"user:email", "read:user"} + +type GithubEmailResponse []struct { + Email string `json:"email"` + Primary bool `json:"primary"` +} + +type GithubUserInfoResponse struct { + Login string `json:"login"` + Name string `json:"name"` +} + +type GithubOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { + return &GithubOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: GithubOAuthScopes, + Endpoint: endpoints.GitHub, + }, + } +} + +func (github *GithubOAuthService) Init() error { + httpClient := &http.Client{} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + github.Context = ctx + github.Verifier = verifier + return nil +} + +func (github *GithubOAuthService) GenerateState() string { + b := make([]byte, 128) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) + return state +} + +func (github *GithubOAuthService) GetAuthURL(state string) string { + return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier)) +} + +func (github *GithubOAuthService) VerifyCode(code string) error { + token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) + + if err != nil { + return err + } + + github.Token = token + return nil +} + +func (github *GithubOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := github.Config.Client(github.Context, github.Token) + + req, err := http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + return user, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err := client.Do(req) + if err != nil { + return user, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return user, err + } + + var userInfo GithubUserInfoResponse + + err = json.Unmarshal(body, &userInfo) + if err != nil { + return user, err + } + + req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil) + if err != nil { + return user, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err = client.Do(req) + if err != nil { + return user, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err = io.ReadAll(res.Body) + if err != nil { + return user, err + } + + var emails GithubEmailResponse + + err = json.Unmarshal(body, &emails) + if err != nil { + return user, err + } + + for _, email := range emails { + if email.Primary { + user.Email = email.Email + break + } + } + + if len(emails) == 0 { + return user, errors.New("no emails found") + } + + // Use first available email if no primary email was found + if user.Email == "" { + user.Email = emails[0].Email + } + + user.PreferredUsername = userInfo.Login + user.Name = userInfo.Name + + return user, nil +} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go new file mode 100644 index 0000000..1605a85 --- /dev/null +++ b/internal/service/google_oauth_service.go @@ -0,0 +1,113 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + "tinyauth/internal/config" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" +) + +var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} + +type GoogleUserInfoResponse struct { + Email string `json:"email"` + Name string `json:"name"` +} + +type GoogleOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { + return &GoogleOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: GoogleOAuthScopes, + Endpoint: endpoints.Google, + }, + } +} + +func (google *GoogleOAuthService) Init() error { + httpClient := &http.Client{} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + google.Context = ctx + google.Verifier = verifier + return nil +} + +func (oauth *GoogleOAuthService) GenerateState() string { + b := make([]byte, 128) + _, err := rand.Read(b) + if err != nil { + return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) + } + state := base64.RawURLEncoding.EncodeToString(b) + return state +} + +func (google *GoogleOAuthService) GetAuthURL(state string) string { + return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier)) +} + +func (google *GoogleOAuthService) VerifyCode(code string) error { + token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) + + if err != nil { + return err + } + + google.Token = token + return nil +} + +func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := google.Config.Client(google.Context, google.Token) + + res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") + if err != nil { + return config.Claims{}, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return user, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return config.Claims{}, err + } + + var userInfo GoogleUserInfoResponse + + err = json.Unmarshal(body, &userInfo) + if err != nil { + return config.Claims{}, err + } + + user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] + user.Name = userInfo.Name + user.Email = userInfo.Email + + return user, nil +} diff --git a/internal/ldap/ldap.go b/internal/service/ldap_service.go similarity index 61% rename from internal/ldap/ldap.go rename to internal/service/ldap_service.go index 61578d7..8576c4d 100644 --- a/internal/ldap/ldap.go +++ b/internal/service/ldap_service.go @@ -1,30 +1,40 @@ -package ldap +package service import ( "context" "crypto/tls" "fmt" "time" - "tinyauth/internal/types" "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" "github.com/rs/zerolog/log" ) -type LDAP struct { - Config types.LdapConfig +type LdapServiceConfig struct { + Address string + BindDN string + BindPassword string + BaseDN string + Insecure bool + SearchFilter string +} + +type LdapService struct { + Config LdapServiceConfig Conn *ldapgo.Conn } -func NewLDAP(config types.LdapConfig) (*LDAP, error) { - ldap := &LDAP{ +func NewLdapService(config LdapServiceConfig) *LdapService { + return &LdapService{ Config: config, } +} +func (ldap *LdapService) Init() error { _, err := ldap.connect() if err != nil { - return nil, fmt.Errorf("failed to connect to LDAP server: %w", err) + return fmt.Errorf("failed to connect to LDAP server: %w", err) } go func() { @@ -41,65 +51,63 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { } }() - return ldap, nil + return nil } -func (l *LDAP) connect() (*ldapgo.Conn, error) { - log.Debug().Msg("Connecting to LDAP server") - conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: l.Config.Insecure, +func (ldap *LdapService) connect() (*ldapgo.Conn, error) { + conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.Config.Insecure, MinVersion: tls.VersionTLS12, })) if err != nil { return nil, err } - log.Debug().Msg("Binding to LDAP server") - err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) + err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword) if err != nil { return nil, err } // Set and return the connection - l.Conn = conn + ldap.Conn = conn return conn, nil } -func (l *LDAP) Search(username string) (string, error) { +func (ldap *LdapService) Search(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - l.Config.BaseDN, + ldap.Config.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, nil, ) - searchResult, err := l.Conn.Search(searchRequest) + searchResult, err := ldap.Conn.Search(searchRequest) if err != nil { return "", err } if len(searchResult.Entries) != 1 { - return "", fmt.Errorf("err multiple or no entries found for user %s", username) + return "", fmt.Errorf("multiple or no entries found for user %s", username) } userDN := searchResult.Entries[0].DN return userDN, nil } -func (l *LDAP) Bind(userDN string, password string) error { - err := l.Conn.Bind(userDN, password) +func (ldap *LdapService) Bind(userDN string, password string) error { + err := ldap.Conn.Bind(userDN, password) if err != nil { return err } return nil } -func (l *LDAP) heartbeat() error { +func (ldap *LdapService) heartbeat() error { log.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( @@ -110,7 +118,7 @@ func (l *LDAP) heartbeat() error { nil, ) - _, err := l.Conn.Search(searchRequest) + _, err := ldap.Conn.Search(searchRequest) if err != nil { return err } @@ -119,7 +127,7 @@ func (l *LDAP) heartbeat() error { return nil } -func (l *LDAP) reconnect() error { +func (ldap *LdapService) reconnect() error { log.Info().Msg("Reconnecting to LDAP server") exp := backoff.NewExponentialBackOff() @@ -129,10 +137,10 @@ func (l *LDAP) reconnect() error { exp.Reset() operation := func() (*ldapgo.Conn, error) { - l.Conn.Close() - conn, err := l.connect() + ldap.Conn.Close() + conn, err := ldap.connect() if err != nil { - return nil, nil + return nil, err } return conn, nil } diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go new file mode 100644 index 0000000..6b5b1e6 --- /dev/null +++ b/internal/service/oauth_broker_service.go @@ -0,0 +1,76 @@ +package service + +import ( + "errors" + "tinyauth/internal/config" + + "github.com/rs/zerolog/log" +) + +type OAuthService interface { + Init() error + GenerateState() string + GetAuthURL(state string) string + VerifyCode(code string) error + Userinfo() (config.Claims, error) +} + +type OAuthBrokerService struct { + Services map[string]OAuthService + Configs map[string]config.OAuthServiceConfig +} + +func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { + return &OAuthBrokerService{ + Services: make(map[string]OAuthService), + Configs: configs, + } +} + +func (broker *OAuthBrokerService) Init() error { + for name, cfg := range broker.Configs { + switch name { + case "github": + service := NewGithubOAuthService(cfg) + broker.Services[name] = service + case "google": + service := NewGoogleOAuthService(cfg) + broker.Services[name] = service + default: + service := NewGenericOAuthService(cfg) + broker.Services[name] = service + } + } + + for name, service := range broker.Services { + err := service.Init() + if err != nil { + log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name) + return err + } + log.Info().Msgf("Initialized OAuth service: %s", name) + } + + return nil +} + +func (broker *OAuthBrokerService) GetConfiguredServices() []string { + services := make([]string, 0, len(broker.Services)) + for name := range broker.Services { + services = append(services, name) + } + return services +} + +func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) { + service, exists := broker.Services[name] + return service, exists +} + +func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) { + oauthService, exists := broker.Services[service] + if !exists { + return config.Claims{}, errors.New("oauth service not found") + } + return oauthService.Userinfo() +} diff --git a/internal/types/api.go b/internal/types/api.go deleted file mode 100644 index fbf8bf7..0000000 --- a/internal/types/api.go +++ /dev/null @@ -1,62 +0,0 @@ -package types - -// LoginQuery is the query parameters for the login endpoint -type LoginQuery struct { - RedirectURI string `url:"redirect_uri"` -} - -// LoginRequest is the request body for the login endpoint -type LoginRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -// OAuthRequest is the request for the OAuth endpoint -type OAuthRequest struct { - Provider string `uri:"provider" binding:"required"` -} - -// UnauthorizedQuery is the query parameters for the unauthorized endpoint -type UnauthorizedQuery struct { - Username string `url:"username"` - Resource string `url:"resource"` - GroupErr bool `url:"groupErr"` - IP string `url:"ip"` -} - -// Proxy is the uri parameters for the proxy endpoint -type Proxy struct { - Proxy string `uri:"proxy" binding:"required"` -} - -// User Context response is the response for the user context endpoint -type UserContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - IsLoggedIn bool `json:"isLoggedIn"` - Username string `json:"username"` - Name string `json:"name"` - Email string `json:"email"` - Provider string `json:"provider"` - Oauth bool `json:"oauth"` - TotpPending bool `json:"totpPending"` -} - -// App Context is the response for the app context endpoint -type AppContext struct { - Status int `json:"status"` - Message string `json:"message"` - ConfiguredProviders []string `json:"configuredProviders"` - DisableContinue bool `json:"disableContinue"` - Title string `json:"title"` - GenericName string `json:"genericName"` - Domain string `json:"domain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` -} - -// Totp request is the request for the totp endpoint -type TotpRequest struct { - Code string `json:"code"` -} diff --git a/internal/types/types.go b/internal/types/types.go deleted file mode 100644 index 2c40ae5..0000000 --- a/internal/types/types.go +++ /dev/null @@ -1,59 +0,0 @@ -package types - -import ( - "time" - "tinyauth/internal/oauth" -) - -// User is the struct for a user -type User struct { - Username string - Password string - TotpSecret string -} - -// UserSearch is the response of the get user -type UserSearch struct { - Username string - Type string // "local", "ldap" or empty -} - -// Users is a list of users -type Users []User - -// OAuthProviders is the struct for the OAuth providers -type OAuthProviders struct { - Github *oauth.OAuth - Google *oauth.OAuth - Microsoft *oauth.OAuth -} - -// SessionCookie is the cookie for the session (exculding the expiry) -type SessionCookie struct { - Username string - Name string - Email string - Provider string - TotpPending bool - OAuthGroups string -} - -// UserContext is the context for the user -type UserContext struct { - Username string - Name string - Email string - IsLoggedIn bool - OAuth bool - Provider string - TotpPending bool - OAuthGroups string - TotpEnabled bool -} - -// LoginAttempt tracks information about login attempts for rate limiting -type LoginAttempt struct { - FailedAttempts int - LastAttempt time.Time - LockedUntil time.Time -} diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go new file mode 100644 index 0000000..85a8754 --- /dev/null +++ b/internal/utils/app_utils.go @@ -0,0 +1,123 @@ +package utils + +import ( + "errors" + "net" + "net/url" + "strings" + "tinyauth/internal/config" + + "github.com/gin-gonic/gin" + + "github.com/rs/zerolog" +) + +// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) +func GetUpperDomain(appUrl string) (string, error) { + appUrlParsed, err := url.Parse(appUrl) + if err != nil { + return "", err + } + + host := appUrlParsed.Hostname() + + if netIP := net.ParseIP(host); netIP != nil { + return "", errors.New("IP addresses are not allowed") + } + + urlParts := strings.Split(host, ".") + + if len(urlParts) < 2 { + return "", errors.New("invalid domain, must be at least second level domain") + } + + return strings.Join(urlParts[1:], "."), nil +} + +func ParseFileToLine(content string) string { + lines := strings.Split(content, "\n") + users := make([]string, 0) + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + users = append(users, strings.TrimSpace(line)) + } + + return strings.Join(users, ",") +} + +func Filter[T any](slice []T, test func(T) bool) (res []T) { + for _, value := range slice { + if test(value) { + res = append(res, value) + } + } + return res +} + +func GetContext(c *gin.Context) (config.UserContext, error) { + userContextValue, exists := c.Get("context") + + if !exists { + return config.UserContext{}, errors.New("no user context in request") + } + + userContext, ok := userContextValue.(*config.UserContext) + + if !ok { + return config.UserContext{}, errors.New("invalid user context in request") + } + + return *userContext, nil +} + +func IsRedirectSafe(redirectURL string, domain string) bool { + if redirectURL == "" { + return false + } + + parsedURL, err := url.Parse(redirectURL) + + if err != nil { + return false + } + + if !parsedURL.IsAbs() { + return false + } + + upper, err := GetUpperDomain(redirectURL) + + if err != nil { + return false + } + + if upper != domain { + return false + } + + return true +} + +func GetLogLevel(level string) zerolog.Level { + switch strings.ToLower(level) { + case "trace": + return zerolog.TraceLevel + case "debug": + return zerolog.DebugLevel + case "info": + return zerolog.InfoLevel + case "warn": + return zerolog.WarnLevel + case "error": + return zerolog.ErrorLevel + case "fatal": + return zerolog.FatalLevel + case "panic": + return zerolog.PanicLevel + default: + return zerolog.InfoLevel + } +} diff --git a/internal/utils/fs_utils.go b/internal/utils/fs_utils.go new file mode 100644 index 0000000..8b9f28b --- /dev/null +++ b/internal/utils/fs_utils.go @@ -0,0 +1,17 @@ +package utils + +import "os" + +func ReadFile(file string) (string, error) { + _, err := os.Stat(file) + if err != nil { + return "", err + } + + data, err := os.ReadFile(file) + if err != nil { + return "", err + } + + return string(data), nil +} diff --git a/internal/utils/label_utils.go b/internal/utils/label_utils.go new file mode 100644 index 0000000..f10092d --- /dev/null +++ b/internal/utils/label_utils.go @@ -0,0 +1,48 @@ +package utils + +import ( + "net/http" + "strings" + "tinyauth/internal/config" + + "github.com/traefik/paerser/parser" +) + +func GetLabels(labels map[string]string) (config.Labels, error) { + var labelsParsed config.Labels + + err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") + if err != nil { + return config.Labels{}, err + } + + return labelsParsed, nil +} + +func ParseHeaders(headers []string) map[string]string { + headerMap := make(map[string]string) + for _, header := range headers { + split := strings.SplitN(header, "=", 2) + if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { + continue + } + key := SanitizeHeader(strings.TrimSpace(split[0])) + if strings.ContainsAny(key, " \t") { + continue + } + key = http.CanonicalHeaderKey(key) + value := SanitizeHeader(strings.TrimSpace(split[1])) + headerMap[key] = value + } + return headerMap +} + +func SanitizeHeader(header string) string { + return strings.Map(func(r rune) rune { + // Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) + if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { + return r + } + return -1 + }, header) +} diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go new file mode 100644 index 0000000..a031900 --- /dev/null +++ b/internal/utils/security_utils.go @@ -0,0 +1,124 @@ +package utils + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "net" + "regexp" + "strings" + + "github.com/google/uuid" + "golang.org/x/crypto/hkdf" +) + +func GetSecret(conf string, file string) string { + if conf == "" && file == "" { + return "" + } + + if conf != "" { + return conf + } + + contents, err := ReadFile(file) + if err != nil { + return "" + } + + return ParseSecretFile(contents) +} + +func ParseSecretFile(contents string) string { + lines := strings.Split(contents, "\n") + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + return strings.TrimSpace(line) + } + + return "" +} + +func GetBasicAuth(username string, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func DeriveKey(secret string, info string) (string, error) { + hash := sha256.New + hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice + key := make([]byte, 24) + + _, err := io.ReadFull(hkdf, key) + if err != nil { + return "", err + } + + if bytes.Equal(key, make([]byte, 24)) { + return "", errors.New("derived key is empty") + } + + encodedKey := base64.StdEncoding.EncodeToString(key) + return encodedKey, nil +} + +func FilterIP(filter string, ip string) (bool, error) { + ipAddr := net.ParseIP(ip) + + if strings.Contains(filter, "/") { + _, cidr, err := net.ParseCIDR(filter) + if err != nil { + return false, err + } + return cidr.Contains(ipAddr), nil + } + + ipFilter := net.ParseIP(filter) + if ipFilter == nil { + return false, errors.New("invalid IP address in filter") + } + + if ipFilter.Equal(ipAddr) { + return true, nil + } + + return false, nil +} + +func CheckFilter(filter string, str string) bool { + if len(strings.TrimSpace(filter)) == 0 { + return true + } + + if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { + re, err := regexp.Compile(filter[1 : len(filter)-1]) + if err != nil { + return false + } + + if re.MatchString(strings.TrimSpace(str)) { + return true + } + } + + filterSplit := strings.Split(filter, ",") + + for _, item := range filterSplit { + if strings.TrimSpace(item) == strings.TrimSpace(str) { + return true + } + } + + return false +} + +func GenerateIdentifier(str string) string { + uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) + uuidString := uuid.String() + return strings.Split(uuidString, "-")[0] +} diff --git a/internal/utils/string_utils.go b/internal/utils/string_utils.go new file mode 100644 index 0000000..8a629ad --- /dev/null +++ b/internal/utils/string_utils.go @@ -0,0 +1,30 @@ +package utils + +import ( + "strings" +) + +func Capitalize(str string) string { + if len(str) == 0 { + return "" + } + return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) +} + +func CoalesceToString(value any) string { + switch v := value.(type) { + case []any: + strs := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + strs = append(strs, str) + continue + } + } + return strings.Join(strs, ",") + case string: + return v + default: + return "" + } +} diff --git a/internal/utils/user_utils.go b/internal/utils/user_utils.go new file mode 100644 index 0000000..0044db4 --- /dev/null +++ b/internal/utils/user_utils.go @@ -0,0 +1,92 @@ +package utils + +import ( + "errors" + "strings" + "tinyauth/internal/config" +) + +func ParseUsers(users string) ([]config.User, error) { + var usersParsed []config.User + + users = strings.TrimSpace(users) + + if users == "" { + return []config.User{}, nil + } + + userList := strings.Split(users, ",") + + if len(userList) == 0 { + return []config.User{}, errors.New("invalid user format") + } + + for _, user := range userList { + if strings.TrimSpace(user) == "" { + continue + } + parsed, err := ParseUser(strings.TrimSpace(user)) + if err != nil { + return []config.User{}, err + } + usersParsed = append(usersParsed, parsed) + } + + return usersParsed, nil +} + +func GetUsers(conf string, file string) ([]config.User, error) { + var users string + + if conf == "" && file == "" { + return []config.User{}, nil + } + + if conf != "" { + users += conf + } + + if file != "" { + contents, err := ReadFile(file) + if err != nil { + return []config.User{}, err + } + if users != "" { + users += "," + } + users += ParseFileToLine(contents) + } + + return ParseUsers(users) +} + +func ParseUser(user string) (config.User, error) { + if strings.Contains(user, "$$") { + user = strings.ReplaceAll(user, "$$", "$") + } + + userSplit := strings.Split(user, ":") + + if len(userSplit) < 2 || len(userSplit) > 3 { + return config.User{}, errors.New("invalid user format") + } + + for _, userPart := range userSplit { + if strings.TrimSpace(userPart) == "" { + return config.User{}, errors.New("invalid user format") + } + } + + if len(userSplit) == 2 { + return config.User{ + Username: strings.TrimSpace(userSplit[0]), + Password: strings.TrimSpace(userSplit[1]), + }, nil + } + + return config.User{ + Username: strings.TrimSpace(userSplit[0]), + Password: strings.TrimSpace(userSplit[1]), + TotpSecret: strings.TrimSpace(userSplit[2]), + }, nil +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go deleted file mode 100644 index 39b1518..0000000 --- a/internal/utils/utils.go +++ /dev/null @@ -1,350 +0,0 @@ -package utils - -import ( - "bytes" - "crypto/sha256" - "encoding/base64" - "errors" - "io" - "net" - "net/url" - "os" - "regexp" - "strings" - "tinyauth/internal/types" - - "github.com/traefik/paerser/parser" - "golang.org/x/crypto/hkdf" - - "github.com/google/uuid" - "github.com/rs/zerolog/log" -) - -// Parses a list of comma separated users in a struct -func ParseUsers(users string) (types.Users, error) { - log.Debug().Msg("Parsing users") - - var usersParsed types.Users - - userList := strings.Split(users, ",") - - if len(userList) == 0 { - return types.Users{}, errors.New("invalid user format") - } - - for _, user := range userList { - parsed, err := ParseUser(user) - if err != nil { - return types.Users{}, err - } - usersParsed = append(usersParsed, parsed) - } - - log.Debug().Msg("Parsed users") - return usersParsed, nil -} - -// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetUpperDomain(urlSrc string) (string, error) { - urlParsed, err := url.Parse(urlSrc) - if err != nil { - return "", err - } - - urlSplitted := strings.Split(urlParsed.Hostname(), ".") - urlFinal := strings.Join(urlSplitted[1:], ".") - - return urlFinal, nil -} - -// Reads a file and returns the contents -func ReadFile(file string) (string, error) { - _, err := os.Stat(file) - if err != nil { - return "", err - } - - data, err := os.ReadFile(file) - if err != nil { - return "", err - } - - return string(data), nil -} - -// Parses a file into a comma separated list of users -func ParseFileToLine(content string) string { - lines := strings.Split(content, "\n") - users := make([]string, 0) - - for _, line := range lines { - if strings.TrimSpace(line) == "" { - continue - } - users = append(users, strings.TrimSpace(line)) - } - - return strings.Join(users, ",") -} - -// Get the secret from the config or file -func GetSecret(conf string, file string) string { - if conf == "" && file == "" { - return "" - } - - if conf != "" { - return conf - } - - contents, err := ReadFile(file) - if err != nil { - return "" - } - - return ParseSecretFile(contents) -} - -// Get the users from the config or file -func GetUsers(conf string, file string) (types.Users, error) { - var users string - - if conf == "" && file == "" { - return types.Users{}, nil - } - - if conf != "" { - log.Debug().Msg("Using users from config") - users += conf - } - - if file != "" { - contents, err := ReadFile(file) - if err == nil { - log.Debug().Msg("Using users from file") - if users != "" { - users += "," - } - users += ParseFileToLine(contents) - } - } - - return ParseUsers(users) -} - -// Parse the headers in a map[string]string format -func ParseHeaders(headers []string) map[string]string { - headerMap := make(map[string]string) - - for _, header := range headers { - split := strings.SplitN(header, "=", 2) - if len(split) != 2 || strings.TrimSpace(split[0]) == "" || strings.TrimSpace(split[1]) == "" { - log.Warn().Str("header", header).Msg("Invalid header format, skipping") - continue - } - key := SanitizeHeader(strings.TrimSpace(split[0])) - value := SanitizeHeader(strings.TrimSpace(split[1])) - headerMap[key] = value - } - - return headerMap -} - -// Get labels parses a map of labels into a struct with only the needed labels -func GetLabels(labels map[string]string) (types.Labels, error) { - var labelsParsed types.Labels - - err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") - if err != nil { - log.Error().Err(err).Msg("Error parsing labels") - return types.Labels{}, err - } - - return labelsParsed, nil -} - -// Check if any of the OAuth providers are configured based on the client id and secret -func OAuthConfigured(config types.Config) bool { - return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") -} - -// Filter helper function -func Filter[T any](slice []T, test func(T) bool) (res []T) { - for _, value := range slice { - if test(value) { - res = append(res, value) - } - } - return res -} - -// Parse user -func ParseUser(user string) (types.User, error) { - if strings.Contains(user, "$$") { - user = strings.ReplaceAll(user, "$$", "$") - } - - userSplit := strings.Split(user, ":") - - if len(userSplit) < 2 || len(userSplit) > 3 { - return types.User{}, errors.New("invalid user format") - } - - for _, userPart := range userSplit { - if strings.TrimSpace(userPart) == "" { - return types.User{}, errors.New("invalid user format") - } - } - - if len(userSplit) == 2 { - return types.User{ - Username: strings.TrimSpace(userSplit[0]), - Password: strings.TrimSpace(userSplit[1]), - }, nil - } - - return types.User{ - Username: strings.TrimSpace(userSplit[0]), - Password: strings.TrimSpace(userSplit[1]), - TotpSecret: strings.TrimSpace(userSplit[2]), - }, nil -} - -// Parse secret file -func ParseSecretFile(contents string) string { - lines := strings.Split(contents, "\n") - - for _, line := range lines { - if strings.TrimSpace(line) == "" { - continue - } - return strings.TrimSpace(line) - } - - return "" -} - -// Check if a string matches a regex or if it is included in a comma separated list -func CheckFilter(filter string, str string) bool { - if len(strings.TrimSpace(filter)) == 0 { - return true - } - - if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { - re, err := regexp.Compile(filter[1 : len(filter)-1]) - if err != nil { - log.Error().Err(err).Msg("Error compiling regex") - return false - } - - if re.MatchString(str) { - return true - } - } - - filterSplit := strings.Split(filter, ",") - - for _, item := range filterSplit { - if strings.TrimSpace(item) == str { - return true - } - } - - return false -} - -// Capitalize just the first letter of a string -func Capitalize(str string) string { - if len(str) == 0 { - return "" - } - return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) -} - -// Sanitize header removes all control characters from a string -func SanitizeHeader(header string) string { - return strings.Map(func(r rune) rune { - // Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) - if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { - return r - } - return -1 - }, header) -} - -// Generate a static identifier from a string -func GenerateIdentifier(str string) string { - uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str)) - uuidString := uuid.String() - log.Debug().Str("uuid", uuidString).Msg("Generated UUID") - return strings.Split(uuidString, "-")[0] -} - -// Get a basic auth header from a username and password -func GetBasicAuth(username string, password string) string { - auth := username + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) -} - -// Check if an IP is contained in a CIDR range/matches a single IP -func FilterIP(filter string, ip string) (bool, error) { - ipAddr := net.ParseIP(ip) - - if strings.Contains(filter, "/") { - _, cidr, err := net.ParseCIDR(filter) - if err != nil { - return false, err - } - return cidr.Contains(ipAddr), nil - } - - ipFilter := net.ParseIP(filter) - if ipFilter == nil { - return false, errors.New("invalid IP address in filter") - } - - if ipFilter.Equal(ipAddr) { - return true, nil - } - - return false, nil -} - -func DeriveKey(secret string, info string) (string, error) { - hash := sha256.New - hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice - key := make([]byte, 24) - - _, err := io.ReadFull(hkdf, key) - if err != nil { - return "", err - } - - if bytes.Equal(key, make([]byte, 24)) { - return "", errors.New("derived key is empty") - } - - encodedKey := base64.StdEncoding.EncodeToString(key) - return encodedKey, nil -} - -func CoalesceToString(value any) string { - switch v := value.(type) { - case []any: - log.Debug().Msg("Coalescing []any to string") - strs := make([]string, 0, len(v)) - for _, item := range v { - if str, ok := item.(string); ok { - strs = append(strs, str) - continue - } - log.Warn().Interface("item", item).Msg("Item in []any is not a string, skipping") - } - return strings.Join(strs, ",") - case string: - return v - default: - log.Warn().Interface("value", value).Interface("type", v).Msg("Unsupported type, returning empty string") - return "" - } -} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go deleted file mode 100644 index 5ae7e89..0000000 --- a/internal/utils/utils_test.go +++ /dev/null @@ -1,548 +0,0 @@ -package utils_test - -import ( - "fmt" - "os" - "reflect" - "testing" - "tinyauth/internal/types" - "tinyauth/internal/utils" -) - -func TestParseUsers(t *testing.T) { - t.Log("Testing parse users with a valid string") - - users := "user1:pass1,user2:pass2" - expected := types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err := utils.ParseUsers(users) - if err != nil { - t.Fatalf("Error parsing users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestGetUpperDomain(t *testing.T) { - t.Log("Testing get upper domain with a valid url") - - url := "https://sub1.sub2.domain.com:8080" - expected := "sub2.domain.com" - - result, err := utils.GetUpperDomain(url) - if err != nil { - t.Fatalf("Error getting root url: %v", err) - } - - if expected != result { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestReadFile(t *testing.T) { - t.Log("Creating a test file") - - err := os.WriteFile("/tmp/test.txt", []byte("test"), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - t.Log("Testing read file with a valid file") - - data, err := utils.ReadFile("/tmp/test.txt") - if err != nil { - t.Fatalf("Error reading file: %v", err) - } - - if data != "test" { - t.Fatalf("Expected test, got %v", data) - } - - t.Log("Cleaning up test file") - - err = os.Remove("/tmp/test.txt") - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestParseFileToLine(t *testing.T) { - t.Log("Testing parse file to line with a valid string") - - content := "\nuser1:pass1\nuser2:pass2\n" - expected := "user1:pass1,user2:pass2" - - result := utils.ParseFileToLine(content) - - if expected != result { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestGetSecret(t *testing.T) { - t.Log("Testing get secret with an empty config and file") - - conf := "" - file := "/tmp/test.txt" - expected := "test" - - err := os.WriteFile(file, []byte(fmt.Sprintf("\n\n \n\n\n %s \n\n \n ", expected)), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - result := utils.GetSecret(conf, file) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get secret with an empty file and a valid config") - - result = utils.GetSecret(expected, "") - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get secret with both a valid config and file") - - result = utils.GetSecret(expected, file) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Cleaning up test file") - - err = os.Remove(file) - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestGetUsers(t *testing.T) { - t.Log("Testing get users with a config and no file") - - conf := "user1:pass1,user2:pass2" - file := "" - expected := types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err := utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get users with a file and no config") - - conf = "" - file = "/tmp/test.txt" - expected = types.Users{ - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - err = os.WriteFile(file, []byte("user1:pass1\nuser2:pass2"), 0644) - if err != nil { - t.Fatalf("Error creating test file: %v", err) - } - - result, err = utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing get users with both a config and file") - - conf = "user3:pass3" - expected = types.Users{ - { - Username: "user3", - Password: "pass3", - }, - { - Username: "user1", - Password: "pass1", - }, - { - Username: "user2", - Password: "pass2", - }, - } - - result, err = utils.GetUsers(conf, file) - if err != nil { - t.Fatalf("Error getting users: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Cleaning up test file") - - err = os.Remove(file) - if err != nil { - t.Fatalf("Error cleaning up test file: %v", err) - } -} - -func TestGetLabels(t *testing.T) { - t.Log("Testing get labels with a valid map") - - labels := map[string]string{ - "tinyauth.users": "user1,user2", - "tinyauth.oauth.whitelist": "/regex/", - "tinyauth.allowed": "random", - "tinyauth.headers": "X-Header=value", - "tinyauth.oauth.groups": "group1,group2", - } - - expected := types.Labels{ - Users: "user1,user2", - Allowed: "random", - Headers: []string{"X-Header=value"}, - OAuth: types.OAuthLabels{ - Whitelist: "/regex/", - Groups: "group1,group2", - }, - } - - result, err := utils.GetLabels(labels) - if err != nil { - t.Fatalf("Error getting labels: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseUser(t *testing.T) { - t.Log("Testing parse user with a valid user") - - user := "user:pass:secret" - expected := types.User{ - Username: "user", - Password: "pass", - TotpSecret: "secret", - } - - result, err := utils.ParseUser(user) - if err != nil { - t.Fatalf("Error parsing user: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse user with an escaped user") - - user = "user:p$$ass$$:secret" - expected = types.User{ - Username: "user", - Password: "p$ass$", - TotpSecret: "secret", - } - - result, err = utils.ParseUser(user) - if err != nil { - t.Fatalf("Error parsing user: %v", err) - } - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse user with an invalid user") - - user = "user::pass" - - _, err = utils.ParseUser(user) - if err == nil { - t.Fatalf("Expected error parsing user") - } -} - -func TestCheckFilter(t *testing.T) { - t.Log("Testing check filter with a comma separated list") - - filter := "user1,user2,user3" - str := "user1" - expected := true - - result := utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with a regex filter") - - filter = "/^user[0-9]+$/" - str = "user1" - expected = true - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with an empty filter") - - filter = "" - str = "user1" - expected = true - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with an invalid regex filter") - - filter = "/^user[0-9+$/" - str = "user1" - expected = false - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing check filter with a non matching list") - - filter = "user1,user2,user3" - str = "user4" - expected = false - - result = utils.CheckFilter(filter, str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestSanitizeHeader(t *testing.T) { - t.Log("Testing sanitize header with a valid string") - - str := "X-Header=value" - expected := "X-Header=value" - - result := utils.SanitizeHeader(str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing sanitize header with an invalid string") - - str = "X-Header=val\nue" - expected = "X-Header=value" - - result = utils.SanitizeHeader(str) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseHeaders(t *testing.T) { - t.Log("Testing parse headers with a valid string") - - headers := []string{"X-Hea\x00der1=value1", "X-Header2=value\n2"} - expected := map[string]string{ - "X-Header1": "value1", - "X-Header2": "value2", - } - - result := utils.ParseHeaders(headers) - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing parse headers with an invalid string") - - headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"} - expected = map[string]string{"X-Header3": "value3"} - - result = utils.ParseHeaders(headers) - - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestParseSecretFile(t *testing.T) { - t.Log("Testing parse secret file with a valid file") - - content := "\n\n \n\n\n secret \n\n \n " - expected := "secret" - - result := utils.ParseSecretFile(content) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestFilterIP(t *testing.T) { - t.Log("Testing filter IP with an IP and a valid CIDR") - - ip := "10.10.10.10" - filter := "10.10.10.0/24" - expected := true - - result, err := utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and a valid IP") - - filter = "10.10.10.10" - expected = true - - result, err = utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and an non matching CIDR") - - filter = "10.10.15.0/24" - expected = false - - result, err = utils.FilterIP(filter, ip) - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with a non matching IP and a valid CIDR") - - filter = "10.10.10.11" - expected = false - - result, err = utils.FilterIP(filter, ip) - - if err != nil { - t.Fatalf("Error filtering IP: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing filter IP with an IP and an invalid CIDR") - - filter = "10.../83" - - _, err = utils.FilterIP(filter, ip) - if err == nil { - t.Fatalf("Expected error filtering IP") - } -} - -func TestDeriveKey(t *testing.T) { - t.Log("Testing the derive key function") - - master := "master" - info := "info" - expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl" - - result, err := utils.DeriveKey(master, info) - - if err != nil { - t.Fatalf("Error deriving key: %v", err) - } - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - -func TestCoalesceToString(t *testing.T) { - t.Log("Testing coalesce to string with a string") - - value := any("test") - expected := "test" - - result := utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing coalesce to string with a slice of strings") - - value = []any{any("test1"), any("test2"), any(123)} - expected = "test1,test2" - - result = utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing coalesce to string with an unsupported type") - - value = 12345 - expected = "" - - result = utils.CoalesceToString(value) - - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} diff --git a/main.go b/main.go index 27792d8..8126e9e 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,6 @@ import ( ) func main() { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.FatalLevel) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).With().Timestamp().Caller().Logger() cmd.Execute() }