diff options
| author | Stefan Boberg <[email protected]> | 2026-03-10 17:27:26 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-10 17:27:26 +0100 |
| commit | d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7 (patch) | |
| tree | 2dfe1e3e0b620043d358e0b7f8bdf8320d985491 /src | |
| parent | changelog entry which was inadvertently omitted from PR merge (diff) | |
| download | zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.tar.xz zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.zip | |
HttpClient using libcurl, Unix Sockets for HTTP. HTTPS support (#770)
The main goal of this change is to eliminate the cpr back-end altogether and replace it with the curl implementation. I would expect to drop cpr as soon as we feel happy with the libcurl back-end. That would leave us with a direct dependency on libcurl only, and cpr can be eliminated as a dependency.
### HttpClient Backend Overhaul
- Implemented a new **libcurl-based HttpClient** backend (`httpclientcurl.cpp`, ~2000 lines)
as an alternative to the cpr-based one
- Made HttpClient backend **configurable at runtime** via constructor arguments
and `-httpclient=...` CLI option (for zen, zenserver, and tests)
- Extended HttpClient test suite to cover multipart/content-range scenarios
### Unix Domain Socket Support
- Added Unix domain socket support to **httpasio** (server side)
- Added Unix domain socket support to **HttpClient**
- Added Unix domain socket support to **HttpWsClient** (WebSocket client)
- Templatized `HttpServerConnectionT<SocketType>` and `WsAsioConnectionT<SocketType>`
to handle TCP, Unix, and SSL sockets uniformly via `if constexpr` dispatch
### HTTPS Support
- Added **preliminary HTTPS support to httpasio** (for Mac/Linux via OpenSSL)
- Added **basic HTTPS support for http.sys** (Windows)
- Implemented HTTPS test for httpasio
- Split `InitializeServer` into smaller sub-functions for http.sys
### Other Notable Changes
- Improved **zenhttp-test stability** with dynamic port allocation
- Enhanced port retry logic in http.sys (handles ERROR_ACCESS_DENIED)
- Fatal signal/exception handlers for backtrace generation in tests
- Added `zen bench http` subcommand to exercise network + HTTP client/server communication stack
Diffstat (limited to 'src')
43 files changed, 4881 insertions, 532 deletions
diff --git a/src/zen/bench.cpp b/src/zen/bench.cpp index 614454ed5..2332ce1b8 100644 --- a/src/zen/bench.cpp +++ b/src/zen/bench.cpp @@ -119,6 +119,53 @@ EmptyStandByList() } // namespace zen::bench::util +#elif ZEN_PLATFORM_LINUX + +# include <fcntl.h> +# include <unistd.h> + +namespace zen::bench::util { + +void +EmptyStandByList() +{ + sync(); + + int Fd = open("/proc/sys/vm/drop_caches", O_WRONLY); + if (Fd < 0) + { + throw std::runtime_error("Failed to open /proc/sys/vm/drop_caches (are you running as root?)"); + } + + if (write(Fd, "3", 1) != 1) + { + close(Fd); + throw std::runtime_error("Failed to write to /proc/sys/vm/drop_caches"); + } + + close(Fd); +} + +} // namespace zen::bench::util + +#elif ZEN_PLATFORM_MAC + +# include <cstdlib> + +namespace zen::bench::util { + +void +EmptyStandByList() +{ + int Result = system("/usr/sbin/purge"); + if (Result != 0) + { + throw std::runtime_error("Failed to run /usr/sbin/purge (are you running as root?)"); + } +} + +} // namespace zen::bench::util + #else namespace zen::bench::util { diff --git a/src/zen/cmds/bench_cmd.cpp b/src/zen/cmds/bench_cmd.cpp index b9c45a328..658b42da6 100644 --- a/src/zen/cmds/bench_cmd.cpp +++ b/src/zen/cmds/bench_cmd.cpp @@ -3,6 +3,7 @@ #include "bench_cmd.h" #include "bench.h" +#include <zencore/compactbinary.h> #include <zencore/except.h> #include <zencore/filesystem.h> #include <zencore/fmtutils.h> @@ -11,93 +12,514 @@ #include <zencore/string.h> #include <zencore/thread.h> #include <zencore/timer.h> +#include <zenhttp/httpclient.h> +#include <zentelemetry/stats.h> + +#include <algorithm> +#include <atomic> +#include <csignal> +#include <mutex> +#include <thread> + +static std::atomic<bool> s_BenchAbort{false}; namespace zen { -BenchCommand::BenchCommand() +////////////////////////////////////////////////////////////////////////// +// BenchPurgeSubCmd + +BenchPurgeSubCmd::BenchPurgeSubCmd() +: ZenSubCmdBase("purge", "Purge standby memory (system cache)") { - m_Options.add_options()("h,help", "Print help"); - m_Options.add_options()("purge", - "Purge standby memory (system cache)", - cxxopts::value<bool>(m_PurgeStandbyLists)->default_value("false")); - m_Options.add_options()("single", "Do not spawn child processes", cxxopts::value<bool>(m_SingleProcess)->default_value("false")); + SubOptions().add_options()("single", + "Do not spawn child processes", + cxxopts::value<bool>(m_SingleProcess)->default_value("false")); } -BenchCommand::~BenchCommand() = default; - void -BenchCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +BenchPurgeSubCmd::Run(const ZenCliOptions& GlobalOptions) { ZEN_UNUSED(GlobalOptions); - if (!ParseOptions(argc, argv)) + bool Ok = false; + + zen::Stopwatch Timer; + + try + { + zen::bench::util::EmptyStandByList(); + + Ok = true; + } + catch (const zen::bench::util::elevation_required_exception&) + { + ZEN_CONSOLE_WARN("Purging standby lists requires elevation. Will try launch as elevated process"); + } + catch (const std::exception& Ex) { - return; + ZEN_CONSOLE_ERROR("{}", Ex.what()); } #if ZEN_PLATFORM_WINDOWS - if (m_PurgeStandbyLists) + if (!Ok && !m_SingleProcess) { - bool Ok = false; - - zen::Stopwatch Timer; - try { - zen::bench::util::EmptyStandByList(); + zen::CreateProcOptions Cpo; + Cpo.Flags = zen::CreateProcOptions::Flag_Elevated | zen::CreateProcOptions::Flag_NewConsole; - Ok = true; - } - catch (const zen::bench::util::elevation_required_exception&) - { - ZEN_CONSOLE_WARN("Purging standby lists requires elevation. Will try launch as elevated process"); + std::filesystem::path CurExe{zen::GetRunningExecutablePath()}; + + if (zen::CreateProcResult Cpr = zen::CreateProc(CurExe, fmt::format("bench purge --single"), Cpo)) + { + zen::ProcessHandle ProcHandle; + ProcHandle.Initialize(Cpr); + + int ExitCode = ProcHandle.WaitExitCode(); + + if (ExitCode == 0) + { + Ok = true; + } + else + { + ZEN_CONSOLE_ERROR("Elevated child process failed with return code {}", ExitCode); + } + } } catch (const std::exception& Ex) { ZEN_CONSOLE_ERROR("{}", Ex.what()); } + } +#endif + + if (Ok) + { + // TODO: could also add reporting on just how much memory was purged + ZEN_CONSOLE("Purged standby lists! (took {})", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + } +} + +////////////////////////////////////////////////////////////////////////// +// BenchHttpSubCmd + + +BenchHttpSubCmd::BenchHttpSubCmd() +: ZenSubCmdBase("http", "Benchmark an HTTP server") +{ + SubOptions().add_option("", "u", "url", "URL to benchmark", cxxopts::value<std::string>(m_Url), "<url>"); + SubOptions().add_option("", + "n", + "count", + "Number of requests to send", + cxxopts::value<int>(m_Count)->default_value("100"), + "<count>"); + SubOptions().add_option("", + "c", + "concurrency", + "Number of concurrent threads", + cxxopts::value<int>(m_Concurrency)->default_value("1"), + "<threads>"); + SubOptions().add_option("", + "", + "method", + "HTTP method to use (GET, HEAD)", + cxxopts::value<std::string>(m_Method)->default_value("GET"), + "<method>"); + SubOptions().add_option("", + "", + "unix-socket", + "Unix domain socket path (overrides TCP)", + cxxopts::value<std::string>(m_SocketPath), + "<path>"); + SubOptions().add_options()("no-keepalive", + "Close connection after each request (disables keep-alive)", + cxxopts::value<bool>(m_NoKeepAlive)->default_value("false")); + SubOptions().add_options()("continuous", + "Run until interrupted (Ctrl+C), printing metrics once per second", + cxxopts::value<bool>(m_Continuous)->default_value("false")); + SubOptions().parse_positional({"url"}); +} + +static std::pair<std::string, std::string> +SplitUrl(std::string_view Url) +{ + size_t SchemeEnd = Url.find("://"); + size_t SearchFrom = (SchemeEnd != std::string_view::npos) ? SchemeEnd + 3 : 0; + size_t PathStart = Url.find('/', SearchFrom); + + if (PathStart == std::string_view::npos) + { + return {std::string(Url), "/"}; + } + + return {std::string(Url.substr(0, PathStart)), std::string(Url.substr(PathStart))}; +} + +void +BenchHttpSubCmd::Run(const ZenCliOptions& GlobalOptions) +{ + ZEN_UNUSED(GlobalOptions); + + if (m_Url.empty()) + { + throw OptionParseException("URL is required", SubOptions().help()); + } + + if (!m_Continuous && m_Count <= 0) + { + throw OptionParseException("--count must be a positive integer", SubOptions().help()); + } + + if (m_Concurrency <= 0) + { + throw OptionParseException("--concurrency must be a positive integer", SubOptions().help()); + } + + if (m_Method != "GET" && m_Method != "HEAD") + { + throw OptionParseException(fmt::format("Unsupported HTTP method '{}'. Supported: GET, HEAD", m_Method), + SubOptions().help()); + } + + auto [BaseUri, Path] = SplitUrl(m_Url); + + std::string ModeStr = m_Continuous ? "continuous" : fmt::format("count={}", m_Count); + + if (m_SocketPath.empty()) + { + ZEN_CONSOLE("Benchmarking {} {} ({}, concurrency={})", m_Method, m_Url, ModeStr, m_Concurrency); + } + else + { + ZEN_CONSOLE("Benchmarking {} {} via {} ({}, concurrency={})", m_Method, m_Url, m_SocketPath, ModeStr, m_Concurrency); + } + + // Probe for a zenserver identity. If the target exposes /health/info and the + // response contains a BuildVersion field we print a short summary. Any failure + // (non-zenserver, timeout, unreachable) is silently ignored. + try + { + HttpClientSettings ProbeSettings{.ConnectTimeout = std::chrono::milliseconds(2000), + .Timeout = std::chrono::milliseconds(2000), + .UnixSocketPath = m_SocketPath}; + HttpClient ProbeHttp(BaseUri, ProbeSettings); + HttpClient::Response ProbeResp = ProbeHttp.Get("/health/info"); + + if (ProbeResp.IsSuccess()) + { + CbObject Info = ProbeResp.AsObject(); + std::string_view BuildVersion = Info["BuildVersion"].AsString(); + + if (!BuildVersion.empty()) + { + std::string_view Hostname = Info["Hostname"].AsString(); + int64_t Pid = Info["Pid"].AsInt64(); + std::string_view HttpServerClass = Info["HttpServerClass"].AsString(); + ZEN_CONSOLE("Remote : zenserver {} on {} (pid {}, {})", BuildVersion, Hostname, Pid, HttpServerClass); + + std::string_view OS = Info["OS"].AsString(); + std::string_view Arch = Info["Arch"].AsString(); + + CbObjectView System = Info["System"].AsObjectView(); + int64_t LpCount = System["lp_count"].AsInt64(); + int64_t TotalMemMiB = System["total_memory_mb"].AsInt64(); + + ZEN_CONSOLE(" : {}, {}, {} logical processors, {} RAM", + OS, + Arch, + LpCount, + NiceBytes(static_cast<uint64_t>(TotalMemMiB) * 1024 * 1024)); + } + } + } + catch (...) + { + } + + if (m_Continuous) + { + RunContinuous(BaseUri, Path); + } + else + { + RunFixedCount(BaseUri, Path); + } +} + +void +BenchHttpSubCmd::RunFixedCount(const std::string& BaseUri, const std::string& Path) +{ + std::atomic<int> NextRequest{0}; + std::vector<double> AllLatencies; + AllLatencies.reserve(m_Count); + std::mutex LatencyMutex; + std::atomic<int> ErrorCount{0}; + std::atomic<int64_t> TotalDownloadedBytes{0}; + std::atomic<int64_t> TotalUploadedBytes{0}; + + Stopwatch Timer; - if (!Ok && !m_SingleProcess) + auto WorkerFn = [&]() { + std::vector<double> LocalLatencies; + + HttpClientSettings Settings{.UnixSocketPath = m_SocketPath, + .ForbidReuseConnection = m_NoKeepAlive}; + HttpClient Http(BaseUri, Settings); + + while (true) { + int RequestIndex = NextRequest.fetch_add(1); + + if (RequestIndex >= m_Count) + { + break; + } + try { - zen::CreateProcOptions Cpo; - Cpo.Flags = zen::CreateProcOptions::Flag_Elevated | zen::CreateProcOptions::Flag_NewConsole; + HttpClient::Response Resp = (m_Method == "HEAD") ? Http.Head(Path) : Http.Get(Path); + + if (Resp.IsSuccess()) + { + LocalLatencies.push_back(Resp.ElapsedSeconds); + TotalDownloadedBytes.fetch_add(Resp.DownloadedBytes); + TotalUploadedBytes.fetch_add(Resp.UploadedBytes); + } + else + { + ErrorCount.fetch_add(1); + } + } + catch (const HttpClientError&) + { + ErrorCount.fetch_add(1); + } + } + + std::lock_guard Lock(LatencyMutex); + AllLatencies.insert(AllLatencies.end(), LocalLatencies.begin(), LocalLatencies.end()); + }; + + std::vector<std::thread> Threads; + Threads.reserve(m_Concurrency); + + for (int i = 0; i < m_Concurrency; ++i) + { + Threads.emplace_back(WorkerFn); + } + + for (std::thread& T : Threads) + { + T.join(); + } + + double TotalSeconds = Timer.GetElapsedTimeMs() / 1000.0; + int SuccessCount = static_cast<int>(AllLatencies.size()); + int TotalCount = SuccessCount + ErrorCount.load(); + + std::sort(AllLatencies.begin(), AllLatencies.end()); + + auto PercentileMs = [&](int Pct) -> double { + if (AllLatencies.empty()) + { + return 0.0; + } + + size_t Index = std::min(AllLatencies.size() * static_cast<size_t>(Pct) / 100, AllLatencies.size() - 1); + + return AllLatencies[Index] * 1000.0; + }; + + double SumMs = 0.0; + for (double L : AllLatencies) + { + SumMs += L * 1000.0; + } + + double MeanMs = SuccessCount > 0 ? SumMs / SuccessCount : 0.0; + double Rps = TotalSeconds > 0.0 ? TotalCount / TotalSeconds : 0.0; + + uint64_t DownBytesPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalDownloadedBytes.load() / TotalSeconds) : 0; + uint64_t UpBytesPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalUploadedBytes.load() / TotalSeconds) : 0; + + ZEN_CONSOLE(" Requests : {:L} total, {:L} success, {:L} errors", TotalCount, SuccessCount, ErrorCount.load()); + ZEN_CONSOLE(" Latency : min={:.1f}ms mean={:.1f}ms p50={:.1f}ms p95={:.1f}ms p99={:.1f}ms max={:.1f}ms", + PercentileMs(0), + MeanMs, + PercentileMs(50), + PercentileMs(95), + PercentileMs(99), + PercentileMs(100)); + ZEN_CONSOLE(" Throughput: {:.1f} req/s down: {}/s up: {}/s (elapsed: {:.2f}s)", + Rps, + NiceBytes(DownBytesPerSec), + NiceBytes(UpBytesPerSec), + TotalSeconds); +} - std::filesystem::path CurExe{zen::GetRunningExecutablePath()}; +void +BenchHttpSubCmd::RunContinuous(const std::string& BaseUri, const std::string& Path) +{ + s_BenchAbort.store(false); + + auto PrevSigInt = std::signal(SIGINT, [](int) { s_BenchAbort.store(true); }); + auto PrevSigTerm = std::signal(SIGTERM, [](int) { s_BenchAbort.store(true); }); + + metrics::Histogram LatencyHistogram; + std::atomic<int64_t> IntervalSuccessCount{0}; + std::atomic<int64_t> IntervalErrorCount{0}; + std::atomic<int64_t> IntervalDownloadBytes{0}; + std::atomic<int64_t> IntervalUploadBytes{0}; + std::atomic<int64_t> TotalSuccessCount{0}; + std::atomic<int64_t> TotalErrorCount{0}; + std::atomic<int64_t> TotalDownloadBytes{0}; + std::atomic<int64_t> TotalUploadBytes{0}; + + Stopwatch RunTimer; + + auto WorkerFn = [&]() { + HttpClientSettings Settings{.UnixSocketPath = m_SocketPath, + .ForbidReuseConnection = m_NoKeepAlive}; + HttpClient Http(BaseUri, Settings); - if (zen::CreateProcResult Cpr = zen::CreateProc(CurExe, fmt::format("bench --purge --single"), Cpo)) + while (!s_BenchAbort.load(std::memory_order_relaxed)) + { + try + { + HttpClient::Response Resp = (m_Method == "HEAD") ? Http.Head(Path) : Http.Get(Path); + + if (Resp.IsSuccess()) + { + LatencyHistogram.Update(static_cast<int64_t>(Resp.ElapsedSeconds * 1.0e6)); + IntervalSuccessCount.fetch_add(1, std::memory_order_relaxed); + IntervalDownloadBytes.fetch_add(Resp.DownloadedBytes, std::memory_order_relaxed); + IntervalUploadBytes.fetch_add(Resp.UploadedBytes, std::memory_order_relaxed); + TotalSuccessCount.fetch_add(1, std::memory_order_relaxed); + TotalDownloadBytes.fetch_add(Resp.DownloadedBytes, std::memory_order_relaxed); + TotalUploadBytes.fetch_add(Resp.UploadedBytes, std::memory_order_relaxed); + } + else { - zen::ProcessHandle ProcHandle; - ProcHandle.Initialize(Cpr); - - int ExitCode = ProcHandle.WaitExitCode(); - - if (ExitCode == 0) - { - Ok = true; - } - else - { - ZEN_CONSOLE_ERROR("Elevated child process failed with return code {}", ExitCode); - } + IntervalErrorCount.fetch_add(1, std::memory_order_relaxed); + TotalErrorCount.fetch_add(1, std::memory_order_relaxed); } } - catch (const std::exception& Ex) + catch (const HttpClientError&) { - ZEN_CONSOLE_ERROR("{}", Ex.what()); + IntervalErrorCount.fetch_add(1, std::memory_order_relaxed); + TotalErrorCount.fetch_add(1, std::memory_order_relaxed); } } + }; - if (Ok) + auto ReporterFn = [&]() { + while (!s_BenchAbort.load(std::memory_order_relaxed)) { - // TODO: could also add reporting on just how much memory was purged - ZEN_CONSOLE("Purged standby lists! (took {})", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + // Sleep 1s in short increments to stay responsive to abort + for (int i = 0; i < 10 && !s_BenchAbort.load(std::memory_order_relaxed); ++i) + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + if (s_BenchAbort.load(std::memory_order_relaxed)) + { + break; + } + + // Snapshot and reset per-interval counters + int64_t Successes = IntervalSuccessCount.exchange(0); + int64_t Errors = IntervalErrorCount.exchange(0); + int64_t DownBytes = IntervalDownloadBytes.exchange(0); + int64_t UpBytes = IntervalUploadBytes.exchange(0); + + // Snapshot and reset latency histogram + uint64_t HistCount = LatencyHistogram.Count(); + int64_t HistMin = LatencyHistogram.Min(); + int64_t HistMax = LatencyHistogram.Max(); + double HistMean = LatencyHistogram.Mean(); + metrics::SampleSnapshot Snap = LatencyHistogram.Snapshot(); + LatencyHistogram.Clear(); + + // Format elapsed as HH:MM:SS + int TotalSec = static_cast<int>(RunTimer.GetElapsedTimeMs() / 1000.0); + int Hours = TotalSec / 3600; + int Minutes = (TotalSec % 3600) / 60; + int Secs = TotalSec % 60; + + if (HistCount > 0) + { + ZEN_CONSOLE( + "[{:02d}:{:02d}:{:02d}] req/s: {:L} errors: {:L} lat(ms): min={:.1f} mean={:.1f} p95={:.1f} p99={:.1f} max={:.1f} down: {}/s up: {}/s", + Hours, + Minutes, + Secs, + Successes, + Errors, + HistMin / 1000.0, + HistMean / 1000.0, + Snap.Get95Percentile() / 1000.0, + Snap.Get99Percentile() / 1000.0, + HistMax / 1000.0, + NiceBytes(static_cast<uint64_t>(std::max(int64_t{0}, DownBytes))), + NiceBytes(static_cast<uint64_t>(std::max(int64_t{0}, UpBytes)))); + } + else + { + ZEN_CONSOLE("[{:02d}:{:02d}:{:02d}] req/s: 0 errors: {:L} (no successful requests)", + Hours, + Minutes, + Secs, + Errors); + } } + }; + + std::vector<std::thread> Threads; + Threads.reserve(m_Concurrency + 1); + Threads.emplace_back(ReporterFn); + + for (int i = 0; i < m_Concurrency; ++i) + { + Threads.emplace_back(WorkerFn); } -#endif - return; + for (std::thread& T : Threads) + { + T.join(); + } + + std::signal(SIGINT, PrevSigInt); + std::signal(SIGTERM, PrevSigTerm); + + double TotalSeconds = RunTimer.GetElapsedTimeMs() / 1000.0; + int64_t TotalCount = TotalSuccessCount.load() + TotalErrorCount.load(); + uint64_t DownPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalDownloadBytes.load() / TotalSeconds) : 0; + uint64_t UpPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalUploadBytes.load() / TotalSeconds) : 0; + + ZEN_CONSOLE("Stopped. Total: {:L} requests, {:L} success, {:L} errors avg throughput: down {}/s up {}/s (elapsed: {:.2f}s)", + TotalCount, + TotalSuccessCount.load(), + TotalErrorCount.load(), + NiceBytes(DownPerSec), + NiceBytes(UpPerSec), + TotalSeconds); +} + +////////////////////////////////////////////////////////////////////////// +// BenchCommand + +BenchCommand::BenchCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), ""); + m_Options.parse_positional({"subcommand"}); + + AddSubCommand(m_PurgeSubCmd); + AddSubCommand(m_HttpSubCmd); } +BenchCommand::~BenchCommand() = default; + } // namespace zen diff --git a/src/zen/cmds/bench_cmd.h b/src/zen/cmds/bench_cmd.h index 7fbf85340..f332b3fcc 100644 --- a/src/zen/cmds/bench_cmd.h +++ b/src/zen/cmds/bench_cmd.h @@ -6,7 +6,36 @@ namespace zen { -class BenchCommand : public ZenCmdBase +class BenchPurgeSubCmd : public ZenSubCmdBase +{ +public: + BenchPurgeSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + bool m_SingleProcess = false; +}; + +class BenchHttpSubCmd : public ZenSubCmdBase +{ +public: + BenchHttpSubCmd(); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + void RunFixedCount(const std::string& BaseUri, const std::string& Path); + void RunContinuous(const std::string& BaseUri, const std::string& Path); + + std::string m_Url; + std::string m_SocketPath; + int m_Count = 100; + int m_Concurrency = 1; + std::string m_Method = "GET"; + bool m_NoKeepAlive = false; + bool m_Continuous = false; +}; + +class BenchCommand : public ZenCmdWithSubCommands { public: static constexpr char Name[] = "bench"; @@ -15,14 +44,14 @@ public: BenchCommand(); ~BenchCommand(); - virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; - virtual cxxopts::Options& Options() override { return m_Options; } - virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } + cxxopts::Options& Options() override { return m_Options; } + ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; } private: cxxopts::Options m_Options{Name, Description}; - bool m_PurgeStandbyLists = false; - bool m_SingleProcess = false; + std::string m_SubCommand; + BenchPurgeSubCmd m_PurgeSubCmd; + BenchHttpSubCmd m_HttpSubCmd; }; } // namespace zen diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua index f889c3296..4c134404a 100644 --- a/src/zen/xmake.lua +++ b/src/zen/xmake.lua @@ -6,7 +6,7 @@ target("zen") add_files("**.cpp") add_files("zen.cpp", {unity_ignored = true }) add_deps("zencore", "zenhttp", "zenremotestore", "zenstore", "zenutil") - add_deps("zencompute", "zennet") + add_deps("zencompute", "zennet", "zentelemetry") add_deps("cxxopts", "fmt") add_packages("json11") add_includedirs(".") diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 9a466da2e..86c29344e 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -196,6 +196,7 @@ ZenCmdBase::GetSubCommand(cxxopts::Options&, ZenSubCmdBase::ZenSubCmdBase(std::string_view Name, std::string_view Description) : m_SubOptions(std::string(Name), std::string(Description)) +, m_Description(Description) { m_SubOptions.add_options()("h,help", "Print help"); } @@ -213,6 +214,35 @@ ZenCmdWithSubCommands::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOption } void +ZenCmdWithSubCommands::PrintHelp() +{ + // Show all option groups except the internal "__hidden__" group used to + // silently capture positional arguments. + std::vector<std::string> Groups = Options().groups(); + Groups.erase(std::remove(Groups.begin(), Groups.end(), std::string("__hidden__")), Groups.end()); + + Options().set_width(TuiConsoleColumns(80)); + printf("%s\n", Options().help(Groups).c_str()); + + // Append subcommand listing. + size_t MaxNameLen = 0; + for (ZenSubCmdBase* SubCmd : m_SubCommands) + { + MaxNameLen = std::max(MaxNameLen, SubCmd->SubOptions().program().size()); + } + + printf("subcommands:\n"); + for (ZenSubCmdBase* SubCmd : m_SubCommands) + { + printf(" %-*s %s\n", + static_cast<int>(MaxNameLen), + SubCmd->SubOptions().program().c_str(), + std::string(SubCmd->Description()).c_str()); + } + printf("\nFor global options run: zen --help\n"); +} + +void ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { std::vector<cxxopts::Options*> SubOptionPtrs; @@ -226,15 +256,47 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** std::vector<char*> SubCommandArguments; int ParentArgc = GetSubCommand(Options(), argc, argv, SubOptionPtrs, MatchedSubOption, SubCommandArguments); - if (!ParseOptions(Options(), ParentArgc, argv)) + // Intercept --help/-h in the parent arg range before calling ParseOptions so + // we can append subcommand information to the output. When a subcommand was + // found argv[ParentArgc-1] is the subcommand name itself, which we exclude. + int ParentArgEnd = (MatchedSubOption != nullptr) ? ParentArgc - 1 : ParentArgc; + for (int i = 1; i < ParentArgEnd; ++i) { - return; + std::string_view Arg(argv[i]); + if (Arg == "--help" || Arg == "-h") + { + PrintHelp(); + return; + } + } + + // Parse parent options. When a subcommand was matched we strip its name from + // the arg list so the parent parser does not see it as an unmatched positional. + if (MatchedSubOption != nullptr) + { + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1)); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs)); + if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data())) + { + return; + } + } + else + { + if (!ParseOptions(Options(), ParentArgc, argv)) + { + return; + } } if (MatchedSubOption == nullptr) { + PrintHelp(); + ExtendableStringBuilder<128> VerbList; - for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands) + for (bool First = true; ZenSubCmdBase* SubCmd : m_SubCommands) { if (!First) { @@ -243,7 +305,7 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** VerbList.Append(SubCmd->SubOptions().program()); First = false; } - throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), Options().help()); + throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), {}); } ZenSubCmdBase* MatchedSubCmd = nullptr; @@ -621,6 +683,9 @@ main(int argc, char** argv) Options.add_options()("malloc", "Configure memory allocator subsystem", cxxopts::value(MemoryOptions)->default_value("mimalloc")); Options.add_options()("help", "Show command line help"); Options.add_options()("c, command", "Sub command", cxxopts::value<std::string>(SubCommand)); + Options.add_options()("httpclient", + "Select HTTP client implementation (e.g. 'curl', 'cpr')", + cxxopts::value<std::string>(GlobalOptions.HttpClientBackend)->default_value("cpr")); int CoreLimit = 0; @@ -783,6 +848,8 @@ main(int argc, char** argv) FreeCallstack(Callstack); }); + zen::SetDefaultHttpClientBackend(GlobalOptions.HttpClientBackend); + zen::MaximizeOpenFileCount(); ////////////////////////////////////////////////////////////////////////// diff --git a/src/zen/zen.h b/src/zen/zen.h index 06e5356a6..3cc06eea6 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -17,6 +17,8 @@ struct ZenCliOptions ZenLoggingConfig LoggingConfig; + std::string HttpClientBackend; // Choice of HTTP client implementation (e.g. "curl", "cpr") + // Arguments after " -- " on command line are passed through and not parsed std::string PassthroughCommandLine; std::string PassthroughArgs; @@ -86,10 +88,14 @@ public: ZenSubCmdBase(std::string_view Name, std::string_view Description); virtual ~ZenSubCmdBase() = default; cxxopts::Options& SubOptions() { return m_SubOptions; } + std::string_view Description() const { return m_Description; } virtual void Run(const ZenCliOptions& GlobalOptions) = 0; protected: cxxopts::Options m_SubOptions; + +private: + std::string m_Description; }; // Base for commands that host subcommands - handles all dispatch boilerplate @@ -101,6 +107,7 @@ public: protected: void AddSubCommand(ZenSubCmdBase& SubCmd); virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions); + void PrintHelp(); private: std::vector<ZenSubCmdBase*> m_SubCommands; diff --git a/src/zenbase/include/zenbase/concepts.h b/src/zenbase/include/zenbase/concepts.h index d4a9d75e8..1da56cefe 100644 --- a/src/zenbase/include/zenbase/concepts.h +++ b/src/zenbase/include/zenbase/concepts.h @@ -4,6 +4,8 @@ #include <zenbase/zenbase.h> +#include <type_traits> + // At the time of writing only ver >= 13 of LLVM's libc++ has an implementation // of std::integral. Some platforms like Ubuntu and Mac OS are still on 12. #if defined(__cpp_lib_concepts) diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index 8d087e8c6..58b76783a 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -60,11 +60,12 @@ public: } try { - std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Msg.GetSource().Filename, Msg.GetSource().Line); - sentry_value_t Event = sentry_value_new_message_event( - /* level */ MapToSentryLevel[Msg.GetLevel()], - /* logger */ nullptr, - /* message */ Message.c_str()); + const char* Filename = Msg.GetSource().Filename ? Msg.GetSource().Filename : "<unknown>"; + std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Filename, Msg.GetSource().Line); + sentry_value_t Event = sentry_value_new_message_event( + /* level */ MapToSentryLevel[Msg.GetLevel()], + /* logger */ nullptr, + /* message */ Message.c_str()); sentry_event_value_add_stacktrace(Event, NULL, 0); sentry_capture_event(Event); } diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index 089e376bb..d7eb3b17d 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -11,17 +11,131 @@ #if ZEN_WITH_TESTS +# include <zencore/callstack.h> + # include <chrono> # include <clocale> +# include <csignal> # include <cstdlib> # include <cstdio> # include <string> # include <vector> +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <execinfo.h> +# include <unistd.h> +# endif + namespace zen::testing { using namespace std::literals; +static void +PrintCrashCallstack([[maybe_unused]] const char* SignalName) +{ +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + // Use write() + backtrace_symbols_fd() which are async-signal-safe + write(STDERR_FILENO, "\n*** Caught ", 12); + write(STDERR_FILENO, SignalName, strlen(SignalName)); + write(STDERR_FILENO, " — callstack:\n", 15); + + void* Frames[64]; + int FrameCount = backtrace(Frames, 64); + backtrace_symbols_fd(Frames, FrameCount, STDERR_FILENO); +# elif ZEN_PLATFORM_WINDOWS + // On Windows we're called from SEH, not a signal handler, so heap/locks are safe + void* Addresses[64]; + uint32_t FrameCount = GetCallstack(2, 64, Addresses); + if (FrameCount > 0) + { + std::vector<std::string> Symbols = GetFrameSymbols(FrameCount, Addresses); + fprintf(stderr, "\n*** Caught %s - callstack:\n", SignalName); + for (uint32_t i = 0; i < FrameCount; ++i) + { + fprintf(stderr, " %2u: %s\n", i, Symbols[i].c_str()); + } + } +# endif +} + +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +static void +CrashSignalHandler(int Signal) +{ + const char* SignalName = "Unknown signal"; + switch (Signal) + { + case SIGSEGV: + SignalName = "SIGSEGV"; + break; + case SIGABRT: + SignalName = "SIGABRT"; + break; + case SIGFPE: + SignalName = "SIGFPE"; + break; + case SIGBUS: + SignalName = "SIGBUS"; + break; + case SIGILL: + SignalName = "SIGILL"; + break; + } + + PrintCrashCallstack(SignalName); + + // Re-raise with default handler so the process terminates normally + signal(Signal, SIG_DFL); + raise(Signal); +} + +# endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +# if ZEN_PLATFORM_WINDOWS + +static LONG CALLBACK +CrashVectoredHandler(PEXCEPTION_POINTERS ExceptionInfo) +{ + // Only handle fatal exceptions, not first-chance exceptions used for normal control flow + switch (ExceptionInfo->ExceptionRecord->ExceptionCode) + { + case EXCEPTION_ACCESS_VIOLATION: + PrintCrashCallstack("EXCEPTION_ACCESS_VIOLATION"); + break; + case EXCEPTION_STACK_OVERFLOW: + PrintCrashCallstack("EXCEPTION_STACK_OVERFLOW"); + break; + case EXCEPTION_ILLEGAL_INSTRUCTION: + PrintCrashCallstack("EXCEPTION_ILLEGAL_INSTRUCTION"); + break; + case EXCEPTION_INT_DIVIDE_BY_ZERO: + PrintCrashCallstack("EXCEPTION_INT_DIVIDE_BY_ZERO"); + break; + default: + break; + } + + // Continue search so doctest's handler can report the test case context + return EXCEPTION_CONTINUE_SEARCH; +} + +# endif // ZEN_PLATFORM_WINDOWS + +static void +InstallCrashSignalHandlers() +{ +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + signal(SIGSEGV, CrashSignalHandler); + signal(SIGABRT, CrashSignalHandler); + signal(SIGFPE, CrashSignalHandler); + signal(SIGBUS, CrashSignalHandler); + signal(SIGILL, CrashSignalHandler); +# elif ZEN_PLATFORM_WINDOWS + AddVectoredExceptionHandler(0 /*called last among vectored handlers*/, CrashVectoredHandler); +# endif +} + struct TestListener : public doctest::IReporter { const std::string_view ColorYellow = "\033[0;33m"sv; @@ -184,6 +298,7 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink zen::logging::InitializeLogging(); zen::MaximizeOpenFileCount(); + InstallCrashSignalHandlers(); TestRunner Runner; diff --git a/src/zenhttp-test/zenhttp-test.cpp b/src/zenhttp-test/zenhttp-test.cpp index b4b406ac8..0a6980462 100644 --- a/src/zenhttp-test/zenhttp-test.cpp +++ b/src/zenhttp-test/zenhttp-test.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include <zencore/testing.h> +#include <zenhttp/httpclient.h> #include <zenhttp/zenhttp.h> #include <zencore/memory/newdelete.h> @@ -9,6 +10,17 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) { #if ZEN_WITH_TESTS + using namespace std::literals; + for (int i = 1; i < argc; ++i) + { + std::string_view Arg(argv[i]); + if (Arg.starts_with("--httpclient="sv)) + { + std::string_view Value = Arg.substr(13); + zen::SetDefaultHttpClientBackend(Value); + } + } + return zen::testing::RunTestMain(argc, argv, "zenhttp-test", zen::zenhttp_forcelinktests); #else return 0; diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 6f4c67dd0..e4d11547a 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -646,6 +646,63 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("ParseContentRange") +{ + SUBCASE("normal range with total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500"); + CHECK_EQ(Offset, 0); + CHECK_EQ(Length, 100); + } + + SUBCASE("non-zero offset") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878"); + CHECK_EQ(Offset, 2638); + CHECK_EQ(Length, 5111437 - 2638 + 1); + } + + SUBCASE("wildcard total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*"); + CHECK_EQ(Offset, 100); + CHECK_EQ(Length, 100); + } + + SUBCASE("no slash (total size omitted)") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 50-149"); + CHECK_EQ(Offset, 50); + CHECK_EQ(Length, 100); + } + + SUBCASE("malformed input returns zeros") + { + auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500"); + CHECK_EQ(Offset1, 0); + CHECK_EQ(Length1, 0); + + auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500"); + CHECK_EQ(Offset2, 0); + CHECK_EQ(Length2, 0); + + auto [Offset3, Length3] = detail::ParseContentRange(""); + CHECK_EQ(Offset3, 0); + CHECK_EQ(Length3, 0); + + auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500"); + CHECK_EQ(Offset4, 0); + CHECK_EQ(Length4, 0); + } + + SUBCASE("single byte range") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000"); + CHECK_EQ(Offset, 42); + CHECK_EQ(Length, 1); + } +} + TEST_CASE("MultipartBoundaryParser") { uint64_t Range1Offset = 2638; diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 14e40b02a..f3082e0a2 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -14,6 +14,11 @@ #include <zenhttp/packageformat.h> #include <algorithm> +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/ssl_options.h> +#include <cpr/unix_socket.h> +ZEN_THIRD_PARTY_INCLUDES_END + namespace zen { HttpClientBase* @@ -24,84 +29,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; -bool -HttpClient::ErrorContext::IsConnectionError() const +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCprError(cpr::ErrorCode Code) { - switch (static_cast<cpr::ErrorCode>(ErrorCode)) + switch (Code) { + case cpr::ErrorCode::OK: + return HttpClientErrorCode::kOK; case cpr::ErrorCode::CONNECTION_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: + return HttpClientErrorCode::kConnectionFailure; case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + return HttpClientErrorCode::kHostResolutionFailure; case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return true; + return HttpClientErrorCode::kProxyResolutionFailure; + case cpr::ErrorCode::INTERNAL_ERROR: + return HttpClientErrorCode::kInternalError; + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + return HttpClientErrorCode::kNetworkSendFailure; + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + return HttpClientErrorCode::kSSLCertificateError; + case cpr::ErrorCode::SSL_CACERT_ERROR: + return HttpClientErrorCode::kSSLCACertError; + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return HttpClientErrorCode::kGenericSSLError; + case cpr::ErrorCode::REQUEST_CANCELLED: + return HttpClientErrorCode::kRequestCancelled; default: - return false; - } -} - -// If we want to support different HTTP client implementations then we'll need to make this more abstract - -HttpClientError::ResponseClass -HttpClientError::GetResponseClass() const -{ - if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) - { - switch ((cpr::ErrorCode)m_Error) - { - case cpr::ErrorCode::CONNECTION_FAILURE: - return ResponseClass::kHttpCantConnectError; - case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: - case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return ResponseClass::kHttpNoHost; - case cpr::ErrorCode::INTERNAL_ERROR: - case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: - case cpr::ErrorCode::NETWORK_SEND_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: - return ResponseClass::kHttpTimeout; - case cpr::ErrorCode::SSL_CONNECT_ERROR: - case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_CACERT_ERROR: - case cpr::ErrorCode::GENERIC_SSL_ERROR: - return ResponseClass::kHttpSLLError; - default: - return ResponseClass::kHttpOtherClientError; - } - } - else if (IsHttpSuccessCode(m_ResponseCode)) - { - return ResponseClass::kSuccess; - } - else - { - switch (m_ResponseCode) - { - case HttpResponseCode::Unauthorized: - return ResponseClass::kHttpUnauthorized; - case HttpResponseCode::NotFound: - return ResponseClass::kHttpNotFound; - case HttpResponseCode::Forbidden: - return ResponseClass::kHttpForbidden; - case HttpResponseCode::Conflict: - return ResponseClass::kHttpConflict; - case HttpResponseCode::InternalServerError: - return ResponseClass::kHttpInternalServerError; - case HttpResponseCode::ServiceUnavailable: - return ResponseClass::kHttpServiceUnavailable; - case HttpResponseCode::BadGateway: - return ResponseClass::kHttpBadGateway; - case HttpResponseCode::GatewayTimeout: - return ResponseClass::kHttpGatewayTimeout; - default: - if (m_ResponseCode >= HttpResponseCode::InternalServerError) - { - return ResponseClass::kHttpOtherServerError; - } - else - { - return ResponseClass::kHttpOtherClientError; - } - } + return HttpClientErrorCode::kOtherError; } } @@ -257,8 +220,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId, .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed, - .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), - .ErrorMessage = HttpResponse.error.message}}; + .Error = + HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}}; } if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) @@ -526,6 +489,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); } + if (ConnectionSettings.ForbidReuseConnection) + { + CprSession->UpdateHeader({{"Connection", "close"}}); + } if (AccessToken) { CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); @@ -544,6 +511,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, CprSession->SetParameters({}); } + if (!ConnectionSettings.UnixSocketPath.empty()) + { + CprSession->SetUnixSocket(cpr::UnixSocket(ConnectionSettings.UnixSocketPath)); + } + + if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty()) + { + cpr::SslOptions SslOpts; + if (ConnectionSettings.InsecureSsl) + { + SslOpts.SetOption(cpr::ssl::VerifyHost{false}); + SslOpts.SetOption(cpr::ssl::VerifyPeer{false}); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath}); + } + CprSession->SetSslOptions(SslOpts); + } + ExtendableStringBuilder<128> UrlBuffer; UrlBuffer << BaseUrl << ResourcePath; CprSession->SetUrl(UrlBuffer.c_str()); diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp new file mode 100644 index 000000000..3cb749018 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -0,0 +1,1947 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcurl.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zenhttp/packageformat.h> +#include <algorithm> + +namespace zen { + +HttpClientBase* +CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) +{ + return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); +} + +static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; + +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback helpers + +struct WriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<WriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct HeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + + // Trim trailing \r\n + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + // Trim whitespace + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct ReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<ReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +struct StreamReadCallbackData +{ + detail::CompositeBufferReadStream* Reader = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<StreamReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + return Data->Reader->Read(Buffer, MaxRead); +} + +struct FileReadCallbackData +{ + detail::BufferedReadFileStream* Buffer = nullptr; + uint64_t TotalSize = 0; + uint64_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<FileReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->TotalSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + Data->Buffer->Read(Buffer, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +static int +CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr) +{ + ZEN_UNUSED(Handle); + LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); + auto Log = [&]() -> LoggerRef { return LogRef; }; + + std::string_view DataView(Data, Size); + + // Trim trailing newlines + while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n')) + { + DataView.remove_suffix(1); + } + + switch (Type) + { + case CURLINFO_TEXT: + if (DataView.find("need more data"sv) == std::string_view::npos) + { + ZEN_INFO("TEXT: {}", DataView); + } + break; + case CURLINFO_HEADER_IN: + ZEN_INFO("HIN : {}", DataView); + break; + case CURLINFO_HEADER_OUT: + if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos) + { + std::string Copy(DataView); + auto BearerStart = TokenPos + 22; + auto BearerEnd = Copy.find_first_of("\r\n", BearerStart); + if (BearerEnd == std::string::npos) + { + BearerEnd = Copy.length(); + } + Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart)); + ZEN_INFO("HOUT: {}", Copy); + } + else + { + ZEN_INFO("HOUT: {}", DataView); + } + break; + default: + break; + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +static curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<HttpClientAccessToken>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + std::string HeaderLine = fmt::format("{}: {}", Key, Value); + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + std::string SessionHeader = fmt::format("UE-Session: {}", SessionId); + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken) + { + std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value); + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + for (const auto& [Key, Value] : ExtraHeaders) + { + std::string HeaderLine = fmt::format("{}: {}", Key, Value); + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +static std::string +BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) +{ + std::string Url; + Url.reserve(BaseUrl.size() + ResourcePath.size() + 64); + Url.append(BaseUrl); + Url.append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size())); + char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size())); + Url += Separator; + Url += EncodedKey; + Url += '='; + Url += EncodedValue; + curl_free(EncodedKey); + curl_free(EncodedValue); + Separator = '&'; + } + } + + return Url; +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::CurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction) +: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)) +{ +} + +CurlHttpClient::~CurlHttpClient() +{ + ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto* Handle : m_Sessions) + { + curl_easy_cleanup(Handle); + } + m_Sessions.clear(); + }); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::Perform() +{ + CurlResult Result; + + char ErrorBuffer[CURL_ERROR_SIZE] = {}; + curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer); + + Result.ErrorCode = curl_easy_perform(Handle); + + if (Result.ErrorCode != CURLE_OK) + { + Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode); + } + + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + Result.ElapsedSeconds = Elapsed; + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + Result.UploadedBytes = static_cast<int64_t>(UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + + return Result; +} + +bool +CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + +HttpClient::Response +CurlHttpClient::ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); + + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Type") + { + const HttpContentType ContentType = ParseContentType(Value); + ResponseBuffer.SetContentType(ContentType); + break; + } + } + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + if (ShouldLogErrorCode(WorkResponseCode)) + { + ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}", + SessionId, + static_cast<int>(WorkResponseCode), + m_BaseUri); + } + } + + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Ranges = std::move(BoundaryPositions)}; +} + +HttpClient::Response +CurlHttpClient::CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode); + if (Result.ErrorCode != CURLE_OK) + { + const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); + if (!Quiet) + { + if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && + Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + SessionId, + static_cast<int>(Result.ErrorCode), + Result.ErrorMessage); + } + } + + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) + { + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Result.Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = std::move(HeaderMap), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + else + { + return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); + } +} + +bool +CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + + IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + + // Find Content-Length in headers + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Length") + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value); + return false; + } + break; + } + } + + if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent)) + { + return true; + } + + // Check X-Jupiter-IoHash + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "X-Jupiter-IoHash") + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(Value, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; + } + } + break; + } + } + + // Validate content-type specific payload + for (const auto& [Key, Value] : Result.Headers) + { + if (Key == "Content-Type") + { + if (Value == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (Value == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; + } + } + break; + } + } + + return true; +} + +bool +CurlHttpClient::ShouldRetry(const CurlResult& Result) +{ + switch (Result.ErrorCode) + { + case CURLE_OK: + break; + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(Result.StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + ZEN_INFO("{} Attempt {}/{}", + CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + Result = Func(); + } + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (ValidatePayload(Result, PayloadFile)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + ZEN_INFO("{} Attempt {}/{}", + CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + Result = Func(); + } + return Result; +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Session +CurlHttpClient::AllocSession(std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken) +{ + ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader); + ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); + CURL* Handle = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + Handle = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (Handle == nullptr) + { + Handle = curl_easy_init(); + } + else + { + curl_easy_reset(Handle); + } + + // Unix domain socket + if (!ConnectionSettings.UnixSocketPath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str()); + } + + // Build URL with parameters + std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Timeouts + if (ConnectionSettings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count())); + } + if (ConnectionSettings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count())); + } + + // HTTP/2 + if (ConnectionSettings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // Verbose/debug + if (ConnectionSettings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); + curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log); + } + + // SSL options + if (ConnectionSettings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str()); + } + + // Disable signal handling for thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (ConnectionSettings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + + // Note: Headers are NOT set here. Each method builds its own header list + // (potentially adding method-specific headers like Content-Type) and is + // responsible for freeing it with curl_slist_free_all. + + return Session(this, Handle); +} + +void +CurlHttpClient::ReleaseSession(CURL* Handle) +{ + ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); + + // Free any header list that was set + // curl_easy_reset will be called on next AllocSession, which cleans up the handle state. + // We just push the handle back to the pool. + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Response +CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::TransactPackage"); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + std::vector<std::pair<std::string, std::string>> OfferExtraHeaders; + OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); + OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); + + std::string FilterBody; + WriteCallbackData WriteData{.Body = &FilterBody}; + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + + CurlResult Result = Sess.Perform(); + + curl_slist_free_all(HeaderList); + + if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + std::vector<std::pair<std::string, std::string>> PkgExtraHeaders; + PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); + PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); + + std::string PkgBody; + WriteCallbackData WriteData{.Body = &PkgBody}; + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + + CurlResult Result = Sess.Perform(); + + curl_slist_free_all(HeaderList); + + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + return {.StatusCode = HttpResponseCode(Result.StatusCode)}; + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size()); + + return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = ResponseBuffer}; +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Get"); + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Head"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_NOBODY, 1L); + + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Delete"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE"); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + // Rebuild headers with content type + curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + } + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + } + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* Headers = + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + std::string Body; + WriteCallbackData WriteData{.Body = &Body}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Sess.Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + curl_slist_free_all(Headers); + return Result; + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + CURL* H = Sess.Get(); + + curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); + curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + // Track requested content length from Range header (sum all ranges) + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + // Header callback that detects Content-Length and switches to file-backed storage when needed + struct DownloadHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + HttpContentType* ContentTypeOut = nullptr; + }; + + DownloadHeaderCallbackData DlHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + DlHdrData.Headers = &ResponseHeaders; + DlHdrData.PayloadFile = &PayloadFile; + DlHdrData.PayloadString = &PayloadString; + DlHdrData.TempFolderPath = &TempFolderPath; + DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + DlHdrData.Log = m_Log; + DlHdrData.BoundaryParser = &BoundaryParser; + DlHdrData.IsMultiRange = &IsMultiRangeResponse; + DlHdrData.ContentTypeOut = &ContentType; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (Key == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + else if (Key == "Content-Type"sv) + { + *Data->IsMultiRange = Data->BoundaryParser->Init(Value); + if (!*Data->IsMultiRange) + { + *Data->ContentTypeOut = ParseContentType(Value); + } + } + else if (Key == "Content-Range"sv) + { + if (!*Data->IsMultiRange) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value); + if (Range.second != 0) + { + Data->BoundaryParser->Boundaries.push_back( + HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = *Data->ContentTypeOut}); + } + } + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData); + + // Write callback that directs data to file or string + struct DownloadWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + }; + + DownloadWriteCallbackData DlWriteData; + DlWriteData.PayloadString = &PayloadString; + DlWriteData.PayloadFile = &PayloadFile; + DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + DlWriteData.TempFolderPath = &TempFolderPath; + DlWriteData.Log = m_Log; + DlWriteData.BoundaryParser = &BoundaryParser; + DlWriteData.IsMultiRange = &IsMultiRangeResponse; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->IsMultiRange) + { + Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes)); + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + // Handle resume logic + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const CurlResult& R) -> bool { + for (const auto& [K, V] : R.Headers) + { + if (K == "Content-Range") + { + return true; + } + if (K == "Accept-Ranges" && V == "bytes") + { + return true; + } + } + return false; + }; + + auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(R)) + { + return SupportsRanges(R); + } + return false; + }; + + if (ShouldResumeCheck(Res)) + { + // Find Content-Length + std::string ContentLengthValue; + for (const auto& [K, V] : Res.Headers) + { + if (K == "Content-Length") + { + ContentLengthValue = V; + break; + } + } + + if (!ContentLengthValue.empty()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + break; // No progress, abort + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session ResumeSess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + CURL* ResumeH = ResumeSess.Get(); + + curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()); + curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList); + curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); + + std::vector<std::pair<std::string, std::string>> ResumeHeaders; + + struct ResumeHeaderCbData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + }; + + ResumeHeaderCbData ResumeHdrData; + ResumeHdrData.Headers = &ResumeHeaders; + ResumeHdrData.PayloadFile = &PayloadFile; + ResumeHdrData.PayloadString = &PayloadString; + + auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<ResumeHeaderCbData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + std::string_view Line(Buffer, TotalBytes); + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return TotalBytes; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos != std::string_view::npos) + { + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + if (Key == "Content-Range"sv) + { + if (Value.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) + { + const std::optional<uint64_t> Start = + ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() + : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) + { + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + } + } + return 0; + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(ResumeH, + CURLOPT_HEADERFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb)); + curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData); + curl_easy_setopt(ResumeH, + CURLOPT_WRITEFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData); + + Res = ResumeSess.Perform(); + Res.Headers = std::move(ResumeHeaders); + + curl_slist_free_all(ResumeHdrList); + } while (ShouldResumeCheck(Res)); + } + } + } + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + curl_slist_free_all(DlHeaders); + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Result), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h new file mode 100644 index 000000000..2a49ff308 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.h @@ -0,0 +1,135 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CurlHttpClient : public HttpClientBase +{ +public: + CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); + ~CurlHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct CurlResult + { + long StatusCode = 0; + std::string Body; + std::vector<std::pair<std::string, std::string>> Headers; + double ElapsedSeconds = 0; + int64_t UploadedBytes = 0; + int64_t DownloadedBytes = 0; + CURLcode ErrorCode = CURLE_OK; + std::string ErrorMessage; + }; + + struct Session + { + Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {} + ~Session() { Outer->ReleaseSession(Handle); } + + CURL* Get() const { return Handle; } + + CurlResult Perform(); + + LoggerRef Log() { return Outer->Log(); } + + private: + CurlHttpClient* Outer; + CURL* Handle; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(std::string_view BaseUrl, + std::string_view Url, + const HttpClientSettings& ConnectionSettings, + const KeyValueMap& AdditionalHeader, + const KeyValueMap& Parameters, + std::string_view SessionId, + std::optional<HttpClientAccessToken> AccessToken); + + RwLock m_SessionLock; + std::vector<CURL*> m_Sessions; + + void ReleaseSession(CURL* Handle); + + struct RetryResult + { + CurlResult Result; + }; + + CurlResult DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + CurlResult DoWithRetry( + std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; }); + + bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + + static bool ShouldRetry(const CurlResult& Result); + + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; + + HttpClient::Response CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); + + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); +}; + +} // namespace zen diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 9497dadb8..792848a6b 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -10,6 +10,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #include <deque> @@ -47,11 +50,7 @@ struct HttpWsClient::Impl m_WorkGuard.reset(); // Close the socket to cancel pending async ops - if (m_Socket) - { - asio::error_code Ec; - m_Socket->close(Ec); - } + CloseSocket(); if (m_IoThread.joinable()) { @@ -59,6 +58,35 @@ struct HttpWsClient::Impl } } + void CloseSocket() + { + asio::error_code Ec; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + m_UnixSocket->close(Ec); + return; + } +#endif + if (m_TcpSocket) + { + m_TcpSocket->close(Ec); + } + } + + template<typename Fn> + void WithSocket(Fn&& Func) + { +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + Func(*m_UnixSocket); + return; + } +#endif + Func(*m_TcpSocket); + } + void ParseUrl(std::string_view Url) { // Expected format: ws://host:port/path @@ -101,9 +129,47 @@ struct HttpWsClient::Impl m_IoThread = std::thread([this] { m_IoContext.run(); }); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_Settings.UnixSocketPath.empty()) + { + asio::post(m_IoContext, [this] { DoConnectUnix(); }); + return; + } +#endif + asio::post(m_IoContext, [this] { DoResolve(); }); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void DoConnectUnix() + { + m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath); + CloseSocket(); + } + }); + + asio::local::stream_protocol::endpoint Endpoint(m_Settings.UnixSocketPath); + m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } +#endif + void DoResolve() { m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext); @@ -122,7 +188,7 @@ struct HttpWsClient::Impl void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) { - m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); + m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); // Start connect timeout timer m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); @@ -130,15 +196,11 @@ struct HttpWsClient::Impl if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) { ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); - if (m_Socket) - { - asio::error_code CloseEc; - m_Socket->close(CloseEc); - } + CloseSocket(); } }); - asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { if (Ec) { m_Timer->cancel(); @@ -194,64 +256,68 @@ struct HttpWsClient::Impl m_HandshakeBuffer = std::make_shared<std::string>(ReqStr); - asio::async_write(*m_Socket, - asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), - [this](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - m_Timer->cancel(); - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); - m_Handler.OnWsClose(1006, "handshake write failed"); - return; - } - - DoReadHandshakeResponse(); - }); + WithSocket([this](auto& Socket) { + asio::async_write(Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + }); } void DoReadHandshakeResponse() { - asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { - m_Timer->cancel(); + WithSocket([this](auto& Socket) { + asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); - if (Ec) - { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); - m_Handler.OnWsClose(1006, "handshake read failed"); - return; - } + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } - // Parse the response - const auto& Data = m_ReadBuffer.data(); - std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); - // Consume the headers from the read buffer (any extra data stays for frame parsing) - auto HeaderEnd = Response.find("\r\n\r\n"); - if (HeaderEnd != std::string::npos) - { - m_ReadBuffer.consume(HeaderEnd + 4); - } + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } - // Validate 101 response - if (Response.find("101") == std::string::npos) - { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); - m_Handler.OnWsClose(1006, "handshake rejected"); - return; - } + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } - // Validate Sec-WebSocket-Accept - std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); - if (Response.find(ExpectedAccept) == std::string::npos) - { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); - m_Handler.OnWsClose(1006, "invalid accept key"); - return; - } + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } - m_IsOpen.store(true); - m_Handler.OnWsOpen(); - EnqueueRead(); + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); }); } @@ -267,8 +333,10 @@ struct HttpWsClient::Impl return; } - asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { - OnDataReceived(Ec); + WithSocket([this](auto& Socket) { + asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); }); } @@ -414,9 +482,11 @@ struct HttpWsClient::Impl auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); - asio::async_write(*m_Socket, - asio::buffer(OwnedFrame->data(), OwnedFrame->size()), - [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + WithSocket([this, OwnedFrame](auto& Socket) { + asio::async_write(Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + }); } void OnWriteComplete(const asio::error_code& Ec) @@ -501,11 +571,14 @@ struct HttpWsClient::Impl // Connection state std::unique_ptr<asio::ip::tcp::resolver> m_Resolver; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::unique_ptr<asio::steady_timer> m_Timer; - asio::streambuf m_ReadBuffer; - std::string m_WebSocketKey; - std::shared_ptr<std::string> m_HandshakeBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket; +#endif + std::unique_ptr<asio::steady_timer> m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr<std::string> m_HandshakeBuffer; // Write queue RwLock m_WriteLock; diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 281d512cf..9baf4346e 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -40,6 +40,35 @@ extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); +extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction); + +static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCpr; + +void +SetDefaultHttpClientBackend(HttpClientBackend Backend) +{ + g_DefaultHttpClientBackend = Backend; +} + +void +SetDefaultHttpClientBackend(std::string_view Backend) +{ + if (Backend == "cpr") + { + g_DefaultHttpClientBackend = HttpClientBackend::kCpr; + } + else if (Backend == "curl") + { + g_DefaultHttpClientBackend = HttpClientBackend::kCurl; + } + else + { + g_DefaultHttpClientBackend = HttpClientBackend::kDefault; + } +} + using namespace std::literals; ////////////////////////////////////////////////////////////////////////// @@ -104,6 +133,71 @@ HttpClientBase::GetAccessToken() ////////////////////////////////////////////////////////////////////////// +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if (m_Error != HttpClientErrorCode::kOK) + { + switch (m_Error) + { + case HttpClientErrorCode::kConnectionFailure: + return ResponseClass::kHttpCantConnectError; + case HttpClientErrorCode::kHostResolutionFailure: + case HttpClientErrorCode::kProxyResolutionFailure: + return ResponseClass::kHttpNoHost; + case HttpClientErrorCode::kInternalError: + case HttpClientErrorCode::kNetworkReceiveError: + case HttpClientErrorCode::kNetworkSendFailure: + case HttpClientErrorCode::kOperationTimedOut: + return ResponseClass::kHttpTimeout; + case HttpClientErrorCode::kSSLConnectError: + case HttpClientErrorCode::kSSLCertificateError: + case HttpClientErrorCode::kSSLCACertError: + case HttpClientErrorCode::kGenericSSLError: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) + { + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + std::vector<std::pair<uint64_t, uint64_t>> HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const { @@ -222,7 +316,11 @@ HttpClient::Response::ErrorMessage(std::string_view Prefix) const { if (Error.has_value()) { - return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage); + return fmt::format("{}{}HTTP error ({}) '{}'", + Prefix, + Prefix.empty() ? ""sv : ": "sv, + static_cast<int>(Error->ErrorCode), + Error->ErrorMessage); } else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode) { @@ -245,19 +343,34 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix) { if (!IsSuccess()) { - throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode); + throw HttpClientError(ErrorMessage(ErrorPrefix), + Error.has_value() ? Error.value().ErrorCode : HttpClientErrorCode::kOK, + StatusCode); } } ////////////////////////////////////////////////////////////////////////// HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) -: m_BaseUri(BaseUri) +: m_Log(zen::logging::Get(ConnectionSettings.LogCategory)) +, m_BaseUri(BaseUri) , m_ConnectionSettings(ConnectionSettings) { m_SessionId = GetSessionIdString(); - m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + HttpClientBackend EffectiveBackend = + ConnectionSettings.Backend != HttpClientBackend::kDefault ? ConnectionSettings.Backend : g_DefaultHttpClientBackend; + + switch (EffectiveBackend) + { + case HttpClientBackend::kCurl: + m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; + case HttpClientBackend::kCpr: + default: + m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; + } } HttpClient::~HttpClient() diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 52bf149a7..2d949c546 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -8,6 +8,7 @@ # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinaryutil.h> # include <zencore/compositebuffer.h> +# include <zencore/filesystem.h> # include <zencore/iobuffer.h> # include <zencore/logging.h> # include <zencore/scopeguard.h> @@ -232,7 +233,7 @@ struct TestServerFixture TestServerFixture() { Server = CreateHttpAsioServer(AsioConfig{}); - Port = Server->Initialize(7600, TmpDir.Path()); + Port = Server->Initialize(0, TmpDir.Path()); ZEN_ASSERT(Port != -1); Server->RegisterService(TestService); ServerThread = std::thread([this]() { Server->Run(false); }); @@ -1044,13 +1045,22 @@ struct FaultTcpServer { m_Port = m_Acceptor.local_endpoint().port(); StartAccept(); - m_Thread = std::thread([this]() { m_IoContext.run(); }); + m_Thread = std::thread([this]() { + try + { + m_IoContext.run(); + } + catch (...) + { + } + }); } ~FaultTcpServer() { - std::error_code Ec; - m_Acceptor.close(Ec); + // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this + // thread — ASIO I/O objects are not safe for concurrent access and the io_context + // thread may be touching the acceptor in StartAccept(). m_IoContext.stop(); if (m_Thread.joinable()) { @@ -1081,6 +1091,105 @@ struct FaultTcpServer } }; +TEST_CASE("httpclient.range-response") +{ + ScopedTemporaryDirectory DownloadDir; + + SUBCASE("single range 206 response populates Ranges") + { + std::string RangeBody(100, 'A'); + + FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = fmt::format( + "HTTP/1.1 206 Partial Content\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Range: bytes 200-299/1000\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + RangeBody.size(), + RangeBody); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + }); + + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent); + REQUIRE(Resp.Ranges.size() == 1); + CHECK_EQ(Resp.Ranges[0].RangeOffset, 200); + CHECK_EQ(Resp.Ranges[0].RangeLength, 100); + } + + SUBCASE("multipart byteranges 206 response populates Ranges") + { + std::string Part1Data(16, 'X'); + std::string Part2Data(12, 'Y'); + std::string Boundary = "testboundary123"; + + std::string MultipartBody = fmt::format( + "\r\n--{}\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Range: bytes 100-115/1000\r\n" + "\r\n" + "{}" + "\r\n--{}\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Range: bytes 500-511/1000\r\n" + "\r\n" + "{}" + "\r\n--{}--", + Boundary, + Part1Data, + Boundary, + Part2Data, + Boundary); + + FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = fmt::format( + "HTTP/1.1 206 Partial Content\r\n" + "Content-Type: multipart/byteranges; boundary={}\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + Boundary, + MultipartBody.size(), + MultipartBody); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + }); + + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent); + REQUIRE(Resp.Ranges.size() == 2); + // Ranges should be sorted by RangeOffset + CHECK_EQ(Resp.Ranges[0].RangeOffset, 100); + CHECK_EQ(Resp.Ranges[0].RangeLength, 16); + CHECK_EQ(Resp.Ranges[1].RangeOffset, 500); + CHECK_EQ(Resp.Ranges[1].RangeLength, 12); + } + + SUBCASE("non-range 200 response has empty Ranges") + { + FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(200, "full content"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + }); + + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK(Resp.Ranges.empty()); + } +} + TEST_CASE("httpclient.transport-faults" * doctest::skip()) { SUBCASE("connection reset before response") @@ -1354,6 +1463,188 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) } } +TEST_CASE("httpclient.unixsocket") +{ + ScopedTemporaryDirectory TmpDir; + std::string SocketPath = (TmpDir.Path() / "zen.sock").string(); + + HttpClientTestService TestService; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != -1); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto _ = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + HttpClientSettings Settings; + Settings.UnixSocketPath = SocketPath; + + HttpClient Client("localhost", Settings, /*CheckIfAbortFunction*/ {}); + + SUBCASE("GET over unix socket") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("POST echo over unix socket") + { + const char* Payload = "unix socket payload"; + IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload)); + Body.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "unix socket payload"); + } +} + +# if ZEN_USE_OPENSSL + +TEST_CASE("httpclient.https") +{ + // Self-signed test certificate for localhost/127.0.0.1, valid until 2036 + static constexpr std::string_view TestCertPem = + "-----BEGIN CERTIFICATE-----\n" + "MIIDJTCCAg2gAwIBAgIUEtJYMSUmJmvJ157We/qXNVJ7W8gwDQYJKoZIhvcNAQEL\n" + "BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMwOTIwMjU1M1oXDTM2MDMw\n" + "NjIwMjU1M1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\n" + "AAOCAQ8AMIIBCgKCAQEAv9YvZ6WeBz3z/Zuxi6OIivWksDxDZZ5oAXKVwlUXaa7v\n" + "iDkm9P5ZsEhN+M5vZMe2Yb9i3cnTUaE6Avs1ddOwTAYNGrE/B5DmibrRWc23R0cv\n" + "gdnYQJ+gjsAeMvUWYLK58xW4YoMR5bmfpj1ruqobUNkG/oJYnAUcjgo4J149irW+\n" + "4n9uLJvxL+5fI/b/AIkv+4TMe70/d/BPmnixWrrzxUT6S5ghE2Mq7+XLScfpY2Sp\n" + "GQ/Xbnj9/ELYLpQnNLuVZwWZDpXj+FLbF1zxgjYdw1cCjbRcOIEW2/GJeJvGXQ6Y\n" + "Vld5pCBm9uKPPLWoFCoakK5YvP00h+8X+HghGVSscQIDAQABo28wbTAdBgNVHQ4E\n" + "FgQUgM6hjymi6g2EBUg2ENu0nIK8yhMwHwYDVR0jBBgwFoAUgM6hjymi6g2EBUg2\n" + "ENu0nIK8yhMwDwYDVR0TAQH/BAUwAwEB/zAaBgNVHREEEzARhwR/AAABgglsb2Nh\n" + "bGhvc3QwDQYJKoZIhvcNAQELBQADggEBABY1oaaWwL4RaK/epKvk/IrmVT2mlAai\n" + "uvGLfjhc6FGvXaxPGTSUPrVbFornaWZAg7bOWCexWnEm2sWd75V/usvZAPN4aIiD\n" + "H66YQipq3OD4F9Gowp01IU4AcGh7MerFpYPk76+wp2ANq71x8axtlZjVn3hSFMmN\n" + "i6m9S/eyCl9WjYBT5ZEC4fJV0nOSmNe/+gCAm11/js9zNfXKmUchJtuZpubY3A0k\n" + "X2II6qYWf1PH+JJkefNZtt2c66CrEN5eAg4/rGEgsp43zcd4ZHVkpBKFLDEls1ev\n" + "drQ45zc4Ht77pHfnHu7YsLcRZ9Wq3COMNZYx5lItqnomX2qBm1pkwjI=\n" + "-----END CERTIFICATE-----\n"; + + static constexpr std::string_view TestKeyPem = + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/1i9npZ4HPfP9\n" + "m7GLo4iK9aSwPENlnmgBcpXCVRdpru+IOSb0/lmwSE34zm9kx7Zhv2LdydNRoToC\n" + "+zV107BMBg0asT8HkOaJutFZzbdHRy+B2dhAn6COwB4y9RZgsrnzFbhigxHluZ+m\n" + "PWu6qhtQ2Qb+glicBRyOCjgnXj2Ktb7if24sm/Ev7l8j9v8AiS/7hMx7vT938E+a\n" + "eLFauvPFRPpLmCETYyrv5ctJx+ljZKkZD9dueP38QtgulCc0u5VnBZkOleP4UtsX\n" + "XPGCNh3DVwKNtFw4gRbb8Yl4m8ZdDphWV3mkIGb24o88tagUKhqQrli8/TSH7xf4\n" + "eCEZVKxxAgMBAAECggEAILd9pDaZqfCF8SWhdQgx3Ekiii/s6qLGaCDLq7XpZUvB\n" + "bEEbBMNwNmFOcvV6B/0LfMYwLVUjZhOSGjoPlwXAVmbdy0SZVEgBGVI0LBWqgUyB\n" + "rKqjd/oBXvci71vfMiSpE+0LYjmqTryGnspw2gfy2qn4yGUgiZNRmGPjycsHweUL\n" + "V3FHm3cf0dyE4sJ0mjVqZzRT/unw2QOCE6FlY7M1XxZL88IWfn6G4lckdJTwoOP5\n" + "VPR2J3XbyhvCeXeDRCHKRXojWWR2HovWnDXQc95GRgCd0vYdHuIUM6RXVPZQvy3X\n" + "l0GwQKHNcVr1uwtYDgGKw0tNCUDvxdfQaWilTFuicQKBgQDvEYp+vL1hnF+AVdu3\n" + "elsYsHpFgExkTI8wnUMvGZrFiIQyCyVDU3jkG3kcKacI1bfwopXopaQCjrYk9epm\n" + "liOVm3/Xtr6e2ENa7w8TQbdK65PciQNOMxml6g8clRRBl0cwj+aI3nW/Kop1cdrR\n" + "A9Vo+8iPTO5gDcxTiIb45a6E3QKBgQDNbE009P6ewx9PU7Llkhb9VBgsb7oQN3EV\n" + "TCYd4taiN6FPnTuL/cdijAA8y04hiVT+Efo9TUN9NCl9HdHXQcjj7/n/eFLH0Pkw\n" + "OIK3QN49OfR88wivLMtwWxIog0tJjc9+7dR4bR4o1jTlIrasEIvUTuDJQ8MKGc9v\n" + "pBITua+SpQKBgE4raSKZqj7hd6Sp7kbnHiRLiB9znQbqtaNKuK4M7DuMsNUAKfYC\n" + "tDO5+/bGc9SCtTtcnjHM/3zKlyossrFKhGYlyz6IhXnA8v0nz8EXKsy3jMh+kHMg\n" + "aFGE394TrOTphyCM3O+B9fRE/7L5QHg5ja1fLqwUlpkXyejCaoe16kONAoGAYIz9\n" + "wN1B67cEOVG6rOI8QfdLoV8mEcctNHhlFfjvLrF89SGOwl6WX0A0QF7CK0sUEpK6\n" + "jiOJjAh/U5o3bbgyxsedNjEEn3weE0cMUTuA+UALJMtKEqO4PuffIgGL2ld35k28\n" + "ZpnK6iC8HdJyD297eV9VkeNygYXeFLgF8xV8ay0CgYEAh4fmVZt9YhgVByYny2kF\n" + "ZUIkGF5h9wxzVOPpQwpizIGFFb3i/ZdGQcuLTfIBVRKf50sT3IwJe65ATv6+Lz0f\n" + "wg/pMvosi0/F5KGbVRVdzBMQy58WyyGti4tNl+8EXGvo8+DCmjlTYwfjRoZGg/qJ\n" + "EMP3/hTN7dHDRxPK8E0Fh0Y=\n" + "-----END PRIVATE KEY-----\n"; + + ScopedTemporaryDirectory TmpDir; + + // Write cert and key to temp files + const auto CertPath = TmpDir.Path() / "test.crt"; + const auto KeyPath = TmpDir.Path() / "test.key"; + WriteFile(CertPath, IoBuffer(IoBuffer::Clone, TestCertPem.data(), TestCertPem.size())); + WriteFile(KeyPath, IoBuffer(IoBuffer::Clone, TestKeyPem.data(), TestKeyPem.size())); + + HttpClientTestService TestService; + + AsioConfig Config; + Config.CertFile = CertPath.string(); + Config.KeyFile = KeyPath.string(); + + Ref<HttpServer> Server = CreateHttpAsioServer(Config); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != -1); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto _ = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + int HttpsPort = Server->GetEffectiveHttpsPort(); + REQUIRE(HttpsPort > 0); + + HttpClientSettings Settings; + Settings.InsecureSsl = true; + + HttpClient Client(fmt::format("https://127.0.0.1:{}", HttpsPort), Settings, /*CheckIfAbortFunction*/ {}); + + SUBCASE("GET over HTTPS") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("POST echo over HTTPS") + { + const char* Payload = "https payload"; + IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload)); + Body.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "https payload"); + } + + SUBCASE("GET JSON over HTTPS") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK_EQ(Obj["ok"].AsBool(), true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("Large payload over HTTPS") + { + HttpClient::Response Resp = Client.Get("/api/test/large"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +# endif // ZEN_USE_OPENSSL + TEST_SUITE_END(); void diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 9bae95690..1a0018908 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -1044,13 +1044,16 @@ HttpServer::OnGetExternalHost() const std::string HttpServer::GetServiceUri(const HttpService* Service) const { + const char* Scheme = (m_EffectiveHttpsPort > 0) ? "https" : "http"; + int Port = (m_EffectiveHttpsPort > 0) ? m_EffectiveHttpsPort : m_EffectivePort; + if (Service) { - return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri()); + return fmt::format("{}://{}:{}{}", Scheme, m_ExternalHost, Port, Service->BaseUri()); } else { - return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort); + return fmt::format("{}://{}:{}", Scheme, m_ExternalHost, Port); } } @@ -1152,9 +1155,13 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig if (ServerClass == "asio"sv) { ZEN_INFO("using asio HTTP server implementation") - return CreateHttpAsioServer(AsioConfig{.ThreadCount = Config.ThreadCount, - .ForceLoopback = Config.ForceLoopback, - .IsDedicatedServer = Config.IsDedicatedServer}); + return CreateHttpAsioServer(AsioConfig { + .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer, + .UnixSocketPath = Config.UnixSocketPath, +#if ZEN_USE_OPENSSL + .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile, +#endif + }); } #if ZEN_WITH_HTTPSYS else if (ServerClass == "httpsys"sv) @@ -1165,7 +1172,11 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled, .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled, .IsDedicatedServer = Config.IsDedicatedServer, - .ForceLoopback = Config.ForceLoopback})); + .ForceLoopback = Config.ForceLoopback, + .HttpsPort = Config.HttpSys.HttpsPort, + .CertThumbprint = Config.HttpSys.CertThumbprint, + .CertStoreName = Config.HttpSys.CertStoreName, + .HttpsOnly = Config.HttpSys.HttpsOnly})); } #endif else if (ServerClass == "null"sv) diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h index 57ab01158..90180391c 100644 --- a/src/zenhttp/include/zenhttp/formatters.h +++ b/src/zenhttp/include/zenhttp/formatters.h @@ -84,7 +84,7 @@ struct fmt::formatter<zen::HttpClient::Response> return fmt::format_to(Ctx.out(), "Failed: Elapsed: {}, Reason: ({}) '{}", NiceResponseTime, - Response.Error.value().ErrorCode, + static_cast<int>(Response.Error.value().ErrorCode), Response.Error.value().ErrorMessage); } else diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 1bb36a298..2e21e3bd6 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -30,6 +30,34 @@ class CompositeBuffer; */ +enum class HttpClientErrorCode : int +{ + kOK = 0, + kConnectionFailure, + kHostResolutionFailure, + kProxyResolutionFailure, + kInternalError, + kNetworkReceiveError, + kNetworkSendFailure, + kOperationTimedOut, + kSSLConnectError, + kSSLCertificateError, + kSSLCACertError, + kGenericSSLError, + kRequestCancelled, + kOtherError, +}; + +enum class HttpClientBackend : uint8_t +{ + kDefault, + kCpr, + kCurl, +}; + +void SetDefaultHttpClientBackend(std::string_view Backend); +void SetDefaultHttpClientBackend(HttpClientBackend Backend); + struct HttpClientAccessToken { using Clock = std::chrono::system_clock; @@ -59,6 +87,22 @@ struct HttpClientSettings Oid SessionId = Oid::Zero; bool Verbose = false; uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u; + HttpClientBackend Backend = HttpClientBackend::kDefault; + + /// Unix domain socket path. When non-empty, the client connects via this + /// socket instead of TCP. BaseUri is still used for the Host header and URL. + std::string UnixSocketPath; + + /// Disable HTTP keep-alive by closing the connection after each request. + /// Useful for testing per-connection overhead. + bool ForbidReuseConnection = false; + + /// Skip TLS certificate verification (for testing with self-signed certs). + bool InsecureSsl = false; + + /// CA certificate bundle path for TLS verification. When non-empty, overrides + /// the system default CA store. + std::string CaBundlePath; /// HTTP status codes that are expected and should not be logged as warnings. /// 404 is always treated as expected regardless of this list. @@ -70,22 +114,22 @@ class HttpClientError : public std::runtime_error public: using _Mybase = runtime_error; - HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode) + HttpClientError(const std::string& Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode) : _Mybase(Message) , m_Error(Error) , m_ResponseCode(ResponseCode) { } - HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode) + HttpClientError(const char* Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode) : _Mybase(Message) , m_Error(Error) , m_ResponseCode(ResponseCode) { } - inline int GetInternalErrorCode() const { return m_Error; } - inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; } + inline HttpClientErrorCode GetInternalErrorCode() const { return m_Error; } + inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; } enum class ResponseClass : std::int8_t { @@ -112,8 +156,8 @@ public: ResponseClass GetResponseClass() const; private: - const int m_Error = 0; - const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot; + const HttpClientErrorCode m_Error = HttpClientErrorCode::kOK; + const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot; }; class HttpClientBase; @@ -137,11 +181,23 @@ public: struct ErrorContext { - int ErrorCode = 0; - std::string ErrorMessage; + HttpClientErrorCode ErrorCode; + std::string ErrorMessage; /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */ - bool IsConnectionError() const; + bool IsConnectionError() const + { + switch (ErrorCode) + { + case HttpClientErrorCode::kConnectionFailure: + case HttpClientErrorCode::kOperationTimedOut: + case HttpClientErrorCode::kHostResolutionFailure: + case HttpClientErrorCode::kProxyResolutionFailure: + return true; + default: + return false; + } + } }; struct KeyValueMap diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 0e1714669..d98877d16 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -255,6 +255,9 @@ public: */ std::string_view GetExternalHost() const { return m_ExternalHost; } + /** Returns the effective HTTPS port, or 0 if HTTPS is not enabled. Only valid after Initialize(). */ + int GetEffectiveHttpsPort() const { return m_EffectiveHttpsPort; } + /** Returns total bytes received and sent across all connections since server start. */ virtual uint64_t GetTotalBytesReceived() const { return 0; } virtual uint64_t GetTotalBytesSent() const { return 0; } @@ -290,7 +293,8 @@ public: private: std::vector<HttpService*> m_KnownServices; - int m_EffectivePort = 0; + int m_EffectivePort = 0; + int m_EffectiveHttpsPort = 0; std::string m_ExternalHost; metrics::Meter m_RequestMeter; std::string m_DefaultRedirect; @@ -308,6 +312,7 @@ private: virtual void OnClose() = 0; protected: + void SetEffectiveHttpsPort(int Port) { m_EffectiveHttpsPort = Port; } virtual std::string OnGetExternalHost() const; }; @@ -324,12 +329,20 @@ struct HttpServerConfig std::vector<HttpServerPluginConfig> PluginConfigs; bool ForceLoopback = false; unsigned int ThreadCount = 0; + std::string UnixSocketPath; // Unix domain socket path (empty = disabled, non-Windows only) + int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend) + std::string CertFile; // PEM certificate chain file path + std::string KeyFile; // PEM private key file path struct { unsigned int AsyncWorkThreadCount = 0; bool IsAsyncResponseEnabled = true; bool IsRequestLoggingEnabled = false; + int HttpsPort = 0; // 0 = HTTPS disabled + std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding + std::string CertStoreName = "MY"; // Windows certificate store name + bool HttpsOnly = false; // When true, disable HTTP listener } HttpSys; }; diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 926ec1e3d..34d338b1d 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -43,6 +43,10 @@ struct HttpWsClientSettings std::string LogCategory = "wsclient"; std::chrono::milliseconds ConnectTimeout{5000}; std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; + + /// Unix domain socket path. When non-empty, connects via this socket + /// instead of TCP. The URL host is still used for the Host header. + std::string UnixSocketPath; }; /** diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h new file mode 100644 index 000000000..25aeaa24e --- /dev/null +++ b/src/zenhttp/servers/asio_socket_traits.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_http { + +/** + * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets. + * SSL sockets need lowest_layer() access and have different shutdown semantics. + */ +template<typename SocketType> +struct SocketTraits +{ + /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because + /// those bypass the encryption layer. This flag lets templated code fall back + /// to reading-into-memory for SSL connections. + static constexpr bool IsSslSocket = false; + + static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); } + + static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); } + + static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); } +}; + +#if ZEN_USE_OPENSSL +using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; + +template<> +struct SocketTraits<SslSocket> +{ + static constexpr bool IsSslSocket = true; + + static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); } + + static void ShutdownBoth(SslSocket& S, std::error_code& Ec) + { + // Best-effort SSL close_notify, then TCP shutdown + S.shutdown(Ec); + S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec); + } + + static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); } +}; +#endif + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index f5178ebe8..ee8e71256 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpasio.h" +#include "asio_socket_traits.h" #include "httptracer.h" #include <zencore/except.h> @@ -35,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START #endif #include <asio.hpp> #include <asio/stream_file.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #define ASIO_VERBOSE_TRACE 0 @@ -144,7 +151,17 @@ using namespace std::literals; struct HttpAcceptor; struct HttpResponse; -struct HttpServerConnection; +template<typename SocketType> +struct HttpServerConnectionT; +using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>; +#if defined(ASIO_HAS_LOCAL_SOCKETS) +struct UnixAcceptor; +using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>; +#endif +#if ZEN_USE_OPENSSL +struct HttpsAcceptor; +using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>; +#endif inline LoggerRef InitLogger() @@ -176,9 +193,9 @@ Log() #endif #if ZEN_USE_TRANSMITFILE -template<typename Handler> +template<typename Handler, typename SocketType> void -TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb) +TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb) { # if ZEN_BUILD_DEBUG const uint64_t FileSize = FileSizeFromHandle(FileHandle); @@ -511,11 +528,20 @@ public: bool IsLoopbackOnly() const; + int GetEffectiveHttpsPort() const; + asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; - std::vector<std::thread> m_ThreadPool; - std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor; +#endif +#if ZEN_USE_OPENSSL + std::unique_ptr<asio::ssl::context> m_SslContext; + std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor; +#endif + std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -573,6 +599,7 @@ public: uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; bool m_IsLocalMachineRequest; + bool m_AllowZeroCopyFileSend = true; std::string m_RemoteAddress; std::unique_ptr<HttpResponse> m_Response; }; @@ -595,6 +622,8 @@ public: ~HttpResponse() = default; + void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } + /** * Initialize the response for sending a payload made up of multiple blobs * @@ -636,7 +665,7 @@ public: bool ChunkHandled = false; #if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE - if (OwnedBuffer.IsWholeFile()) + if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile()) { if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef)) { @@ -751,7 +780,8 @@ public: return m_Headers; } - void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) + template<typename SocketType> + void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) { ZEN_ASSERT(m_State == State::kInitialized); @@ -761,10 +791,11 @@ public: m_SendCb = std::move(Token); m_State = State::kSending; - SendNextChunk(TcpSocket); + SendNextChunk(Socket); } - void SendNextChunk(asio::ip::tcp::socket& TcpSocket) + template<typename SocketType> + void SendNextChunk(SocketType& Socket) { ZEN_ASSERT(m_State == State::kSending); @@ -781,12 +812,12 @@ public: auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); }; - asio::defer(TcpSocket.get_executor(), std::move(CompletionToken)); + asio::defer(Socket.get_executor(), std::move(CompletionToken)); return; } - auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) { + auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) { ZEN_ASSERT(m_State == State::kSending); m_TotalBytesSent += ByteCount; @@ -797,7 +828,7 @@ public: } else { - SendNextChunk(TcpSocket); + SendNextChunk(Socket); } }; @@ -811,25 +842,21 @@ public: Io.Ref.FileRef.FileChunkSize); #if ZEN_USE_TRANSMITFILE - TransmitFileAsync(TcpSocket, + TransmitFileAsync(Socket, Io.Ref.FileRef.FileHandle, Io.Ref.FileRef.FileChunkOffset, gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize), OnCompletion); + return; #elif ZEN_USE_ASYNC_SENDFILE - SendFileAsync(TcpSocket, + SendFileAsync(Socket, Io.Ref.FileRef.FileHandle, Io.Ref.FileRef.FileChunkOffset, Io.Ref.FileRef.FileChunkSize, 64 * 1024, OnCompletion); -#else - // This should never occur unless we compile with one - // of the options above - ZEN_WARN("invalid file reference in response"); -#endif - return; +#endif } // Send as many consecutive non-file references as possible in one asio operation @@ -850,7 +877,7 @@ public: ++m_IoVecCursor; } - asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion); + asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion); } private: @@ -863,12 +890,13 @@ private: kFailed }; - uint32_t m_RequestNumber = 0; - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - State m_State = State::kUninitialized; - HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + uint32_t m_RequestNumber = 0; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + bool m_AllowZeroCopyFileSend = true; + State m_State = State::kUninitialized; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -895,12 +923,13 @@ private: ////////////////////////////////////////////////////////////////////////// -struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection> +template<typename SocketType> +struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>> { - HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); - ~HttpServerConnection(); + HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket); + ~HttpServerConnectionT(); - std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); } // HttpConnectionBase implementation @@ -962,12 +991,13 @@ private: RwLock m_ActiveResponsesLock; std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<SocketType> m_Socket; }; std::atomic<uint32_t> g_ConnectionIdCounter{0}; -HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +template<typename SocketType> +HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket) : m_Server(Server) , m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) , m_Socket(std::move(Socket)) @@ -975,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); } -HttpServerConnection::~HttpServerConnection() +template<typename SocketType> +HttpServerConnectionT<SocketType>::~HttpServerConnectionT() { RwLock::ExclusiveLockScope _(m_ActiveResponsesLock); ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); } +template<typename SocketType> void -HttpServerConnection::HandleNewRequest() +HttpServerConnectionT<SocketType>::HandleNewRequest() { EnqueueRead(); } +template<typename SocketType> void -HttpServerConnection::TerminateConnection() +HttpServerConnectionT<SocketType>::TerminateConnection() { if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) { @@ -1001,12 +1034,13 @@ HttpServerConnection::TerminateConnection() // Terminating, we don't care about any errors when closing socket std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); - m_Socket->close(Ec); + SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec); + SocketTraits<SocketType>::Close(*m_Socket, Ec); } +template<typename SocketType> void -HttpServerConnection::EnqueueRead() +HttpServerConnectionT<SocketType>::EnqueueRead() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1027,8 +1061,9 @@ HttpServerConnection::EnqueueRead() [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); } +template<typename SocketType> void -HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1086,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } +template<typename SocketType> void -HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, - [[maybe_unused]] std::size_t ByteCount, - [[maybe_unused]] uint32_t RequestNumber, - HttpResponse* ResponseToPop) +HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec, + [[maybe_unused]] std::size_t ByteCount, + [[maybe_unused]] uint32_t RequestNumber, + HttpResponse* ResponseToPop) { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1144,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, } } +template<typename SocketType> void -HttpServerConnection::CloseConnection() +HttpServerConnectionT<SocketType>::CloseConnection() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1157,23 +1194,24 @@ HttpServerConnection::CloseConnection() m_RequestState = RequestState::kDone; std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec); if (Ec) { ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); } - m_Socket->close(Ec); + SocketTraits<SocketType>::Close(*m_Socket, Ec); if (Ec) { ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); } } +template<typename SocketType> void -HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, - std::string_view StatusLine, - std::string_view Headers, - std::string_view Body) +HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber, + std::string_view StatusLine, + std::string_view Headers, + std::string_view Body) { ExtendableStringBuilder<256> ResponseBuilder; ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n"; @@ -1194,15 +1232,16 @@ HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size()); auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize()); asio::async_write( - *m_Socket.get(), + *m_Socket, Buffer, [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); }); } +template<typename SocketType> void -HttpServerConnection::HandleRequest() +HttpServerConnectionT<SocketType>::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1229,24 +1268,25 @@ HttpServerConnection::HandleRequest() ResponseStr->append("\r\n\r\n"); // Send the 101 response on the current socket, then hand the socket off - // to a WsAsioConnection for the WebSocket protocol. - asio::async_write(*m_Socket, - asio::buffer(ResponseStr->data(), ResponseStr->size()), - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); - return; - } - - Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); - Ref<WsAsioConnection> WsConn( - new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); - Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); - WsConn->Start(); - }); + // to a WsAsioConnectionT for the WebSocket protocol. + asio::async_write( + *m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + using WsConnType = WsAsioConnectionT<SocketType>; + Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); m_RequestState = RequestState::kDone; return; @@ -1260,7 +1300,7 @@ HttpServerConnection::HandleRequest() m_RequestState = RequestState::kWritingFinal; std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec); if (Ec) { @@ -1280,15 +1320,36 @@ HttpServerConnection::HandleRequest() m_Server.m_HttpServer->MarkRequest(); - auto RemoteEndpoint = m_Socket->remote_endpoint(); - bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + bool IsLocalConnection = true; + std::string RemoteAddress; + + if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>) + { + auto RemoteEndpoint = m_Socket->remote_endpoint(); + IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + RemoteAddress = RemoteEndpoint.address().to_string(); + } +#if ZEN_USE_OPENSSL + else if constexpr (std::is_same_v<SocketType, SslSocket>) + { + auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint(); + IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address(); + RemoteAddress = RemoteEndpoint.address().to_string(); + } +#endif + else + { + RemoteAddress = "unix"; + } HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection, - RemoteEndpoint.address().to_string()); + std::move(RemoteAddress)); + + Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket; ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); @@ -1439,14 +1500,23 @@ HttpServerConnection::HandleRequest() } ////////////////////////////////////////////////////////////////////////// +// Base class for TCP acceptors that handles socket setup, port binding +// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support. +// Subclasses only need to implement OnAccept() to handle new connections. -struct HttpAcceptor +struct TcpAcceptorBase { - HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + TcpAcceptorBase(HttpAsioServerImpl& Server, + asio::io_service& IoService, + uint16_t BasePort, + bool ForceLoopback, + bool AllowPortProbing, + std::string_view Label) : m_Server(Server) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) , m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4()) + , m_Label(Label) { const bool IsUsingIPv6 = IsIPv6Capable(); if (!IsUsingIPv6) @@ -1455,7 +1525,6 @@ struct HttpAcceptor } #if ZEN_PLATFORM_WINDOWS - // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address; m_Acceptor.set_option(exclusive_address(true)); m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); @@ -1468,83 +1537,54 @@ struct HttpAcceptor #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - std::string BoundBaseUrl; if (IsUsingIPv6) { - BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing); + BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing); } else { - ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only"); - - BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing); + ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label); + BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing); } + } - if (!IsValid()) - { - return; - } - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor.native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); - - if (m_UseAlternateProtocolAcceptor) - { - NativeSocket = m_AlternateProtocolAcceptor.native_handle(); - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); - } -#endif - m_Acceptor.listen(); + virtual ~TcpAcceptorBase() + { + m_Acceptor.close(); if (m_UseAlternateProtocolAcceptor) { - m_AlternateProtocolAcceptor.listen(); + m_AlternateProtocolAcceptor.close(); } - - ZEN_INFO("Started asio server at '{}", BoundBaseUrl); } - ~HttpAcceptor() + void Start() { - m_Acceptor.close(); + ZEN_ASSERT(!m_IsStopped); + InitAcceptLoop(m_Acceptor); if (m_UseAlternateProtocolAcceptor) { - m_AlternateProtocolAcceptor.close(); + InitAcceptLoop(m_AlternateProtocolAcceptor); } } + void StopAccepting() { m_IsStopped = true; } + + uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); } + bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } + bool IsValid() const { return m_IsValid; } + +protected: + /// Called for each accepted TCP socket. Subclasses create the appropriate connection type. + virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0; + + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + +private: template<typename AddressType> - std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) { uint16_t EffectivePort = BasePort; @@ -1571,7 +1611,7 @@ struct HttpAcceptor if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback()) { - ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort); + ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort); BindAddress = AddressType::loopback(); @@ -1585,7 +1625,7 @@ struct HttpAcceptor if (BindErrorCode == asio::error::address_in_use) { - ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message()); + ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message()); Sleep(500); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } @@ -1601,7 +1641,8 @@ struct HttpAcceptor if (BindErrorCode) { - ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')", + ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')", + m_Label, BindErrorCode.message()); EffectivePort = 0; @@ -1617,7 +1658,7 @@ struct HttpAcceptor { for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message()); + ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message()); Sleep(500); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } @@ -1625,14 +1666,13 @@ struct HttpAcceptor if (BindErrorCode) { - ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message()); - - return {}; + ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message()); + return; } if (EffectivePort != BasePort) { - ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort); + ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort); } if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>) @@ -1642,55 +1682,64 @@ struct HttpAcceptor // IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does // IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1) - m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode); + asio::error_code AltEc; + m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc); - if (BindErrorCode) + if (AltEc) { - ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort); + ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort); } else { m_UseAlternateProtocolAcceptor = true; - ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts", - "localhost", - EffectivePort); } } } - m_IsValid = true; +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); - if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>) - { - return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort); - } - else + if (m_UseAlternateProtocolAcceptor) { - return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort); + NativeSocket = m_AlternateProtocolAcceptor.native_handle(); + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); } - } - - void Start() - { - ZEN_MEMSCOPE(GetHttpasioTag()); +#endif - ZEN_ASSERT(!m_IsStopped); - InitAcceptInternal(m_Acceptor); + m_Acceptor.listen(); if (m_UseAlternateProtocolAcceptor) { - InitAcceptInternal(m_AlternateProtocolAcceptor); + m_AlternateProtocolAcceptor.listen(); } - } - void StopAccepting() { m_IsStopped = true; } - - int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); } - bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } - - bool IsValid() const { return m_IsValid; } + m_IsValid = true; + ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port()); + } -private: - void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor) + void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor) { auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); @@ -1698,29 +1747,19 @@ private: Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", - Acceptor.local_endpoint().address().to_string(), - Acceptor.local_endpoint().port(), - Ec.message()); + if (!m_IsStopped.load()) + { + ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message()); + } } else { - // New connection established, pass socket ownership into connection object - // and initiate request handling loop. The connection lifetime is - // managed by the async read/write loop by passing the shared - // reference to the callbacks. - - Socket->set_option(asio::ip::tcp::no_delay(true)); - Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); - Conn->HandleNewRequest(); + OnAccept(std::move(Socket)); } if (!m_IsStopped.load()) { - InitAcceptInternal(Acceptor); + InitAcceptLoop(Acceptor); } else { @@ -1728,21 +1767,204 @@ private: Acceptor.close(CloseEc); if (CloseEc) { - ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message()); } } }); } - HttpAsioServerImpl& m_Server; - asio::io_service& m_IoService; asio::ip::tcp::acceptor m_Acceptor; asio::ip::tcp::acceptor m_AlternateProtocolAcceptor; bool m_UseAlternateProtocolAcceptor{false}; bool m_IsValid{false}; std::atomic<bool> m_IsStopped{false}; + std::string_view m_Label; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor final : TcpAcceptorBase +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP") + { + } + + int GetAcceptPort() const { return GetPort(); } + +protected: + void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override + { + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } }; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + +////////////////////////////////////////////////////////////////////////// + +struct UnixAcceptor +{ + UnixAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, const std::string& SocketPath) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService) + , m_SocketPath(SocketPath) + { + // Remove any stale socket file from a previous run + std::filesystem::remove(m_SocketPath); + + asio::local::stream_protocol::endpoint Endpoint(m_SocketPath); + + asio::error_code Ec; + m_Acceptor.open(Endpoint.protocol(), Ec); + if (Ec) + { + ZEN_WARN("failed to open unix domain socket: {}", Ec.message()); + return; + } + + m_Acceptor.bind(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message()); + return; + } + + m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec); + if (Ec) + { + ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message()); + return; + } + + m_IsValid = true; + ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath); + } + + ~UnixAcceptor() + { + asio::error_code Ec; + m_Acceptor.close(Ec); + std::filesystem::remove(m_SocketPath); + } + + void Start() + { + ZEN_ASSERT(!m_IsStopped); + InitAccept(); + } + + void StopAccepting() { m_IsStopped = true; } + + bool IsValid() const { return m_IsValid; } + +private: + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService); + asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + if (!m_IsStopped.load()) + { + ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message()); + } + } + else + { + auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + } + }); + } + + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::local::stream_protocol::acceptor m_Acceptor; + std::string m_SocketPath; + bool m_IsValid{false}; + std::atomic<bool> m_IsStopped{false}; +}; + +#endif // ASIO_HAS_LOCAL_SOCKETS + +#if ZEN_USE_OPENSSL + +////////////////////////////////////////////////////////////////////////// + +struct HttpsAcceptor final : TcpAcceptorBase +{ + HttpsAcceptor(HttpAsioServerImpl& Server, + asio::io_service& IoService, + asio::ssl::context& SslContext, + uint16_t Port, + bool ForceLoopback, + bool AllowPortProbing) + : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS") + , m_SslContext(SslContext) + { + } + +protected: + void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override + { + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + // Wrap accepted TCP socket in an SSL stream and perform the handshake + auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext); + + SslSocket& SslRef = *SslSocketPtr; + SslRef.async_handshake(asio::ssl::stream_base::server, + [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable { + if (HandshakeEc) + { + ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message()); + std::error_code Ec; + SslSocket->lowest_layer().close(Ec); + return; + } + + auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket)); + Conn->HandleNewRequest(); + }); + } + +private: + asio::ssl::context& m_SslContext; +}; + +#endif // ZEN_USE_OPENSSL + +int +HttpAsioServerImpl::GetEffectiveHttpsPort() const +{ +#if ZEN_USE_OPENSSL + return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0; +#else + return 0; +#endif +} + ////////////////////////////////////////////////////////////////////////// HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, @@ -1860,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -1873,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -1883,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); @@ -1942,6 +2167,51 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config) m_Acceptor->Start(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!Config.UnixSocketPath.empty()) + { + m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath)); + + if (m_UnixAcceptor->IsValid()) + { + m_UnixAcceptor->Start(); + } + else + { + m_UnixAcceptor.reset(); + } + } +#endif + +#if ZEN_USE_OPENSSL + if (!Config.CertFile.empty() && !Config.KeyFile.empty()) + { + m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server); + m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 | + asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + m_SslContext->use_certificate_chain_file(Config.CertFile); + m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem); + + ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile); + + m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this, + m_IoService, + *m_SslContext, + gsl::narrow<uint16_t>(Config.HttpsPort), + Config.ForceLoopback, + /*AllowPortProbing*/ !Config.IsDedicatedServer)); + + if (m_HttpsAcceptor->IsValid()) + { + m_HttpsAcceptor->Start(); + } + else + { + m_HttpsAcceptor.reset(); + } + } +#endif + // This should consist of a set of minimum threads and grow on demand to // meet concurrency needs? Right now we end up allocating a large number // of threads even if we never end up using all of them, which seems @@ -1990,6 +2260,18 @@ HttpAsioServerImpl::Stop() { m_Acceptor->StopAccepting(); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixAcceptor) + { + m_UnixAcceptor->StopAccepting(); + } +#endif +#if ZEN_USE_OPENSSL + if (m_HttpsAcceptor) + { + m_HttpsAcceptor->StopAccepting(); + } +#endif m_IoService.stop(); for (auto& Thread : m_ThreadPool) { @@ -1999,7 +2281,23 @@ HttpAsioServerImpl::Stop() } } m_ThreadPool.clear(); + + // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket + // connections) so that their captured Ref<> pointers are released while the + // io_service and its epoll reactor are still alive. Without this, sockets + // held by external code (e.g. IWebSocketHandler connection lists) can outlive + // the reactor and crash during deregistration. + m_IoService.restart(); + m_IoService.poll(); + m_Acceptor.reset(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + m_UnixAcceptor.reset(); +#endif +#if ZEN_USE_OPENSSL + m_HttpsAcceptor.reset(); + m_SslContext.reset(); +#endif } void @@ -2166,6 +2464,13 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config); +#if ZEN_USE_OPENSSL + if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0) + { + SetEffectiveHttpsPort(EffectiveHttpsPort); + } +#endif + return m_BasePort; } diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h index 3ec1141a7..5adf4d5e8 100644 --- a/src/zenhttp/servers/httpasio.h +++ b/src/zenhttp/servers/httpasio.h @@ -11,6 +11,12 @@ struct AsioConfig unsigned int ThreadCount = 0; bool ForceLoopback = false; bool IsDedicatedServer = false; + std::string UnixSocketPath; +#if ZEN_USE_OPENSSL + int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS + std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled) + std::string KeyFile; // PEM private key file +#endif }; Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index dfe6bb6aa..83b98013e 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -116,6 +116,12 @@ public: private: int InitializeServer(int BasePort); + bool CreateSessionAndUrlGroup(); + bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris); + int RegisterHttpUrls(int BasePort); + bool RegisterHttpsUrls(); + bool CreateRequestQueue(int EffectivePort); + bool SetupIoCompletionPort(); void Cleanup(); void StartServer(); @@ -125,6 +131,9 @@ private: void RegisterService(const char* Endpoint, HttpService& Service); void UnregisterService(const char* Endpoint, HttpService& Service); + bool BindSslCertificate(int Port); + void UnbindSslCertificate(); + private: LoggerRef m_Log; LoggerRef m_RequestLog; @@ -140,7 +149,10 @@ private: RwLock m_AsyncWorkPoolInitLock; std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; - std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/ + bool m_DidAutoBindCert = false; + int m_HttpsPort = 0; HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; HANDLE m_RequestQueueHandle = 0; @@ -1082,39 +1094,63 @@ HttpSysServer::OnClose() } } -int -HttpSysServer::InitializeServer(int BasePort) +bool +HttpSysServer::CreateSessionAndUrlGroup() { - ZEN_MEMSCOPE(GetHttpsysTag()); - - using namespace std::literals; - - WideStringBuilder<64> WildcardUrlPath; - WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_IsOk = false; - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})", - WideToUtf8(WildcardUrlPath), - GetSystemErrorAsString(Result), - Result); + ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result); + ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } + return true; +} + +bool +HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris) +{ + using namespace std::literals; + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrl; + LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (Result == NO_ERROR) + { + ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl)); + OutUris.push_back(LocalUrl.c_str()); + } + else + { + break; + } + } + + return !OutUris.empty(); +} + +int +HttpSysServer::RegisterHttpUrls(int BasePort) +{ + using namespace std::literals; + m_BaseUris.clear(); const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer; @@ -1122,6 +1158,11 @@ HttpSysServer::InitializeServer(int BasePort) int EffectivePort = BasePort; + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + ULONG Result; + if (m_InitialConfig.ForceLoopback) { // Force trigger of opening using local port @@ -1177,11 +1218,11 @@ HttpSysServer::InitializeServer(int BasePort) { if (AllowLocalOnly) { - // If we can't register the wildcard path, we fall back to local paths - // This local paths allow requests originating locally to function, but will not allow - // remote origin requests to function. This can be remedied by using netsh + // If we can't register the wildcard path, we fall back to local paths. + // Local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh // during an install process to grant permissions to route public access to the appropriate - // port for the current user. eg: + // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> if (!m_InitialConfig.ForceLoopback) @@ -1246,7 +1287,7 @@ HttpSysServer::InitializeServer(int BasePort) } } - if (m_BaseUris.empty()) + if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0) { ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), @@ -1256,16 +1297,104 @@ HttpSysServer::InitializeServer(int BasePort) return 0; } + return EffectivePort; +} + +bool +HttpSysServer::RegisterHttpsUrls() +{ + using namespace std::literals; + + const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer; + const int HttpsPort = m_InitialConfig.HttpsPort; + + // If HTTPS-only mode, remove HTTP URLs and clear base URIs + if (m_InitialConfig.HttpsOnly) + { + for (const std::wstring& Uri : m_BaseUris) + { + HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0); + } + m_BaseUris.clear(); + } + + // Auto-bind certificate if thumbprint is provided + if (!m_InitialConfig.CertThumbprint.empty()) + { + if (!BindSslCertificate(HttpsPort)) + { + return false; + } + } + else + { + ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort); + } + + // Register HTTPS URLs using same pattern as HTTP + + WideStringBuilder<64> HttpsWildcard; + HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv; + + ULONG HttpsResult = NO_ERROR; + + if (m_InitialConfig.ForceLoopback) + { + HttpsResult = ERROR_ACCESS_DENIED; + } + else + { + HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + if (HttpsResult == NO_ERROR) + { + m_HttpsBaseUris.push_back(HttpsWildcard.c_str()); + } + else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly) + { + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register HTTPS handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.", + WideToUtf8(HttpsWildcard)); + } + + RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris); + } + else if (HttpsResult != NO_ERROR) + { + ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})", + WideToUtf8(HttpsWildcard), + GetSystemErrorAsString(HttpsResult), + HttpsResult); + return false; + } + + if (m_HttpsBaseUris.empty()) + { + ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort); + return false; + } + + m_HttpsPort = HttpsPort; + return true; +} + +bool +HttpSysServer::CreateRequestQueue(int EffectivePort) +{ HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; WideStringBuilder<64> QueueName; QueueName << "zenserver_" << EffectivePort; - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, - /* Name */ QueueName.c_str(), - /* SecurityAttributes */ nullptr, - /* Flags */ 0, - &m_RequestQueueHandle); + ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); if (Result != NO_ERROR) { @@ -1274,7 +1403,7 @@ HttpSysServer::InitializeServer(int BasePort) GetSystemErrorAsString(Result), Result); - return 0; + return false; } HttpBindingInfo.Flags.Present = 1; @@ -1289,7 +1418,7 @@ HttpSysServer::InitializeServer(int BasePort) GetSystemErrorAsString(Result), Result); - return 0; + return false; } // Configure rejection method. Default is to drop the connection, it's better if we @@ -1323,22 +1452,77 @@ HttpSysServer::InitializeServer(int BasePort) } } - // Create I/O completion port + return true; +} +bool +HttpSysServer::SetupIoCompletionPort() +{ std::error_code ErrorCode; m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); if (ErrorCode) { - ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message()); + return false; + } + m_IsOk = true; + + if (!m_BaseUris.empty()) + { + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + if (!m_HttpsBaseUris.empty()) + { + ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front())); + } + + return true; +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + ZEN_MEMSCOPE(GetHttpsysTag()); + + m_IsOk = false; + + if (!CreateSessionAndUrlGroup()) + { return 0; } - else + + int EffectivePort = RegisterHttpUrls(BasePort); + + if (m_InitialConfig.HttpsPort > 0) + { + if (!RegisterHttpsUrls()) + { + return 0; + } + } + + if (m_BaseUris.empty() && m_HttpsBaseUris.empty()) { - m_IsOk = true; + ZEN_ERROR("No HTTP or HTTPS listeners could be registered"); + return 0; + } - ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + if (!CreateRequestQueue(EffectivePort)) + { + return 0; + } + + if (!SetupIoCompletionPort()) + { + return 0; + } + + // When HTTPS-only, return the HTTPS port as the effective port + if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0) + { + return m_HttpsPort; } return EffectivePort; @@ -1349,6 +1533,8 @@ HttpSysServer::Cleanup() { ++m_IsShuttingDown; + UnbindSslCertificate(); + if (m_RequestQueueHandle) { HttpCloseRequestQueue(m_RequestQueueHandle); @@ -1368,6 +1554,105 @@ HttpSysServer::Cleanup() } } +// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings +static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}}; + +bool +HttpSysServer::BindSslCertificate(int Port) +{ + const std::string& Thumbprint = m_InitialConfig.CertThumbprint; + if (Thumbprint.size() != 40) + { + ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size()); + return false; + } + + BYTE CertHash[20] = {}; + if (!ParseHexBytes(Thumbprint, CertHash)) + { + ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters"); + return false; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(Port)); + Address.sin_addr.s_addr = INADDR_ANY; + + const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str()); + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + SslConfig.ParamDesc.pSslHash = CertHash; + SslConfig.ParamDesc.SslHashLength = sizeof(CertHash); + SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str()); + SslConfig.ParamDesc.AppId = ZenServerSslAppId; + + ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result == ERROR_ALREADY_EXISTS) + { + // Remove existing binding and retry + HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {}; + DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr); + + Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + } + + if (Result != NO_ERROR) + { + ZEN_ERROR( + "Failed to bind SSL certificate to port {}: {} ({:#x}). " + "This operation may require running as administrator.", + Port, + GetSystemErrorAsString(Result), + Result); + return false; + } + + m_DidAutoBindCert = true; + m_HttpsPort = Port; + + ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})", + Port, + Thumbprint.substr(0, 8), + m_InitialConfig.CertStoreName); + + return true; +} + +void +HttpSysServer::UnbindSslCertificate() +{ + if (!m_DidAutoBindCert) + { + return; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort)); + Address.sin_addr.s_addr = INADDR_ANY; + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result != NO_ERROR) + { + ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result); + } + else + { + ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort); + } + + m_DidAutoBindCert = false; +} + WorkerThreadPool& HttpSysServer::WorkPool() { @@ -1495,19 +1780,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) // Convert to wide string - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) + auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) { - ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + std::wstring Url16 = BaseUri + PathUtf16; - return; + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + return; + } } - } + }; + + RegisterWithBaseUris(m_BaseUris); + RegisterWithBaseUris(m_HttpsBaseUris); } void @@ -1522,19 +1811,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; + auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } } - } + }; + + UnregisterFromBaseUris(m_BaseUris); + UnregisterFromBaseUris(m_HttpsBaseUris); } ////////////////////////////////////////////////////////////////////////// @@ -2422,6 +2714,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir) ZEN_UNUSED(DataDir); if (int EffectivePort = InitializeServer(BasePort)) { + if (m_HttpsPort > 0) + { + SetEffectiveHttpsPort(m_HttpsPort); + } + StartServer(); return EffectivePort; diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h index b2fe7475b..ca465ad00 100644 --- a/src/zenhttp/servers/httpsys.h +++ b/src/zenhttp/servers/httpsys.h @@ -22,6 +22,10 @@ struct HttpSysConfig bool IsRequestLoggingEnabled = false; bool IsDedicatedServer = false; bool ForceLoopback = false; + int HttpsPort = 0; // 0 = HTTPS disabled + std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding + std::string CertStoreName = "MY"; // Windows certificate store name + bool HttpsOnly = false; // When true, disable HTTP listener }; Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config); diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index b2543277a..5ae48f5b3 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -1,6 +1,7 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "wsasio.h" +#include "asio_socket_traits.h" #include "wsframecodec.h" #include <zencore/logging.h> @@ -17,14 +18,16 @@ WsLog() ////////////////////////////////////////////////////////////////////////// -WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server) +template<typename SocketType> +WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server) : m_Socket(std::move(Socket)) , m_Handler(Handler) , m_HttpServer(Server) { } -WsAsioConnection::~WsAsioConnection() +template<typename SocketType> +WsAsioConnectionT<SocketType>::~WsAsioConnectionT() { m_IsOpen.store(false); if (m_HttpServer) @@ -33,14 +36,16 @@ WsAsioConnection::~WsAsioConnection() } } +template<typename SocketType> void -WsAsioConnection::Start() +WsAsioConnectionT<SocketType>::Start() { EnqueueRead(); } +template<typename SocketType> bool -WsAsioConnection::IsOpen() const +WsAsioConnectionT<SocketType>::IsOpen() const { return m_IsOpen.load(std::memory_order_relaxed); } @@ -50,23 +55,25 @@ WsAsioConnection::IsOpen() const // Read loop // +template<typename SocketType> void -WsAsioConnection::EnqueueRead() +WsAsioConnectionT<SocketType>::EnqueueRead() { if (!m_IsOpen.load(std::memory_order_relaxed)) { return; } - Ref<WsAsioConnection> Self(this); + Ref<WsAsioConnectionT> Self(this); asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnDataReceived(Ec, ByteCount); }); } +template<typename SocketType> void -WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { @@ -90,8 +97,9 @@ WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] st } } +template<typename SocketType> void -WsAsioConnection::ProcessReceivedData() +WsAsioConnectionT<SocketType>::ProcessReceivedData() { while (m_ReadBuffer.size() > 0) { @@ -162,8 +170,8 @@ WsAsioConnection::ProcessReceivedData() // Shut down the socket std::error_code ShutdownEc; - m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); - m_Socket->close(ShutdownEc); + SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc); + SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc); return; } @@ -179,8 +187,9 @@ WsAsioConnection::ProcessReceivedData() // Write queue // +template<typename SocketType> void -WsAsioConnection::SendText(std::string_view Text) +WsAsioConnectionT<SocketType>::SendText(std::string_view Text) { if (!m_IsOpen.load(std::memory_order_relaxed)) { @@ -192,8 +201,9 @@ WsAsioConnection::SendText(std::string_view Text) EnqueueWrite(std::move(Frame)); } +template<typename SocketType> void -WsAsioConnection::SendBinary(std::span<const uint8_t> Data) +WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data) { if (!m_IsOpen.load(std::memory_order_relaxed)) { @@ -204,14 +214,16 @@ WsAsioConnection::SendBinary(std::span<const uint8_t> Data) EnqueueWrite(std::move(Frame)); } +template<typename SocketType> void -WsAsioConnection::Close(uint16_t Code, std::string_view Reason) +WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason) { DoClose(Code, Reason); } +template<typename SocketType> void -WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) +WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason) { if (!m_IsOpen.exchange(false)) { @@ -227,8 +239,9 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) m_Handler.OnWebSocketClose(*this, Code, Reason); } +template<typename SocketType> void -WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame) +WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame) { if (m_HttpServer) { @@ -252,8 +265,9 @@ WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame) } } +template<typename SocketType> void -WsAsioConnection::FlushWriteQueue() +WsAsioConnectionT<SocketType>::FlushWriteQueue() { std::vector<uint8_t> Frame; @@ -272,7 +286,7 @@ WsAsioConnection::FlushWriteQueue() return; } - Ref<WsAsioConnection> Self(this); + Ref<WsAsioConnectionT> Self(this); // Move Frame into a shared_ptr so we can create the buffer and capture ownership // in the same async_write call without evaluation order issues. @@ -283,8 +297,9 @@ WsAsioConnection::FlushWriteQueue() [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); } +template<typename SocketType> void -WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { @@ -308,4 +323,17 @@ WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] s FlushWriteQueue(); } +////////////////////////////////////////////////////////////////////////// +// Explicit template instantiations + +template class WsAsioConnectionT<asio::ip::tcp::socket>; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +template class WsAsioConnectionT<asio::local::stream_protocol::socket>; +#endif + +#if ZEN_USE_OPENSSL +template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>; +#endif + } // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h index e8bb3b1d2..64602ee46 100644 --- a/src/zenhttp/servers/wsasio.h +++ b/src/zenhttp/servers/wsasio.h @@ -8,6 +8,12 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #include <deque> @@ -21,22 +27,23 @@ class HttpServer; namespace zen::asio_http { /** - * WebSocket connection over an ASIO TCP socket + * WebSocket connection over an ASIO stream socket * - * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake) + * Templated on SocketType to support both TCP and Unix domain sockets. + * Owns the socket (moved from HttpServerConnection after the 101 handshake) * and runs an async read/write loop to exchange WebSocket frames. * * Lifetime is managed solely through intrusive reference counting (RefCounted). - * The async read/write callbacks capture Ref<WsAsioConnection> to keep the - * connection alive for the duration of the async operation. The service layer - * also holds a Ref<WebSocketConnection>. + * The async read/write callbacks capture Ref<> to keep the connection alive + * for the duration of the async operation. The service layer also holds a + * Ref<WebSocketConnection>. */ - -class WsAsioConnection : public WebSocketConnection +template<typename SocketType> +class WsAsioConnectionT : public WebSocketConnection { public: - WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server); - ~WsAsioConnection() override; + WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server); + ~WsAsioConnectionT() override; /** * Start the async read loop. Must be called once after construction @@ -61,10 +68,10 @@ private: void DoClose(uint16_t Code, std::string_view Reason); - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - IWebSocketHandler& m_Handler; - zen::HttpServer* m_HttpServer; - asio::streambuf m_ReadBuffer; + std::unique_ptr<SocketType> m_Socket; + IWebSocketHandler& m_Handler; + zen::HttpServer* m_HttpServer; + asio::streambuf m_ReadBuffer; RwLock m_WriteLock; std::deque<std::vector<uint8_t>> m_WriteQueue; @@ -74,4 +81,14 @@ private: std::atomic<bool> m_CloseSent{false}; }; +using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>; +#endif + +#if ZEN_USE_OPENSSL +using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>; +#endif + } // namespace zen::asio_http diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index 2134e4ff1..042afd8ff 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -485,7 +485,7 @@ TEST_CASE("websocket.integration") Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); - int Port = Server->Initialize(7575, TmpDir.Path()); + int Port = Server->Initialize(0, TmpDir.Path()); REQUIRE(Port != 0); Server->RegisterService(TestService); @@ -797,7 +797,7 @@ TEST_CASE("websocket.client") Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); - int Port = Server->Initialize(7576, TmpDir.Path()); + int Port = Server->Initialize(0, TmpDir.Path()); REQUIRE(Port != 0); Server->RegisterService(TestService); @@ -913,6 +913,75 @@ TEST_CASE("websocket.client") } } +TEST_CASE("websocket.client.unixsocket") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + std::string SocketPath = (TmpDir.Path() / "ws.sock").string(); + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close over unix socket") + { + TestWsClientHandler Handler; + HttpWsClientSettings Settings; + Settings.UnixSocketPath = SocketPath; + + HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello over unix socket"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello over unix socket"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } +} + TEST_SUITE_END(); void diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index e8f87b668..9b461662e 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -12,6 +12,11 @@ target('zenhttp') add_packages("http_parser", "json11") add_options("httpsys") + if is_plat("linux", "macosx") then + add_packages("openssl3") + end + if is_plat("linux") then add_syslinks("dl") -- TODO: is libdl needed? end + diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp index 8e16da1a9..c3f7b9e71 100644 --- a/src/zenremotestore/builds/jupiterbuildstorage.cpp +++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp @@ -23,9 +23,10 @@ using namespace std::literals; namespace { [[noreturn]] void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix) { - int Error = Result.ErrorCode < (int)HttpResponseCode::Continue ? Result.ErrorCode : 0; - HttpResponseCode Status = - Result.ErrorCode >= int(HttpResponseCode::Continue) ? HttpResponseCode(Result.ErrorCode) : HttpResponseCode::ImATeapot; + HttpClientErrorCode Error = Result.ErrorCode < static_cast<int>(HttpResponseCode::Continue) ? HttpClientErrorCode(Result.ErrorCode) + : HttpClientErrorCode::kOK; + HttpResponseCode Status = Result.ErrorCode >= static_cast<int>(HttpResponseCode::Continue) ? HttpResponseCode(Result.ErrorCode) + : HttpResponseCode::ImATeapot; throw HttpClientError(fmt::format("{}: {} ({})", Prefix, Result.Reason, Result.ErrorCode), Error, Status); } } // namespace diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp index 52f9eb678..b5531fa60 100644 --- a/src/zenremotestore/jupiter/jupitersession.cpp +++ b/src/zenremotestore/jupiter/jupitersession.cpp @@ -68,7 +68,7 @@ namespace detail { return {.SentBytes = gsl::narrow<uint64_t>(Response.UploadedBytes), .ReceivedBytes = gsl::narrow<uint64_t>(Response.DownloadedBytes), .ElapsedSeconds = Response.ElapsedSeconds, - .ErrorCode = Response.Error.value().ErrorCode, + .ErrorCode = static_cast<int32_t>(Response.Error.value().ErrorCode), .Reason = Response.ErrorMessage(ErrorPrefix), .Success = false}; } diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp index 2282a31dd..e95d9118c 100644 --- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp @@ -284,9 +284,7 @@ public: } catch (const HttpClientError& Ex) { - Result.ErrorCode = Ex.GetInternalErrorCode() != 0 ? Ex.GetInternalErrorCode() - : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? (int)Ex.GetHttpResponseCode() - : 0; + Result.ErrorCode = MakeErrorCode(Ex); Result.Reason = fmt::format("Failed finalizing oplog container build part to {}/{}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, @@ -315,9 +313,7 @@ public: } catch (const HttpClientError& Ex) { - Result.ErrorCode = Ex.GetInternalErrorCode() != 0 ? Ex.GetInternalErrorCode() - : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? (int)Ex.GetHttpResponseCode() - : 0; + Result.ErrorCode = MakeErrorCode(Ex); Result.Reason = fmt::format("Failed finalizing oplog container build to {}/{}/{}/{}. Reason: '{}'", m_BuildStorageHttp.GetBaseUri(), m_Namespace, @@ -591,8 +587,8 @@ public: private: static int MakeErrorCode(const HttpClientError& Ex) { - return Ex.GetInternalErrorCode() != 0 ? Ex.GetInternalErrorCode() - : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? (int)Ex.GetHttpResponseCode() + return Ex.GetInternalErrorCode() != HttpClientErrorCode::kOK ? static_cast<int>(Ex.GetInternalErrorCode()) + : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? static_cast<int>(Ex.GetHttpResponseCode()) : 0; } diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp index 115d6438d..a08a07fcd 100644 --- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp +++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp @@ -329,7 +329,7 @@ private: { if (Response.Error) { - return {.ErrorCode = Response.Error.value().ErrorCode, + return {.ErrorCode = static_cast<int32_t>(Response.Error.value().ErrorCode), .ElapsedSeconds = Response.ElapsedSeconds, .Reason = Response.ErrorMessage(""), .Text = Response.ToText()}; diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index 8d5400294..0b2bc50c0 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -24,6 +24,10 @@ # include <atomic> # include <filesystem> +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + # if ZEN_PLATFORM_WINDOWS # include <ppl.h> # include <process.h> @@ -109,6 +113,11 @@ main(int argc, char** argv) ServerClass = argv[++i]; } } + else if (std::string_view Arg(argv[i]); Arg.starts_with("--httpclient="sv)) + { + std::string_view Value = Arg.substr(13); + zen::SetDefaultHttpClientBackend(Value); + } else if (argv[i] == "--verbose"sv) { Verbose = true; @@ -341,6 +350,38 @@ TEST_CASE("http.package") CHECK_EQ(ResponsePackage, TestPackage); } +# if defined(ASIO_HAS_LOCAL_SOCKETS) +TEST_CASE("http.unixsocket") +{ + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + std::filesystem::path SocketDir = TestEnv.CreateNewTestDir(); + std::string SocketPath = (SocketDir / "zen.sock").string(); + + ZenServerInstance Instance(TestEnv); + Instance.SetDataDir(TestDir); + const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--http=asio --unix-socket {}", SocketPath)); + + // Connect via Unix socket (BaseUri still needed for Host header) + HttpClientSettings Settings; + Settings.UnixSocketPath = SocketPath; + HttpClient Http{fmt::format("http://localhost:{}", PortNumber), Settings, {}}; + + SUBCASE("GET over unix socket") + { + HttpClient::Response Res = Http.Get("/testing/hello"); + CHECK(Res.IsSuccess()); + } + + SUBCASE("POST echo over unix socket") + { + IoBuffer Body{IoBuffer::Wrap, "unix-test", 9}; + HttpClient::Response Res = Http.Post("/testing/echo", Body); + CHECK(Res.IsSuccess()); + CHECK(Res.ResponsePayload.GetView().EqualBytes(Body.GetView())); + } +} +# endif + TEST_SUITE_END(); # if 0 diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index c64f081b3..2ac3de599 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -674,7 +674,7 @@ ZenComputeServer::PostAnnounce() { ZEN_ERROR("failed to notify coordinator at '{}': HTTP error {} - {}", m_CoordinatorEndpoint, - Result.Error->ErrorCode, + static_cast<int>(Result.Error->ErrorCode), Result.Error->ErrorMessage); } else if (!IsHttpOk(Result.StatusCode)) diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index e36352dae..ef9c6b7b8 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -144,10 +144,15 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions ////// network + LuaOptions.AddOption("network.httpclientbackend"sv, ServerOptions.HttpClient.Backend, "httpclient"sv); LuaOptions.AddOption("network.httpserverclass"sv, ServerOptions.HttpConfig.ServerClass, "http"sv); LuaOptions.AddOption("network.httpserverthreads"sv, ServerOptions.HttpConfig.ThreadCount, "http-threads"sv); LuaOptions.AddOption("network.port"sv, ServerOptions.BasePort, "port"sv); LuaOptions.AddOption("network.forceloopback"sv, ServerOptions.HttpConfig.ForceLoopback, "http-forceloopback"sv); + LuaOptions.AddOption("network.unixsocket"sv, ServerOptions.HttpConfig.UnixSocketPath, "unix-socket"sv); + LuaOptions.AddOption("network.https.port"sv, ServerOptions.HttpConfig.HttpsPort, "https-port"sv); + LuaOptions.AddOption("network.https.certfile"sv, ServerOptions.HttpConfig.CertFile, "cert-file"sv); + LuaOptions.AddOption("network.https.keyfile"sv, ServerOptions.HttpConfig.KeyFile, "key-file"sv); #if ZEN_WITH_HTTPSYS LuaOptions.AddOption("network.httpsys.async.workthreads"sv, @@ -159,6 +164,10 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions LuaOptions.AddOption("network.httpsys.requestlogging"sv, ServerOptions.HttpConfig.HttpSys.IsRequestLoggingEnabled, "httpsys-enable-request-logging"sv); + LuaOptions.AddOption("network.httpsys.httpsport"sv, ServerOptions.HttpConfig.HttpSys.HttpsPort, "httpsys-https-port"sv); + LuaOptions.AddOption("network.httpsys.certthumbprint"sv, ServerOptions.HttpConfig.HttpSys.CertThumbprint, "httpsys-cert-thumbprint"sv); + LuaOptions.AddOption("network.httpsys.certstorename"sv, ServerOptions.HttpConfig.HttpSys.CertStoreName, "httpsys-cert-store"sv); + LuaOptions.AddOption("network.httpsys.httpsonly"sv, ServerOptions.HttpConfig.HttpSys.HttpsOnly, "httpsys-https-only"sv); #endif #if ZEN_WITH_TRACE @@ -304,6 +313,34 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi options.add_option("network", "", + "unix-socket", + "Unix domain socket path to listen on (in addition to TCP)", + cxxopts::value<std::string>(ServerOptions.HttpConfig.UnixSocketPath), + "<path>"); + + options.add_option("network", + "", + "https-port", + "HTTPS listen port (0 = disabled)", + cxxopts::value<int>(ServerOptions.HttpConfig.HttpsPort)->default_value("0"), + "<port>"); + + options.add_option("network", + "", + "cert-file", + "Path to PEM certificate chain file for HTTPS", + cxxopts::value<std::string>(ServerOptions.HttpConfig.CertFile), + "<path>"); + + options.add_option("network", + "", + "key-file", + "Path to PEM private key file for HTTPS", + cxxopts::value<std::string>(ServerOptions.HttpConfig.KeyFile), + "<path>"); + + options.add_option("network", + "", "security-config-path", "Path to http security configuration file", cxxopts::value<std::string>(SecurityConfigPath), @@ -330,10 +367,45 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi "Enables Httpsys request logging", cxxopts::value<bool>(ServerOptions.HttpConfig.HttpSys.IsRequestLoggingEnabled), "<httpsys request logging>"); + + options.add_option("httpsys", + "", + "httpsys-https-port", + "HTTPS listen port for http.sys (0 = disabled)", + cxxopts::value<int>(ServerOptions.HttpConfig.HttpSys.HttpsPort)->default_value("0"), + "<port>"); + + options.add_option("httpsys", + "", + "httpsys-cert-thumbprint", + "SHA-1 certificate thumbprint for auto SSL binding", + cxxopts::value<std::string>(ServerOptions.HttpConfig.HttpSys.CertThumbprint), + "<thumbprint>"); + + options.add_option("httpsys", + "", + "httpsys-cert-store", + "Windows certificate store name for SSL binding", + cxxopts::value<std::string>(ServerOptions.HttpConfig.HttpSys.CertStoreName)->default_value("MY"), + "<store name>"); + + options.add_option("httpsys", + "", + "httpsys-https-only", + "Disable HTTP listener when HTTPS is active", + cxxopts::value<bool>(ServerOptions.HttpConfig.HttpSys.HttpsOnly)->default_value("false"), + ""); #endif options.add_option("network", "", + "httpclient", + "Select HTTP client implementation (e.g. 'curl', 'cpr')", + cxxopts::value<std::string>(ServerOptions.HttpClient.Backend)->default_value("cpr"), + "<http client>"); + + options.add_option("network", + "", "http", "Select HTTP server implementation (asio|" #if ZEN_WITH_HTTPSYS @@ -397,6 +469,56 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); LoggingOptions.ApplyOptions(ServerOptions.LoggingConfig); + +#if ZEN_WITH_HTTPSYS + // Validate HTTPS options + const auto& HttpSys = ServerOptions.HttpConfig.HttpSys; + if (HttpSys.HttpsOnly && HttpSys.HttpsPort == 0) + { + throw OptionParseException("'--httpsys-https-only' requires '--httpsys-https-port' to be set", options.help()); + } + if (!HttpSys.CertThumbprint.empty() && HttpSys.CertThumbprint.size() != 40) + { + throw OptionParseException("'--httpsys-cert-thumbprint' must be exactly 40 hex characters (SHA-1)", options.help()); + } + if (!HttpSys.CertThumbprint.empty()) + { + for (char Ch : HttpSys.CertThumbprint) + { + if (!((Ch >= '0' && Ch <= '9') || (Ch >= 'a' && Ch <= 'f') || (Ch >= 'A' && Ch <= 'F'))) + { + throw OptionParseException("'--httpsys-cert-thumbprint' contains non-hex characters", options.help()); + } + } + } + if (HttpSys.HttpsPort > 0 && HttpSys.HttpsPort == ServerOptions.BasePort && !HttpSys.HttpsOnly) + { + throw OptionParseException("'--httpsys-https-port' must differ from '--port' when both HTTP and HTTPS are active", options.help()); + } +#endif + + // Validate generic HTTPS options (used by ASIO backend) + if (ServerOptions.HttpConfig.HttpsPort > 0) + { + if (ServerOptions.HttpConfig.CertFile.empty() || ServerOptions.HttpConfig.KeyFile.empty()) + { + throw OptionParseException("'--https-port' requires both '--cert-file' and '--key-file' to be set", options.help()); + } + if (!std::filesystem::exists(ServerOptions.HttpConfig.CertFile)) + { + throw OptionParseException(fmt::format("'--cert-file' path '{}' does not exist", ServerOptions.HttpConfig.CertFile), + options.help()); + } + if (!std::filesystem::exists(ServerOptions.HttpConfig.KeyFile)) + { + throw OptionParseException(fmt::format("'--key-file' path '{}' does not exist", ServerOptions.HttpConfig.KeyFile), + options.help()); + } + if (ServerOptions.HttpConfig.HttpsPort == ServerOptions.BasePort) + { + throw OptionParseException("'--https-port' must differ from '--port'", options.help()); + } + } } ////////////////////////////////////////////////////////////////////////// diff --git a/src/zenserver/config/config.h b/src/zenserver/config/config.h index 55aee07f9..88226f810 100644 --- a/src/zenserver/config/config.h +++ b/src/zenserver/config/config.h @@ -38,8 +38,14 @@ struct ZenSentryConfig bool Debug = false; // Enable debug mode for Sentry }; +struct HttpClientConfig +{ + std::string Backend = "cpr"; // Choice of HTTP client implementation (e.g. "curl", "cpr") +}; + struct ZenServerConfig { + HttpClientConfig HttpClient; HttpServerConfig HttpConfig; ZenSentryConfig SentryConfig; ZenStatsConfig StatsConfig; diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp index 05be3c814..6cf12bea4 100644 --- a/src/zenserver/sessions/httpsessions.cpp +++ b/src/zenserver/sessions/httpsessions.cpp @@ -258,6 +258,10 @@ HttpSessionsService::SessionRequest(HttpRouterRequest& Req) } return ServerRequest.WriteResponse(HttpResponseCode::NotFound); } + default: + { + return ServerRequest.WriteResponse(HttpResponseCode::MethodNotAllowed); + } } } diff --git a/src/zenserver/sessions/sessions.cpp b/src/zenserver/sessions/sessions.cpp index f73aa40ff..d919db6e9 100644 --- a/src/zenserver/sessions/sessions.cpp +++ b/src/zenserver/sessions/sessions.cpp @@ -64,6 +64,8 @@ SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, cons return false; } + ZEN_INFO("Session {} registered (AppName: {}, JobId: {})", SessionId, AppName, JobId); + const DateTime Now = DateTime::Now(); m_Sessions.emplace(SessionId, Ref(new Session(SessionInfo{.Id = SessionId, @@ -72,8 +74,6 @@ SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, cons .Metadata = CbObject::Clone(Metadata), .CreatedAt = Now, .UpdatedAt = Now}))); - - ZEN_INFO("Session {} registered (AppName: {}, JobId: {})", SessionId, AppName, JobId); return true; } diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index 06b8f6c27..bbdb03ba4 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -1892,8 +1892,6 @@ HttpStructuredCacheService::CollectStats() { Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount; - Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); - Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); } Cbo << "cidhits" << ChunkHitCount << "cidmisses" << ChunkMissCount << "cidwrites" << ChunkWriteCount; @@ -2025,8 +2023,6 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) { Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount; - Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); - Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); } Cbo << "cidhits" << ChunkHitCount << "cidmisses" << ChunkMissCount << "cidwrites" << ChunkWriteCount; diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 88b85d7d9..49ae1b6ff 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -23,6 +23,7 @@ #include <zencore/timer.h> #include <zencore/trace.h> #include <zencore/workthreadpool.h> +#include <zenhttp/httpclient.h> #include <zenhttp/httpserver.h> #include <zenhttp/security/passwordsecurityfilter.h> #include <zentelemetry/otlptrace.h> @@ -146,6 +147,14 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: EnqueueSigIntTimer(); + // Configure HTTP client back-end + + const std::string HttpClientBackend = ToLower(ServerOptions.HttpClient.Backend); + zen::SetDefaultHttpClientBackend(HttpClientBackend); + ZEN_INFO("Using '{}' as HTTP client backend", HttpClientBackend); + + // Initialize HTTP server + m_Http = CreateHttpServer(ServerOptions.HttpConfig); int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort, ServerOptions.DataDir); if (EffectiveBasePort == 0) |