aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-03-10 17:27:26 +0100
committerGitHub Enterprise <[email protected]>2026-03-10 17:27:26 +0100
commitd0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7 (patch)
tree2dfe1e3e0b620043d358e0b7f8bdf8320d985491 /src
parentchangelog entry which was inadvertently omitted from PR merge (diff)
downloadzen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.tar.xz
zen-d0a07e555577dcd4a8f55f1b45d9e8e4e6366ab7.zip
HttpClient using libcurl, Unix Sockets for HTTP. HTTPS support (#770)
The main goal of this change is to eliminate the cpr back-end altogether and replace it with the curl implementation. I would expect to drop cpr as soon as we feel happy with the libcurl back-end. That would leave us with a direct dependency on libcurl only, and cpr can be eliminated as a dependency. ### HttpClient Backend Overhaul - Implemented a new **libcurl-based HttpClient** backend (`httpclientcurl.cpp`, ~2000 lines) as an alternative to the cpr-based one - Made HttpClient backend **configurable at runtime** via constructor arguments and `-httpclient=...` CLI option (for zen, zenserver, and tests) - Extended HttpClient test suite to cover multipart/content-range scenarios ### Unix Domain Socket Support - Added Unix domain socket support to **httpasio** (server side) - Added Unix domain socket support to **HttpClient** - Added Unix domain socket support to **HttpWsClient** (WebSocket client) - Templatized `HttpServerConnectionT<SocketType>` and `WsAsioConnectionT<SocketType>` to handle TCP, Unix, and SSL sockets uniformly via `if constexpr` dispatch ### HTTPS Support - Added **preliminary HTTPS support to httpasio** (for Mac/Linux via OpenSSL) - Added **basic HTTPS support for http.sys** (Windows) - Implemented HTTPS test for httpasio - Split `InitializeServer` into smaller sub-functions for http.sys ### Other Notable Changes - Improved **zenhttp-test stability** with dynamic port allocation - Enhanced port retry logic in http.sys (handles ERROR_ACCESS_DENIED) - Fatal signal/exception handlers for backtrace generation in tests - Added `zen bench http` subcommand to exercise network + HTTP client/server communication stack
Diffstat (limited to 'src')
-rw-r--r--src/zen/bench.cpp47
-rw-r--r--src/zen/cmds/bench_cmd.cpp516
-rw-r--r--src/zen/cmds/bench_cmd.h41
-rw-r--r--src/zen/xmake.lua2
-rw-r--r--src/zen/zen.cpp75
-rw-r--r--src/zen/zen.h7
-rw-r--r--src/zenbase/include/zenbase/concepts.h2
-rw-r--r--src/zencore/sentryintegration.cpp11
-rw-r--r--src/zencore/testing.cpp115
-rw-r--r--src/zenhttp-test/zenhttp-test.cpp12
-rw-r--r--src/zenhttp/clients/httpclientcommon.cpp57
-rw-r--r--src/zenhttp/clients/httpclientcpr.cpp135
-rw-r--r--src/zenhttp/clients/httpclientcurl.cpp1947
-rw-r--r--src/zenhttp/clients/httpclientcurl.h135
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp213
-rw-r--r--src/zenhttp/httpclient.cpp121
-rw-r--r--src/zenhttp/httpclient_test.cpp299
-rw-r--r--src/zenhttp/httpserver.cpp23
-rw-r--r--src/zenhttp/include/zenhttp/formatters.h2
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h74
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h15
-rw-r--r--src/zenhttp/include/zenhttp/httpwsclient.h4
-rw-r--r--src/zenhttp/servers/asio_socket_traits.h54
-rw-r--r--src/zenhttp/servers/httpasio.cpp687
-rw-r--r--src/zenhttp/servers/httpasio.h6
-rw-r--r--src/zenhttp/servers/httpsys.cpp409
-rw-r--r--src/zenhttp/servers/httpsys.h4
-rw-r--r--src/zenhttp/servers/wsasio.cpp64
-rw-r--r--src/zenhttp/servers/wsasio.h43
-rw-r--r--src/zenhttp/servers/wstest.cpp73
-rw-r--r--src/zenhttp/xmake.lua5
-rw-r--r--src/zenremotestore/builds/jupiterbuildstorage.cpp7
-rw-r--r--src/zenremotestore/jupiter/jupitersession.cpp2
-rw-r--r--src/zenremotestore/projectstore/buildsremoteprojectstore.cpp12
-rw-r--r--src/zenremotestore/projectstore/zenremoteprojectstore.cpp2
-rw-r--r--src/zenserver-test/zenserver-test.cpp41
-rw-r--r--src/zenserver/compute/computeserver.cpp2
-rw-r--r--src/zenserver/config/config.cpp122
-rw-r--r--src/zenserver/config/config.h6
-rw-r--r--src/zenserver/sessions/httpsessions.cpp4
-rw-r--r--src/zenserver/sessions/sessions.cpp4
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.cpp4
-rw-r--r--src/zenserver/zenserver.cpp9
43 files changed, 4881 insertions, 532 deletions
diff --git a/src/zen/bench.cpp b/src/zen/bench.cpp
index 614454ed5..2332ce1b8 100644
--- a/src/zen/bench.cpp
+++ b/src/zen/bench.cpp
@@ -119,6 +119,53 @@ EmptyStandByList()
} // namespace zen::bench::util
+#elif ZEN_PLATFORM_LINUX
+
+# include <fcntl.h>
+# include <unistd.h>
+
+namespace zen::bench::util {
+
+void
+EmptyStandByList()
+{
+ sync();
+
+ int Fd = open("/proc/sys/vm/drop_caches", O_WRONLY);
+ if (Fd < 0)
+ {
+ throw std::runtime_error("Failed to open /proc/sys/vm/drop_caches (are you running as root?)");
+ }
+
+ if (write(Fd, "3", 1) != 1)
+ {
+ close(Fd);
+ throw std::runtime_error("Failed to write to /proc/sys/vm/drop_caches");
+ }
+
+ close(Fd);
+}
+
+} // namespace zen::bench::util
+
+#elif ZEN_PLATFORM_MAC
+
+# include <cstdlib>
+
+namespace zen::bench::util {
+
+void
+EmptyStandByList()
+{
+ int Result = system("/usr/sbin/purge");
+ if (Result != 0)
+ {
+ throw std::runtime_error("Failed to run /usr/sbin/purge (are you running as root?)");
+ }
+}
+
+} // namespace zen::bench::util
+
#else
namespace zen::bench::util {
diff --git a/src/zen/cmds/bench_cmd.cpp b/src/zen/cmds/bench_cmd.cpp
index b9c45a328..658b42da6 100644
--- a/src/zen/cmds/bench_cmd.cpp
+++ b/src/zen/cmds/bench_cmd.cpp
@@ -3,6 +3,7 @@
#include "bench_cmd.h"
#include "bench.h"
+#include <zencore/compactbinary.h>
#include <zencore/except.h>
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
@@ -11,93 +12,514 @@
#include <zencore/string.h>
#include <zencore/thread.h>
#include <zencore/timer.h>
+#include <zenhttp/httpclient.h>
+#include <zentelemetry/stats.h>
+
+#include <algorithm>
+#include <atomic>
+#include <csignal>
+#include <mutex>
+#include <thread>
+
+static std::atomic<bool> s_BenchAbort{false};
namespace zen {
-BenchCommand::BenchCommand()
+//////////////////////////////////////////////////////////////////////////
+// BenchPurgeSubCmd
+
+BenchPurgeSubCmd::BenchPurgeSubCmd()
+: ZenSubCmdBase("purge", "Purge standby memory (system cache)")
{
- m_Options.add_options()("h,help", "Print help");
- m_Options.add_options()("purge",
- "Purge standby memory (system cache)",
- cxxopts::value<bool>(m_PurgeStandbyLists)->default_value("false"));
- m_Options.add_options()("single", "Do not spawn child processes", cxxopts::value<bool>(m_SingleProcess)->default_value("false"));
+ SubOptions().add_options()("single",
+ "Do not spawn child processes",
+ cxxopts::value<bool>(m_SingleProcess)->default_value("false"));
}
-BenchCommand::~BenchCommand() = default;
-
void
-BenchCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+BenchPurgeSubCmd::Run(const ZenCliOptions& GlobalOptions)
{
ZEN_UNUSED(GlobalOptions);
- if (!ParseOptions(argc, argv))
+ bool Ok = false;
+
+ zen::Stopwatch Timer;
+
+ try
+ {
+ zen::bench::util::EmptyStandByList();
+
+ Ok = true;
+ }
+ catch (const zen::bench::util::elevation_required_exception&)
+ {
+ ZEN_CONSOLE_WARN("Purging standby lists requires elevation. Will try launch as elevated process");
+ }
+ catch (const std::exception& Ex)
{
- return;
+ ZEN_CONSOLE_ERROR("{}", Ex.what());
}
#if ZEN_PLATFORM_WINDOWS
- if (m_PurgeStandbyLists)
+ if (!Ok && !m_SingleProcess)
{
- bool Ok = false;
-
- zen::Stopwatch Timer;
-
try
{
- zen::bench::util::EmptyStandByList();
+ zen::CreateProcOptions Cpo;
+ Cpo.Flags = zen::CreateProcOptions::Flag_Elevated | zen::CreateProcOptions::Flag_NewConsole;
- Ok = true;
- }
- catch (const zen::bench::util::elevation_required_exception&)
- {
- ZEN_CONSOLE_WARN("Purging standby lists requires elevation. Will try launch as elevated process");
+ std::filesystem::path CurExe{zen::GetRunningExecutablePath()};
+
+ if (zen::CreateProcResult Cpr = zen::CreateProc(CurExe, fmt::format("bench purge --single"), Cpo))
+ {
+ zen::ProcessHandle ProcHandle;
+ ProcHandle.Initialize(Cpr);
+
+ int ExitCode = ProcHandle.WaitExitCode();
+
+ if (ExitCode == 0)
+ {
+ Ok = true;
+ }
+ else
+ {
+ ZEN_CONSOLE_ERROR("Elevated child process failed with return code {}", ExitCode);
+ }
+ }
}
catch (const std::exception& Ex)
{
ZEN_CONSOLE_ERROR("{}", Ex.what());
}
+ }
+#endif
+
+ if (Ok)
+ {
+ // TODO: could also add reporting on just how much memory was purged
+ ZEN_CONSOLE("Purged standby lists! (took {})", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// BenchHttpSubCmd
+
+
+BenchHttpSubCmd::BenchHttpSubCmd()
+: ZenSubCmdBase("http", "Benchmark an HTTP server")
+{
+ SubOptions().add_option("", "u", "url", "URL to benchmark", cxxopts::value<std::string>(m_Url), "<url>");
+ SubOptions().add_option("",
+ "n",
+ "count",
+ "Number of requests to send",
+ cxxopts::value<int>(m_Count)->default_value("100"),
+ "<count>");
+ SubOptions().add_option("",
+ "c",
+ "concurrency",
+ "Number of concurrent threads",
+ cxxopts::value<int>(m_Concurrency)->default_value("1"),
+ "<threads>");
+ SubOptions().add_option("",
+ "",
+ "method",
+ "HTTP method to use (GET, HEAD)",
+ cxxopts::value<std::string>(m_Method)->default_value("GET"),
+ "<method>");
+ SubOptions().add_option("",
+ "",
+ "unix-socket",
+ "Unix domain socket path (overrides TCP)",
+ cxxopts::value<std::string>(m_SocketPath),
+ "<path>");
+ SubOptions().add_options()("no-keepalive",
+ "Close connection after each request (disables keep-alive)",
+ cxxopts::value<bool>(m_NoKeepAlive)->default_value("false"));
+ SubOptions().add_options()("continuous",
+ "Run until interrupted (Ctrl+C), printing metrics once per second",
+ cxxopts::value<bool>(m_Continuous)->default_value("false"));
+ SubOptions().parse_positional({"url"});
+}
+
+static std::pair<std::string, std::string>
+SplitUrl(std::string_view Url)
+{
+ size_t SchemeEnd = Url.find("://");
+ size_t SearchFrom = (SchemeEnd != std::string_view::npos) ? SchemeEnd + 3 : 0;
+ size_t PathStart = Url.find('/', SearchFrom);
+
+ if (PathStart == std::string_view::npos)
+ {
+ return {std::string(Url), "/"};
+ }
+
+ return {std::string(Url.substr(0, PathStart)), std::string(Url.substr(PathStart))};
+}
+
+void
+BenchHttpSubCmd::Run(const ZenCliOptions& GlobalOptions)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (m_Url.empty())
+ {
+ throw OptionParseException("URL is required", SubOptions().help());
+ }
+
+ if (!m_Continuous && m_Count <= 0)
+ {
+ throw OptionParseException("--count must be a positive integer", SubOptions().help());
+ }
+
+ if (m_Concurrency <= 0)
+ {
+ throw OptionParseException("--concurrency must be a positive integer", SubOptions().help());
+ }
+
+ if (m_Method != "GET" && m_Method != "HEAD")
+ {
+ throw OptionParseException(fmt::format("Unsupported HTTP method '{}'. Supported: GET, HEAD", m_Method),
+ SubOptions().help());
+ }
+
+ auto [BaseUri, Path] = SplitUrl(m_Url);
+
+ std::string ModeStr = m_Continuous ? "continuous" : fmt::format("count={}", m_Count);
+
+ if (m_SocketPath.empty())
+ {
+ ZEN_CONSOLE("Benchmarking {} {} ({}, concurrency={})", m_Method, m_Url, ModeStr, m_Concurrency);
+ }
+ else
+ {
+ ZEN_CONSOLE("Benchmarking {} {} via {} ({}, concurrency={})", m_Method, m_Url, m_SocketPath, ModeStr, m_Concurrency);
+ }
+
+ // Probe for a zenserver identity. If the target exposes /health/info and the
+ // response contains a BuildVersion field we print a short summary. Any failure
+ // (non-zenserver, timeout, unreachable) is silently ignored.
+ try
+ {
+ HttpClientSettings ProbeSettings{.ConnectTimeout = std::chrono::milliseconds(2000),
+ .Timeout = std::chrono::milliseconds(2000),
+ .UnixSocketPath = m_SocketPath};
+ HttpClient ProbeHttp(BaseUri, ProbeSettings);
+ HttpClient::Response ProbeResp = ProbeHttp.Get("/health/info");
+
+ if (ProbeResp.IsSuccess())
+ {
+ CbObject Info = ProbeResp.AsObject();
+ std::string_view BuildVersion = Info["BuildVersion"].AsString();
+
+ if (!BuildVersion.empty())
+ {
+ std::string_view Hostname = Info["Hostname"].AsString();
+ int64_t Pid = Info["Pid"].AsInt64();
+ std::string_view HttpServerClass = Info["HttpServerClass"].AsString();
+ ZEN_CONSOLE("Remote : zenserver {} on {} (pid {}, {})", BuildVersion, Hostname, Pid, HttpServerClass);
+
+ std::string_view OS = Info["OS"].AsString();
+ std::string_view Arch = Info["Arch"].AsString();
+
+ CbObjectView System = Info["System"].AsObjectView();
+ int64_t LpCount = System["lp_count"].AsInt64();
+ int64_t TotalMemMiB = System["total_memory_mb"].AsInt64();
+
+ ZEN_CONSOLE(" : {}, {}, {} logical processors, {} RAM",
+ OS,
+ Arch,
+ LpCount,
+ NiceBytes(static_cast<uint64_t>(TotalMemMiB) * 1024 * 1024));
+ }
+ }
+ }
+ catch (...)
+ {
+ }
+
+ if (m_Continuous)
+ {
+ RunContinuous(BaseUri, Path);
+ }
+ else
+ {
+ RunFixedCount(BaseUri, Path);
+ }
+}
+
+void
+BenchHttpSubCmd::RunFixedCount(const std::string& BaseUri, const std::string& Path)
+{
+ std::atomic<int> NextRequest{0};
+ std::vector<double> AllLatencies;
+ AllLatencies.reserve(m_Count);
+ std::mutex LatencyMutex;
+ std::atomic<int> ErrorCount{0};
+ std::atomic<int64_t> TotalDownloadedBytes{0};
+ std::atomic<int64_t> TotalUploadedBytes{0};
+
+ Stopwatch Timer;
- if (!Ok && !m_SingleProcess)
+ auto WorkerFn = [&]() {
+ std::vector<double> LocalLatencies;
+
+ HttpClientSettings Settings{.UnixSocketPath = m_SocketPath,
+ .ForbidReuseConnection = m_NoKeepAlive};
+ HttpClient Http(BaseUri, Settings);
+
+ while (true)
{
+ int RequestIndex = NextRequest.fetch_add(1);
+
+ if (RequestIndex >= m_Count)
+ {
+ break;
+ }
+
try
{
- zen::CreateProcOptions Cpo;
- Cpo.Flags = zen::CreateProcOptions::Flag_Elevated | zen::CreateProcOptions::Flag_NewConsole;
+ HttpClient::Response Resp = (m_Method == "HEAD") ? Http.Head(Path) : Http.Get(Path);
+
+ if (Resp.IsSuccess())
+ {
+ LocalLatencies.push_back(Resp.ElapsedSeconds);
+ TotalDownloadedBytes.fetch_add(Resp.DownloadedBytes);
+ TotalUploadedBytes.fetch_add(Resp.UploadedBytes);
+ }
+ else
+ {
+ ErrorCount.fetch_add(1);
+ }
+ }
+ catch (const HttpClientError&)
+ {
+ ErrorCount.fetch_add(1);
+ }
+ }
+
+ std::lock_guard Lock(LatencyMutex);
+ AllLatencies.insert(AllLatencies.end(), LocalLatencies.begin(), LocalLatencies.end());
+ };
+
+ std::vector<std::thread> Threads;
+ Threads.reserve(m_Concurrency);
+
+ for (int i = 0; i < m_Concurrency; ++i)
+ {
+ Threads.emplace_back(WorkerFn);
+ }
+
+ for (std::thread& T : Threads)
+ {
+ T.join();
+ }
+
+ double TotalSeconds = Timer.GetElapsedTimeMs() / 1000.0;
+ int SuccessCount = static_cast<int>(AllLatencies.size());
+ int TotalCount = SuccessCount + ErrorCount.load();
+
+ std::sort(AllLatencies.begin(), AllLatencies.end());
+
+ auto PercentileMs = [&](int Pct) -> double {
+ if (AllLatencies.empty())
+ {
+ return 0.0;
+ }
+
+ size_t Index = std::min(AllLatencies.size() * static_cast<size_t>(Pct) / 100, AllLatencies.size() - 1);
+
+ return AllLatencies[Index] * 1000.0;
+ };
+
+ double SumMs = 0.0;
+ for (double L : AllLatencies)
+ {
+ SumMs += L * 1000.0;
+ }
+
+ double MeanMs = SuccessCount > 0 ? SumMs / SuccessCount : 0.0;
+ double Rps = TotalSeconds > 0.0 ? TotalCount / TotalSeconds : 0.0;
+
+ uint64_t DownBytesPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalDownloadedBytes.load() / TotalSeconds) : 0;
+ uint64_t UpBytesPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalUploadedBytes.load() / TotalSeconds) : 0;
+
+ ZEN_CONSOLE(" Requests : {:L} total, {:L} success, {:L} errors", TotalCount, SuccessCount, ErrorCount.load());
+ ZEN_CONSOLE(" Latency : min={:.1f}ms mean={:.1f}ms p50={:.1f}ms p95={:.1f}ms p99={:.1f}ms max={:.1f}ms",
+ PercentileMs(0),
+ MeanMs,
+ PercentileMs(50),
+ PercentileMs(95),
+ PercentileMs(99),
+ PercentileMs(100));
+ ZEN_CONSOLE(" Throughput: {:.1f} req/s down: {}/s up: {}/s (elapsed: {:.2f}s)",
+ Rps,
+ NiceBytes(DownBytesPerSec),
+ NiceBytes(UpBytesPerSec),
+ TotalSeconds);
+}
- std::filesystem::path CurExe{zen::GetRunningExecutablePath()};
+void
+BenchHttpSubCmd::RunContinuous(const std::string& BaseUri, const std::string& Path)
+{
+ s_BenchAbort.store(false);
+
+ auto PrevSigInt = std::signal(SIGINT, [](int) { s_BenchAbort.store(true); });
+ auto PrevSigTerm = std::signal(SIGTERM, [](int) { s_BenchAbort.store(true); });
+
+ metrics::Histogram LatencyHistogram;
+ std::atomic<int64_t> IntervalSuccessCount{0};
+ std::atomic<int64_t> IntervalErrorCount{0};
+ std::atomic<int64_t> IntervalDownloadBytes{0};
+ std::atomic<int64_t> IntervalUploadBytes{0};
+ std::atomic<int64_t> TotalSuccessCount{0};
+ std::atomic<int64_t> TotalErrorCount{0};
+ std::atomic<int64_t> TotalDownloadBytes{0};
+ std::atomic<int64_t> TotalUploadBytes{0};
+
+ Stopwatch RunTimer;
+
+ auto WorkerFn = [&]() {
+ HttpClientSettings Settings{.UnixSocketPath = m_SocketPath,
+ .ForbidReuseConnection = m_NoKeepAlive};
+ HttpClient Http(BaseUri, Settings);
- if (zen::CreateProcResult Cpr = zen::CreateProc(CurExe, fmt::format("bench --purge --single"), Cpo))
+ while (!s_BenchAbort.load(std::memory_order_relaxed))
+ {
+ try
+ {
+ HttpClient::Response Resp = (m_Method == "HEAD") ? Http.Head(Path) : Http.Get(Path);
+
+ if (Resp.IsSuccess())
+ {
+ LatencyHistogram.Update(static_cast<int64_t>(Resp.ElapsedSeconds * 1.0e6));
+ IntervalSuccessCount.fetch_add(1, std::memory_order_relaxed);
+ IntervalDownloadBytes.fetch_add(Resp.DownloadedBytes, std::memory_order_relaxed);
+ IntervalUploadBytes.fetch_add(Resp.UploadedBytes, std::memory_order_relaxed);
+ TotalSuccessCount.fetch_add(1, std::memory_order_relaxed);
+ TotalDownloadBytes.fetch_add(Resp.DownloadedBytes, std::memory_order_relaxed);
+ TotalUploadBytes.fetch_add(Resp.UploadedBytes, std::memory_order_relaxed);
+ }
+ else
{
- zen::ProcessHandle ProcHandle;
- ProcHandle.Initialize(Cpr);
-
- int ExitCode = ProcHandle.WaitExitCode();
-
- if (ExitCode == 0)
- {
- Ok = true;
- }
- else
- {
- ZEN_CONSOLE_ERROR("Elevated child process failed with return code {}", ExitCode);
- }
+ IntervalErrorCount.fetch_add(1, std::memory_order_relaxed);
+ TotalErrorCount.fetch_add(1, std::memory_order_relaxed);
}
}
- catch (const std::exception& Ex)
+ catch (const HttpClientError&)
{
- ZEN_CONSOLE_ERROR("{}", Ex.what());
+ IntervalErrorCount.fetch_add(1, std::memory_order_relaxed);
+ TotalErrorCount.fetch_add(1, std::memory_order_relaxed);
}
}
+ };
- if (Ok)
+ auto ReporterFn = [&]() {
+ while (!s_BenchAbort.load(std::memory_order_relaxed))
{
- // TODO: could also add reporting on just how much memory was purged
- ZEN_CONSOLE("Purged standby lists! (took {})", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ // Sleep 1s in short increments to stay responsive to abort
+ for (int i = 0; i < 10 && !s_BenchAbort.load(std::memory_order_relaxed); ++i)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ }
+
+ if (s_BenchAbort.load(std::memory_order_relaxed))
+ {
+ break;
+ }
+
+ // Snapshot and reset per-interval counters
+ int64_t Successes = IntervalSuccessCount.exchange(0);
+ int64_t Errors = IntervalErrorCount.exchange(0);
+ int64_t DownBytes = IntervalDownloadBytes.exchange(0);
+ int64_t UpBytes = IntervalUploadBytes.exchange(0);
+
+ // Snapshot and reset latency histogram
+ uint64_t HistCount = LatencyHistogram.Count();
+ int64_t HistMin = LatencyHistogram.Min();
+ int64_t HistMax = LatencyHistogram.Max();
+ double HistMean = LatencyHistogram.Mean();
+ metrics::SampleSnapshot Snap = LatencyHistogram.Snapshot();
+ LatencyHistogram.Clear();
+
+ // Format elapsed as HH:MM:SS
+ int TotalSec = static_cast<int>(RunTimer.GetElapsedTimeMs() / 1000.0);
+ int Hours = TotalSec / 3600;
+ int Minutes = (TotalSec % 3600) / 60;
+ int Secs = TotalSec % 60;
+
+ if (HistCount > 0)
+ {
+ ZEN_CONSOLE(
+ "[{:02d}:{:02d}:{:02d}] req/s: {:L} errors: {:L} lat(ms): min={:.1f} mean={:.1f} p95={:.1f} p99={:.1f} max={:.1f} down: {}/s up: {}/s",
+ Hours,
+ Minutes,
+ Secs,
+ Successes,
+ Errors,
+ HistMin / 1000.0,
+ HistMean / 1000.0,
+ Snap.Get95Percentile() / 1000.0,
+ Snap.Get99Percentile() / 1000.0,
+ HistMax / 1000.0,
+ NiceBytes(static_cast<uint64_t>(std::max(int64_t{0}, DownBytes))),
+ NiceBytes(static_cast<uint64_t>(std::max(int64_t{0}, UpBytes))));
+ }
+ else
+ {
+ ZEN_CONSOLE("[{:02d}:{:02d}:{:02d}] req/s: 0 errors: {:L} (no successful requests)",
+ Hours,
+ Minutes,
+ Secs,
+ Errors);
+ }
}
+ };
+
+ std::vector<std::thread> Threads;
+ Threads.reserve(m_Concurrency + 1);
+ Threads.emplace_back(ReporterFn);
+
+ for (int i = 0; i < m_Concurrency; ++i)
+ {
+ Threads.emplace_back(WorkerFn);
}
-#endif
- return;
+ for (std::thread& T : Threads)
+ {
+ T.join();
+ }
+
+ std::signal(SIGINT, PrevSigInt);
+ std::signal(SIGTERM, PrevSigTerm);
+
+ double TotalSeconds = RunTimer.GetElapsedTimeMs() / 1000.0;
+ int64_t TotalCount = TotalSuccessCount.load() + TotalErrorCount.load();
+ uint64_t DownPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalDownloadBytes.load() / TotalSeconds) : 0;
+ uint64_t UpPerSec = TotalSeconds > 0.0 ? static_cast<uint64_t>(TotalUploadBytes.load() / TotalSeconds) : 0;
+
+ ZEN_CONSOLE("Stopped. Total: {:L} requests, {:L} success, {:L} errors avg throughput: down {}/s up {}/s (elapsed: {:.2f}s)",
+ TotalCount,
+ TotalSuccessCount.load(),
+ TotalErrorCount.load(),
+ NiceBytes(DownPerSec),
+ NiceBytes(UpPerSec),
+ TotalSeconds);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// BenchCommand
+
+BenchCommand::BenchCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), "");
+ m_Options.parse_positional({"subcommand"});
+
+ AddSubCommand(m_PurgeSubCmd);
+ AddSubCommand(m_HttpSubCmd);
}
+BenchCommand::~BenchCommand() = default;
+
} // namespace zen
diff --git a/src/zen/cmds/bench_cmd.h b/src/zen/cmds/bench_cmd.h
index 7fbf85340..f332b3fcc 100644
--- a/src/zen/cmds/bench_cmd.h
+++ b/src/zen/cmds/bench_cmd.h
@@ -6,7 +6,36 @@
namespace zen {
-class BenchCommand : public ZenCmdBase
+class BenchPurgeSubCmd : public ZenSubCmdBase
+{
+public:
+ BenchPurgeSubCmd();
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ bool m_SingleProcess = false;
+};
+
+class BenchHttpSubCmd : public ZenSubCmdBase
+{
+public:
+ BenchHttpSubCmd();
+ void Run(const ZenCliOptions& GlobalOptions) override;
+
+private:
+ void RunFixedCount(const std::string& BaseUri, const std::string& Path);
+ void RunContinuous(const std::string& BaseUri, const std::string& Path);
+
+ std::string m_Url;
+ std::string m_SocketPath;
+ int m_Count = 100;
+ int m_Concurrency = 1;
+ std::string m_Method = "GET";
+ bool m_NoKeepAlive = false;
+ bool m_Continuous = false;
+};
+
+class BenchCommand : public ZenCmdWithSubCommands
{
public:
static constexpr char Name[] = "bench";
@@ -15,14 +44,14 @@ public:
BenchCommand();
~BenchCommand();
- virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
- virtual cxxopts::Options& Options() override { return m_Options; }
- virtual ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
+ cxxopts::Options& Options() override { return m_Options; }
+ ZenCmdCategory& CommandCategory() const override { return g_UtilitiesCategory; }
private:
cxxopts::Options m_Options{Name, Description};
- bool m_PurgeStandbyLists = false;
- bool m_SingleProcess = false;
+ std::string m_SubCommand;
+ BenchPurgeSubCmd m_PurgeSubCmd;
+ BenchHttpSubCmd m_HttpSubCmd;
};
} // namespace zen
diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua
index f889c3296..4c134404a 100644
--- a/src/zen/xmake.lua
+++ b/src/zen/xmake.lua
@@ -6,7 +6,7 @@ target("zen")
add_files("**.cpp")
add_files("zen.cpp", {unity_ignored = true })
add_deps("zencore", "zenhttp", "zenremotestore", "zenstore", "zenutil")
- add_deps("zencompute", "zennet")
+ add_deps("zencompute", "zennet", "zentelemetry")
add_deps("cxxopts", "fmt")
add_packages("json11")
add_includedirs(".")
diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp
index 9a466da2e..86c29344e 100644
--- a/src/zen/zen.cpp
+++ b/src/zen/zen.cpp
@@ -196,6 +196,7 @@ ZenCmdBase::GetSubCommand(cxxopts::Options&,
ZenSubCmdBase::ZenSubCmdBase(std::string_view Name, std::string_view Description)
: m_SubOptions(std::string(Name), std::string(Description))
+, m_Description(Description)
{
m_SubOptions.add_options()("h,help", "Print help");
}
@@ -213,6 +214,35 @@ ZenCmdWithSubCommands::OnParentOptionsParsed(const ZenCliOptions& /*GlobalOption
}
void
+ZenCmdWithSubCommands::PrintHelp()
+{
+ // Show all option groups except the internal "__hidden__" group used to
+ // silently capture positional arguments.
+ std::vector<std::string> Groups = Options().groups();
+ Groups.erase(std::remove(Groups.begin(), Groups.end(), std::string("__hidden__")), Groups.end());
+
+ Options().set_width(TuiConsoleColumns(80));
+ printf("%s\n", Options().help(Groups).c_str());
+
+ // Append subcommand listing.
+ size_t MaxNameLen = 0;
+ for (ZenSubCmdBase* SubCmd : m_SubCommands)
+ {
+ MaxNameLen = std::max(MaxNameLen, SubCmd->SubOptions().program().size());
+ }
+
+ printf("subcommands:\n");
+ for (ZenSubCmdBase* SubCmd : m_SubCommands)
+ {
+ printf(" %-*s %s\n",
+ static_cast<int>(MaxNameLen),
+ SubCmd->SubOptions().program().c_str(),
+ std::string(SubCmd->Description()).c_str());
+ }
+ printf("\nFor global options run: zen --help\n");
+}
+
+void
ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
std::vector<cxxopts::Options*> SubOptionPtrs;
@@ -226,15 +256,47 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char**
std::vector<char*> SubCommandArguments;
int ParentArgc = GetSubCommand(Options(), argc, argv, SubOptionPtrs, MatchedSubOption, SubCommandArguments);
- if (!ParseOptions(Options(), ParentArgc, argv))
+ // Intercept --help/-h in the parent arg range before calling ParseOptions so
+ // we can append subcommand information to the output. When a subcommand was
+ // found argv[ParentArgc-1] is the subcommand name itself, which we exclude.
+ int ParentArgEnd = (MatchedSubOption != nullptr) ? ParentArgc - 1 : ParentArgc;
+ for (int i = 1; i < ParentArgEnd; ++i)
{
- return;
+ std::string_view Arg(argv[i]);
+ if (Arg == "--help" || Arg == "-h")
+ {
+ PrintHelp();
+ return;
+ }
+ }
+
+ // Parse parent options. When a subcommand was matched we strip its name from
+ // the arg list so the parent parser does not see it as an unmatched positional.
+ if (MatchedSubOption != nullptr)
+ {
+ std::vector<char*> ParentArgs;
+ ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1));
+ ParentArgs.push_back(argv[0]);
+ std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs));
+ if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data()))
+ {
+ return;
+ }
+ }
+ else
+ {
+ if (!ParseOptions(Options(), ParentArgc, argv))
+ {
+ return;
+ }
}
if (MatchedSubOption == nullptr)
{
+ PrintHelp();
+
ExtendableStringBuilder<128> VerbList;
- for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands)
+ for (bool First = true; ZenSubCmdBase* SubCmd : m_SubCommands)
{
if (!First)
{
@@ -243,7 +305,7 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char**
VerbList.Append(SubCmd->SubOptions().program());
First = false;
}
- throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), Options().help());
+ throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), {});
}
ZenSubCmdBase* MatchedSubCmd = nullptr;
@@ -621,6 +683,9 @@ main(int argc, char** argv)
Options.add_options()("malloc", "Configure memory allocator subsystem", cxxopts::value(MemoryOptions)->default_value("mimalloc"));
Options.add_options()("help", "Show command line help");
Options.add_options()("c, command", "Sub command", cxxopts::value<std::string>(SubCommand));
+ Options.add_options()("httpclient",
+ "Select HTTP client implementation (e.g. 'curl', 'cpr')",
+ cxxopts::value<std::string>(GlobalOptions.HttpClientBackend)->default_value("cpr"));
int CoreLimit = 0;
@@ -783,6 +848,8 @@ main(int argc, char** argv)
FreeCallstack(Callstack);
});
+ zen::SetDefaultHttpClientBackend(GlobalOptions.HttpClientBackend);
+
zen::MaximizeOpenFileCount();
//////////////////////////////////////////////////////////////////////////
diff --git a/src/zen/zen.h b/src/zen/zen.h
index 06e5356a6..3cc06eea6 100644
--- a/src/zen/zen.h
+++ b/src/zen/zen.h
@@ -17,6 +17,8 @@ struct ZenCliOptions
ZenLoggingConfig LoggingConfig;
+ std::string HttpClientBackend; // Choice of HTTP client implementation (e.g. "curl", "cpr")
+
// Arguments after " -- " on command line are passed through and not parsed
std::string PassthroughCommandLine;
std::string PassthroughArgs;
@@ -86,10 +88,14 @@ public:
ZenSubCmdBase(std::string_view Name, std::string_view Description);
virtual ~ZenSubCmdBase() = default;
cxxopts::Options& SubOptions() { return m_SubOptions; }
+ std::string_view Description() const { return m_Description; }
virtual void Run(const ZenCliOptions& GlobalOptions) = 0;
protected:
cxxopts::Options m_SubOptions;
+
+private:
+ std::string m_Description;
};
// Base for commands that host subcommands - handles all dispatch boilerplate
@@ -101,6 +107,7 @@ public:
protected:
void AddSubCommand(ZenSubCmdBase& SubCmd);
virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions);
+ void PrintHelp();
private:
std::vector<ZenSubCmdBase*> m_SubCommands;
diff --git a/src/zenbase/include/zenbase/concepts.h b/src/zenbase/include/zenbase/concepts.h
index d4a9d75e8..1da56cefe 100644
--- a/src/zenbase/include/zenbase/concepts.h
+++ b/src/zenbase/include/zenbase/concepts.h
@@ -4,6 +4,8 @@
#include <zenbase/zenbase.h>
+#include <type_traits>
+
// At the time of writing only ver >= 13 of LLVM's libc++ has an implementation
// of std::integral. Some platforms like Ubuntu and Mac OS are still on 12.
#if defined(__cpp_lib_concepts)
diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp
index 8d087e8c6..58b76783a 100644
--- a/src/zencore/sentryintegration.cpp
+++ b/src/zencore/sentryintegration.cpp
@@ -60,11 +60,12 @@ public:
}
try
{
- std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Msg.GetSource().Filename, Msg.GetSource().Line);
- sentry_value_t Event = sentry_value_new_message_event(
- /* level */ MapToSentryLevel[Msg.GetLevel()],
- /* logger */ nullptr,
- /* message */ Message.c_str());
+ const char* Filename = Msg.GetSource().Filename ? Msg.GetSource().Filename : "<unknown>";
+ std::string Message = fmt::format("{}\n{}({})", Msg.GetPayload(), Filename, Msg.GetSource().Line);
+ sentry_value_t Event = sentry_value_new_message_event(
+ /* level */ MapToSentryLevel[Msg.GetLevel()],
+ /* logger */ nullptr,
+ /* message */ Message.c_str());
sentry_event_value_add_stacktrace(Event, NULL, 0);
sentry_capture_event(Event);
}
diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp
index 089e376bb..d7eb3b17d 100644
--- a/src/zencore/testing.cpp
+++ b/src/zencore/testing.cpp
@@ -11,17 +11,131 @@
#if ZEN_WITH_TESTS
+# include <zencore/callstack.h>
+
# include <chrono>
# include <clocale>
+# include <csignal>
# include <cstdlib>
# include <cstdio>
# include <string>
# include <vector>
+# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+# include <execinfo.h>
+# include <unistd.h>
+# endif
+
namespace zen::testing {
using namespace std::literals;
+static void
+PrintCrashCallstack([[maybe_unused]] const char* SignalName)
+{
+# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ // Use write() + backtrace_symbols_fd() which are async-signal-safe
+ write(STDERR_FILENO, "\n*** Caught ", 12);
+ write(STDERR_FILENO, SignalName, strlen(SignalName));
+ write(STDERR_FILENO, " — callstack:\n", 15);
+
+ void* Frames[64];
+ int FrameCount = backtrace(Frames, 64);
+ backtrace_symbols_fd(Frames, FrameCount, STDERR_FILENO);
+# elif ZEN_PLATFORM_WINDOWS
+ // On Windows we're called from SEH, not a signal handler, so heap/locks are safe
+ void* Addresses[64];
+ uint32_t FrameCount = GetCallstack(2, 64, Addresses);
+ if (FrameCount > 0)
+ {
+ std::vector<std::string> Symbols = GetFrameSymbols(FrameCount, Addresses);
+ fprintf(stderr, "\n*** Caught %s - callstack:\n", SignalName);
+ for (uint32_t i = 0; i < FrameCount; ++i)
+ {
+ fprintf(stderr, " %2u: %s\n", i, Symbols[i].c_str());
+ }
+ }
+# endif
+}
+
+# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
+static void
+CrashSignalHandler(int Signal)
+{
+ const char* SignalName = "Unknown signal";
+ switch (Signal)
+ {
+ case SIGSEGV:
+ SignalName = "SIGSEGV";
+ break;
+ case SIGABRT:
+ SignalName = "SIGABRT";
+ break;
+ case SIGFPE:
+ SignalName = "SIGFPE";
+ break;
+ case SIGBUS:
+ SignalName = "SIGBUS";
+ break;
+ case SIGILL:
+ SignalName = "SIGILL";
+ break;
+ }
+
+ PrintCrashCallstack(SignalName);
+
+ // Re-raise with default handler so the process terminates normally
+ signal(Signal, SIG_DFL);
+ raise(Signal);
+}
+
+# endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
+# if ZEN_PLATFORM_WINDOWS
+
+static LONG CALLBACK
+CrashVectoredHandler(PEXCEPTION_POINTERS ExceptionInfo)
+{
+ // Only handle fatal exceptions, not first-chance exceptions used for normal control flow
+ switch (ExceptionInfo->ExceptionRecord->ExceptionCode)
+ {
+ case EXCEPTION_ACCESS_VIOLATION:
+ PrintCrashCallstack("EXCEPTION_ACCESS_VIOLATION");
+ break;
+ case EXCEPTION_STACK_OVERFLOW:
+ PrintCrashCallstack("EXCEPTION_STACK_OVERFLOW");
+ break;
+ case EXCEPTION_ILLEGAL_INSTRUCTION:
+ PrintCrashCallstack("EXCEPTION_ILLEGAL_INSTRUCTION");
+ break;
+ case EXCEPTION_INT_DIVIDE_BY_ZERO:
+ PrintCrashCallstack("EXCEPTION_INT_DIVIDE_BY_ZERO");
+ break;
+ default:
+ break;
+ }
+
+ // Continue search so doctest's handler can report the test case context
+ return EXCEPTION_CONTINUE_SEARCH;
+}
+
+# endif // ZEN_PLATFORM_WINDOWS
+
+static void
+InstallCrashSignalHandlers()
+{
+# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ signal(SIGSEGV, CrashSignalHandler);
+ signal(SIGABRT, CrashSignalHandler);
+ signal(SIGFPE, CrashSignalHandler);
+ signal(SIGBUS, CrashSignalHandler);
+ signal(SIGILL, CrashSignalHandler);
+# elif ZEN_PLATFORM_WINDOWS
+ AddVectoredExceptionHandler(0 /*called last among vectored handlers*/, CrashVectoredHandler);
+# endif
+}
+
struct TestListener : public doctest::IReporter
{
const std::string_view ColorYellow = "\033[0;33m"sv;
@@ -184,6 +298,7 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink
zen::logging::InitializeLogging();
zen::MaximizeOpenFileCount();
+ InstallCrashSignalHandlers();
TestRunner Runner;
diff --git a/src/zenhttp-test/zenhttp-test.cpp b/src/zenhttp-test/zenhttp-test.cpp
index b4b406ac8..0a6980462 100644
--- a/src/zenhttp-test/zenhttp-test.cpp
+++ b/src/zenhttp-test/zenhttp-test.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include <zencore/testing.h>
+#include <zenhttp/httpclient.h>
#include <zenhttp/zenhttp.h>
#include <zencore/memory/newdelete.h>
@@ -9,6 +10,17 @@ int
main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
{
#if ZEN_WITH_TESTS
+ using namespace std::literals;
+ for (int i = 1; i < argc; ++i)
+ {
+ std::string_view Arg(argv[i]);
+ if (Arg.starts_with("--httpclient="sv))
+ {
+ std::string_view Value = Arg.substr(13);
+ zen::SetDefaultHttpClientBackend(Value);
+ }
+ }
+
return zen::testing::RunTestMain(argc, argv, "zenhttp-test", zen::zenhttp_forcelinktests);
#else
return 0;
diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp
index 6f4c67dd0..e4d11547a 100644
--- a/src/zenhttp/clients/httpclientcommon.cpp
+++ b/src/zenhttp/clients/httpclientcommon.cpp
@@ -646,6 +646,63 @@ TEST_CASE("CompositeBufferReadStream")
CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data));
}
+TEST_CASE("ParseContentRange")
+{
+ SUBCASE("normal range with total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500");
+ CHECK_EQ(Offset, 0);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("non-zero offset")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878");
+ CHECK_EQ(Offset, 2638);
+ CHECK_EQ(Length, 5111437 - 2638 + 1);
+ }
+
+ SUBCASE("wildcard total size")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*");
+ CHECK_EQ(Offset, 100);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("no slash (total size omitted)")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 50-149");
+ CHECK_EQ(Offset, 50);
+ CHECK_EQ(Length, 100);
+ }
+
+ SUBCASE("malformed input returns zeros")
+ {
+ auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500");
+ CHECK_EQ(Offset1, 0);
+ CHECK_EQ(Length1, 0);
+
+ auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500");
+ CHECK_EQ(Offset2, 0);
+ CHECK_EQ(Length2, 0);
+
+ auto [Offset3, Length3] = detail::ParseContentRange("");
+ CHECK_EQ(Offset3, 0);
+ CHECK_EQ(Length3, 0);
+
+ auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500");
+ CHECK_EQ(Offset4, 0);
+ CHECK_EQ(Length4, 0);
+ }
+
+ SUBCASE("single byte range")
+ {
+ auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000");
+ CHECK_EQ(Offset, 42);
+ CHECK_EQ(Length, 1);
+ }
+}
+
TEST_CASE("MultipartBoundaryParser")
{
uint64_t Range1Offset = 2638;
diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp
index 14e40b02a..f3082e0a2 100644
--- a/src/zenhttp/clients/httpclientcpr.cpp
+++ b/src/zenhttp/clients/httpclientcpr.cpp
@@ -14,6 +14,11 @@
#include <zenhttp/packageformat.h>
#include <algorithm>
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/ssl_options.h>
+#include <cpr/unix_socket.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
namespace zen {
HttpClientBase*
@@ -24,84 +29,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti
static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
-bool
-HttpClient::ErrorContext::IsConnectionError() const
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClientErrorCode
+MapCprError(cpr::ErrorCode Code)
{
- switch (static_cast<cpr::ErrorCode>(ErrorCode))
+ switch (Code)
{
+ case cpr::ErrorCode::OK:
+ return HttpClientErrorCode::kOK;
case cpr::ErrorCode::CONNECTION_FAILURE:
- case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kConnectionFailure;
case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
+ return HttpClientErrorCode::kHostResolutionFailure;
case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
- return true;
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case cpr::ErrorCode::INTERNAL_ERROR:
+ return HttpClientErrorCode::kInternalError;
+ case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case cpr::ErrorCode::NETWORK_SEND_FAILURE:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case cpr::ErrorCode::OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case cpr::ErrorCode::SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
+ case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case cpr::ErrorCode::SSL_CACERT_ERROR:
+ return HttpClientErrorCode::kSSLCACertError;
+ case cpr::ErrorCode::GENERIC_SSL_ERROR:
+ return HttpClientErrorCode::kGenericSSLError;
+ case cpr::ErrorCode::REQUEST_CANCELLED:
+ return HttpClientErrorCode::kRequestCancelled;
default:
- return false;
- }
-}
-
-// If we want to support different HTTP client implementations then we'll need to make this more abstract
-
-HttpClientError::ResponseClass
-HttpClientError::GetResponseClass() const
-{
- if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK)
- {
- switch ((cpr::ErrorCode)m_Error)
- {
- case cpr::ErrorCode::CONNECTION_FAILURE:
- return ResponseClass::kHttpCantConnectError;
- case cpr::ErrorCode::HOST_RESOLUTION_FAILURE:
- case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE:
- return ResponseClass::kHttpNoHost;
- case cpr::ErrorCode::INTERNAL_ERROR:
- case cpr::ErrorCode::NETWORK_RECEIVE_ERROR:
- case cpr::ErrorCode::NETWORK_SEND_FAILURE:
- case cpr::ErrorCode::OPERATION_TIMEDOUT:
- return ResponseClass::kHttpTimeout;
- case cpr::ErrorCode::SSL_CONNECT_ERROR:
- case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR:
- case cpr::ErrorCode::SSL_CACERT_ERROR:
- case cpr::ErrorCode::GENERIC_SSL_ERROR:
- return ResponseClass::kHttpSLLError;
- default:
- return ResponseClass::kHttpOtherClientError;
- }
- }
- else if (IsHttpSuccessCode(m_ResponseCode))
- {
- return ResponseClass::kSuccess;
- }
- else
- {
- switch (m_ResponseCode)
- {
- case HttpResponseCode::Unauthorized:
- return ResponseClass::kHttpUnauthorized;
- case HttpResponseCode::NotFound:
- return ResponseClass::kHttpNotFound;
- case HttpResponseCode::Forbidden:
- return ResponseClass::kHttpForbidden;
- case HttpResponseCode::Conflict:
- return ResponseClass::kHttpConflict;
- case HttpResponseCode::InternalServerError:
- return ResponseClass::kHttpInternalServerError;
- case HttpResponseCode::ServiceUnavailable:
- return ResponseClass::kHttpServiceUnavailable;
- case HttpResponseCode::BadGateway:
- return ResponseClass::kHttpBadGateway;
- case HttpResponseCode::GatewayTimeout:
- return ResponseClass::kHttpGatewayTimeout;
- default:
- if (m_ResponseCode >= HttpResponseCode::InternalServerError)
- {
- return ResponseClass::kHttpOtherServerError;
- }
- else
- {
- return ResponseClass::kHttpOtherClientError;
- }
- }
+ return HttpClientErrorCode::kOtherError;
}
}
@@ -257,8 +220,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId,
.UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes),
.DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes),
.ElapsedSeconds = HttpResponse.elapsed,
- .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code),
- .ErrorMessage = HttpResponse.error.message}};
+ .Error =
+ HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}};
}
if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload))
@@ -526,6 +489,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
{
CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}});
}
+ if (ConnectionSettings.ForbidReuseConnection)
+ {
+ CprSession->UpdateHeader({{"Connection", "close"}});
+ }
if (AccessToken)
{
CprSession->UpdateHeader({{"Authorization", AccessToken->Value}});
@@ -544,6 +511,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl,
CprSession->SetParameters({});
}
+ if (!ConnectionSettings.UnixSocketPath.empty())
+ {
+ CprSession->SetUnixSocket(cpr::UnixSocket(ConnectionSettings.UnixSocketPath));
+ }
+
+ if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty())
+ {
+ cpr::SslOptions SslOpts;
+ if (ConnectionSettings.InsecureSsl)
+ {
+ SslOpts.SetOption(cpr::ssl::VerifyHost{false});
+ SslOpts.SetOption(cpr::ssl::VerifyPeer{false});
+ }
+ if (!ConnectionSettings.CaBundlePath.empty())
+ {
+ SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath});
+ }
+ CprSession->SetSslOptions(SslOpts);
+ }
+
ExtendableStringBuilder<128> UrlBuffer;
UrlBuffer << BaseUrl << ResourcePath;
CprSession->SetUrl(UrlBuffer.c_str());
diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp
new file mode 100644
index 000000000..3cb749018
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.cpp
@@ -0,0 +1,1947 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpclientcurl.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryutil.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zenhttp/packageformat.h>
+#include <algorithm>
+
+namespace zen {
+
+HttpClientBase*
+CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
+{
+ return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+}
+
+static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0};
+
+//////////////////////////////////////////////////////////////////////////
+
+static HttpClientErrorCode
+MapCurlError(CURLcode Code)
+{
+ switch (Code)
+ {
+ case CURLE_OK:
+ return HttpClientErrorCode::kOK;
+ case CURLE_COULDNT_CONNECT:
+ return HttpClientErrorCode::kConnectionFailure;
+ case CURLE_COULDNT_RESOLVE_HOST:
+ return HttpClientErrorCode::kHostResolutionFailure;
+ case CURLE_COULDNT_RESOLVE_PROXY:
+ return HttpClientErrorCode::kProxyResolutionFailure;
+ case CURLE_RECV_ERROR:
+ return HttpClientErrorCode::kNetworkReceiveError;
+ case CURLE_SEND_ERROR:
+ return HttpClientErrorCode::kNetworkSendFailure;
+ case CURLE_OPERATION_TIMEDOUT:
+ return HttpClientErrorCode::kOperationTimedOut;
+ case CURLE_SSL_CONNECT_ERROR:
+ return HttpClientErrorCode::kSSLConnectError;
+ case CURLE_SSL_CERTPROBLEM:
+ return HttpClientErrorCode::kSSLCertificateError;
+ case CURLE_PEER_FAILED_VERIFICATION:
+ return HttpClientErrorCode::kSSLCACertError;
+ case CURLE_SSL_CIPHER:
+ case CURLE_SSL_ENGINE_NOTFOUND:
+ case CURLE_SSL_ENGINE_SETFAILED:
+ return HttpClientErrorCode::kGenericSSLError;
+ case CURLE_ABORTED_BY_CALLBACK:
+ return HttpClientErrorCode::kRequestCancelled;
+ default:
+ return HttpClientErrorCode::kOtherError;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Curl callback helpers
+
+struct WriteCallbackData
+{
+ std::string* Body = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<WriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0; // Signal abort to curl
+ }
+
+ Data->Body->append(Ptr, TotalBytes);
+ return TotalBytes;
+}
+
+struct HeaderCallbackData
+{
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+};
+
+static size_t
+CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<HeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ std::string_view Line(Buffer, TotalBytes);
+
+ // Trim trailing \r\n
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return TotalBytes;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ // Trim whitespace
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+}
+
+struct ReadCallbackData
+{
+ const uint8_t* DataPtr = nullptr;
+ size_t DataSize = 0;
+ size_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<ReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->DataSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+struct StreamReadCallbackData
+{
+ detail::CompositeBufferReadStream* Reader = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<StreamReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ return Data->Reader->Read(Buffer, MaxRead);
+}
+
+struct FileReadCallbackData
+{
+ detail::BufferedReadFileStream* Buffer = nullptr;
+ uint64_t TotalSize = 0;
+ uint64_t Offset = 0;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+};
+
+static size_t
+CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData)
+{
+ auto* Data = static_cast<FileReadCallbackData*>(UserData);
+ size_t MaxRead = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return CURL_READFUNC_ABORT;
+ }
+
+ size_t Remaining = Data->TotalSize - Data->Offset;
+ size_t ToRead = std::min(MaxRead, Remaining);
+
+ if (ToRead > 0)
+ {
+ Data->Buffer->Read(Buffer, ToRead);
+ Data->Offset += ToRead;
+ }
+
+ return ToRead;
+}
+
+static int
+CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr)
+{
+ ZEN_UNUSED(Handle);
+ LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr);
+ auto Log = [&]() -> LoggerRef { return LogRef; };
+
+ std::string_view DataView(Data, Size);
+
+ // Trim trailing newlines
+ while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n'))
+ {
+ DataView.remove_suffix(1);
+ }
+
+ switch (Type)
+ {
+ case CURLINFO_TEXT:
+ if (DataView.find("need more data"sv) == std::string_view::npos)
+ {
+ ZEN_INFO("TEXT: {}", DataView);
+ }
+ break;
+ case CURLINFO_HEADER_IN:
+ ZEN_INFO("HIN : {}", DataView);
+ break;
+ case CURLINFO_HEADER_OUT:
+ if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos)
+ {
+ std::string Copy(DataView);
+ auto BearerStart = TokenPos + 22;
+ auto BearerEnd = Copy.find_first_of("\r\n", BearerStart);
+ if (BearerEnd == std::string::npos)
+ {
+ BearerEnd = Copy.length();
+ }
+ Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart));
+ ZEN_INFO("HOUT: {}", Copy);
+ }
+ else
+ {
+ ZEN_INFO("HOUT: {}", DataView);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static std::pair<std::string, std::string>
+HeaderContentType(ZenContentType ContentType)
+{
+ return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType)));
+}
+
+static curl_slist*
+BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader,
+ std::string_view SessionId,
+ const std::optional<HttpClientAccessToken>& AccessToken,
+ const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {})
+{
+ curl_slist* Headers = nullptr;
+
+ for (const auto& [Key, Value] : *AdditionalHeader)
+ {
+ std::string HeaderLine = fmt::format("{}: {}", Key, Value);
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ if (!SessionId.empty())
+ {
+ std::string SessionHeader = fmt::format("UE-Session: {}", SessionId);
+ Headers = curl_slist_append(Headers, SessionHeader.c_str());
+ }
+
+ if (AccessToken)
+ {
+ std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value);
+ Headers = curl_slist_append(Headers, AuthHeader.c_str());
+ }
+
+ for (const auto& [Key, Value] : ExtraHeaders)
+ {
+ std::string HeaderLine = fmt::format("{}: {}", Key, Value);
+ Headers = curl_slist_append(Headers, HeaderLine.c_str());
+ }
+
+ return Headers;
+}
+
+static std::string
+BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters)
+{
+ std::string Url;
+ Url.reserve(BaseUrl.size() + ResourcePath.size() + 64);
+ Url.append(BaseUrl);
+ Url.append(ResourcePath);
+
+ if (!Parameters->empty())
+ {
+ char Separator = '?';
+ for (const auto& [Key, Value] : *Parameters)
+ {
+ char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size()));
+ char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size()));
+ Url += Separator;
+ Url += EncodedKey;
+ Url += '=';
+ Url += EncodedValue;
+ curl_free(EncodedKey);
+ curl_free(EncodedValue);
+ Separator = '&';
+ }
+ }
+
+ return Url;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::CurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction)
+: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction))
+{
+}
+
+CurlHttpClient::~CurlHttpClient()
+{
+ ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient");
+ m_SessionLock.WithExclusiveLock([&] {
+ for (auto* Handle : m_Sessions)
+ {
+ curl_easy_cleanup(Handle);
+ }
+ m_Sessions.clear();
+ });
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::Session::Perform()
+{
+ CurlResult Result;
+
+ char ErrorBuffer[CURL_ERROR_SIZE] = {};
+ curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer);
+
+ Result.ErrorCode = curl_easy_perform(Handle);
+
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode);
+ }
+
+ curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode);
+
+ double Elapsed = 0;
+ curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed);
+ Result.ElapsedSeconds = Elapsed;
+
+ curl_off_t UpBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes);
+ Result.UploadedBytes = static_cast<int64_t>(UpBytes);
+
+ curl_off_t DownBytes = 0;
+ curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes);
+ Result.DownloadedBytes = static_cast<int64_t>(DownBytes);
+
+ return Result;
+}
+
+bool
+CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const
+{
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return false;
+ }
+ const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes;
+ return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end();
+}
+
+HttpClient::Response
+CurlHttpClient::ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size());
+
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "Content-Type")
+ {
+ const HttpContentType ContentType = ParseContentType(Value);
+ ResponseBuffer.SetContentType(ContentType);
+ break;
+ }
+ }
+
+ if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound)
+ {
+ if (ShouldLogErrorCode(WorkResponseCode))
+ {
+ ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}",
+ SessionId,
+ static_cast<int>(WorkResponseCode),
+ m_BaseUri);
+ }
+ }
+
+ std::sort(BoundaryPositions.begin(),
+ BoundaryPositions.end(),
+ [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) {
+ return Lhs.RangeOffset < Rhs.RangeOffset;
+ });
+
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .ResponsePayload = std::move(ResponseBuffer),
+ .Header = std::move(HeaderMap),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Ranges = std::move(BoundaryPositions)};
+}
+
+HttpClient::Response
+CurlHttpClient::CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions)
+{
+ const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode);
+ if (Result.ErrorCode != CURLE_OK)
+ {
+ const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction();
+ if (!Quiet)
+ {
+ if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT &&
+ Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK)
+ {
+ ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'",
+ SessionId,
+ static_cast<int>(Result.ErrorCode),
+ Result.ErrorMessage);
+ }
+ }
+
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+
+ return HttpClient::Response{
+ .StatusCode = WorkResponseCode,
+ .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()),
+ .Header = std::move(HeaderMap),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}};
+ }
+
+ if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload))
+ {
+ HttpClient::KeyValueMap HeaderMap;
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ HeaderMap->insert_or_assign(Key, Value);
+ }
+
+ return HttpClient::Response{.StatusCode = WorkResponseCode,
+ .Header = std::move(HeaderMap),
+ .UploadedBytes = Result.UploadedBytes,
+ .DownloadedBytes = Result.DownloadedBytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ else
+ {
+ return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions));
+ }
+}
+
+bool
+CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ ZEN_TRACE_CPU("ValidatePayload");
+
+ IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer()
+ : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size());
+
+ // Find Content-Length in headers
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "Content-Length")
+ {
+ std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value);
+ if (!ExpectedContentSize.has_value())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value);
+ return false;
+ }
+ if (ExpectedContentSize.value() != ResponseBuffer.GetSize())
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value);
+ return false;
+ }
+ break;
+ }
+ }
+
+ if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent))
+ {
+ return true;
+ }
+
+ // Check X-Jupiter-IoHash
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "X-Jupiter-IoHash")
+ {
+ IoHash ExpectedPayloadHash;
+ if (IoHash::TryParse(Value, ExpectedPayloadHash))
+ {
+ IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer);
+ if (PayloadHash != ExpectedPayloadHash)
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}",
+ PayloadHash.ToHexString(),
+ ExpectedPayloadHash.ToHexString());
+ return false;
+ }
+ }
+ break;
+ }
+ }
+
+ // Validate content-type specific payload
+ for (const auto& [Key, Value] : Result.Headers)
+ {
+ if (Key == "Content-Type")
+ {
+ if (Value == "application/x-ue-comp")
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer,
+ RawHash,
+ RawSize,
+ /*OutOptionalTotalCompressedSize*/ nullptr))
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = "Compressed binary failed validation";
+ return false;
+ }
+ }
+ if (Value == "application/x-ue-cb")
+ {
+ if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default);
+ Error == CbValidateError::None)
+ {
+ return true;
+ }
+ else
+ {
+ Result.ErrorCode = CURLE_RECV_ERROR;
+ Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error));
+ return false;
+ }
+ }
+ break;
+ }
+ }
+
+ return true;
+}
+
+bool
+CurlHttpClient::ShouldRetry(const CurlResult& Result)
+{
+ switch (Result.ErrorCode)
+ {
+ case CURLE_OK:
+ break;
+ case CURLE_RECV_ERROR:
+ case CURLE_SEND_ERROR:
+ case CURLE_OPERATION_TIMEDOUT:
+ return true;
+ default:
+ return false;
+ }
+ switch (static_cast<HttpResponseCode>(Result.StatusCode))
+ {
+ case HttpResponseCode::RequestTimeout:
+ case HttpResponseCode::TooManyRequests:
+ case HttpResponseCode::InternalServerError:
+ case HttpResponseCode::BadGateway:
+ case HttpResponseCode::ServiceUnavailable:
+ case HttpResponseCode::GatewayTimeout:
+ return true;
+ default:
+ return false;
+ }
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate)
+{
+ uint8_t Attempt = 0;
+ CurlResult Result = Func();
+ while (Attempt < m_ConnectionSettings.RetryCount)
+ {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return Result;
+ }
+ if (!ShouldRetry(Result))
+ {
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ break;
+ }
+ if (Validate(Result))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode)))
+ {
+ ZEN_INFO("{} Attempt {}/{}",
+ CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ Result = Func();
+ }
+ return Result;
+}
+
+CurlHttpClient::CurlResult
+CurlHttpClient::DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile)
+{
+ uint8_t Attempt = 0;
+ CurlResult Result = Func();
+ while (Attempt < m_ConnectionSettings.RetryCount)
+ {
+ if (m_CheckIfAbortFunction && m_CheckIfAbortFunction())
+ {
+ return Result;
+ }
+ if (!ShouldRetry(Result))
+ {
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ break;
+ }
+ if (ValidatePayload(Result, PayloadFile))
+ {
+ break;
+ }
+ }
+ Sleep(100 * (Attempt + 1));
+ Attempt++;
+ if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode)))
+ {
+ ZEN_INFO("{} Attempt {}/{}",
+ CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"),
+ Attempt,
+ m_ConnectionSettings.RetryCount + 1);
+ }
+ Result = Func();
+ }
+ return Result;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::Session
+CurlHttpClient::AllocSession(std::string_view BaseUrl,
+ std::string_view ResourcePath,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken)
+{
+ ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader);
+ ZEN_TRACE_CPU("CurlHttpClient::AllocSession");
+ CURL* Handle = nullptr;
+ m_SessionLock.WithExclusiveLock([&] {
+ if (!m_Sessions.empty())
+ {
+ Handle = m_Sessions.back();
+ m_Sessions.pop_back();
+ }
+ });
+
+ if (Handle == nullptr)
+ {
+ Handle = curl_easy_init();
+ }
+ else
+ {
+ curl_easy_reset(Handle);
+ }
+
+ // Unix domain socket
+ if (!ConnectionSettings.UnixSocketPath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str());
+ }
+
+ // Build URL with parameters
+ std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters);
+ curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str());
+
+ // Timeouts
+ if (ConnectionSettings.ConnectTimeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count()));
+ }
+ if (ConnectionSettings.Timeout.count() > 0)
+ {
+ curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count()));
+ }
+
+ // HTTP/2
+ if (ConnectionSettings.AssumeHttp2)
+ {
+ curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE);
+ }
+
+ // Verbose/debug
+ if (ConnectionSettings.Verbose)
+ {
+ curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback);
+ curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log);
+ }
+
+ // SSL options
+ if (ConnectionSettings.InsecureSsl)
+ {
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L);
+ curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L);
+ }
+ if (!ConnectionSettings.CaBundlePath.empty())
+ {
+ curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str());
+ }
+
+ // Disable signal handling for thread safety
+ curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L);
+
+ if (ConnectionSettings.ForbidReuseConnection)
+ {
+ curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L);
+ }
+
+ // Note: Headers are NOT set here. Each method builds its own header list
+ // (potentially adding method-specific headers like Content-Type) and is
+ // responsible for freeing it with curl_slist_free_all.
+
+ return Session(this, Handle);
+}
+
+void
+CurlHttpClient::ReleaseSession(CURL* Handle)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession");
+
+ // Free any header list that was set
+ // curl_easy_reset will be called on next AllocSession, which cleans up the handle state.
+ // We just push the handle back to the pool.
+ m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CurlHttpClient::Response
+CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::TransactPackage");
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++CurlHttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (Attachments.empty() == false)
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Writer.AddHash(Attachment.GetHash());
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ std::vector<std::pair<std::string, std::string>> OfferExtraHeaders;
+ OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer));
+ OfferExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders);
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList);
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size()));
+
+ std::string FilterBody;
+ WriteCallbackData WriteData{.Body = &FilterBody};
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+
+ CurlResult Result = Sess.Perform();
+
+ curl_slist_free_all(HeaderList);
+
+ if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200)
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size());
+ CbValidateError ValidationError = CbValidateError::None;
+ if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError);
+ ValidationError == CbValidateError::None)
+ {
+ for (CbFieldView& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ std::vector<std::pair<std::string, std::string>> PkgExtraHeaders;
+ PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage));
+ PkgExtraHeaders.emplace_back("UE-Request", RequestIdString);
+
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders);
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList);
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize()));
+
+ std::string PkgBody;
+ WriteCallbackData WriteData{.Body = &PkgBody};
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+
+ CurlResult Result = Sess.Perform();
+
+ curl_slist_free_all(HeaderList);
+
+ if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode))
+ {
+ return {.StatusCode = HttpResponseCode(Result.StatusCode)};
+ }
+
+ IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size());
+
+ return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = ResponseBuffer};
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Standard HTTP verbs
+//
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Put");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}};
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Get");
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_HTTPGET, 1L);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ },
+ [this](CurlResult& Result) {
+ std::unique_ptr<detail::TempPayloadFile> NoTempFile;
+ return ValidatePayload(Result, NoTempFile);
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Head");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_NOBODY, 1L);
+
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Delete");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE");
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ // Rebuild headers with content type
+ curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData()));
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize()));
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader)
+{
+ return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader);
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Post");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_POST, 1L);
+ curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ IoBufferFileReference FileRef = {nullptr, 0, 0};
+ if (Payload.GetFileReference(FileRef))
+ {
+ detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u);
+
+ FileReadCallbackData ReadData{.Buffer = &Buffer,
+ .TotalSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }
+
+ ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()),
+ .DataSize = Payload.GetSize(),
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Upload");
+
+ return CommonResponse(
+ m_SessionId,
+ DoWithRetry(m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* Headers =
+ BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)});
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers);
+
+ curl_easy_setopt(H, CURLOPT_UPLOAD, 1L);
+ curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize()));
+
+ detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u);
+
+ StreamReadCallbackData ReadData{.Reader = &Reader,
+ .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr};
+
+ curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback);
+ curl_easy_setopt(H, CURLOPT_READDATA, &ReadData);
+
+ std::string Body;
+ WriteCallbackData WriteData{.Body = &Body};
+ HeaderCallbackData HdrData{};
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ HdrData.Headers = &ResponseHeaders;
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback);
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData);
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback);
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData);
+
+ CurlResult Result = Sess.Perform();
+ Result.Body = std::move(Body);
+ Result.Headers = std::move(ResponseHeaders);
+
+ curl_slist_free_all(Headers);
+ return Result;
+ }),
+ {});
+}
+
+CurlHttpClient::Response
+CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader)
+{
+ ZEN_TRACE_CPU("CurlHttpClient::Download");
+
+ std::string PayloadString;
+ std::unique_ptr<detail::TempPayloadFile> PayloadFile;
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+ detail::MultipartBoundaryParser BoundaryParser;
+ bool IsMultiRangeResponse = false;
+
+ CurlResult Result = DoWithRetry(
+ m_SessionId,
+ [&]() -> CurlResult {
+ Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken());
+ CURL* H = Sess.Get();
+
+ curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken());
+ curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders);
+ curl_easy_setopt(H, CURLOPT_HTTPGET, 1L);
+
+ // Reset state from any previous attempt
+ PayloadString.clear();
+ PayloadFile.reset();
+ BoundaryParser.Boundaries.clear();
+ ContentType = HttpContentType::kUnknownContentType;
+ IsMultiRangeResponse = false;
+
+ // Track requested content length from Range header (sum all ranges)
+ uint64_t RequestedContentLength = (uint64_t)-1;
+ if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end())
+ {
+ if (RangeIt->second.starts_with("bytes"))
+ {
+ std::string_view RangeValue(RangeIt->second);
+ size_t RangeStartPos = RangeValue.find('=', 5);
+ if (RangeStartPos != std::string::npos)
+ {
+ RangeStartPos++;
+ while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ')
+ {
+ RangeStartPos++;
+ }
+ RequestedContentLength = 0;
+
+ while (RangeStartPos < RangeValue.length())
+ {
+ size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos);
+ if (RangeEnd == std::string::npos)
+ {
+ RangeEnd = RangeValue.length();
+ }
+
+ std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos);
+ size_t RangeSplitPos = RangeString.find('-');
+ if (RangeSplitPos != std::string::npos)
+ {
+ std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos));
+ std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1));
+ if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value())
+ {
+ RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1;
+ }
+ }
+ RangeStartPos = RangeEnd;
+ while (RangeStartPos != RangeValue.length() &&
+ (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' '))
+ {
+ RangeStartPos++;
+ }
+ }
+ }
+ }
+ }
+
+ // Header callback that detects Content-Length and switches to file-backed storage when needed
+ struct DownloadHeaderCallbackData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ uint64_t MaxInMemorySize = 0;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ HttpContentType* ContentTypeOut = nullptr;
+ };
+
+ DownloadHeaderCallbackData DlHdrData;
+ std::vector<std::pair<std::string, std::string>> ResponseHeaders;
+ DlHdrData.Headers = &ResponseHeaders;
+ DlHdrData.PayloadFile = &PayloadFile;
+ DlHdrData.PayloadString = &PayloadString;
+ DlHdrData.TempFolderPath = &TempFolderPath;
+ DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize;
+ DlHdrData.Log = m_Log;
+ DlHdrData.BoundaryParser = &BoundaryParser;
+ DlHdrData.IsMultiRange = &IsMultiRangeResponse;
+ DlHdrData.ContentTypeOut = &ContentType;
+
+ auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ std::string_view Line(Buffer, TotalBytes);
+
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return TotalBytes;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ if (Key == "Content-Length"sv)
+ {
+ std::optional<size_t> ContentLength = ParseInt<size_t>(Value);
+ if (ContentLength.has_value())
+ {
+ if (ContentLength.value() > Data->MaxInMemorySize)
+ {
+ *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>();
+ std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value());
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ Data->PayloadFile->reset();
+ }
+ }
+ else
+ {
+ Data->PayloadString->reserve(ContentLength.value());
+ }
+ }
+ }
+ else if (Key == "Content-Type"sv)
+ {
+ *Data->IsMultiRange = Data->BoundaryParser->Init(Value);
+ if (!*Data->IsMultiRange)
+ {
+ *Data->ContentTypeOut = ParseContentType(Value);
+ }
+ }
+ else if (Key == "Content-Range"sv)
+ {
+ if (!*Data->IsMultiRange)
+ {
+ std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value);
+ if (Range.second != 0)
+ {
+ Data->BoundaryParser->Boundaries.push_back(
+ HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0,
+ .RangeOffset = Range.first,
+ .RangeLength = Range.second,
+ .ContentType = *Data->ContentTypeOut});
+ }
+ }
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb));
+ curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData);
+
+ // Write callback that directs data to file or string
+ struct DownloadWriteCallbackData
+ {
+ std::string* PayloadString = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::function<bool()>* CheckIfAbortFunction = nullptr;
+ const std::filesystem::path* TempFolderPath = nullptr;
+ LoggerRef Log;
+ detail::MultipartBoundaryParser* BoundaryParser = nullptr;
+ bool* IsMultiRange = nullptr;
+ };
+
+ DownloadWriteCallbackData DlWriteData;
+ DlWriteData.PayloadString = &PayloadString;
+ DlWriteData.PayloadFile = &PayloadFile;
+ DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr;
+ DlWriteData.TempFolderPath = &TempFolderPath;
+ DlWriteData.Log = m_Log;
+ DlWriteData.BoundaryParser = &BoundaryParser;
+ DlWriteData.IsMultiRange = &IsMultiRangeResponse;
+
+ auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<DownloadWriteCallbackData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)())
+ {
+ return 0;
+ }
+
+ if (*Data->IsMultiRange)
+ {
+ Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes));
+ }
+
+ if (*Data->PayloadFile)
+ {
+ std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes));
+ if (Ec)
+ {
+ auto Log = [&]() -> LoggerRef { return Data->Log; };
+ ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}",
+ Data->TempFolderPath->string(),
+ Ec.message());
+ return 0;
+ }
+ }
+ else
+ {
+ Data->PayloadString->append(Ptr, TotalBytes);
+ }
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData);
+
+ CurlResult Res = Sess.Perform();
+ Res.Headers = std::move(ResponseHeaders);
+
+ // Handle resume logic
+ if (m_ConnectionSettings.AllowResume)
+ {
+ auto SupportsRanges = [](const CurlResult& R) -> bool {
+ for (const auto& [K, V] : R.Headers)
+ {
+ if (K == "Content-Range")
+ {
+ return true;
+ }
+ if (K == "Accept-Ranges" && V == "bytes")
+ {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool {
+ if (IsMultiRangeResponse)
+ {
+ return false;
+ }
+ if (ShouldRetry(R))
+ {
+ return SupportsRanges(R);
+ }
+ return false;
+ };
+
+ if (ShouldResumeCheck(Res))
+ {
+ // Find Content-Length
+ std::string ContentLengthValue;
+ for (const auto& [K, V] : Res.Headers)
+ {
+ if (K == "Content-Length")
+ {
+ ContentLengthValue = V;
+ break;
+ }
+ }
+
+ if (!ContentLengthValue.empty())
+ {
+ uint64_t ContentLength = RequestedContentLength;
+ if (ContentLength == uint64_t(-1))
+ {
+ if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value())
+ {
+ ContentLength = ParsedContentLength.value();
+ }
+ }
+
+ KeyValueMap HeadersWithRange(AdditionalHeader);
+ do
+ {
+ uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length();
+
+ std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1);
+ if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end())
+ {
+ if (RangeIt->second == Range)
+ {
+ break; // No progress, abort
+ }
+ }
+ HeadersWithRange.Entries.insert_or_assign("Range", Range);
+
+ Session ResumeSess =
+ AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken());
+ CURL* ResumeH = ResumeSess.Get();
+
+ curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken());
+ curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList);
+ curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L);
+
+ std::vector<std::pair<std::string, std::string>> ResumeHeaders;
+
+ struct ResumeHeaderCbData
+ {
+ std::vector<std::pair<std::string, std::string>>* Headers = nullptr;
+ std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr;
+ std::string* PayloadString = nullptr;
+ };
+
+ ResumeHeaderCbData ResumeHdrData;
+ ResumeHdrData.Headers = &ResumeHeaders;
+ ResumeHdrData.PayloadFile = &PayloadFile;
+ ResumeHdrData.PayloadString = &PayloadString;
+
+ auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t {
+ auto* Data = static_cast<ResumeHeaderCbData*>(UserData);
+ size_t TotalBytes = Size * Nmemb;
+
+ std::string_view Line(Buffer, TotalBytes);
+ while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n'))
+ {
+ Line.remove_suffix(1);
+ }
+
+ if (Line.empty())
+ {
+ return TotalBytes;
+ }
+
+ size_t ColonPos = Line.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ std::string_view Key = Line.substr(0, ColonPos);
+ std::string_view Value = Line.substr(ColonPos + 1);
+ while (!Key.empty() && Key.back() == ' ')
+ {
+ Key.remove_suffix(1);
+ }
+ while (!Value.empty() && Value.front() == ' ')
+ {
+ Value.remove_prefix(1);
+ }
+
+ if (Key == "Content-Range"sv)
+ {
+ if (Value.starts_with("bytes "sv))
+ {
+ size_t RangeStartEnd = Value.find('-', 6);
+ if (RangeStartEnd != std::string_view::npos)
+ {
+ const std::optional<uint64_t> Start =
+ ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6));
+ if (Start)
+ {
+ uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize()
+ : Data->PayloadString->length();
+ if (Start.value() == DownloadedSize)
+ {
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ else if (Start.value() > DownloadedSize)
+ {
+ return 0;
+ }
+ if (*Data->PayloadFile)
+ {
+ (*Data->PayloadFile)->ResetWritePos(Start.value());
+ }
+ else
+ {
+ *Data->PayloadString = Data->PayloadString->substr(0, Start.value());
+ }
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ return TotalBytes;
+ }
+ }
+ }
+ return 0;
+ }
+
+ Data->Headers->emplace_back(std::string(Key), std::string(Value));
+ }
+
+ return TotalBytes;
+ };
+
+ curl_easy_setopt(ResumeH,
+ CURLOPT_HEADERFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb));
+ curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData);
+ curl_easy_setopt(ResumeH,
+ CURLOPT_WRITEFUNCTION,
+ static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb));
+ curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData);
+
+ Res = ResumeSess.Perform();
+ Res.Headers = std::move(ResumeHeaders);
+
+ curl_slist_free_all(ResumeHdrList);
+ } while (ShouldResumeCheck(Res));
+ }
+ }
+ }
+
+ if (!PayloadString.empty())
+ {
+ Res.Body = std::move(PayloadString);
+ }
+
+ curl_slist_free_all(DlHeaders);
+
+ return Res;
+ },
+ PayloadFile);
+
+ return CommonResponse(m_SessionId,
+ std::move(Result),
+ PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{},
+ std::move(BoundaryParser.Boundaries));
+}
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h
new file mode 100644
index 000000000..2a49ff308
--- /dev/null
+++ b/src/zenhttp/clients/httpclientcurl.h
@@ -0,0 +1,135 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "httpclientcommon.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpclient.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <curl/curl.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CurlHttpClient : public HttpClientBase
+{
+public:
+ CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction);
+ ~CurlHttpClient();
+
+ // HttpClientBase
+
+ [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Get(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const KeyValueMap& AdditionalHeader = {},
+ const KeyValueMap& Parameters = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const IoBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Post(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override;
+ [[nodiscard]] virtual Response Upload(std::string_view Url,
+ const CompositeBuffer& Payload,
+ ZenContentType ContentType,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response Download(std::string_view Url,
+ const std::filesystem::path& TempFolderPath,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+ [[nodiscard]] virtual Response TransactPackage(std::string_view Url,
+ CbPackage Package,
+ const KeyValueMap& AdditionalHeader = {}) override;
+
+private:
+ struct CurlResult
+ {
+ long StatusCode = 0;
+ std::string Body;
+ std::vector<std::pair<std::string, std::string>> Headers;
+ double ElapsedSeconds = 0;
+ int64_t UploadedBytes = 0;
+ int64_t DownloadedBytes = 0;
+ CURLcode ErrorCode = CURLE_OK;
+ std::string ErrorMessage;
+ };
+
+ struct Session
+ {
+ Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {}
+ ~Session() { Outer->ReleaseSession(Handle); }
+
+ CURL* Get() const { return Handle; }
+
+ CurlResult Perform();
+
+ LoggerRef Log() { return Outer->Log(); }
+
+ private:
+ CurlHttpClient* Outer;
+ CURL* Handle;
+
+ Session(Session&&) = delete;
+ Session& operator=(Session&&) = delete;
+ };
+
+ Session AllocSession(std::string_view BaseUrl,
+ std::string_view Url,
+ const HttpClientSettings& ConnectionSettings,
+ const KeyValueMap& AdditionalHeader,
+ const KeyValueMap& Parameters,
+ std::string_view SessionId,
+ std::optional<HttpClientAccessToken> AccessToken);
+
+ RwLock m_SessionLock;
+ std::vector<CURL*> m_Sessions;
+
+ void ReleaseSession(CURL* Handle);
+
+ struct RetryResult
+ {
+ CurlResult Result;
+ };
+
+ CurlResult DoWithRetry(std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+ CurlResult DoWithRetry(
+ std::string_view SessionId,
+ std::function<CurlResult()>&& Func,
+ std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; });
+
+ bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile);
+
+ static bool ShouldRetry(const CurlResult& Result);
+
+ bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const;
+
+ HttpClient::Response CommonResponse(std::string_view SessionId,
+ CurlResult&& Result,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {});
+
+ HttpClient::Response ResponseWithPayload(std::string_view SessionId,
+ CurlResult&& Result,
+ const HttpResponseCode WorkResponseCode,
+ IoBuffer&& Payload,
+ std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 9497dadb8..792848a6b 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -10,6 +10,9 @@
ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#include <deque>
@@ -47,11 +50,7 @@ struct HttpWsClient::Impl
m_WorkGuard.reset();
// Close the socket to cancel pending async ops
- if (m_Socket)
- {
- asio::error_code Ec;
- m_Socket->close(Ec);
- }
+ CloseSocket();
if (m_IoThread.joinable())
{
@@ -59,6 +58,35 @@ struct HttpWsClient::Impl
}
}
+ void CloseSocket()
+ {
+ asio::error_code Ec;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ m_UnixSocket->close(Ec);
+ return;
+ }
+#endif
+ if (m_TcpSocket)
+ {
+ m_TcpSocket->close(Ec);
+ }
+ }
+
+ template<typename Fn>
+ void WithSocket(Fn&& Func)
+ {
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixSocket)
+ {
+ Func(*m_UnixSocket);
+ return;
+ }
+#endif
+ Func(*m_TcpSocket);
+ }
+
void ParseUrl(std::string_view Url)
{
// Expected format: ws://host:port/path
@@ -101,9 +129,47 @@ struct HttpWsClient::Impl
m_IoThread = std::thread([this] { m_IoContext.run(); });
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!m_Settings.UnixSocketPath.empty())
+ {
+ asio::post(m_IoContext, [this] { DoConnectUnix(); });
+ return;
+ }
+#endif
+
asio::post(m_IoContext, [this] { DoResolve(); });
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ void DoConnectUnix()
+ {
+ m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath);
+ CloseSocket();
+ }
+ });
+
+ asio::local::stream_protocol::endpoint Endpoint(m_Settings.UnixSocketPath);
+ m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+#endif
+
void DoResolve()
{
m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
@@ -122,7 +188,7 @@ struct HttpWsClient::Impl
void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
{
- m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+ m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
// Start connect timeout timer
m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
@@ -130,15 +196,11 @@ struct HttpWsClient::Impl
if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
{
ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
- if (m_Socket)
- {
- asio::error_code CloseEc;
- m_Socket->close(CloseEc);
- }
+ CloseSocket();
}
});
- asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
if (Ec)
{
m_Timer->cancel();
@@ -194,64 +256,68 @@ struct HttpWsClient::Impl
m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
- asio::async_write(*m_Socket,
- asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
- [this](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- m_Timer->cancel();
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
- m_Handler.OnWsClose(1006, "handshake write failed");
- return;
- }
-
- DoReadHandshakeResponse();
- });
+ WithSocket([this](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ });
}
void DoReadHandshakeResponse()
{
- asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
- m_Timer->cancel();
+ WithSocket([this](auto& Socket) {
+ asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
- if (Ec)
- {
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
- m_Handler.OnWsClose(1006, "handshake read failed");
- return;
- }
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
- // Parse the response
- const auto& Data = m_ReadBuffer.data();
- std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
- // Consume the headers from the read buffer (any extra data stays for frame parsing)
- auto HeaderEnd = Response.find("\r\n\r\n");
- if (HeaderEnd != std::string::npos)
- {
- m_ReadBuffer.consume(HeaderEnd + 4);
- }
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
- // Validate 101 response
- if (Response.find("101") == std::string::npos)
- {
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
- m_Handler.OnWsClose(1006, "handshake rejected");
- return;
- }
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
- // Validate Sec-WebSocket-Accept
- std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
- if (Response.find(ExpectedAccept) == std::string::npos)
- {
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
- m_Handler.OnWsClose(1006, "invalid accept key");
- return;
- }
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
- m_IsOpen.store(true);
- m_Handler.OnWsOpen();
- EnqueueRead();
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
});
}
@@ -267,8 +333,10 @@ struct HttpWsClient::Impl
return;
}
- asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
- OnDataReceived(Ec);
+ WithSocket([this](auto& Socket) {
+ asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
});
}
@@ -414,9 +482,11 @@ struct HttpWsClient::Impl
auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
- asio::async_write(*m_Socket,
- asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
- [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ WithSocket([this, OwnedFrame](auto& Socket) {
+ asio::async_write(Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ });
}
void OnWriteComplete(const asio::error_code& Ec)
@@ -501,11 +571,14 @@ struct HttpWsClient::Impl
// Connection state
std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::unique_ptr<asio::steady_timer> m_Timer;
- asio::streambuf m_ReadBuffer;
- std::string m_WebSocketKey;
- std::shared_ptr<std::string> m_HandshakeBuffer;
+ std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket;
+#endif
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
// Write queue
RwLock m_WriteLock;
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index 281d512cf..9baf4346e 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -40,6 +40,35 @@ extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri,
const HttpClientSettings& ConnectionSettings,
std::function<bool()>&& CheckIfAbortFunction);
+extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri,
+ const HttpClientSettings& ConnectionSettings,
+ std::function<bool()>&& CheckIfAbortFunction);
+
+static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCpr;
+
+void
+SetDefaultHttpClientBackend(HttpClientBackend Backend)
+{
+ g_DefaultHttpClientBackend = Backend;
+}
+
+void
+SetDefaultHttpClientBackend(std::string_view Backend)
+{
+ if (Backend == "cpr")
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kCpr;
+ }
+ else if (Backend == "curl")
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kCurl;
+ }
+ else
+ {
+ g_DefaultHttpClientBackend = HttpClientBackend::kDefault;
+ }
+}
+
using namespace std::literals;
//////////////////////////////////////////////////////////////////////////
@@ -104,6 +133,71 @@ HttpClientBase::GetAccessToken()
//////////////////////////////////////////////////////////////////////////
+HttpClientError::ResponseClass
+HttpClientError::GetResponseClass() const
+{
+ if (m_Error != HttpClientErrorCode::kOK)
+ {
+ switch (m_Error)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ return ResponseClass::kHttpCantConnectError;
+ case HttpClientErrorCode::kHostResolutionFailure:
+ case HttpClientErrorCode::kProxyResolutionFailure:
+ return ResponseClass::kHttpNoHost;
+ case HttpClientErrorCode::kInternalError:
+ case HttpClientErrorCode::kNetworkReceiveError:
+ case HttpClientErrorCode::kNetworkSendFailure:
+ case HttpClientErrorCode::kOperationTimedOut:
+ return ResponseClass::kHttpTimeout;
+ case HttpClientErrorCode::kSSLConnectError:
+ case HttpClientErrorCode::kSSLCertificateError:
+ case HttpClientErrorCode::kSSLCACertError:
+ case HttpClientErrorCode::kGenericSSLError:
+ return ResponseClass::kHttpSLLError;
+ default:
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ else if (IsHttpSuccessCode(m_ResponseCode))
+ {
+ return ResponseClass::kSuccess;
+ }
+ else
+ {
+ switch (m_ResponseCode)
+ {
+ case HttpResponseCode::Unauthorized:
+ return ResponseClass::kHttpUnauthorized;
+ case HttpResponseCode::NotFound:
+ return ResponseClass::kHttpNotFound;
+ case HttpResponseCode::Forbidden:
+ return ResponseClass::kHttpForbidden;
+ case HttpResponseCode::Conflict:
+ return ResponseClass::kHttpConflict;
+ case HttpResponseCode::InternalServerError:
+ return ResponseClass::kHttpInternalServerError;
+ case HttpResponseCode::ServiceUnavailable:
+ return ResponseClass::kHttpServiceUnavailable;
+ case HttpResponseCode::BadGateway:
+ return ResponseClass::kHttpBadGateway;
+ case HttpResponseCode::GatewayTimeout:
+ return ResponseClass::kHttpGatewayTimeout;
+ default:
+ if (m_ResponseCode >= HttpResponseCode::InternalServerError)
+ {
+ return ResponseClass::kHttpOtherServerError;
+ }
+ else
+ {
+ return ResponseClass::kHttpOtherClientError;
+ }
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
std::vector<std::pair<uint64_t, uint64_t>>
HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const
{
@@ -222,7 +316,11 @@ HttpClient::Response::ErrorMessage(std::string_view Prefix) const
{
if (Error.has_value())
{
- return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage);
+ return fmt::format("{}{}HTTP error ({}) '{}'",
+ Prefix,
+ Prefix.empty() ? ""sv : ": "sv,
+ static_cast<int>(Error->ErrorCode),
+ Error->ErrorMessage);
}
else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode)
{
@@ -245,19 +343,34 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix)
{
if (!IsSuccess())
{
- throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode);
+ throw HttpClientError(ErrorMessage(ErrorPrefix),
+ Error.has_value() ? Error.value().ErrorCode : HttpClientErrorCode::kOK,
+ StatusCode);
}
}
//////////////////////////////////////////////////////////////////////////
HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction)
-: m_BaseUri(BaseUri)
+: m_Log(zen::logging::Get(ConnectionSettings.LogCategory))
+, m_BaseUri(BaseUri)
, m_ConnectionSettings(ConnectionSettings)
{
m_SessionId = GetSessionIdString();
- m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ HttpClientBackend EffectiveBackend =
+ ConnectionSettings.Backend != HttpClientBackend::kDefault ? ConnectionSettings.Backend : g_DefaultHttpClientBackend;
+
+ switch (EffectiveBackend)
+ {
+ case HttpClientBackend::kCurl:
+ m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ break;
+ case HttpClientBackend::kCpr:
+ default:
+ m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction));
+ break;
+ }
}
HttpClient::~HttpClient()
diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp
index 52bf149a7..2d949c546 100644
--- a/src/zenhttp/httpclient_test.cpp
+++ b/src/zenhttp/httpclient_test.cpp
@@ -8,6 +8,7 @@
# include <zencore/compactbinarybuilder.h>
# include <zencore/compactbinaryutil.h>
# include <zencore/compositebuffer.h>
+# include <zencore/filesystem.h>
# include <zencore/iobuffer.h>
# include <zencore/logging.h>
# include <zencore/scopeguard.h>
@@ -232,7 +233,7 @@ struct TestServerFixture
TestServerFixture()
{
Server = CreateHttpAsioServer(AsioConfig{});
- Port = Server->Initialize(7600, TmpDir.Path());
+ Port = Server->Initialize(0, TmpDir.Path());
ZEN_ASSERT(Port != -1);
Server->RegisterService(TestService);
ServerThread = std::thread([this]() { Server->Run(false); });
@@ -1044,13 +1045,22 @@ struct FaultTcpServer
{
m_Port = m_Acceptor.local_endpoint().port();
StartAccept();
- m_Thread = std::thread([this]() { m_IoContext.run(); });
+ m_Thread = std::thread([this]() {
+ try
+ {
+ m_IoContext.run();
+ }
+ catch (...)
+ {
+ }
+ });
}
~FaultTcpServer()
{
- std::error_code Ec;
- m_Acceptor.close(Ec);
+ // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this
+ // thread — ASIO I/O objects are not safe for concurrent access and the io_context
+ // thread may be touching the acceptor in StartAccept().
m_IoContext.stop();
if (m_Thread.joinable())
{
@@ -1081,6 +1091,105 @@ struct FaultTcpServer
}
};
+TEST_CASE("httpclient.range-response")
+{
+ ScopedTemporaryDirectory DownloadDir;
+
+ SUBCASE("single range 206 response populates Ranges")
+ {
+ std::string RangeBody(100, 'A');
+
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = fmt::format(
+ "HTTP/1.1 206 Partial Content\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 200-299/1000\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ RangeBody.size(),
+ RangeBody);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE(Resp.Ranges.size() == 1);
+ CHECK_EQ(Resp.Ranges[0].RangeOffset, 200);
+ CHECK_EQ(Resp.Ranges[0].RangeLength, 100);
+ }
+
+ SUBCASE("multipart byteranges 206 response populates Ranges")
+ {
+ std::string Part1Data(16, 'X');
+ std::string Part2Data(12, 'Y');
+ std::string Boundary = "testboundary123";
+
+ std::string MultipartBody = fmt::format(
+ "\r\n--{}\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 100-115/1000\r\n"
+ "\r\n"
+ "{}"
+ "\r\n--{}\r\n"
+ "Content-Type: application/octet-stream\r\n"
+ "Content-Range: bytes 500-511/1000\r\n"
+ "\r\n"
+ "{}"
+ "\r\n--{}--",
+ Boundary,
+ Part1Data,
+ Boundary,
+ Part2Data,
+ Boundary);
+
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = fmt::format(
+ "HTTP/1.1 206 Partial Content\r\n"
+ "Content-Type: multipart/byteranges; boundary={}\r\n"
+ "Content-Length: {}\r\n"
+ "\r\n"
+ "{}",
+ Boundary,
+ MultipartBody.size(),
+ MultipartBody);
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent);
+ REQUIRE(Resp.Ranges.size() == 2);
+ // Ranges should be sorted by RangeOffset
+ CHECK_EQ(Resp.Ranges[0].RangeOffset, 100);
+ CHECK_EQ(Resp.Ranges[0].RangeLength, 16);
+ CHECK_EQ(Resp.Ranges[1].RangeOffset, 500);
+ CHECK_EQ(Resp.Ranges[1].RangeLength, 12);
+ }
+
+ SUBCASE("non-range 200 response has empty Ranges")
+ {
+ FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) {
+ DrainHttpRequest(Socket);
+ std::string Response = MakeRawHttpResponse(200, "full content");
+ std::error_code Ec;
+ asio::write(Socket, asio::buffer(Response), Ec);
+ });
+
+ HttpClient Client = Server.MakeClient();
+ HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path());
+ CHECK(Resp.IsSuccess());
+ CHECK(Resp.Ranges.empty());
+ }
+}
+
TEST_CASE("httpclient.transport-faults" * doctest::skip())
{
SUBCASE("connection reset before response")
@@ -1354,6 +1463,188 @@ TEST_CASE("httpclient.transport-faults-post" * doctest::skip())
}
}
+TEST_CASE("httpclient.unixsocket")
+{
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "zen.sock").string();
+
+ HttpClientTestService TestService;
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto _ = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ HttpClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpClient Client("localhost", Settings, /*CheckIfAbortFunction*/ {});
+
+ SUBCASE("GET over unix socket")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("POST echo over unix socket")
+ {
+ const char* Payload = "unix socket payload";
+ IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload));
+ Body.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "unix socket payload");
+ }
+}
+
+# if ZEN_USE_OPENSSL
+
+TEST_CASE("httpclient.https")
+{
+ // Self-signed test certificate for localhost/127.0.0.1, valid until 2036
+ static constexpr std::string_view TestCertPem =
+ "-----BEGIN CERTIFICATE-----\n"
+ "MIIDJTCCAg2gAwIBAgIUEtJYMSUmJmvJ157We/qXNVJ7W8gwDQYJKoZIhvcNAQEL\n"
+ "BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMwOTIwMjU1M1oXDTM2MDMw\n"
+ "NjIwMjU1M1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\n"
+ "AAOCAQ8AMIIBCgKCAQEAv9YvZ6WeBz3z/Zuxi6OIivWksDxDZZ5oAXKVwlUXaa7v\n"
+ "iDkm9P5ZsEhN+M5vZMe2Yb9i3cnTUaE6Avs1ddOwTAYNGrE/B5DmibrRWc23R0cv\n"
+ "gdnYQJ+gjsAeMvUWYLK58xW4YoMR5bmfpj1ruqobUNkG/oJYnAUcjgo4J149irW+\n"
+ "4n9uLJvxL+5fI/b/AIkv+4TMe70/d/BPmnixWrrzxUT6S5ghE2Mq7+XLScfpY2Sp\n"
+ "GQ/Xbnj9/ELYLpQnNLuVZwWZDpXj+FLbF1zxgjYdw1cCjbRcOIEW2/GJeJvGXQ6Y\n"
+ "Vld5pCBm9uKPPLWoFCoakK5YvP00h+8X+HghGVSscQIDAQABo28wbTAdBgNVHQ4E\n"
+ "FgQUgM6hjymi6g2EBUg2ENu0nIK8yhMwHwYDVR0jBBgwFoAUgM6hjymi6g2EBUg2\n"
+ "ENu0nIK8yhMwDwYDVR0TAQH/BAUwAwEB/zAaBgNVHREEEzARhwR/AAABgglsb2Nh\n"
+ "bGhvc3QwDQYJKoZIhvcNAQELBQADggEBABY1oaaWwL4RaK/epKvk/IrmVT2mlAai\n"
+ "uvGLfjhc6FGvXaxPGTSUPrVbFornaWZAg7bOWCexWnEm2sWd75V/usvZAPN4aIiD\n"
+ "H66YQipq3OD4F9Gowp01IU4AcGh7MerFpYPk76+wp2ANq71x8axtlZjVn3hSFMmN\n"
+ "i6m9S/eyCl9WjYBT5ZEC4fJV0nOSmNe/+gCAm11/js9zNfXKmUchJtuZpubY3A0k\n"
+ "X2II6qYWf1PH+JJkefNZtt2c66CrEN5eAg4/rGEgsp43zcd4ZHVkpBKFLDEls1ev\n"
+ "drQ45zc4Ht77pHfnHu7YsLcRZ9Wq3COMNZYx5lItqnomX2qBm1pkwjI=\n"
+ "-----END CERTIFICATE-----\n";
+
+ static constexpr std::string_view TestKeyPem =
+ "-----BEGIN PRIVATE KEY-----\n"
+ "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/1i9npZ4HPfP9\n"
+ "m7GLo4iK9aSwPENlnmgBcpXCVRdpru+IOSb0/lmwSE34zm9kx7Zhv2LdydNRoToC\n"
+ "+zV107BMBg0asT8HkOaJutFZzbdHRy+B2dhAn6COwB4y9RZgsrnzFbhigxHluZ+m\n"
+ "PWu6qhtQ2Qb+glicBRyOCjgnXj2Ktb7if24sm/Ev7l8j9v8AiS/7hMx7vT938E+a\n"
+ "eLFauvPFRPpLmCETYyrv5ctJx+ljZKkZD9dueP38QtgulCc0u5VnBZkOleP4UtsX\n"
+ "XPGCNh3DVwKNtFw4gRbb8Yl4m8ZdDphWV3mkIGb24o88tagUKhqQrli8/TSH7xf4\n"
+ "eCEZVKxxAgMBAAECggEAILd9pDaZqfCF8SWhdQgx3Ekiii/s6qLGaCDLq7XpZUvB\n"
+ "bEEbBMNwNmFOcvV6B/0LfMYwLVUjZhOSGjoPlwXAVmbdy0SZVEgBGVI0LBWqgUyB\n"
+ "rKqjd/oBXvci71vfMiSpE+0LYjmqTryGnspw2gfy2qn4yGUgiZNRmGPjycsHweUL\n"
+ "V3FHm3cf0dyE4sJ0mjVqZzRT/unw2QOCE6FlY7M1XxZL88IWfn6G4lckdJTwoOP5\n"
+ "VPR2J3XbyhvCeXeDRCHKRXojWWR2HovWnDXQc95GRgCd0vYdHuIUM6RXVPZQvy3X\n"
+ "l0GwQKHNcVr1uwtYDgGKw0tNCUDvxdfQaWilTFuicQKBgQDvEYp+vL1hnF+AVdu3\n"
+ "elsYsHpFgExkTI8wnUMvGZrFiIQyCyVDU3jkG3kcKacI1bfwopXopaQCjrYk9epm\n"
+ "liOVm3/Xtr6e2ENa7w8TQbdK65PciQNOMxml6g8clRRBl0cwj+aI3nW/Kop1cdrR\n"
+ "A9Vo+8iPTO5gDcxTiIb45a6E3QKBgQDNbE009P6ewx9PU7Llkhb9VBgsb7oQN3EV\n"
+ "TCYd4taiN6FPnTuL/cdijAA8y04hiVT+Efo9TUN9NCl9HdHXQcjj7/n/eFLH0Pkw\n"
+ "OIK3QN49OfR88wivLMtwWxIog0tJjc9+7dR4bR4o1jTlIrasEIvUTuDJQ8MKGc9v\n"
+ "pBITua+SpQKBgE4raSKZqj7hd6Sp7kbnHiRLiB9znQbqtaNKuK4M7DuMsNUAKfYC\n"
+ "tDO5+/bGc9SCtTtcnjHM/3zKlyossrFKhGYlyz6IhXnA8v0nz8EXKsy3jMh+kHMg\n"
+ "aFGE394TrOTphyCM3O+B9fRE/7L5QHg5ja1fLqwUlpkXyejCaoe16kONAoGAYIz9\n"
+ "wN1B67cEOVG6rOI8QfdLoV8mEcctNHhlFfjvLrF89SGOwl6WX0A0QF7CK0sUEpK6\n"
+ "jiOJjAh/U5o3bbgyxsedNjEEn3weE0cMUTuA+UALJMtKEqO4PuffIgGL2ld35k28\n"
+ "ZpnK6iC8HdJyD297eV9VkeNygYXeFLgF8xV8ay0CgYEAh4fmVZt9YhgVByYny2kF\n"
+ "ZUIkGF5h9wxzVOPpQwpizIGFFb3i/ZdGQcuLTfIBVRKf50sT3IwJe65ATv6+Lz0f\n"
+ "wg/pMvosi0/F5KGbVRVdzBMQy58WyyGti4tNl+8EXGvo8+DCmjlTYwfjRoZGg/qJ\n"
+ "EMP3/hTN7dHDRxPK8E0Fh0Y=\n"
+ "-----END PRIVATE KEY-----\n";
+
+ ScopedTemporaryDirectory TmpDir;
+
+ // Write cert and key to temp files
+ const auto CertPath = TmpDir.Path() / "test.crt";
+ const auto KeyPath = TmpDir.Path() / "test.key";
+ WriteFile(CertPath, IoBuffer(IoBuffer::Clone, TestCertPem.data(), TestCertPem.size()));
+ WriteFile(KeyPath, IoBuffer(IoBuffer::Clone, TestKeyPem.data(), TestKeyPem.size()));
+
+ HttpClientTestService TestService;
+
+ AsioConfig Config;
+ Config.CertFile = CertPath.string();
+ Config.KeyFile = KeyPath.string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(Config);
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto _ = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ int HttpsPort = Server->GetEffectiveHttpsPort();
+ REQUIRE(HttpsPort > 0);
+
+ HttpClientSettings Settings;
+ Settings.InsecureSsl = true;
+
+ HttpClient Client(fmt::format("https://127.0.0.1:{}", HttpsPort), Settings, /*CheckIfAbortFunction*/ {});
+
+ SUBCASE("GET over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/hello");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "hello world");
+ }
+
+ SUBCASE("POST echo over HTTPS")
+ {
+ const char* Payload = "https payload";
+ IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload));
+ Body.SetContentType(ZenContentType::kText);
+
+ HttpClient::Response Resp = Client.Post("/api/test/echo", Body);
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.AsText(), "https payload");
+ }
+
+ SUBCASE("GET JSON over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/json");
+ CHECK(Resp.IsSuccess());
+ CbObject Obj = Resp.AsObject();
+ CHECK_EQ(Obj["ok"].AsBool(), true);
+ CHECK_EQ(Obj["message"].AsString(), "test");
+ }
+
+ SUBCASE("Large payload over HTTPS")
+ {
+ HttpClient::Response Resp = Client.Get("/api/test/large");
+ CHECK(Resp.IsSuccess());
+ CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u);
+ }
+}
+
+# endif // ZEN_USE_OPENSSL
+
TEST_SUITE_END();
void
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index 9bae95690..1a0018908 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -1044,13 +1044,16 @@ HttpServer::OnGetExternalHost() const
std::string
HttpServer::GetServiceUri(const HttpService* Service) const
{
+ const char* Scheme = (m_EffectiveHttpsPort > 0) ? "https" : "http";
+ int Port = (m_EffectiveHttpsPort > 0) ? m_EffectiveHttpsPort : m_EffectivePort;
+
if (Service)
{
- return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri());
+ return fmt::format("{}://{}:{}{}", Scheme, m_ExternalHost, Port, Service->BaseUri());
}
else
{
- return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort);
+ return fmt::format("{}://{}:{}", Scheme, m_ExternalHost, Port);
}
}
@@ -1152,9 +1155,13 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig
if (ServerClass == "asio"sv)
{
ZEN_INFO("using asio HTTP server implementation")
- return CreateHttpAsioServer(AsioConfig{.ThreadCount = Config.ThreadCount,
- .ForceLoopback = Config.ForceLoopback,
- .IsDedicatedServer = Config.IsDedicatedServer});
+ return CreateHttpAsioServer(AsioConfig {
+ .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer,
+ .UnixSocketPath = Config.UnixSocketPath,
+#if ZEN_USE_OPENSSL
+ .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile,
+#endif
+ });
}
#if ZEN_WITH_HTTPSYS
else if (ServerClass == "httpsys"sv)
@@ -1165,7 +1172,11 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig
.IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled,
.IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled,
.IsDedicatedServer = Config.IsDedicatedServer,
- .ForceLoopback = Config.ForceLoopback}));
+ .ForceLoopback = Config.ForceLoopback,
+ .HttpsPort = Config.HttpSys.HttpsPort,
+ .CertThumbprint = Config.HttpSys.CertThumbprint,
+ .CertStoreName = Config.HttpSys.CertStoreName,
+ .HttpsOnly = Config.HttpSys.HttpsOnly}));
}
#endif
else if (ServerClass == "null"sv)
diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h
index 57ab01158..90180391c 100644
--- a/src/zenhttp/include/zenhttp/formatters.h
+++ b/src/zenhttp/include/zenhttp/formatters.h
@@ -84,7 +84,7 @@ struct fmt::formatter<zen::HttpClient::Response>
return fmt::format_to(Ctx.out(),
"Failed: Elapsed: {}, Reason: ({}) '{}",
NiceResponseTime,
- Response.Error.value().ErrorCode,
+ static_cast<int>(Response.Error.value().ErrorCode),
Response.Error.value().ErrorMessage);
}
else
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
index 1bb36a298..2e21e3bd6 100644
--- a/src/zenhttp/include/zenhttp/httpclient.h
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -30,6 +30,34 @@ class CompositeBuffer;
*/
+enum class HttpClientErrorCode : int
+{
+ kOK = 0,
+ kConnectionFailure,
+ kHostResolutionFailure,
+ kProxyResolutionFailure,
+ kInternalError,
+ kNetworkReceiveError,
+ kNetworkSendFailure,
+ kOperationTimedOut,
+ kSSLConnectError,
+ kSSLCertificateError,
+ kSSLCACertError,
+ kGenericSSLError,
+ kRequestCancelled,
+ kOtherError,
+};
+
+enum class HttpClientBackend : uint8_t
+{
+ kDefault,
+ kCpr,
+ kCurl,
+};
+
+void SetDefaultHttpClientBackend(std::string_view Backend);
+void SetDefaultHttpClientBackend(HttpClientBackend Backend);
+
struct HttpClientAccessToken
{
using Clock = std::chrono::system_clock;
@@ -59,6 +87,22 @@ struct HttpClientSettings
Oid SessionId = Oid::Zero;
bool Verbose = false;
uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u;
+ HttpClientBackend Backend = HttpClientBackend::kDefault;
+
+ /// Unix domain socket path. When non-empty, the client connects via this
+ /// socket instead of TCP. BaseUri is still used for the Host header and URL.
+ std::string UnixSocketPath;
+
+ /// Disable HTTP keep-alive by closing the connection after each request.
+ /// Useful for testing per-connection overhead.
+ bool ForbidReuseConnection = false;
+
+ /// Skip TLS certificate verification (for testing with self-signed certs).
+ bool InsecureSsl = false;
+
+ /// CA certificate bundle path for TLS verification. When non-empty, overrides
+ /// the system default CA store.
+ std::string CaBundlePath;
/// HTTP status codes that are expected and should not be logged as warnings.
/// 404 is always treated as expected regardless of this list.
@@ -70,22 +114,22 @@ class HttpClientError : public std::runtime_error
public:
using _Mybase = runtime_error;
- HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode)
+ HttpClientError(const std::string& Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode)
: _Mybase(Message)
, m_Error(Error)
, m_ResponseCode(ResponseCode)
{
}
- HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode)
+ HttpClientError(const char* Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode)
: _Mybase(Message)
, m_Error(Error)
, m_ResponseCode(ResponseCode)
{
}
- inline int GetInternalErrorCode() const { return m_Error; }
- inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
+ inline HttpClientErrorCode GetInternalErrorCode() const { return m_Error; }
+ inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; }
enum class ResponseClass : std::int8_t
{
@@ -112,8 +156,8 @@ public:
ResponseClass GetResponseClass() const;
private:
- const int m_Error = 0;
- const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
+ const HttpClientErrorCode m_Error = HttpClientErrorCode::kOK;
+ const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot;
};
class HttpClientBase;
@@ -137,11 +181,23 @@ public:
struct ErrorContext
{
- int ErrorCode = 0;
- std::string ErrorMessage;
+ HttpClientErrorCode ErrorCode;
+ std::string ErrorMessage;
/** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */
- bool IsConnectionError() const;
+ bool IsConnectionError() const
+ {
+ switch (ErrorCode)
+ {
+ case HttpClientErrorCode::kConnectionFailure:
+ case HttpClientErrorCode::kOperationTimedOut:
+ case HttpClientErrorCode::kHostResolutionFailure:
+ case HttpClientErrorCode::kProxyResolutionFailure:
+ return true;
+ default:
+ return false;
+ }
+ }
};
struct KeyValueMap
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index 0e1714669..d98877d16 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -255,6 +255,9 @@ public:
*/
std::string_view GetExternalHost() const { return m_ExternalHost; }
+ /** Returns the effective HTTPS port, or 0 if HTTPS is not enabled. Only valid after Initialize(). */
+ int GetEffectiveHttpsPort() const { return m_EffectiveHttpsPort; }
+
/** Returns total bytes received and sent across all connections since server start. */
virtual uint64_t GetTotalBytesReceived() const { return 0; }
virtual uint64_t GetTotalBytesSent() const { return 0; }
@@ -290,7 +293,8 @@ public:
private:
std::vector<HttpService*> m_KnownServices;
- int m_EffectivePort = 0;
+ int m_EffectivePort = 0;
+ int m_EffectiveHttpsPort = 0;
std::string m_ExternalHost;
metrics::Meter m_RequestMeter;
std::string m_DefaultRedirect;
@@ -308,6 +312,7 @@ private:
virtual void OnClose() = 0;
protected:
+ void SetEffectiveHttpsPort(int Port) { m_EffectiveHttpsPort = Port; }
virtual std::string OnGetExternalHost() const;
};
@@ -324,12 +329,20 @@ struct HttpServerConfig
std::vector<HttpServerPluginConfig> PluginConfigs;
bool ForceLoopback = false;
unsigned int ThreadCount = 0;
+ std::string UnixSocketPath; // Unix domain socket path (empty = disabled, non-Windows only)
+ int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend)
+ std::string CertFile; // PEM certificate chain file path
+ std::string KeyFile; // PEM private key file path
struct
{
unsigned int AsyncWorkThreadCount = 0;
bool IsAsyncResponseEnabled = true;
bool IsRequestLoggingEnabled = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
} HttpSys;
};
diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h
index 926ec1e3d..34d338b1d 100644
--- a/src/zenhttp/include/zenhttp/httpwsclient.h
+++ b/src/zenhttp/include/zenhttp/httpwsclient.h
@@ -43,6 +43,10 @@ struct HttpWsClientSettings
std::string LogCategory = "wsclient";
std::chrono::milliseconds ConnectTimeout{5000};
std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider;
+
+ /// Unix domain socket path. When non-empty, connects via this socket
+ /// instead of TCP. The URL host is still used for the Host header.
+ std::string UnixSocketPath;
};
/**
diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h
new file mode 100644
index 000000000..25aeaa24e
--- /dev/null
+++ b/src/zenhttp/servers/asio_socket_traits.h
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::asio_http {
+
+/**
+ * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets.
+ * SSL sockets need lowest_layer() access and have different shutdown semantics.
+ */
+template<typename SocketType>
+struct SocketTraits
+{
+ /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because
+ /// those bypass the encryption layer. This flag lets templated code fall back
+ /// to reading-into-memory for SSL connections.
+ static constexpr bool IsSslSocket = false;
+
+ static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); }
+
+ static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); }
+};
+
+#if ZEN_USE_OPENSSL
+using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>;
+
+template<>
+struct SocketTraits<SslSocket>
+{
+ static constexpr bool IsSslSocket = true;
+
+ static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); }
+
+ static void ShutdownBoth(SslSocket& S, std::error_code& Ec)
+ {
+ // Best-effort SSL close_notify, then TCP shutdown
+ S.shutdown(Ec);
+ S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec);
+ }
+
+ static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); }
+};
+#endif
+
+} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index f5178ebe8..ee8e71256 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "httpasio.h"
+#include "asio_socket_traits.h"
#include "httptracer.h"
#include <zencore/except.h>
@@ -35,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START
#endif
#include <asio.hpp>
#include <asio/stream_file.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#define ASIO_VERBOSE_TRACE 0
@@ -144,7 +151,17 @@ using namespace std::literals;
struct HttpAcceptor;
struct HttpResponse;
-struct HttpServerConnection;
+template<typename SocketType>
+struct HttpServerConnectionT;
+using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+struct UnixAcceptor;
+using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>;
+#endif
+#if ZEN_USE_OPENSSL
+struct HttpsAcceptor;
+using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>;
+#endif
inline LoggerRef
InitLogger()
@@ -176,9 +193,9 @@ Log()
#endif
#if ZEN_USE_TRANSMITFILE
-template<typename Handler>
+template<typename Handler, typename SocketType>
void
-TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
+TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb)
{
# if ZEN_BUILD_DEBUG
const uint64_t FileSize = FileSizeFromHandle(FileHandle);
@@ -511,11 +528,20 @@ public:
bool IsLoopbackOnly() const;
+ int GetEffectiveHttpsPort() const;
+
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
- std::vector<std::thread> m_ThreadPool;
- std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor;
+#endif
+#if ZEN_USE_OPENSSL
+ std::unique_ptr<asio::ssl::context> m_SslContext;
+ std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor;
+#endif
+ std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -573,6 +599,7 @@ public:
uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers
IoBuffer m_PayloadBuffer;
bool m_IsLocalMachineRequest;
+ bool m_AllowZeroCopyFileSend = true;
std::string m_RemoteAddress;
std::unique_ptr<HttpResponse> m_Response;
};
@@ -595,6 +622,8 @@ public:
~HttpResponse() = default;
+ void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; }
+
/**
* Initialize the response for sending a payload made up of multiple blobs
*
@@ -636,7 +665,7 @@ public:
bool ChunkHandled = false;
#if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE
- if (OwnedBuffer.IsWholeFile())
+ if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile())
{
if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef))
{
@@ -751,7 +780,8 @@ public:
return m_Headers;
}
- void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
+ template<typename SocketType>
+ void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token)
{
ZEN_ASSERT(m_State == State::kInitialized);
@@ -761,10 +791,11 @@ public:
m_SendCb = std::move(Token);
m_State = State::kSending;
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
- void SendNextChunk(asio::ip::tcp::socket& TcpSocket)
+ template<typename SocketType>
+ void SendNextChunk(SocketType& Socket)
{
ZEN_ASSERT(m_State == State::kSending);
@@ -781,12 +812,12 @@ public:
auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); };
- asio::defer(TcpSocket.get_executor(), std::move(CompletionToken));
+ asio::defer(Socket.get_executor(), std::move(CompletionToken));
return;
}
- auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) {
+ auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) {
ZEN_ASSERT(m_State == State::kSending);
m_TotalBytesSent += ByteCount;
@@ -797,7 +828,7 @@ public:
}
else
{
- SendNextChunk(TcpSocket);
+ SendNextChunk(Socket);
}
};
@@ -811,25 +842,21 @@ public:
Io.Ref.FileRef.FileChunkSize);
#if ZEN_USE_TRANSMITFILE
- TransmitFileAsync(TcpSocket,
+ TransmitFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize),
OnCompletion);
+ return;
#elif ZEN_USE_ASYNC_SENDFILE
- SendFileAsync(TcpSocket,
+ SendFileAsync(Socket,
Io.Ref.FileRef.FileHandle,
Io.Ref.FileRef.FileChunkOffset,
Io.Ref.FileRef.FileChunkSize,
64 * 1024,
OnCompletion);
-#else
- // This should never occur unless we compile with one
- // of the options above
- ZEN_WARN("invalid file reference in response");
-#endif
-
return;
+#endif
}
// Send as many consecutive non-file references as possible in one asio operation
@@ -850,7 +877,7 @@ public:
++m_IoVecCursor;
}
- asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
+ asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion);
}
private:
@@ -863,12 +890,13 @@ private:
kFailed
};
- uint32_t m_RequestNumber = 0;
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- State m_State = State::kUninitialized;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
+ uint32_t m_RequestNumber = 0;
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ bool m_AllowZeroCopyFileSend = true;
+ State m_State = State::kUninitialized;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive
ExtendableStringBuilder<160> m_Headers;
@@ -895,12 +923,13 @@ private:
//////////////////////////////////////////////////////////////////////////
-struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection>
+template<typename SocketType>
+struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>>
{
- HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
- ~HttpServerConnection();
+ HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket);
+ ~HttpServerConnectionT();
- std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+ std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); }
// HttpConnectionBase implementation
@@ -962,12 +991,13 @@ private:
RwLock m_ActiveResponsesLock;
std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<SocketType> m_Socket;
};
std::atomic<uint32_t> g_ConnectionIdCounter{0};
-HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket)
: m_Server(Server)
, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
, m_Socket(std::move(Socket))
@@ -975,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq
ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
}
-HttpServerConnection::~HttpServerConnection()
+template<typename SocketType>
+HttpServerConnectionT<SocketType>::~HttpServerConnectionT()
{
RwLock::ExclusiveLockScope _(m_ActiveResponsesLock);
ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
}
+template<typename SocketType>
void
-HttpServerConnection::HandleNewRequest()
+HttpServerConnectionT<SocketType>::HandleNewRequest()
{
EnqueueRead();
}
+template<typename SocketType>
void
-HttpServerConnection::TerminateConnection()
+HttpServerConnectionT<SocketType>::TerminateConnection()
{
if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated)
{
@@ -1001,12 +1034,13 @@ HttpServerConnection::TerminateConnection()
// Terminating, we don't care about any errors when closing socket
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_both, Ec);
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
}
+template<typename SocketType>
void
-HttpServerConnection::EnqueueRead()
+HttpServerConnectionT<SocketType>::EnqueueRead()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1027,8 +1061,9 @@ HttpServerConnection::EnqueueRead()
[Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
}
+template<typename SocketType>
void
-HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1086,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]
}
}
+template<typename SocketType>
void
-HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
- [[maybe_unused]] std::size_t ByteCount,
- [[maybe_unused]] uint32_t RequestNumber,
- HttpResponse* ResponseToPop)
+HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec,
+ [[maybe_unused]] std::size_t ByteCount,
+ [[maybe_unused]] uint32_t RequestNumber,
+ HttpResponse* ResponseToPop)
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1144,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec,
}
}
+template<typename SocketType>
void
-HttpServerConnection::CloseConnection()
+HttpServerConnectionT<SocketType>::CloseConnection()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1157,23 +1194,24 @@ HttpServerConnection::CloseConnection()
m_RequestState = RequestState::kDone;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
}
- m_Socket->close(Ec);
+ SocketTraits<SocketType>::Close(*m_Socket, Ec);
if (Ec)
{
ZEN_WARN("socket close ERROR, reason '{}'", Ec.message());
}
}
+template<typename SocketType>
void
-HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
- std::string_view StatusLine,
- std::string_view Headers,
- std::string_view Body)
+HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber,
+ std::string_view StatusLine,
+ std::string_view Headers,
+ std::string_view Body)
{
ExtendableStringBuilder<256> ResponseBuilder;
ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n";
@@ -1194,15 +1232,16 @@ HttpServerConnection::SendInlineResponse(uint32_t RequestNumber,
IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size());
auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize());
asio::async_write(
- *m_Socket.get(),
+ *m_Socket,
Buffer,
[Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) {
Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr);
});
}
+template<typename SocketType>
void
-HttpServerConnection::HandleRequest()
+HttpServerConnectionT<SocketType>::HandleRequest()
{
ZEN_MEMSCOPE(GetHttpasioTag());
@@ -1229,24 +1268,25 @@ HttpServerConnection::HandleRequest()
ResponseStr->append("\r\n\r\n");
// Send the 101 response on the current socket, then hand the socket off
- // to a WsAsioConnection for the WebSocket protocol.
- asio::async_write(*m_Socket,
- asio::buffer(ResponseStr->data(), ResponseStr->size()),
- [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
- return;
- }
-
- Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
- Ref<WsAsioConnection> WsConn(
- new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
- Ref<WebSocketConnection> WsConnRef(WsConn.Get());
-
- WsHandler->OnWebSocketOpen(std::move(WsConnRef));
- WsConn->Start();
- });
+ // to a WsAsioConnectionT for the WebSocket protocol.
+ asio::async_write(
+ *m_Socket,
+ asio::buffer(ResponseStr->data(), ResponseStr->size()),
+ [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_WARN("WebSocket 101 send failed: {}", Ec.message());
+ return;
+ }
+
+ Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened();
+ using WsConnType = WsAsioConnectionT<SocketType>;
+ Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer));
+ Ref<WebSocketConnection> WsConnRef(WsConn.Get());
+
+ WsHandler->OnWebSocketOpen(std::move(WsConnRef));
+ WsConn->Start();
+ });
m_RequestState = RequestState::kDone;
return;
@@ -1260,7 +1300,7 @@ HttpServerConnection::HandleRequest()
m_RequestState = RequestState::kWritingFinal;
std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+ SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec);
if (Ec)
{
@@ -1280,15 +1320,36 @@ HttpServerConnection::HandleRequest()
m_Server.m_HttpServer->MarkRequest();
- auto RemoteEndpoint = m_Socket->remote_endpoint();
- bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ bool IsLocalConnection = true;
+ std::string RemoteAddress;
+
+ if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>)
+ {
+ auto RemoteEndpoint = m_Socket->remote_endpoint();
+ IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#if ZEN_USE_OPENSSL
+ else if constexpr (std::is_same_v<SocketType, SslSocket>)
+ {
+ auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint();
+ IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address();
+ RemoteAddress = RemoteEndpoint.address().to_string();
+ }
+#endif
+ else
+ {
+ RemoteAddress = "unix";
+ }
HttpAsioServerRequest Request(m_RequestData,
*Service,
m_RequestData.Body(),
RequestNumber,
IsLocalConnection,
- RemoteEndpoint.address().to_string());
+ std::move(RemoteAddress));
+
+ Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket;
ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber);
@@ -1439,14 +1500,23 @@ HttpServerConnection::HandleRequest()
}
//////////////////////////////////////////////////////////////////////////
+// Base class for TCP acceptors that handles socket setup, port binding
+// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support.
+// Subclasses only need to implement OnAccept() to handle new connections.
-struct HttpAcceptor
+struct TcpAcceptorBase
{
- HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ TcpAcceptorBase(HttpAsioServerImpl& Server,
+ asio::io_service& IoService,
+ uint16_t BasePort,
+ bool ForceLoopback,
+ bool AllowPortProbing,
+ std::string_view Label)
: m_Server(Server)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
, m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4())
+ , m_Label(Label)
{
const bool IsUsingIPv6 = IsIPv6Capable();
if (!IsUsingIPv6)
@@ -1455,7 +1525,6 @@ struct HttpAcceptor
}
#if ZEN_PLATFORM_WINDOWS
- // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address;
m_Acceptor.set_option(exclusive_address(true));
m_AlternateProtocolAcceptor.set_option(exclusive_address(true));
@@ -1468,83 +1537,54 @@ struct HttpAcceptor
#endif // ZEN_PLATFORM_WINDOWS
m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- std::string BoundBaseUrl;
if (IsUsingIPv6)
{
- BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
+ BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing);
}
else
{
- ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only");
-
- BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
+ ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label);
+ BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing);
}
+ }
- if (!IsValid())
- {
- return;
- }
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor.native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-
- if (m_UseAlternateProtocolAcceptor)
- {
- NativeSocket = m_AlternateProtocolAcceptor.native_handle();
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
- }
-#endif
- m_Acceptor.listen();
+ virtual ~TcpAcceptorBase()
+ {
+ m_Acceptor.close();
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.listen();
+ m_AlternateProtocolAcceptor.close();
}
-
- ZEN_INFO("Started asio server at '{}", BoundBaseUrl);
}
- ~HttpAcceptor()
+ void Start()
{
- m_Acceptor.close();
+ ZEN_ASSERT(!m_IsStopped);
+ InitAcceptLoop(m_Acceptor);
if (m_UseAlternateProtocolAcceptor)
{
- m_AlternateProtocolAcceptor.close();
+ InitAcceptLoop(m_AlternateProtocolAcceptor);
}
}
+ void StopAccepting() { m_IsStopped = true; }
+
+ uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); }
+ bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
+ bool IsValid() const { return m_IsValid; }
+
+protected:
+ /// Called for each accepted TCP socket. Subclasses create the appropriate connection type.
+ virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0;
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+
+private:
template<typename AddressType>
- std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
{
uint16_t EffectivePort = BasePort;
@@ -1571,7 +1611,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback())
{
- ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort);
+ ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort);
BindAddress = AddressType::loopback();
@@ -1585,7 +1625,7 @@ struct HttpAcceptor
if (BindErrorCode == asio::error::address_in_use)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1601,7 +1641,8 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')",
+ m_Label,
BindErrorCode.message());
EffectivePort = 0;
@@ -1617,7 +1658,7 @@ struct HttpAcceptor
{
for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++)
{
- ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message());
+ ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message());
Sleep(500);
m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode);
}
@@ -1625,14 +1666,13 @@ struct HttpAcceptor
if (BindErrorCode)
{
- ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message());
-
- return {};
+ ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message());
+ return;
}
if (EffectivePort != BasePort)
{
- ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort);
+ ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort);
}
if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
@@ -1642,55 +1682,64 @@ struct HttpAcceptor
// IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does
// IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1)
- m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode);
+ asio::error_code AltEc;
+ m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc);
- if (BindErrorCode)
+ if (AltEc)
{
- ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort);
+ ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort);
}
else
{
m_UseAlternateProtocolAcceptor = true;
- ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts",
- "localhost",
- EffectivePort);
}
}
}
- m_IsValid = true;
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
- if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>)
- {
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort);
- }
- else
+ if (m_UseAlternateProtocolAcceptor)
{
- return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort);
+ NativeSocket = m_AlternateProtocolAcceptor.native_handle();
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
}
- }
-
- void Start()
- {
- ZEN_MEMSCOPE(GetHttpasioTag());
+#endif
- ZEN_ASSERT(!m_IsStopped);
- InitAcceptInternal(m_Acceptor);
+ m_Acceptor.listen();
if (m_UseAlternateProtocolAcceptor)
{
- InitAcceptInternal(m_AlternateProtocolAcceptor);
+ m_AlternateProtocolAcceptor.listen();
}
- }
- void StopAccepting() { m_IsStopped = true; }
-
- int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); }
- bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); }
-
- bool IsValid() const { return m_IsValid; }
+ m_IsValid = true;
+ ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port());
+ }
-private:
- void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor)
+ void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor)
{
auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
@@ -1698,29 +1747,19 @@ private:
Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
if (Ec)
{
- ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
- Acceptor.local_endpoint().address().to_string(),
- Acceptor.local_endpoint().port(),
- Ec.message());
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message());
+ }
}
else
{
- // New connection established, pass socket ownership into connection object
- // and initiate request handling loop. The connection lifetime is
- // managed by the async read/write loop by passing the shared
- // reference to the callbacks.
-
- Socket->set_option(asio::ip::tcp::no_delay(true));
- Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
- Conn->HandleNewRequest();
+ OnAccept(std::move(Socket));
}
if (!m_IsStopped.load())
{
- InitAcceptInternal(Acceptor);
+ InitAcceptLoop(Acceptor);
}
else
{
@@ -1728,21 +1767,204 @@ private:
Acceptor.close(CloseEc);
if (CloseEc)
{
- ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message());
+ ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message());
}
}
});
}
- HttpAsioServerImpl& m_Server;
- asio::io_service& m_IoService;
asio::ip::tcp::acceptor m_Acceptor;
asio::ip::tcp::acceptor m_AlternateProtocolAcceptor;
bool m_UseAlternateProtocolAcceptor{false};
bool m_IsValid{false};
std::atomic<bool> m_IsStopped{false};
+ std::string_view m_Label;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor final : TcpAcceptorBase
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP")
+ {
+ }
+
+ int GetAcceptPort() const { return GetPort(); }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
};
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+
+//////////////////////////////////////////////////////////////////////////
+
+struct UnixAcceptor
+{
+ UnixAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, const std::string& SocketPath)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService)
+ , m_SocketPath(SocketPath)
+ {
+ // Remove any stale socket file from a previous run
+ std::filesystem::remove(m_SocketPath);
+
+ asio::local::stream_protocol::endpoint Endpoint(m_SocketPath);
+
+ asio::error_code Ec;
+ m_Acceptor.open(Endpoint.protocol(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to open unix domain socket: {}", Ec.message());
+ return;
+ }
+
+ m_Acceptor.bind(Endpoint, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message());
+ return;
+ }
+
+ m_IsValid = true;
+ ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath);
+ }
+
+ ~UnixAcceptor()
+ {
+ asio::error_code Ec;
+ m_Acceptor.close(Ec);
+ std::filesystem::remove(m_SocketPath);
+ }
+
+ void Start()
+ {
+ ZEN_ASSERT(!m_IsStopped);
+ InitAccept();
+ }
+
+ void StopAccepting() { m_IsStopped = true; }
+
+ bool IsValid() const { return m_IsValid; }
+
+private:
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService);
+ asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ if (!m_IsStopped.load())
+ {
+ ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ std::error_code CloseEc;
+ m_Acceptor.close(CloseEc);
+ }
+ });
+ }
+
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+ asio::local::stream_protocol::acceptor m_Acceptor;
+ std::string m_SocketPath;
+ bool m_IsValid{false};
+ std::atomic<bool> m_IsStopped{false};
+};
+
+#endif // ASIO_HAS_LOCAL_SOCKETS
+
+#if ZEN_USE_OPENSSL
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpsAcceptor final : TcpAcceptorBase
+{
+ HttpsAcceptor(HttpAsioServerImpl& Server,
+ asio::io_service& IoService,
+ asio::ssl::context& SslContext,
+ uint16_t Port,
+ bool ForceLoopback,
+ bool AllowPortProbing)
+ : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS")
+ , m_SslContext(SslContext)
+ {
+ }
+
+protected:
+ void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override
+ {
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ // Wrap accepted TCP socket in an SSL stream and perform the handshake
+ auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext);
+
+ SslSocket& SslRef = *SslSocketPtr;
+ SslRef.async_handshake(asio::ssl::stream_base::server,
+ [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable {
+ if (HandshakeEc)
+ {
+ ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message());
+ std::error_code Ec;
+ SslSocket->lowest_layer().close(Ec);
+ return;
+ }
+
+ auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket));
+ Conn->HandleNewRequest();
+ });
+ }
+
+private:
+ asio::ssl::context& m_SslContext;
+};
+
+#endif // ZEN_USE_OPENSSL
+
+int
+HttpAsioServerImpl::GetEffectiveHttpsPort() const
+{
+#if ZEN_USE_OPENSSL
+ return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0;
+#else
+ return 0;
+#endif
+}
+
//////////////////////////////////////////////////////////////////////////
HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request,
@@ -1860,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
std::array<IoBuffer, 0> Empty;
m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
@@ -1873,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
}
@@ -1883,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT
ZEN_ASSERT(!m_Response);
m_Response.reset(new HttpResponse(ContentType, m_RequestNumber));
+ m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend);
IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
@@ -1942,6 +2167,51 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config)
m_Acceptor->Start();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (!Config.UnixSocketPath.empty())
+ {
+ m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath));
+
+ if (m_UnixAcceptor->IsValid())
+ {
+ m_UnixAcceptor->Start();
+ }
+ else
+ {
+ m_UnixAcceptor.reset();
+ }
+ }
+#endif
+
+#if ZEN_USE_OPENSSL
+ if (!Config.CertFile.empty() && !Config.KeyFile.empty())
+ {
+ m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server);
+ m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+ asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1);
+ m_SslContext->use_certificate_chain_file(Config.CertFile);
+ m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem);
+
+ ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile);
+
+ m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this,
+ m_IoService,
+ *m_SslContext,
+ gsl::narrow<uint16_t>(Config.HttpsPort),
+ Config.ForceLoopback,
+ /*AllowPortProbing*/ !Config.IsDedicatedServer));
+
+ if (m_HttpsAcceptor->IsValid())
+ {
+ m_HttpsAcceptor->Start();
+ }
+ else
+ {
+ m_HttpsAcceptor.reset();
+ }
+ }
+#endif
+
// This should consist of a set of minimum threads and grow on demand to
// meet concurrency needs? Right now we end up allocating a large number
// of threads even if we never end up using all of them, which seems
@@ -1990,6 +2260,18 @@ HttpAsioServerImpl::Stop()
{
m_Acceptor->StopAccepting();
}
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ if (m_UnixAcceptor)
+ {
+ m_UnixAcceptor->StopAccepting();
+ }
+#endif
+#if ZEN_USE_OPENSSL
+ if (m_HttpsAcceptor)
+ {
+ m_HttpsAcceptor->StopAccepting();
+ }
+#endif
m_IoService.stop();
for (auto& Thread : m_ThreadPool)
{
@@ -1999,7 +2281,23 @@ HttpAsioServerImpl::Stop()
}
}
m_ThreadPool.clear();
+
+ // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket
+ // connections) so that their captured Ref<> pointers are released while the
+ // io_service and its epoll reactor are still alive. Without this, sockets
+ // held by external code (e.g. IWebSocketHandler connection lists) can outlive
+ // the reactor and crash during deregistration.
+ m_IoService.restart();
+ m_IoService.poll();
+
m_Acceptor.reset();
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+ m_UnixAcceptor.reset();
+#endif
+#if ZEN_USE_OPENSSL
+ m_HttpsAcceptor.reset();
+ m_SslContext.reset();
+#endif
}
void
@@ -2166,6 +2464,13 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config);
+#if ZEN_USE_OPENSSL
+ if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(EffectiveHttpsPort);
+ }
+#endif
+
return m_BasePort;
}
diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h
index 3ec1141a7..5adf4d5e8 100644
--- a/src/zenhttp/servers/httpasio.h
+++ b/src/zenhttp/servers/httpasio.h
@@ -11,6 +11,12 @@ struct AsioConfig
unsigned int ThreadCount = 0;
bool ForceLoopback = false;
bool IsDedicatedServer = false;
+ std::string UnixSocketPath;
+#if ZEN_USE_OPENSSL
+ int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS
+ std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled)
+ std::string KeyFile; // PEM private key file
+#endif
};
Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config);
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index dfe6bb6aa..83b98013e 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -116,6 +116,12 @@ public:
private:
int InitializeServer(int BasePort);
+ bool CreateSessionAndUrlGroup();
+ bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris);
+ int RegisterHttpUrls(int BasePort);
+ bool RegisterHttpsUrls();
+ bool CreateRequestQueue(int EffectivePort);
+ bool SetupIoCompletionPort();
void Cleanup();
void StartServer();
@@ -125,6 +131,9 @@ private:
void RegisterService(const char* Endpoint, HttpService& Service);
void UnregisterService(const char* Endpoint, HttpService& Service);
+ bool BindSslCertificate(int Port);
+ void UnbindSslCertificate();
+
private:
LoggerRef m_Log;
LoggerRef m_RequestLog;
@@ -140,7 +149,10 @@ private:
RwLock m_AsyncWorkPoolInitLock;
std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr;
- std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/
+ bool m_DidAutoBindCert = false;
+ int m_HttpsPort = 0;
HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
HANDLE m_RequestQueueHandle = 0;
@@ -1082,39 +1094,63 @@ HttpSysServer::OnClose()
}
}
-int
-HttpSysServer::InitializeServer(int BasePort)
+bool
+HttpSysServer::CreateSessionAndUrlGroup()
{
- ZEN_MEMSCOPE(GetHttpsysTag());
-
- using namespace std::literals;
-
- WideStringBuilder<64> WildcardUrlPath;
- WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
-
- m_IsOk = false;
-
ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})",
- WideToUtf8(WildcardUrlPath),
- GetSystemErrorAsString(Result),
- Result);
+ ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
if (Result != NO_ERROR)
{
- ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result);
+ ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result);
- return 0;
+ return false;
}
+ return true;
+}
+
+bool
+HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris)
+{
+ using namespace std::literals;
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrl;
+ LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (Result == NO_ERROR)
+ {
+ ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl));
+ OutUris.push_back(LocalUrl.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+
+ return !OutUris.empty();
+}
+
+int
+HttpSysServer::RegisterHttpUrls(int BasePort)
+{
+ using namespace std::literals;
+
m_BaseUris.clear();
const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer;
@@ -1122,6 +1158,11 @@ HttpSysServer::InitializeServer(int BasePort)
int EffectivePort = BasePort;
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ ULONG Result;
+
if (m_InitialConfig.ForceLoopback)
{
// Force trigger of opening using local port
@@ -1177,11 +1218,11 @@ HttpSysServer::InitializeServer(int BasePort)
{
if (AllowLocalOnly)
{
- // If we can't register the wildcard path, we fall back to local paths
- // This local paths allow requests originating locally to function, but will not allow
- // remote origin requests to function. This can be remedied by using netsh
+ // If we can't register the wildcard path, we fall back to local paths.
+ // Local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
// during an install process to grant permissions to route public access to the appropriate
- // port for the current user. eg:
+ // port for the current user. eg:
// netsh http add urlacl url=http://*:8558/ user=<some_user>
if (!m_InitialConfig.ForceLoopback)
@@ -1246,7 +1287,7 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- if (m_BaseUris.empty())
+ if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0)
{
ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})",
WideToUtf8(WildcardUrlPath),
@@ -1256,16 +1297,104 @@ HttpSysServer::InitializeServer(int BasePort)
return 0;
}
+ return EffectivePort;
+}
+
+bool
+HttpSysServer::RegisterHttpsUrls()
+{
+ using namespace std::literals;
+
+ const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer;
+ const int HttpsPort = m_InitialConfig.HttpsPort;
+
+ // If HTTPS-only mode, remove HTTP URLs and clear base URIs
+ if (m_InitialConfig.HttpsOnly)
+ {
+ for (const std::wstring& Uri : m_BaseUris)
+ {
+ HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0);
+ }
+ m_BaseUris.clear();
+ }
+
+ // Auto-bind certificate if thumbprint is provided
+ if (!m_InitialConfig.CertThumbprint.empty())
+ {
+ if (!BindSslCertificate(HttpsPort))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort);
+ }
+
+ // Register HTTPS URLs using same pattern as HTTP
+
+ WideStringBuilder<64> HttpsWildcard;
+ HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv;
+
+ ULONG HttpsResult = NO_ERROR;
+
+ if (m_InitialConfig.ForceLoopback)
+ {
+ HttpsResult = ERROR_ACCESS_DENIED;
+ }
+ else
+ {
+ HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ if (HttpsResult == NO_ERROR)
+ {
+ m_HttpsBaseUris.push_back(HttpsWildcard.c_str());
+ }
+ else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly)
+ {
+ if (!m_InitialConfig.ForceLoopback)
+ {
+ ZEN_WARN(
+ "Unable to register HTTPS handler using '{}' - falling back to local-only. "
+ "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.",
+ WideToUtf8(HttpsWildcard));
+ }
+
+ RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris);
+ }
+ else if (HttpsResult != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})",
+ WideToUtf8(HttpsWildcard),
+ GetSystemErrorAsString(HttpsResult),
+ HttpsResult);
+ return false;
+ }
+
+ if (m_HttpsBaseUris.empty())
+ {
+ ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort);
+ return false;
+ }
+
+ m_HttpsPort = HttpsPort;
+ return true;
+}
+
+bool
+HttpSysServer::CreateRequestQueue(int EffectivePort)
+{
HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
WideStringBuilder<64> QueueName;
QueueName << "zenserver_" << EffectivePort;
- Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
- /* Name */ QueueName.c_str(),
- /* SecurityAttributes */ nullptr,
- /* Flags */ 0,
- &m_RequestQueueHandle);
+ ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ QueueName.c_str(),
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
if (Result != NO_ERROR)
{
@@ -1274,7 +1403,7 @@ HttpSysServer::InitializeServer(int BasePort)
GetSystemErrorAsString(Result),
Result);
- return 0;
+ return false;
}
HttpBindingInfo.Flags.Present = 1;
@@ -1289,7 +1418,7 @@ HttpSysServer::InitializeServer(int BasePort)
GetSystemErrorAsString(Result),
Result);
- return 0;
+ return false;
}
// Configure rejection method. Default is to drop the connection, it's better if we
@@ -1323,22 +1452,77 @@ HttpSysServer::InitializeServer(int BasePort)
}
}
- // Create I/O completion port
+ return true;
+}
+bool
+HttpSysServer::SetupIoCompletionPort()
+{
std::error_code ErrorCode;
m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
if (ErrorCode)
{
- ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message());
+ return false;
+ }
+ m_IsOk = true;
+
+ if (!m_BaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+ if (!m_HttpsBaseUris.empty())
+ {
+ ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front()));
+ }
+
+ return true;
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ ZEN_MEMSCOPE(GetHttpsysTag());
+
+ m_IsOk = false;
+
+ if (!CreateSessionAndUrlGroup())
+ {
return 0;
}
- else
+
+ int EffectivePort = RegisterHttpUrls(BasePort);
+
+ if (m_InitialConfig.HttpsPort > 0)
+ {
+ if (!RegisterHttpsUrls())
+ {
+ return 0;
+ }
+ }
+
+ if (m_BaseUris.empty() && m_HttpsBaseUris.empty())
{
- m_IsOk = true;
+ ZEN_ERROR("No HTTP or HTTPS listeners could be registered");
+ return 0;
+ }
- ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ if (!CreateRequestQueue(EffectivePort))
+ {
+ return 0;
+ }
+
+ if (!SetupIoCompletionPort())
+ {
+ return 0;
+ }
+
+ // When HTTPS-only, return the HTTPS port as the effective port
+ if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0)
+ {
+ return m_HttpsPort;
}
return EffectivePort;
@@ -1349,6 +1533,8 @@ HttpSysServer::Cleanup()
{
++m_IsShuttingDown;
+ UnbindSslCertificate();
+
if (m_RequestQueueHandle)
{
HttpCloseRequestQueue(m_RequestQueueHandle);
@@ -1368,6 +1554,105 @@ HttpSysServer::Cleanup()
}
}
+// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings
+static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}};
+
+bool
+HttpSysServer::BindSslCertificate(int Port)
+{
+ const std::string& Thumbprint = m_InitialConfig.CertThumbprint;
+ if (Thumbprint.size() != 40)
+ {
+ ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size());
+ return false;
+ }
+
+ BYTE CertHash[20] = {};
+ if (!ParseHexBytes(Thumbprint, CertHash))
+ {
+ ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters");
+ return false;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(Port));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str());
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+ SslConfig.ParamDesc.pSslHash = CertHash;
+ SslConfig.ParamDesc.SslHashLength = sizeof(CertHash);
+ SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str());
+ SslConfig.ParamDesc.AppId = ZenServerSslAppId;
+
+ ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result == ERROR_ALREADY_EXISTS)
+ {
+ // Remove existing binding and retry
+ HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {};
+ DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr);
+
+ Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+ }
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR(
+ "Failed to bind SSL certificate to port {}: {} ({:#x}). "
+ "This operation may require running as administrator.",
+ Port,
+ GetSystemErrorAsString(Result),
+ Result);
+ return false;
+ }
+
+ m_DidAutoBindCert = true;
+ m_HttpsPort = Port;
+
+ ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})",
+ Port,
+ Thumbprint.substr(0, 8),
+ m_InitialConfig.CertStoreName);
+
+ return true;
+}
+
+void
+HttpSysServer::UnbindSslCertificate()
+{
+ if (!m_DidAutoBindCert)
+ {
+ return;
+ }
+
+ SOCKADDR_IN Address = {};
+ Address.sin_family = AF_INET;
+ Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort));
+ Address.sin_addr.s_addr = INADDR_ANY;
+
+ HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {};
+ SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address);
+
+ ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result);
+ }
+ else
+ {
+ ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort);
+ }
+
+ m_DidAutoBindCert = false;
+}
+
WorkerThreadPool&
HttpSysServer::WorkPool()
{
@@ -1495,19 +1780,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
// Convert to wide string
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
-
- if (Result != NO_ERROR)
+ auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
{
- ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ std::wstring Url16 = BaseUri + PathUtf16;
- return;
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ return;
+ }
}
- }
+ };
+
+ RegisterWithBaseUris(m_BaseUris);
+ RegisterWithBaseUris(m_HttpsBaseUris);
}
void
@@ -1522,19 +1811,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
+ auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) {
+ for (const std::wstring& BaseUri : BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
- ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
}
- }
+ };
+
+ UnregisterFromBaseUris(m_BaseUris);
+ UnregisterFromBaseUris(m_HttpsBaseUris);
}
//////////////////////////////////////////////////////////////////////////
@@ -2422,6 +2714,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
ZEN_UNUSED(DataDir);
if (int EffectivePort = InitializeServer(BasePort))
{
+ if (m_HttpsPort > 0)
+ {
+ SetEffectiveHttpsPort(m_HttpsPort);
+ }
+
StartServer();
return EffectivePort;
diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h
index b2fe7475b..ca465ad00 100644
--- a/src/zenhttp/servers/httpsys.h
+++ b/src/zenhttp/servers/httpsys.h
@@ -22,6 +22,10 @@ struct HttpSysConfig
bool IsRequestLoggingEnabled = false;
bool IsDedicatedServer = false;
bool ForceLoopback = false;
+ int HttpsPort = 0; // 0 = HTTPS disabled
+ std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding
+ std::string CertStoreName = "MY"; // Windows certificate store name
+ bool HttpsOnly = false; // When true, disable HTTP listener
};
Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config);
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index b2543277a..5ae48f5b3 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -1,6 +1,7 @@
// Copyright Epic Games, Inc. All Rights Reserved.
#include "wsasio.h"
+#include "asio_socket_traits.h"
#include "wsframecodec.h"
#include <zencore/logging.h>
@@ -17,14 +18,16 @@ WsLog()
//////////////////////////////////////////////////////////////////////////
-WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server)
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server)
: m_Socket(std::move(Socket))
, m_Handler(Handler)
, m_HttpServer(Server)
{
}
-WsAsioConnection::~WsAsioConnection()
+template<typename SocketType>
+WsAsioConnectionT<SocketType>::~WsAsioConnectionT()
{
m_IsOpen.store(false);
if (m_HttpServer)
@@ -33,14 +36,16 @@ WsAsioConnection::~WsAsioConnection()
}
}
+template<typename SocketType>
void
-WsAsioConnection::Start()
+WsAsioConnectionT<SocketType>::Start()
{
EnqueueRead();
}
+template<typename SocketType>
bool
-WsAsioConnection::IsOpen() const
+WsAsioConnectionT<SocketType>::IsOpen() const
{
return m_IsOpen.load(std::memory_order_relaxed);
}
@@ -50,23 +55,25 @@ WsAsioConnection::IsOpen() const
// Read loop
//
+template<typename SocketType>
void
-WsAsioConnection::EnqueueRead()
+WsAsioConnectionT<SocketType>::EnqueueRead()
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
return;
}
- Ref<WsAsioConnection> Self(this);
+ Ref<WsAsioConnectionT> Self(this);
asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) {
Self->OnDataReceived(Ec, ByteCount);
});
}
+template<typename SocketType>
void
-WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
if (Ec)
{
@@ -90,8 +97,9 @@ WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] st
}
}
+template<typename SocketType>
void
-WsAsioConnection::ProcessReceivedData()
+WsAsioConnectionT<SocketType>::ProcessReceivedData()
{
while (m_ReadBuffer.size() > 0)
{
@@ -162,8 +170,8 @@ WsAsioConnection::ProcessReceivedData()
// Shut down the socket
std::error_code ShutdownEc;
- m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc);
- m_Socket->close(ShutdownEc);
+ SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc);
+ SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc);
return;
}
@@ -179,8 +187,9 @@ WsAsioConnection::ProcessReceivedData()
// Write queue
//
+template<typename SocketType>
void
-WsAsioConnection::SendText(std::string_view Text)
+WsAsioConnectionT<SocketType>::SendText(std::string_view Text)
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
@@ -192,8 +201,9 @@ WsAsioConnection::SendText(std::string_view Text)
EnqueueWrite(std::move(Frame));
}
+template<typename SocketType>
void
-WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
+WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data)
{
if (!m_IsOpen.load(std::memory_order_relaxed))
{
@@ -204,14 +214,16 @@ WsAsioConnection::SendBinary(std::span<const uint8_t> Data)
EnqueueWrite(std::move(Frame));
}
+template<typename SocketType>
void
-WsAsioConnection::Close(uint16_t Code, std::string_view Reason)
+WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason)
{
DoClose(Code, Reason);
}
+template<typename SocketType>
void
-WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
+WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason)
{
if (!m_IsOpen.exchange(false))
{
@@ -227,8 +239,9 @@ WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason)
m_Handler.OnWebSocketClose(*this, Code, Reason);
}
+template<typename SocketType>
void
-WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
+WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame)
{
if (m_HttpServer)
{
@@ -252,8 +265,9 @@ WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame)
}
}
+template<typename SocketType>
void
-WsAsioConnection::FlushWriteQueue()
+WsAsioConnectionT<SocketType>::FlushWriteQueue()
{
std::vector<uint8_t> Frame;
@@ -272,7 +286,7 @@ WsAsioConnection::FlushWriteQueue()
return;
}
- Ref<WsAsioConnection> Self(this);
+ Ref<WsAsioConnectionT> Self(this);
// Move Frame into a shared_ptr so we can create the buffer and capture ownership
// in the same async_write call without evaluation order issues.
@@ -283,8 +297,9 @@ WsAsioConnection::FlushWriteQueue()
[Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); });
}
+template<typename SocketType>
void
-WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
{
if (Ec)
{
@@ -308,4 +323,17 @@ WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] s
FlushWriteQueue();
}
+//////////////////////////////////////////////////////////////////////////
+// Explicit template instantiations
+
+template class WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+template class WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
index e8bb3b1d2..64602ee46 100644
--- a/src/zenhttp/servers/wsasio.h
+++ b/src/zenhttp/servers/wsasio.h
@@ -8,6 +8,12 @@
ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+# include <asio/local/stream_protocol.hpp>
+#endif
+#if ZEN_USE_OPENSSL
+# include <asio/ssl.hpp>
+#endif
ZEN_THIRD_PARTY_INCLUDES_END
#include <deque>
@@ -21,22 +27,23 @@ class HttpServer;
namespace zen::asio_http {
/**
- * WebSocket connection over an ASIO TCP socket
+ * WebSocket connection over an ASIO stream socket
*
- * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake)
+ * Templated on SocketType to support both TCP and Unix domain sockets.
+ * Owns the socket (moved from HttpServerConnection after the 101 handshake)
* and runs an async read/write loop to exchange WebSocket frames.
*
* Lifetime is managed solely through intrusive reference counting (RefCounted).
- * The async read/write callbacks capture Ref<WsAsioConnection> to keep the
- * connection alive for the duration of the async operation. The service layer
- * also holds a Ref<WebSocketConnection>.
+ * The async read/write callbacks capture Ref<> to keep the connection alive
+ * for the duration of the async operation. The service layer also holds a
+ * Ref<WebSocketConnection>.
*/
-
-class WsAsioConnection : public WebSocketConnection
+template<typename SocketType>
+class WsAsioConnectionT : public WebSocketConnection
{
public:
- WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server);
- ~WsAsioConnection() override;
+ WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server);
+ ~WsAsioConnectionT() override;
/**
* Start the async read loop. Must be called once after construction
@@ -61,10 +68,10 @@ private:
void DoClose(uint16_t Code, std::string_view Reason);
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- IWebSocketHandler& m_Handler;
- zen::HttpServer* m_HttpServer;
- asio::streambuf m_ReadBuffer;
+ std::unique_ptr<SocketType> m_Socket;
+ IWebSocketHandler& m_Handler;
+ zen::HttpServer* m_HttpServer;
+ asio::streambuf m_ReadBuffer;
RwLock m_WriteLock;
std::deque<std::vector<uint8_t>> m_WriteQueue;
@@ -74,4 +81,14 @@ private:
std::atomic<bool> m_CloseSent{false};
};
+using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>;
+
+#if defined(ASIO_HAS_LOCAL_SOCKETS)
+using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>;
+#endif
+
+#if ZEN_USE_OPENSSL
+using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>;
+#endif
+
} // namespace zen::asio_http
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index 2134e4ff1..042afd8ff 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -485,7 +485,7 @@ TEST_CASE("websocket.integration")
Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
- int Port = Server->Initialize(7575, TmpDir.Path());
+ int Port = Server->Initialize(0, TmpDir.Path());
REQUIRE(Port != 0);
Server->RegisterService(TestService);
@@ -797,7 +797,7 @@ TEST_CASE("websocket.client")
Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{});
- int Port = Server->Initialize(7576, TmpDir.Path());
+ int Port = Server->Initialize(0, TmpDir.Path());
REQUIRE(Port != 0);
Server->RegisterService(TestService);
@@ -913,6 +913,75 @@ TEST_CASE("websocket.client")
}
}
+TEST_CASE("websocket.client.unixsocket")
+{
+ WsTestService TestService;
+ ScopedTemporaryDirectory TmpDir;
+ std::string SocketPath = (TmpDir.Path() / "ws.sock").string();
+
+ Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath});
+
+ int Port = Server->Initialize(0, TmpDir.Path());
+ REQUIRE(Port != 0);
+
+ Server->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { Server->Run(false); });
+
+ auto ServerGuard = MakeGuard([&]() {
+ Server->RequestExit();
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ Server->Close();
+ });
+
+ Sleep(100);
+
+ SUBCASE("connect, echo, close over unix socket")
+ {
+ TestWsClientHandler Handler;
+ HttpWsClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+
+ HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings);
+ Client.Connect();
+
+ // Wait for OnWsOpen
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ REQUIRE_EQ(Handler.m_OpenCount.load(), 1);
+ CHECK(Client.IsOpen());
+
+ // Send text, expect echo
+ Client.SendText("hello over unix socket");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+ CHECK_EQ(Handler.m_MessageCount.load(), 1);
+ CHECK_EQ(Handler.m_LastMessage, "hello over unix socket");
+
+ // Close
+ Client.Close(1000, "done");
+
+ Deadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline)
+ {
+ Sleep(10);
+ }
+
+ Sleep(50);
+ CHECK_FALSE(Client.IsOpen());
+ }
+}
+
TEST_SUITE_END();
void
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
index e8f87b668..9b461662e 100644
--- a/src/zenhttp/xmake.lua
+++ b/src/zenhttp/xmake.lua
@@ -12,6 +12,11 @@ target('zenhttp')
add_packages("http_parser", "json11")
add_options("httpsys")
+ if is_plat("linux", "macosx") then
+ add_packages("openssl3")
+ end
+
if is_plat("linux") then
add_syslinks("dl") -- TODO: is libdl needed?
end
+
diff --git a/src/zenremotestore/builds/jupiterbuildstorage.cpp b/src/zenremotestore/builds/jupiterbuildstorage.cpp
index 8e16da1a9..c3f7b9e71 100644
--- a/src/zenremotestore/builds/jupiterbuildstorage.cpp
+++ b/src/zenremotestore/builds/jupiterbuildstorage.cpp
@@ -23,9 +23,10 @@ using namespace std::literals;
namespace {
[[noreturn]] void ThrowFromJupiterResult(const JupiterResult& Result, std::string_view Prefix)
{
- int Error = Result.ErrorCode < (int)HttpResponseCode::Continue ? Result.ErrorCode : 0;
- HttpResponseCode Status =
- Result.ErrorCode >= int(HttpResponseCode::Continue) ? HttpResponseCode(Result.ErrorCode) : HttpResponseCode::ImATeapot;
+ HttpClientErrorCode Error = Result.ErrorCode < static_cast<int>(HttpResponseCode::Continue) ? HttpClientErrorCode(Result.ErrorCode)
+ : HttpClientErrorCode::kOK;
+ HttpResponseCode Status = Result.ErrorCode >= static_cast<int>(HttpResponseCode::Continue) ? HttpResponseCode(Result.ErrorCode)
+ : HttpResponseCode::ImATeapot;
throw HttpClientError(fmt::format("{}: {} ({})", Prefix, Result.Reason, Result.ErrorCode), Error, Status);
}
} // namespace
diff --git a/src/zenremotestore/jupiter/jupitersession.cpp b/src/zenremotestore/jupiter/jupitersession.cpp
index 52f9eb678..b5531fa60 100644
--- a/src/zenremotestore/jupiter/jupitersession.cpp
+++ b/src/zenremotestore/jupiter/jupitersession.cpp
@@ -68,7 +68,7 @@ namespace detail {
return {.SentBytes = gsl::narrow<uint64_t>(Response.UploadedBytes),
.ReceivedBytes = gsl::narrow<uint64_t>(Response.DownloadedBytes),
.ElapsedSeconds = Response.ElapsedSeconds,
- .ErrorCode = Response.Error.value().ErrorCode,
+ .ErrorCode = static_cast<int32_t>(Response.Error.value().ErrorCode),
.Reason = Response.ErrorMessage(ErrorPrefix),
.Success = false};
}
diff --git a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
index 2282a31dd..e95d9118c 100644
--- a/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/buildsremoteprojectstore.cpp
@@ -284,9 +284,7 @@ public:
}
catch (const HttpClientError& Ex)
{
- Result.ErrorCode = Ex.GetInternalErrorCode() != 0 ? Ex.GetInternalErrorCode()
- : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? (int)Ex.GetHttpResponseCode()
- : 0;
+ Result.ErrorCode = MakeErrorCode(Ex);
Result.Reason = fmt::format("Failed finalizing oplog container build part to {}/{}/{}/{}/{}. Reason: '{}'",
m_BuildStorageHttp.GetBaseUri(),
m_Namespace,
@@ -315,9 +313,7 @@ public:
}
catch (const HttpClientError& Ex)
{
- Result.ErrorCode = Ex.GetInternalErrorCode() != 0 ? Ex.GetInternalErrorCode()
- : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? (int)Ex.GetHttpResponseCode()
- : 0;
+ Result.ErrorCode = MakeErrorCode(Ex);
Result.Reason = fmt::format("Failed finalizing oplog container build to {}/{}/{}/{}. Reason: '{}'",
m_BuildStorageHttp.GetBaseUri(),
m_Namespace,
@@ -591,8 +587,8 @@ public:
private:
static int MakeErrorCode(const HttpClientError& Ex)
{
- return Ex.GetInternalErrorCode() != 0 ? Ex.GetInternalErrorCode()
- : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? (int)Ex.GetHttpResponseCode()
+ return Ex.GetInternalErrorCode() != HttpClientErrorCode::kOK ? static_cast<int>(Ex.GetInternalErrorCode())
+ : Ex.GetHttpResponseCode() != HttpResponseCode::ImATeapot ? static_cast<int>(Ex.GetHttpResponseCode())
: 0;
}
diff --git a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp
index 115d6438d..a08a07fcd 100644
--- a/src/zenremotestore/projectstore/zenremoteprojectstore.cpp
+++ b/src/zenremotestore/projectstore/zenremoteprojectstore.cpp
@@ -329,7 +329,7 @@ private:
{
if (Response.Error)
{
- return {.ErrorCode = Response.Error.value().ErrorCode,
+ return {.ErrorCode = static_cast<int32_t>(Response.Error.value().ErrorCode),
.ElapsedSeconds = Response.ElapsedSeconds,
.Reason = Response.ErrorMessage(""),
.Text = Response.ToText()};
diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp
index 8d5400294..0b2bc50c0 100644
--- a/src/zenserver-test/zenserver-test.cpp
+++ b/src/zenserver-test/zenserver-test.cpp
@@ -24,6 +24,10 @@
# include <atomic>
# include <filesystem>
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
# if ZEN_PLATFORM_WINDOWS
# include <ppl.h>
# include <process.h>
@@ -109,6 +113,11 @@ main(int argc, char** argv)
ServerClass = argv[++i];
}
}
+ else if (std::string_view Arg(argv[i]); Arg.starts_with("--httpclient="sv))
+ {
+ std::string_view Value = Arg.substr(13);
+ zen::SetDefaultHttpClientBackend(Value);
+ }
else if (argv[i] == "--verbose"sv)
{
Verbose = true;
@@ -341,6 +350,38 @@ TEST_CASE("http.package")
CHECK_EQ(ResponsePackage, TestPackage);
}
+# if defined(ASIO_HAS_LOCAL_SOCKETS)
+TEST_CASE("http.unixsocket")
+{
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ std::filesystem::path SocketDir = TestEnv.CreateNewTestDir();
+ std::string SocketPath = (SocketDir / "zen.sock").string();
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetDataDir(TestDir);
+ const uint16_t PortNumber = Instance.SpawnServerAndWaitUntilReady(fmt::format("--http=asio --unix-socket {}", SocketPath));
+
+ // Connect via Unix socket (BaseUri still needed for Host header)
+ HttpClientSettings Settings;
+ Settings.UnixSocketPath = SocketPath;
+ HttpClient Http{fmt::format("http://localhost:{}", PortNumber), Settings, {}};
+
+ SUBCASE("GET over unix socket")
+ {
+ HttpClient::Response Res = Http.Get("/testing/hello");
+ CHECK(Res.IsSuccess());
+ }
+
+ SUBCASE("POST echo over unix socket")
+ {
+ IoBuffer Body{IoBuffer::Wrap, "unix-test", 9};
+ HttpClient::Response Res = Http.Post("/testing/echo", Body);
+ CHECK(Res.IsSuccess());
+ CHECK(Res.ResponsePayload.GetView().EqualBytes(Body.GetView()));
+ }
+}
+# endif
+
TEST_SUITE_END();
# if 0
diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp
index c64f081b3..2ac3de599 100644
--- a/src/zenserver/compute/computeserver.cpp
+++ b/src/zenserver/compute/computeserver.cpp
@@ -674,7 +674,7 @@ ZenComputeServer::PostAnnounce()
{
ZEN_ERROR("failed to notify coordinator at '{}': HTTP error {} - {}",
m_CoordinatorEndpoint,
- Result.Error->ErrorCode,
+ static_cast<int>(Result.Error->ErrorCode),
Result.Error->ErrorMessage);
}
else if (!IsHttpOk(Result.StatusCode))
diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp
index e36352dae..ef9c6b7b8 100644
--- a/src/zenserver/config/config.cpp
+++ b/src/zenserver/config/config.cpp
@@ -144,10 +144,15 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions
////// network
+ LuaOptions.AddOption("network.httpclientbackend"sv, ServerOptions.HttpClient.Backend, "httpclient"sv);
LuaOptions.AddOption("network.httpserverclass"sv, ServerOptions.HttpConfig.ServerClass, "http"sv);
LuaOptions.AddOption("network.httpserverthreads"sv, ServerOptions.HttpConfig.ThreadCount, "http-threads"sv);
LuaOptions.AddOption("network.port"sv, ServerOptions.BasePort, "port"sv);
LuaOptions.AddOption("network.forceloopback"sv, ServerOptions.HttpConfig.ForceLoopback, "http-forceloopback"sv);
+ LuaOptions.AddOption("network.unixsocket"sv, ServerOptions.HttpConfig.UnixSocketPath, "unix-socket"sv);
+ LuaOptions.AddOption("network.https.port"sv, ServerOptions.HttpConfig.HttpsPort, "https-port"sv);
+ LuaOptions.AddOption("network.https.certfile"sv, ServerOptions.HttpConfig.CertFile, "cert-file"sv);
+ LuaOptions.AddOption("network.https.keyfile"sv, ServerOptions.HttpConfig.KeyFile, "key-file"sv);
#if ZEN_WITH_HTTPSYS
LuaOptions.AddOption("network.httpsys.async.workthreads"sv,
@@ -159,6 +164,10 @@ ZenServerConfiguratorBase::AddCommonConfigOptions(LuaConfig::Options& LuaOptions
LuaOptions.AddOption("network.httpsys.requestlogging"sv,
ServerOptions.HttpConfig.HttpSys.IsRequestLoggingEnabled,
"httpsys-enable-request-logging"sv);
+ LuaOptions.AddOption("network.httpsys.httpsport"sv, ServerOptions.HttpConfig.HttpSys.HttpsPort, "httpsys-https-port"sv);
+ LuaOptions.AddOption("network.httpsys.certthumbprint"sv, ServerOptions.HttpConfig.HttpSys.CertThumbprint, "httpsys-cert-thumbprint"sv);
+ LuaOptions.AddOption("network.httpsys.certstorename"sv, ServerOptions.HttpConfig.HttpSys.CertStoreName, "httpsys-cert-store"sv);
+ LuaOptions.AddOption("network.httpsys.httpsonly"sv, ServerOptions.HttpConfig.HttpSys.HttpsOnly, "httpsys-https-only"sv);
#endif
#if ZEN_WITH_TRACE
@@ -304,6 +313,34 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi
options.add_option("network",
"",
+ "unix-socket",
+ "Unix domain socket path to listen on (in addition to TCP)",
+ cxxopts::value<std::string>(ServerOptions.HttpConfig.UnixSocketPath),
+ "<path>");
+
+ options.add_option("network",
+ "",
+ "https-port",
+ "HTTPS listen port (0 = disabled)",
+ cxxopts::value<int>(ServerOptions.HttpConfig.HttpsPort)->default_value("0"),
+ "<port>");
+
+ options.add_option("network",
+ "",
+ "cert-file",
+ "Path to PEM certificate chain file for HTTPS",
+ cxxopts::value<std::string>(ServerOptions.HttpConfig.CertFile),
+ "<path>");
+
+ options.add_option("network",
+ "",
+ "key-file",
+ "Path to PEM private key file for HTTPS",
+ cxxopts::value<std::string>(ServerOptions.HttpConfig.KeyFile),
+ "<path>");
+
+ options.add_option("network",
+ "",
"security-config-path",
"Path to http security configuration file",
cxxopts::value<std::string>(SecurityConfigPath),
@@ -330,10 +367,45 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi
"Enables Httpsys request logging",
cxxopts::value<bool>(ServerOptions.HttpConfig.HttpSys.IsRequestLoggingEnabled),
"<httpsys request logging>");
+
+ options.add_option("httpsys",
+ "",
+ "httpsys-https-port",
+ "HTTPS listen port for http.sys (0 = disabled)",
+ cxxopts::value<int>(ServerOptions.HttpConfig.HttpSys.HttpsPort)->default_value("0"),
+ "<port>");
+
+ options.add_option("httpsys",
+ "",
+ "httpsys-cert-thumbprint",
+ "SHA-1 certificate thumbprint for auto SSL binding",
+ cxxopts::value<std::string>(ServerOptions.HttpConfig.HttpSys.CertThumbprint),
+ "<thumbprint>");
+
+ options.add_option("httpsys",
+ "",
+ "httpsys-cert-store",
+ "Windows certificate store name for SSL binding",
+ cxxopts::value<std::string>(ServerOptions.HttpConfig.HttpSys.CertStoreName)->default_value("MY"),
+ "<store name>");
+
+ options.add_option("httpsys",
+ "",
+ "httpsys-https-only",
+ "Disable HTTP listener when HTTPS is active",
+ cxxopts::value<bool>(ServerOptions.HttpConfig.HttpSys.HttpsOnly)->default_value("false"),
+ "");
#endif
options.add_option("network",
"",
+ "httpclient",
+ "Select HTTP client implementation (e.g. 'curl', 'cpr')",
+ cxxopts::value<std::string>(ServerOptions.HttpClient.Backend)->default_value("cpr"),
+ "<http client>");
+
+ options.add_option("network",
+ "",
"http",
"Select HTTP server implementation (asio|"
#if ZEN_WITH_HTTPSYS
@@ -397,6 +469,56 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig
ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath);
LoggingOptions.ApplyOptions(ServerOptions.LoggingConfig);
+
+#if ZEN_WITH_HTTPSYS
+ // Validate HTTPS options
+ const auto& HttpSys = ServerOptions.HttpConfig.HttpSys;
+ if (HttpSys.HttpsOnly && HttpSys.HttpsPort == 0)
+ {
+ throw OptionParseException("'--httpsys-https-only' requires '--httpsys-https-port' to be set", options.help());
+ }
+ if (!HttpSys.CertThumbprint.empty() && HttpSys.CertThumbprint.size() != 40)
+ {
+ throw OptionParseException("'--httpsys-cert-thumbprint' must be exactly 40 hex characters (SHA-1)", options.help());
+ }
+ if (!HttpSys.CertThumbprint.empty())
+ {
+ for (char Ch : HttpSys.CertThumbprint)
+ {
+ if (!((Ch >= '0' && Ch <= '9') || (Ch >= 'a' && Ch <= 'f') || (Ch >= 'A' && Ch <= 'F')))
+ {
+ throw OptionParseException("'--httpsys-cert-thumbprint' contains non-hex characters", options.help());
+ }
+ }
+ }
+ if (HttpSys.HttpsPort > 0 && HttpSys.HttpsPort == ServerOptions.BasePort && !HttpSys.HttpsOnly)
+ {
+ throw OptionParseException("'--httpsys-https-port' must differ from '--port' when both HTTP and HTTPS are active", options.help());
+ }
+#endif
+
+ // Validate generic HTTPS options (used by ASIO backend)
+ if (ServerOptions.HttpConfig.HttpsPort > 0)
+ {
+ if (ServerOptions.HttpConfig.CertFile.empty() || ServerOptions.HttpConfig.KeyFile.empty())
+ {
+ throw OptionParseException("'--https-port' requires both '--cert-file' and '--key-file' to be set", options.help());
+ }
+ if (!std::filesystem::exists(ServerOptions.HttpConfig.CertFile))
+ {
+ throw OptionParseException(fmt::format("'--cert-file' path '{}' does not exist", ServerOptions.HttpConfig.CertFile),
+ options.help());
+ }
+ if (!std::filesystem::exists(ServerOptions.HttpConfig.KeyFile))
+ {
+ throw OptionParseException(fmt::format("'--key-file' path '{}' does not exist", ServerOptions.HttpConfig.KeyFile),
+ options.help());
+ }
+ if (ServerOptions.HttpConfig.HttpsPort == ServerOptions.BasePort)
+ {
+ throw OptionParseException("'--https-port' must differ from '--port'", options.help());
+ }
+ }
}
//////////////////////////////////////////////////////////////////////////
diff --git a/src/zenserver/config/config.h b/src/zenserver/config/config.h
index 55aee07f9..88226f810 100644
--- a/src/zenserver/config/config.h
+++ b/src/zenserver/config/config.h
@@ -38,8 +38,14 @@ struct ZenSentryConfig
bool Debug = false; // Enable debug mode for Sentry
};
+struct HttpClientConfig
+{
+ std::string Backend = "cpr"; // Choice of HTTP client implementation (e.g. "curl", "cpr")
+};
+
struct ZenServerConfig
{
+ HttpClientConfig HttpClient;
HttpServerConfig HttpConfig;
ZenSentryConfig SentryConfig;
ZenStatsConfig StatsConfig;
diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp
index 05be3c814..6cf12bea4 100644
--- a/src/zenserver/sessions/httpsessions.cpp
+++ b/src/zenserver/sessions/httpsessions.cpp
@@ -258,6 +258,10 @@ HttpSessionsService::SessionRequest(HttpRouterRequest& Req)
}
return ServerRequest.WriteResponse(HttpResponseCode::NotFound);
}
+ default:
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::MethodNotAllowed);
+ }
}
}
diff --git a/src/zenserver/sessions/sessions.cpp b/src/zenserver/sessions/sessions.cpp
index f73aa40ff..d919db6e9 100644
--- a/src/zenserver/sessions/sessions.cpp
+++ b/src/zenserver/sessions/sessions.cpp
@@ -64,6 +64,8 @@ SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, cons
return false;
}
+ ZEN_INFO("Session {} registered (AppName: {}, JobId: {})", SessionId, AppName, JobId);
+
const DateTime Now = DateTime::Now();
m_Sessions.emplace(SessionId,
Ref(new Session(SessionInfo{.Id = SessionId,
@@ -72,8 +74,6 @@ SessionsService::RegisterSession(const Oid& SessionId, std::string AppName, cons
.Metadata = CbObject::Clone(Metadata),
.CreatedAt = Now,
.UpdatedAt = Now})));
-
- ZEN_INFO("Session {} registered (AppName: {}, JobId: {})", SessionId, AppName, JobId);
return true;
}
diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp
index 06b8f6c27..bbdb03ba4 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.cpp
+++ b/src/zenserver/storage/cache/httpstructuredcache.cpp
@@ -1892,8 +1892,6 @@ HttpStructuredCacheService::CollectStats()
{
Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount;
- Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
- Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
}
Cbo << "cidhits" << ChunkHitCount << "cidmisses" << ChunkMissCount << "cidwrites" << ChunkWriteCount;
@@ -2025,8 +2023,6 @@ HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
{
Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount;
- Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
- Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
}
Cbo << "cidhits" << ChunkHitCount << "cidmisses" << ChunkMissCount << "cidwrites" << ChunkWriteCount;
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
index 88b85d7d9..49ae1b6ff 100644
--- a/src/zenserver/zenserver.cpp
+++ b/src/zenserver/zenserver.cpp
@@ -23,6 +23,7 @@
#include <zencore/timer.h>
#include <zencore/trace.h>
#include <zencore/workthreadpool.h>
+#include <zenhttp/httpclient.h>
#include <zenhttp/httpserver.h>
#include <zenhttp/security/passwordsecurityfilter.h>
#include <zentelemetry/otlptrace.h>
@@ -146,6 +147,14 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState::
EnqueueSigIntTimer();
+ // Configure HTTP client back-end
+
+ const std::string HttpClientBackend = ToLower(ServerOptions.HttpClient.Backend);
+ zen::SetDefaultHttpClientBackend(HttpClientBackend);
+ ZEN_INFO("Using '{}' as HTTP client backend", HttpClientBackend);
+
+ // Initialize HTTP server
+
m_Http = CreateHttpServer(ServerOptions.HttpConfig);
int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort, ServerOptions.DataDir);
if (EffectiveBasePort == 0)