diff options
| author | Stefan Boberg <[email protected]> | 2026-03-16 10:56:11 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-03-16 10:56:11 +0100 |
| commit | 8c3ba4e8c522d119df3cb48966e36c0eaa80aeb9 (patch) | |
| tree | cf51b07e097904044b4bf65bc3fe0ad14134074f | |
| parent | Merge branch 'sb/no-network' of https://github.ol.epicgames.net/ue-foundation... (diff) | |
| parent | Enable cross compilation of Windows targets on Linux (#839) (diff) | |
| download | zen-sb/no-network.tar.xz zen-sb/no-network.zip | |
Merge branch 'main' into sb/no-networksb/no-network
100 files changed, 4281 insertions, 2025 deletions
diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index 3345573c0..6d6a15bfc 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -127,6 +127,62 @@ jobs: name: zenserver-macos path: build/zenserver-macos.zip + docker-build: + name: Build Docker Images + runs-on: [linux, x64, zen] + timeout-minutes: 15 + needs: [bundle-linux, bundle-windows] + + steps: + - uses: actions/checkout@v4 + + - name: Read VERSION.txt + id: read_version + uses: ue-foundation/[email protected] + with: + path: "./VERSION.txt" + + - name: Download Linux bundle + uses: actions/download-artifact@v1 + with: + name: zenserver-linux + path: artifacts/linux + + - name: Download Windows bundle + uses: actions/download-artifact@v1 + with: + name: zenserver-win64 + path: artifacts/win64 + + - name: Extract binaries + run: | + mkdir -p build/linux/x86_64/release + unzip artifacts/linux/zenserver-linux.zip -d artifacts/linux-extracted + cp artifacts/linux-extracted/zenserver build/linux/x86_64/release/ + mkdir -p build/win-binary-staging + unzip artifacts/win64/zenserver-win64.zip -d artifacts/win-extracted + cp artifacts/win-extracted/zenserver.exe build/win-binary-staging/ + + - name: Build Docker image (with Wine + Windows binary) + run: | + docker build \ + -t zenserver-compute:${{ steps.read_version.outputs.content }} \ + --build-arg WIN_BINARY_DIR=build/win-binary-staging \ + -f docker/Dockerfile . + + - name: Build Docker image (Linux only, no Wine) + run: | + docker build \ + -t zenserver-compute-linux:${{ steps.read_version.outputs.content }} \ + --build-arg INSTALL_WINE=false \ + -f docker/Dockerfile . + + # TODO: Push images to container registry + # - name: Push images + # run: | + # docker push zenserver-compute:${{ steps.read_version.outputs.content }} + # docker push zenserver-compute-linux:${{ steps.read_version.outputs.content }} + create-release: runs-on: [linux, x64, zen] timeout-minutes: 5 diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index d96645ac9..dfdb9677d 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -247,3 +247,48 @@ jobs: with: name: zenserver-macos path: build/zenserver-macos.zip + + docker-build: + name: Build Docker Images + if: github.ref_name == 'main' + runs-on: [linux, x64, zen] + timeout-minutes: 15 + needs: [linux-build, windows-build] + + steps: + - uses: actions/checkout@v4 + + - name: Download Linux bundle + uses: actions/download-artifact@v1 + with: + name: zenserver-linux + path: artifacts/linux + + - name: Download Windows bundle + uses: actions/download-artifact@v1 + with: + name: zenserver-win64 + path: artifacts/win64 + + - name: Extract binaries + run: | + mkdir -p build/linux/x86_64/release + unzip artifacts/linux/zenserver-linux.zip -d artifacts/linux-extracted + cp artifacts/linux-extracted/zenserver build/linux/x86_64/release/ + mkdir -p build/win-binary-staging + unzip artifacts/win64/zenserver-win64.zip -d artifacts/win-extracted + cp artifacts/win-extracted/zenserver.exe build/win-binary-staging/ + + - name: Build Docker image (with Wine + Windows binary) + run: | + docker build \ + -t zenserver-compute:latest \ + --build-arg WIN_BINARY_DIR=build/win-binary-staging \ + -f docker/Dockerfile . + + - name: Build Docker image (Linux only, no Wine) + run: | + docker build \ + -t zenserver-compute-linux:latest \ + --build-arg INSTALL_WINE=false \ + -f docker/Dockerfile . diff --git a/.gitignore b/.gitignore index 0aa028930..5c9195566 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ .DS_Store .claude/settings.local.json .profile/ +.xwin-cache/ # User-specific files *.suo @@ -112,3 +113,6 @@ CMake* # Ue tool chain temp directory .tmp-ue-toolchain/ + +# Generated frontend zip (built automatically by xmake) +src/zenserver/frontend/html.zip diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dd9ae35f..930b9938c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,15 @@ - Feature: Added `zen down --all` flag to shut down all running zenserver instances - Feature: Added `xmake kill` task to terminate all running zenserver instances - Feature: Support `ZEN_MALLOC` environment variable for default allocator selection; default allocator switched to rpmalloc +- Feature: Added Dockerfile and extended GHA workflows to produce zenserver Docker images, primarily intended for compute nodes +- Improvement: Added support for CoW block-cloning (used by build download) on Linux (tested with: btrfs/ XFS) +- Improvement: Added full-file CoW copying on macOS (APFS) - Improvement: Updated asio to 1.38.0 - Improvement: Updated fmt to 1.12.1 - Bugfix: Fixed sentry-native build to allow LTO on Windows - Bugfix: Minor test stability fixes (flaky hash collisions, per-thread RNG seeds) -## 5.7.21 +## 5.7.22 - Feature: Add `--allow-partial-block-requests` to `zen oplog-import` - Feature: The integrated dashboard has been updated with streamlined more interactive interface with streaming statistics where available. The dashboard can now also be reached without qualification. @@ -173,7 +173,7 @@ The codebase is organized into layered modules with clear dependencies: - Web UI bundled as ZIP in `src/zenserver/frontend/*.zip` - Dashboards for hub, orchestrator, and compute services are located in `src/zenserver/frontent/html/` - These are the files which end up being bundled into the front-end zip mentioned above -- Update with `xmake updatefrontend` after modifying HTML/JS, and check in the resulting zip +- The zip is generated automatically at build time when source files change **Memory Management:** - Can use mimalloc or rpmalloc for performance @@ -309,6 +309,4 @@ When debugging zenserver-test or other multi-process scenarios, use child proces # Create deployable ZIP bundle xmake bundle -# Update frontend ZIP after HTML changes -xmake updatefrontend ``` diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 000000000..060042d85 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,55 @@ +# Copyright Epic Games, Inc. All Rights Reserved. +# +# Runtime image for zenserver compute workers. +# +# Build variants: +# With Wine (for running Windows worker executables via WineProcessRunner): +# docker build -t zenserver-compute -f docker/Dockerfile . +# +# Without Wine (Linux-only workers, smaller image): +# docker build -t zenserver-compute --build-arg INSTALL_WINE=false -f docker/Dockerfile . +# +# The build context must contain the pre-built Linux zenserver binary at: +# build/linux/x86_64/release/zenserver + +FROM ubuntu:24.04 + +# Avoid interactive prompts during package installation +ENV DEBIAN_FRONTEND=noninteractive + +# Set to "false" to skip Wine installation and produce a smaller image +ARG INSTALL_WINE=true + +# Install WineHQ (only when INSTALL_WINE=true) +# Enables i386 architecture required for Wine32 support +RUN if [ "$INSTALL_WINE" = "true" ]; then \ + dpkg --add-architecture i386 \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ + ca-certificates \ + gnupg \ + wget \ + && wget -qO- https://dl.winehq.org/wine-builds/winehq.key \ + | gpg --dearmor -o /usr/share/keyrings/winehq-archive.key \ + && echo "deb [signed-by=/usr/share/keyrings/winehq-archive.key] https://dl.winehq.org/wine-builds/ubuntu/ noble main" \ + > /etc/apt/sources.list.d/winehq.list \ + && apt-get update \ + && apt-get install -y --no-install-recommends winehq-stable \ + && apt-get remove -y gnupg wget \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* ; \ + fi + +# Copy the pre-built zenserver binary +COPY build/linux/x86_64/release/zenserver /opt/zenserver/zenserver +RUN chmod +x /opt/zenserver/zenserver + +# Optional: Windows zenserver binary for Wine-based compute workers. +# Set WIN_BINARY_DIR to a directory containing zenserver.exe to include it. +# Defaults to docker/empty (an empty placeholder) so the COPY is a no-op. +ARG WIN_BINARY_DIR=docker/empty +COPY ${WIN_BINARY_DIR}/ /opt/zenserver/ + +EXPOSE 8558 + +ENTRYPOINT ["/opt/zenserver/zenserver"] diff --git a/docker/empty/.gitkeep b/docker/empty/.gitkeep new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/docker/empty/.gitkeep diff --git a/docs/WindowsOnLinux.md b/docs/WindowsOnLinux.md new file mode 100644 index 000000000..540447cb2 --- /dev/null +++ b/docs/WindowsOnLinux.md @@ -0,0 +1,109 @@ +# Cross-Compiling Windows Binaries on Linux + +This document describes how to build `zenserver.exe` and `zen.exe` for Windows +from a Linux host using clang-cl, lld-link, and the Windows SDK fetched via +[xwin](https://github.com/Jake-Shadle/xwin). + +## Prerequisites + +- **LLVM/Clang 17+** with the following tools (typically in `/usr/lib/llvm-<ver>/bin/`): + - `clang` (used as `clang-cl` via `--driver-mode=cl`) + - `lld-link` + - `llvm-lib` + - `llvm-rc` + - `llvm-ml` + - `llvm-mt` + + On Ubuntu/Debian: + ```bash + sudo apt install clang-18 lld-18 llvm-18 + ``` + +- **xmake** (build system) — see [xmake.io](https://xmake.io) for installation. + +## Step 1: Fetch the Windows SDK + +Run the provided script to download the Windows SDK and MSVC CRT headers/libs: + +```bash +./scripts/win_cross/get_win_sdk.sh +``` + +This downloads the SDK to `~/.xwin-sdk` by default. Override the location with +an argument or the `$XWIN_SDK_DIR` environment variable: + +```bash +./scripts/win_cross/get_win_sdk.sh /path/to/sdk +``` + +The script also creates: +- Tool wrapper scripts (clang-cl, lld-link, llvm-lib, etc.) in the SDK `bin/` directory +- An MSVC-compatible directory layout that xmake's built-in clang-cl toolchain can discover +- Debug CRT lib symlinks so cmake package builds succeed + +## Step 2: Configure and Build + +```bash +xmake config -y -p windows -a x64 \ + --toolchain=clang-cl \ + --sdk=$HOME/.xwin-sdk \ + --mrc=$HOME/.xwin-sdk/bin/x64/rc.exe \ + -m release \ + --zensentry=n \ + --zenmimalloc=n + +xmake +``` + +### Notes + +- **`--sdk` must use an absolute path** — `$HOME/.xwin-sdk` works, but + `~/.xwin-sdk` does not (xmake doesn't expand `~`). +- **`--mrc`** points to the resource compiler wrapper. Without it, xmake can't + auto-detect `llvm-rc` (it doesn't support `--version`). +- **`--zensentry=no`** is recommended — crashpad (used by sentry-native) is + difficult to cross-compile. +- **`--zenmimalloc=no`** is recommended initially to reduce the package surface. +- **LTO is automatically disabled** for cross-compilation builds. + +## Step 3: Verify + +```bash +file build/windows/x64/release/zenserver.exe +# PE32+ executable (console) x86-64, for MS Windows, 12 sections + +file build/windows/x64/release/zen.exe +# PE32+ executable (console) x86-64, for MS Windows, 12 sections +``` + +## Running with Wine (optional) + +The resulting binaries can be run under Wine for testing: + +```bash +wine build/windows/x64/release/zenserver.exe --help +``` + +## Troubleshooting + +### Library not found (case sensitivity) + +The Windows SDK ships headers and libs with specific casing (e.g. `DbgHelp.h`, +`ws2_32.lib`). Linux filesystems are case-sensitive, so `#include <Dbghelp.h>` +won't find `DbgHelp.h`. The codebase uses lowercase includes where possible. If +you encounter a missing header/lib, check the actual casing in the SDK directory. + +### `add_ldflags(...) is ignored` + +xmake's auto flag checker may reject MSVC linker flags when using clang-cl. Add +`{force = true}` to the `add_ldflags()` call. + +### Stale build state + +If you hit unexpected errors after changing toolchain settings, clean everything: + +```bash +rm -rf .xmake build ~/.xmake +``` + +Then reconfigure and rebuild. diff --git a/repo/packages/c/consul/xmake.lua b/repo/packages/c/consul/xmake.lua index 6982e6f03..82bd803b6 100644 --- a/repo/packages/c/consul/xmake.lua +++ b/repo/packages/c/consul/xmake.lua @@ -29,10 +29,14 @@ package("consul") end) on_test(function (package) + -- Skip binary verification when cross-compiling (e.g. Windows target on Linux host) + if package:is_cross() then + return + end if is_plat("windows") then os.run("%s version", package:installdir("bin", "consul.exe")) - elseif is_plat("linux") then - -- this should include macosx as well, but needs more logic to differentiate arm64 vs + elseif is_plat("linux") then + -- this should include macosx as well, but needs more logic to differentiate arm64 vs -- amd64 since arm64 binary won't run on amd64 macs. arm64 macs have Rosetta though so -- they can run the amd64 binary. os.run("%s version", package:installdir("bin", "consul")) diff --git a/repo/packages/m/mimalloc/xmake.lua b/repo/packages/m/mimalloc/xmake.lua index 993e4c1a9..54d6613b8 100644 --- a/repo/packages/m/mimalloc/xmake.lua +++ b/repo/packages/m/mimalloc/xmake.lua @@ -35,7 +35,7 @@ package("mimalloc") end on_install("macosx", "windows", "linux", "android", "mingw", function (package) - local configs = {"-DMI_OVERRIDE=OFF"} + local configs = {"-DMI_OVERRIDE=OFF", "-DCMAKE_BUILD_TYPE=" .. (package:is_debug() and "Debug" or "Release")} table.insert(configs, "-DMI_BUILD_STATIC=" .. (package:config("shared") and "OFF" or "ON")) table.insert(configs, "-DMI_BUILD_SHARED=" .. (package:config("shared") and "ON" or "OFF")) table.insert(configs, "-DMI_SECURE=" .. (package:config("secure") and "ON" or "OFF")) @@ -71,5 +71,7 @@ package("mimalloc") end) on_test(function (package) - assert(package:has_cfuncs("mi_malloc", {includes = "mimalloc.h"})) + if not package:is_cross() then + assert(package:has_cfuncs("mi_malloc", {includes = "mimalloc.h"})) + end end) diff --git a/repo/packages/n/nomad/xmake.lua b/repo/packages/n/nomad/xmake.lua index 85ea10985..20380e1a1 100644 --- a/repo/packages/n/nomad/xmake.lua +++ b/repo/packages/n/nomad/xmake.lua @@ -29,6 +29,9 @@ package("nomad") end) on_test(function (package) + if package:is_cross() then + return + end if is_plat("windows") then os.run("%s version", package:installdir("bin", "nomad.exe")) elseif is_plat("linux") then diff --git a/repo/packages/o/oidctoken/xmake.lua b/repo/packages/o/oidctoken/xmake.lua index 76360e7bf..4dc231b21 100644 --- a/repo/packages/o/oidctoken/xmake.lua +++ b/repo/packages/o/oidctoken/xmake.lua @@ -15,6 +15,9 @@ package("oidctoken") end) on_test(function (package) + if package:is_cross() then + return + end if is_plat("windows") then os.run("%s --help", package:installdir("bin", "OidcToken.exe")) else diff --git a/repo/packages/s/sentry-native/patches/0.12.1/crashpad_static_libcxx.patch b/repo/packages/s/sentry-native/patches/0.12.1/crashpad_static_libcxx.patch index 2005ad4ec..8d0a8f11e 100644 --- a/repo/packages/s/sentry-native/patches/0.12.1/crashpad_static_libcxx.patch +++ b/repo/packages/s/sentry-native/patches/0.12.1/crashpad_static_libcxx.patch @@ -1,24 +1,25 @@ --- a/external/crashpad/handler/CMakeLists.txt 2026-03-09 14:47:42.109197582 +0000 +++ b/external/crashpad/handler/CMakeLists.txt 2026-03-09 14:51:45.343538268 +0000 -@@ -120,6 +120,21 @@ +@@ -120,6 +120,22 @@ endif() endif() - + ++ # Statically link libc++ and libc++abi into crashpad_handler so it has ++ # no runtime dependency on libc++.so.1. This is needed when building with ++ # a toolchain that uses libc++ (e.g. UE clang) but deploys to systems ++ # where libc++.so.1 is not available. ++ # Only applied when -stdlib=libc++ is active (i.e. not GCC or system clang ++ # using libstdc++). + if(LINUX) -+ # Statically link libc++ and libc++abi into crashpad_handler so it has -+ # no runtime dependency on libc++.so.1. This is needed when building with -+ # a toolchain that uses libc++ (e.g. UE clang) but deploys to systems -+ # where libc++.so.1 is not available. -+ # -nostdlib++ suppresses clang's automatic -lc++ addition (a linker flag, -+ # added at the end). The explicit -Bstatic libs are added via -+ # target_link_libraries so they appear after crashpad's static archives in -+ # the link order, letting the single-pass linker resolve all libc++ symbols. -+ target_link_options(crashpad_handler PRIVATE -nostdlib++) -+ target_link_libraries(crashpad_handler PRIVATE -+ -Wl,-Bstatic,-lc++,-lc++abi,-Bdynamic -+ ) ++ string(FIND "${CMAKE_CXX_FLAGS}" "-stdlib=libc++" _libcxx_pos) ++ if(NOT _libcxx_pos EQUAL -1) ++ target_link_options(crashpad_handler PRIVATE -nostdlib++) ++ target_link_libraries(crashpad_handler PRIVATE ++ -Wl,-Bstatic,-lc++,-lc++abi,-Bdynamic ++ ) ++ endif() + endif() + set_property(TARGET crashpad_handler PROPERTY EXPORT_NAME crashpad_handler) add_executable(crashpad::handler ALIAS crashpad_handler) - + diff --git a/repo/packages/s/sentry-native/xmake.lua b/repo/packages/s/sentry-native/xmake.lua index 43672b9de..0da513ead 100644 --- a/repo/packages/s/sentry-native/xmake.lua +++ b/repo/packages/s/sentry-native/xmake.lua @@ -37,7 +37,7 @@ package("sentry-native") add_versions("0.4.4", "fe6c711d42861e66e53bfd7ee0b2b226027c64446857f0d1bbb239ca824a3d8d") add_patches("0.4.4", path.join(os.scriptdir(), "patches", "0.4.4", "zlib_fix.patch"), "1a6ac711b7824112a9062ec1716a316facce5055498d1f87090d2cad031b865b") add_patches("0.7.6", path.join(os.scriptdir(), "patches", "0.7.6", "breakpad_exceptions.patch"), "7781bad0404a92252cbad39e865d17ac663eedade03cbd29c899636c7bfab1b5") - add_patches("0.12.1", path.join(os.scriptdir(), "patches", "0.12.1", "crashpad_static_libcxx.patch"), "3c2115b90179808fa639865f6eb23090e2cb6025d816ffb66c2d75c26473ec72") + add_patches("0.12.1", path.join(os.scriptdir(), "patches", "0.12.1", "crashpad_static_libcxx.patch"), "e297c1b9dc58f446edfec5566a73c9e3e6b53c207f7247d45b93c640af2bff1a") add_patches("0.12.1", path.join(os.scriptdir(), "patches", "0.12.1", "breakpad_exceptions.patch"), "9e0cd152192f87b9ce182c8ddff22c0471acb99bd61a872ca48afbbacdf27575") add_deps("cmake") diff --git a/scripts/docker.lua b/scripts/docker.lua new file mode 100644 index 000000000..f66f8db86 --- /dev/null +++ b/scripts/docker.lua @@ -0,0 +1,88 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +import("core.base.option") + +-------------------------------------------------------------------------------- +local function _get_version() + local version_file = path.join(os.projectdir(), "VERSION.txt") + local version = io.readfile(version_file) + if version then + version = version:trim() + end + if not version or version == "" then + raise("Failed to read version from VERSION.txt") + end + return version +end + +-------------------------------------------------------------------------------- +function main() + local registry = option.get("registry") + local tag = option.get("tag") + local push = option.get("push") + local no_wine = option.get("no-wine") + local win_binary = option.get("win-binary") + + if not tag then + tag = _get_version() + end + + local image_name = no_wine and "zenserver-compute-linux" or "zenserver-compute" + if registry then + image_name = registry .. "/" .. image_name + end + + local full_tag = image_name .. ":" .. tag + + -- Verify the zenserver binary exists + local binary_path = path.join(os.projectdir(), "build/linux/x86_64/release/zenserver") + if not os.isfile(binary_path) then + raise("zenserver binary not found at %s\nBuild it first with: xmake config -y -m release -a x64 && xmake build -y zenserver", binary_path) + end + + -- Stage Windows binary if provided + local win_staging_dir = nil + if win_binary then + if not os.isfile(win_binary) then + raise("Windows binary not found at %s", win_binary) + end + win_staging_dir = path.join(os.projectdir(), "build/win-binary-staging") + os.mkdir(win_staging_dir) + os.cp(win_binary, path.join(win_staging_dir, "zenserver.exe")) + print("-- Including Windows binary: %s", win_binary) + end + + -- Build the Docker image + local dockerfile = path.join(os.projectdir(), "docker/Dockerfile") + print("-- Building Docker image: %s", full_tag) + local args = {"build", "-t", full_tag, "-f", dockerfile} + if no_wine then + table.insert(args, "--build-arg") + table.insert(args, "INSTALL_WINE=false") + end + if win_staging_dir then + table.insert(args, "--build-arg") + table.insert(args, "WIN_BINARY_DIR=build/win-binary-staging") + end + table.insert(args, os.projectdir()) + local ret = os.execv("docker", args) + if ret > 0 then + raise("Docker build failed") + end + + -- Clean up staging directory + if win_staging_dir then + os.rmdir(win_staging_dir) + end + + print("-- Built image: %s", full_tag) + + if push then + print("-- Pushing image: %s", full_tag) + ret = os.execv("docker", {"push", full_tag}) + if ret > 0 then + raise("Docker push failed") + end + print("-- Pushed image: %s", full_tag) + end +end diff --git a/scripts/test.lua b/scripts/test.lua new file mode 100644 index 000000000..df1218ce8 --- /dev/null +++ b/scripts/test.lua @@ -0,0 +1,401 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +function main() + import("core.base.option") + import("core.project.config") + import("core.project.project") + + config.load() + + -- Override table: target name -> short name (for targets that don't follow convention) + local short_name_overrides = { + ["zenserver-test"] = "integration", + } + + -- Build test list from targets in the "tests" group + local available_tests = {} + for name, target in pairs(project.targets()) do + if target:get("group") == "tests" and name:endswith("-test") then + local short = short_name_overrides[name] + if not short then + -- Derive short name: "zencore-test" -> "core" + short = name + if short:startswith("zen") then short = short:sub(4) end + if short:endswith("-test") then short = short:sub(1, -6) end + end + table.insert(available_tests, {short, name}) + end + end + + -- Add non-test-group entries that have a test subcommand + table.insert(available_tests, {"server", "zenserver"}) + + table.sort(available_tests, function(a, b) return a[1] < b[1] end) + + -- Handle --list: print discovered test names and exit + if option.get("list") then + printf("Available tests:\n") + for _, entry in ipairs(available_tests) do + printf(" %-16s -> %s\n", entry[1], entry[2]) + end + return + end + + local testname = option.get("run") + + -- Parse comma-separated test names into a set + local requested = {} + for token in testname:gmatch("[^,]+") do + requested[token:match("^%s*(.-)%s*$")] = true + end + + -- Filter to requested test(s) + local tests = {} + local matched = {} + + for _, entry in ipairs(available_tests) do + local name, target = entry[1], entry[2] + if requested["all"] or requested[name] then + table.insert(tests, {name = name, target = target}) + matched[name] = true + end + end + + -- Check for unknown test names + if not requested["all"] then + for name, _ in pairs(requested) do + if not matched[name] then + raise("no tests match specification: '%s'", name) + end + end + end + + if #tests == 0 then + raise("no tests match specification: '%s'", testname) + end + + local plat, arch + if is_host("windows") then + plat = "windows" + arch = "x64" + elseif is_host("macosx") then + plat = "macosx" + arch = is_arch("arm64") and "arm64" or "x86_64" + else + plat = "linux" + arch = "x86_64" + end + + -- Only reconfigure if current config doesn't already match + if config.get("mode") ~= "debug" or config.get("plat") ~= plat or config.get("arch") ~= arch then + local toolchain_flag = config.get("toolchain") and ("--toolchain=" .. config.get("toolchain")) or "" + local sdk_flag = config.get("sdk") and ("--sdk=" .. config.get("sdk")) or "" + os.exec("xmake config -y -c -m debug -p %s -a %s %s %s", plat, arch, toolchain_flag, sdk_flag) + end + + -- Build targets we're going to run + if requested["all"] then + os.exec("xmake build -y") + else + for _, entry in ipairs(tests) do + os.exec("xmake build -y %s", entry.target) + end + end + + local use_junit_reporting = option.get("junit") + local use_noskip = option.get("noskip") + local use_verbose = option.get("verbose") + local repeat_count = tonumber(option.get("repeat")) or 1 + local extra_args = option.get("arguments") or {} + local junit_report_files = {} + + local junit_report_dir + if use_junit_reporting then + junit_report_dir = path.join(os.projectdir(), config.get("buildir"), "reports") + os.mkdir(junit_report_dir) + end + + -- Results collection for summary table + local results = {} + local any_failed = false + + -- Format a number with thousands separators (e.g. 31103 -> "31,103") + local function format_number(n) + local s = tostring(n) + local pos = #s % 3 + if pos == 0 then pos = 3 end + local result = s:sub(1, pos) + for i = pos + 1, #s, 3 do + result = result .. "," .. s:sub(i, i + 2) + end + return result + end + + -- Center a string within a given width + local function center_str(s, width) + local pad = width - #s + local lpad = math.floor(pad / 2) + local rpad = pad - lpad + return string.rep(" ", lpad) .. s .. string.rep(" ", rpad) + end + + -- Left-align a string within a given width (with 1-space left margin) + local function left_align_str(s, width) + return " " .. s .. string.rep(" ", width - #s - 1) + end + + -- Right-align a string within a given width (with 1-space right margin) + local function right_align_str(s, width) + return string.rep(" ", width - #s - 1) .. s .. " " + end + + -- Format elapsed seconds as a human-readable string + local function format_time(seconds) + if seconds >= 60 then + local mins = math.floor(seconds / 60) + local secs = seconds - mins * 60 + return string.format("%dm %04.1fs", mins, secs) + else + return string.format("%.1fs", seconds) + end + end + + -- Parse test summary file written by TestListener + local function parse_summary_file(filepath) + if not os.isfile(filepath) then return nil end + local content = io.readfile(filepath) + if not content then return nil end + local ct = content:match("cases_total=(%d+)") + local cp = content:match("cases_passed=(%d+)") + local at = content:match("assertions_total=(%d+)") + local ap = content:match("assertions_passed=(%d+)") + if ct then + local failures = {} + for name, file, line in content:gmatch("failed=([^|\n]+)|([^|\n]+)|(%d+)") do + table.insert(failures, {name = name, file = file, line = tonumber(line)}) + end + local es = content:match("elapsed_seconds=([%d%.]+)") + return { + cases_total = tonumber(ct), + cases_passed = tonumber(cp) or 0, + asserts_total = tonumber(at) or 0, + asserts_passed = tonumber(ap) or 0, + elapsed_seconds = tonumber(es) or 0, + failures = failures + } + end + return nil + end + + -- Temp directory for summary files + local summary_dir = path.join(os.tmpdir(), "zen-test-summary") + os.mkdir(summary_dir) + + -- Run each test suite and collect results + for iteration = 1, repeat_count do + if repeat_count > 1 then + printf("\n*** Iteration %d/%d ***\n", iteration, repeat_count) + end + + for _, entry in ipairs(tests) do + local name, target = entry.name, entry.target + printf("=== %s ===\n", target) + + local suite_name = target + if name == "server" then + suite_name = "zenserver (test)" + end + + local cmd = string.format("xmake run %s", target) + if name == "server" then + cmd = string.format("xmake run %s test", target) + end + cmd = string.format("%s --duration=true", cmd) + + if use_junit_reporting then + local junit_report_file = path.join(junit_report_dir, string.format("junit-%s-%s-%s.xml", config.plat(), arch, target)) + junit_report_files[target] = junit_report_file + cmd = string.format("%s --reporters=junit --out=%s", cmd, junit_report_file) + end + if use_noskip then + cmd = string.format("%s --no-skip", cmd) + end + if use_verbose and name == "integration" then + cmd = string.format("%s --verbose", cmd) + end + for _, arg in ipairs(extra_args) do + cmd = string.format("%s %s", cmd, arg) + end + + -- Tell TestListener where to write the summary + local summary_file = path.join(summary_dir, target .. ".txt") + os.setenv("ZEN_TEST_SUMMARY_FILE", summary_file) + + -- Run test with real-time streaming output + local test_ok = true + try { + function() + os.exec(cmd) + end, + catch { + function(errors) + test_ok = false + end + } + } + + -- Read summary written by TestListener + local summary = parse_summary_file(summary_file) + os.tryrm(summary_file) + + if not test_ok then + any_failed = true + end + + table.insert(results, { + suite = suite_name, + cases_passed = summary and summary.cases_passed or 0, + cases_total = summary and summary.cases_total or 0, + asserts_passed = summary and summary.asserts_passed or 0, + asserts_total = summary and summary.asserts_total or 0, + elapsed_seconds = summary and summary.elapsed_seconds or 0, + failures = summary and summary.failures or {}, + passed = test_ok + }) + end + + if any_failed then + if repeat_count > 1 then + printf("\n*** Failure detected on iteration %d, stopping ***\n", iteration) + end + break + end + end + + -- Clean up + os.setenv("ZEN_TEST_SUMMARY_FILE", "") + os.tryrm(summary_dir) + + -- Print JUnit reports if requested + for test, junit_report_file in pairs(junit_report_files) do + printf("=== report - %s ===\n", test) + if os.isfile(junit_report_file) then + local data = io.readfile(junit_report_file) + if data then + print(data) + end + end + end + + -- Print summary table + if #results > 0 then + -- Calculate column widths based on content + local col_suite = #("Suite") + local col_cases = #("Cases") + local col_asserts = #("Assertions") + local col_time = #("Time") + local col_status = #("Status") + + -- Compute totals + local total_cases_passed = 0 + local total_cases_total = 0 + local total_asserts_passed = 0 + local total_asserts_total = 0 + local total_elapsed = 0 + + for _, r in ipairs(results) do + col_suite = math.max(col_suite, #r.suite) + local cases_str = format_number(r.cases_passed) .. "/" .. format_number(r.cases_total) + col_cases = math.max(col_cases, #cases_str) + local asserts_str = format_number(r.asserts_passed) .. "/" .. format_number(r.asserts_total) + col_asserts = math.max(col_asserts, #asserts_str) + col_time = math.max(col_time, #format_time(r.elapsed_seconds)) + local status_str = r.passed and "SUCCESS" or "FAILED" + col_status = math.max(col_status, #status_str) + + total_cases_passed = total_cases_passed + r.cases_passed + total_cases_total = total_cases_total + r.cases_total + total_asserts_passed = total_asserts_passed + r.asserts_passed + total_asserts_total = total_asserts_total + r.asserts_total + total_elapsed = total_elapsed + r.elapsed_seconds + end + + -- Account for totals row in column widths + col_suite = math.max(col_suite, #("Total")) + col_cases = math.max(col_cases, #(format_number(total_cases_passed) .. "/" .. format_number(total_cases_total))) + col_asserts = math.max(col_asserts, #(format_number(total_asserts_passed) .. "/" .. format_number(total_asserts_total))) + col_time = math.max(col_time, #format_time(total_elapsed)) + + -- Add padding (1 space each side) + col_suite = col_suite + 2 + col_cases = col_cases + 2 + col_asserts = col_asserts + 2 + col_time = col_time + 2 + col_status = col_status + 2 + + -- Build horizontal border segments + local h_suite = string.rep("-", col_suite) + local h_cases = string.rep("-", col_cases) + local h_asserts = string.rep("-", col_asserts) + local h_time = string.rep("-", col_time) + local h_status = string.rep("-", col_status) + + local top = "+" .. h_suite .. "+" .. h_cases .. "+" .. h_asserts .. "+" .. h_time .. "+" .. h_status .. "+" + local mid = "+" .. h_suite .. "+" .. h_cases .. "+" .. h_asserts .. "+" .. h_time .. "+" .. h_status .. "+" + local bottom = "+" .. h_suite .. "+" .. h_cases .. "+" .. h_asserts .. "+" .. h_time .. "+" .. h_status .. "+" + local vbar = "|" + + local header_msg = any_failed and "Some tests failed:" or "All tests passed:" + printf("\n* %s\n", header_msg) + printf(" %s\n", top) + printf(" %s%s%s%s%s%s%s%s%s%s%s\n", vbar, center_str("Suite", col_suite), vbar, center_str("Cases", col_cases), vbar, center_str("Assertions", col_asserts), vbar, center_str("Time", col_time), vbar, center_str("Status", col_status), vbar) + + for _, r in ipairs(results) do + printf(" %s\n", mid) + local cases_str = format_number(r.cases_passed) .. "/" .. format_number(r.cases_total) + local asserts_str = format_number(r.asserts_passed) .. "/" .. format_number(r.asserts_total) + local time_str = format_time(r.elapsed_seconds) + local status_str = r.passed and "SUCCESS" or "FAILED" + printf(" %s%s%s%s%s%s%s%s%s%s%s\n", vbar, left_align_str(r.suite, col_suite), vbar, right_align_str(cases_str, col_cases), vbar, right_align_str(asserts_str, col_asserts), vbar, right_align_str(time_str, col_time), vbar, right_align_str(status_str, col_status), vbar) + end + + -- Totals row + if #results > 1 then + local h_suite_eq = string.rep("=", col_suite) + local h_cases_eq = string.rep("=", col_cases) + local h_asserts_eq = string.rep("=", col_asserts) + local h_time_eq = string.rep("=", col_time) + local h_status_eq = string.rep("=", col_status) + local totals_sep = "+" .. h_suite_eq .. "+" .. h_cases_eq .. "+" .. h_asserts_eq .. "+" .. h_time_eq .. "+" .. h_status_eq .. "+" + printf(" %s\n", totals_sep) + + local total_cases_str = format_number(total_cases_passed) .. "/" .. format_number(total_cases_total) + local total_asserts_str = format_number(total_asserts_passed) .. "/" .. format_number(total_asserts_total) + local total_time_str = format_time(total_elapsed) + local total_status_str = any_failed and "FAILED" or "SUCCESS" + printf(" %s%s%s%s%s%s%s%s%s%s%s\n", vbar, left_align_str("Total", col_suite), vbar, right_align_str(total_cases_str, col_cases), vbar, right_align_str(total_asserts_str, col_asserts), vbar, right_align_str(total_time_str, col_time), vbar, right_align_str(total_status_str, col_status), vbar) + end + + printf(" %s\n", bottom) + end + + -- Print list of individual failing tests + if any_failed then + printf("\n Failures:\n") + for _, r in ipairs(results) do + if #r.failures > 0 then + printf(" -- %s --\n", r.suite) + for _, f in ipairs(r.failures) do + printf(" FAILED: %s (%s:%d)\n", f.name, f.file, f.line) + end + elseif not r.passed then + printf(" -- %s --\n", r.suite) + printf(" (test binary exited with error, no failure details available)\n") + end + end + end + + if any_failed then + raise("one or more test suites failed") + end +end diff --git a/scripts/test_scripts/block-clone-test-mac.sh b/scripts/test_scripts/block-clone-test-mac.sh new file mode 100755 index 000000000..a3d3ca4d3 --- /dev/null +++ b/scripts/test_scripts/block-clone-test-mac.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Test block-clone functionality on macOS (APFS). +# +# APFS is the default filesystem on modern Macs and natively supports +# clonefile(), so no special setup is needed — just run the tests. +# +# Usage: +# ./scripts/test_scripts/block-clone-test-mac.sh [path-to-zencore-test] +# +# If no path is given, defaults to build/macosx/<arch>/debug/zencore-test +# relative to the repository root. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +ARCH="$(uname -m)" +TEST_BINARY="${1:-$REPO_ROOT/build/macosx/$ARCH/debug/zencore-test}" + +if [ ! -x "$TEST_BINARY" ]; then + echo "error: test binary not found or not executable: $TEST_BINARY" >&2 + echo "hint: build with 'xmake config -m debug && xmake build zencore-test'" >&2 + exit 1 +fi + +# Verify we're on APFS +BINARY_DIR="$(dirname "$TEST_BINARY")" +FS_TYPE="$(diskutil info "$(df "$BINARY_DIR" | tail -1 | awk '{print $1}')" 2>/dev/null | grep "Type (Bundle)" | awk '{print $NF}' || true)" + +if [ "$FS_TYPE" != "apfs" ]; then + echo "warning: filesystem does not appear to be APFS (got: ${FS_TYPE:-unknown}), clone tests may skip" >&2 +fi + +TEST_CASES="TryCloneFile,CopyFile.Clone,SupportsBlockRefCounting,CloneQueryInterface" + +echo "Running block-clone tests ..." +echo "---" +"$TEST_BINARY" \ + --test-suite="core.filesystem" \ + --test-case="$TEST_CASES" +echo "---" +echo "All block-clone tests passed." diff --git a/scripts/test_scripts/block-clone-test-windows.ps1 b/scripts/test_scripts/block-clone-test-windows.ps1 new file mode 100644 index 000000000..df24831a4 --- /dev/null +++ b/scripts/test_scripts/block-clone-test-windows.ps1 @@ -0,0 +1,145 @@ +# Test block-clone functionality on a temporary ReFS VHD. +# +# Requires: +# - Administrator privileges +# - Windows Server, or Windows 10/11 Pro for Workstations (ReFS support) +# - Hyper-V PowerShell module (for New-VHD), or diskpart fallback +# +# Usage: +# # From an elevated PowerShell prompt: +# .\scripts\test_scripts\block-clone-test-windows.ps1 [-TestBinary <path>] +# +# If -TestBinary is not given, defaults to build\windows\x64\debug\zencore-test.exe +# relative to the repository root. + +param( + [string]$TestBinary = "" +) + +$ErrorActionPreference = "Stop" + +$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Definition +$RepoRoot = (Resolve-Path "$ScriptDir\..\..").Path + +if (-not $TestBinary) { + $TestBinary = Join-Path $RepoRoot "build\windows\x64\debug\zencore-test.exe" +} + +$ImageSizeMB = 2048 +$TestCases = "TryCloneFile,CopyFile.Clone,SupportsBlockRefCounting,CloneQueryInterface" + +$VhdPath = "" +$MountLetter = "" + +function Cleanup { + $ErrorActionPreference = "SilentlyContinue" + + if ($MountLetter) { + Write-Host "Dismounting VHD ..." + Dismount-VHD -Path $VhdPath -ErrorAction SilentlyContinue + } + if ($VhdPath -and (Test-Path $VhdPath)) { + Remove-Item -Force $VhdPath -ErrorAction SilentlyContinue + } +} + +trap { + Cleanup + throw $_ +} + +# --- Preflight checks --- + +$IsAdmin = ([Security.Principal.WindowsPrincipal] [Security.Principal.WindowsIdentity]::GetCurrent()).IsInRole( + [Security.Principal.WindowsBuiltInRole]::Administrator) +if (-not $IsAdmin) { + Write-Error "This script must be run as Administrator (for VHD mount/format)." + exit 1 +} + +if (-not (Test-Path $TestBinary)) { + Write-Error "Test binary not found: $TestBinary`nHint: build with 'xmake config -m debug && xmake build zencore-test'" + exit 1 +} + +# Check that ReFS formatting is available +$RefsAvailable = $true +try { + # A quick check: on non-Server/Workstation SKUs, Format-Volume -FileSystem ReFS will fail + $OsCaption = (Get-CimInstance Win32_OperatingSystem).Caption + if ($OsCaption -notmatch "Server|Workstation|Enterprise") { + Write-Warning "ReFS may not be available on this Windows edition: $OsCaption" + Write-Warning "Continuing anyway — format step will fail if unsupported." + } +} catch { + # Non-fatal, just proceed +} + +# --- Create and mount ReFS VHD --- + +$VhdPath = Join-Path $env:TEMP "refs-clone-test-$([guid]::NewGuid().ToString('N').Substring(0,8)).vhdx" + +Write-Host "Creating ${ImageSizeMB}MB VHDX at $VhdPath ..." + +try { + # Prefer Hyper-V cmdlet if available + New-VHD -Path $VhdPath -SizeBytes ($ImageSizeMB * 1MB) -Fixed | Out-Null +} catch { + # Fallback to diskpart + Write-Host "New-VHD not available, falling back to diskpart ..." + $DiskpartScript = @" +create vdisk file="$VhdPath" maximum=$ImageSizeMB type=fixed +"@ + $DiskpartScript | diskpart | Out-Null +} + +Write-Host "Mounting and initializing VHD ..." + +Mount-VHD -Path $VhdPath +$Disk = Get-VHD -Path $VhdPath | Get-Disk + +# Suppress Explorer's auto-open / "format disk?" prompts for the raw partition +Stop-Service ShellHWDetection -ErrorAction SilentlyContinue + +try { + Initialize-Disk -Number $Disk.Number -PartitionStyle GPT -ErrorAction SilentlyContinue + $Partition = New-Partition -DiskNumber $Disk.Number -UseMaximumSize -AssignDriveLetter + $MountLetter = $Partition.DriveLetter + + Write-Host "Formatting ${MountLetter}: as ReFS with integrity disabled ..." + Format-Volume -DriveLetter $MountLetter -FileSystem ReFS -NewFileSystemLabel "CloneTest" -Confirm:$false | Out-Null + + # Disable integrity streams (required for block cloning to work on ReFS) + Set-FileIntegrity "${MountLetter}:\" -Enable $false -ErrorAction SilentlyContinue +} finally { + Start-Service ShellHWDetection -ErrorAction SilentlyContinue +} + +$MountRoot = "${MountLetter}:\" + +# --- Copy test binary and run --- + +Write-Host "Copying test binary to ReFS volume ..." +Copy-Item $TestBinary "$MountRoot\zencore-test.exe" + +Write-Host "Running block-clone tests ..." +Write-Host "---" + +$proc = Start-Process -FilePath "$MountRoot\zencore-test.exe" ` + -ArgumentList "--test-suite=core.filesystem", "--test-case=$TestCases" ` + -NoNewWindow -Wait -PassThru + +Write-Host "---" + +if ($proc.ExitCode -ne 0) { + Write-Error "Tests failed with exit code $($proc.ExitCode)" + Cleanup + exit $proc.ExitCode +} + +Write-Host "ReFS: all block-clone tests passed." + +# --- Cleanup --- + +Cleanup +Write-Host "Done." diff --git a/scripts/test_scripts/block-clone-test.sh b/scripts/test_scripts/block-clone-test.sh new file mode 100755 index 000000000..7c6bf5605 --- /dev/null +++ b/scripts/test_scripts/block-clone-test.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +# Test block-clone functionality on temporary Btrfs and XFS loopback filesystems. +# +# Requires: root/sudo, btrfs-progs (mkfs.btrfs), xfsprogs (mkfs.xfs) +# +# Usage: +# sudo ./scripts/test_scripts/block-clone-test.sh [path-to-zencore-test] +# +# If no path is given, defaults to build/linux/x86_64/debug/zencore-test +# relative to the repository root. +# +# Options: +# --btrfs-only Only test Btrfs +# --xfs-only Only test XFS + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +TEST_BINARY="" +RUN_BTRFS=true +RUN_XFS=true + +for arg in "$@"; do + case "$arg" in + --btrfs-only) RUN_XFS=false ;; + --xfs-only) RUN_BTRFS=false ;; + *) TEST_BINARY="$arg" ;; + esac +done + +TEST_BINARY="${TEST_BINARY:-$REPO_ROOT/build/linux/x86_64/debug/zencore-test}" +IMAGE_SIZE="512M" +TEST_CASES="TryCloneFile,CopyFile.Clone,SupportsBlockRefCounting,CloneQueryInterface" + +# Track all temp files for cleanup +CLEANUP_MOUNTS=() +CLEANUP_DIRS=() +CLEANUP_FILES=() + +cleanup() { + local exit_code=$? + set +e + + for mnt in "${CLEANUP_MOUNTS[@]}"; do + if mountpoint -q "$mnt" 2>/dev/null; then + umount "$mnt" + fi + done + for dir in "${CLEANUP_DIRS[@]}"; do + [ -d "$dir" ] && rmdir "$dir" + done + for f in "${CLEANUP_FILES[@]}"; do + [ -f "$f" ] && rm -f "$f" + done + + if [ $exit_code -ne 0 ]; then + echo "FAILED (exit code $exit_code)" + fi + exit $exit_code +} +trap cleanup EXIT + +# --- Preflight checks --- + +if [ "$(id -u)" -ne 0 ]; then + echo "error: this script must be run as root (for mount/umount)" >&2 + exit 1 +fi + +if [ ! -x "$TEST_BINARY" ]; then + echo "error: test binary not found or not executable: $TEST_BINARY" >&2 + echo "hint: build with 'xmake config -m debug && xmake build zencore-test'" >&2 + exit 1 +fi + +if $RUN_BTRFS && ! command -v mkfs.btrfs &>/dev/null; then + echo "warning: mkfs.btrfs not found — install btrfs-progs to test Btrfs, skipping" >&2 + RUN_BTRFS=false +fi + +if $RUN_XFS && ! command -v mkfs.xfs &>/dev/null; then + echo "warning: mkfs.xfs not found — install xfsprogs to test XFS, skipping" >&2 + RUN_XFS=false +fi + +if ! $RUN_BTRFS && ! $RUN_XFS; then + echo "error: no filesystems to test" >&2 + exit 1 +fi + +# --- Helper to create, mount, and run tests on a loopback filesystem --- + +run_tests_on_fs() { + local fs_type="$1" + local mkfs_cmd="$2" + + echo "" + echo "========================================" + echo " Testing block-clone on $fs_type" + echo "========================================" + + local image_path mount_path + image_path="$(mktemp "/tmp/${fs_type}-clone-test-XXXXXX.img")" + mount_path="$(mktemp -d "/tmp/${fs_type}-clone-mount-XXXXXX")" + CLEANUP_FILES+=("$image_path") + CLEANUP_DIRS+=("$mount_path") + CLEANUP_MOUNTS+=("$mount_path") + + echo "Creating ${IMAGE_SIZE} ${fs_type} image at ${image_path} ..." + truncate -s "$IMAGE_SIZE" "$image_path" + $mkfs_cmd "$image_path" + + echo "Mounting at ${mount_path} ..." + mount -o loop "$image_path" "$mount_path" + chmod 777 "$mount_path" + + echo "Copying test binary ..." + cp "$TEST_BINARY" "$mount_path/zencore-test" + chmod +x "$mount_path/zencore-test" + + echo "Running tests ..." + echo "---" + "$mount_path/zencore-test" \ + --test-suite="core.filesystem" \ + --test-case="$TEST_CASES" + echo "---" + echo "$fs_type: all block-clone tests passed." +} + +# --- Run --- + +if $RUN_BTRFS; then + run_tests_on_fs "btrfs" "mkfs.btrfs -q" +fi + +if $RUN_XFS; then + run_tests_on_fs "xfs" "mkfs.xfs -q -m reflink=1" +fi + +echo "" +echo "All block-clone tests passed." diff --git a/scripts/test_scripts/builds-download-upload-test.py b/scripts/test_scripts/builds-download-upload-test.py index e4fee7cb8..8ff5245c1 100644 --- a/scripts/test_scripts/builds-download-upload-test.py +++ b/scripts/test_scripts/builds-download-upload-test.py @@ -4,6 +4,8 @@ from __future__ import annotations import argparse +import json +import os import platform import subprocess import sys @@ -15,22 +17,51 @@ _ARCH = "x64" if sys.platform == "win32" else platform.machine().lower() _EXE_SUFFIX = ".exe" if sys.platform == "win32" else "" +def _cache_dir() -> Path: + if sys.platform == "win32": + base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local")) + return base / "Temp" / "zen" + elif sys.platform == "darwin": + return Path.home() / "Library" / "Caches" / "zen" + else: + base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) + return base / "zen" + + +_BUILD_IDS_PATH = _cache_dir() / "builds-download-upload-build-ids.json" + + class Build(NamedTuple): name: str bucket: str id: str -BUILDS = [ - Build("XB1Client", "fortnitegame.staged-build.fortnite-main.xb1-client", "09a7616c1a388dfe6056aa57"), - Build("WindowsClient", "fortnitegame.staged-build.fortnite-main.windows-client", "09a762c81e2cf213142d0ce5"), - Build("SwitchClient", "fortnitegame.staged-build.fortnite-main.switch-client", "09a75bf9c3ce75bce09f644f"), - Build("LinuxServer", "fortnitegame.staged-build.fortnite-main.linux-server", "09a750ac155eb3e3b62e87e0"), - Build("Switch2Client", "fortnitegame.staged-build.fortnite-main.switch2-client", "09a78f3df07b289691ec5710"), - Build("PS4Client", "fortnitegame.staged-build.fortnite-main.ps4-client", "09a76ea92ad301d4724fafad"), - Build("IOSClient", "fortnitegame.staged-build.fortnite-main.ios-client", "09a7816fa26c23362fef0c5d"), - Build("AndroidClient", "fortnitegame.staged-build.fortnite-main.android-client", "09a76725f1620d62c6be06e4"), -] +def load_builds() -> tuple[str, list[Build]]: + if not _BUILD_IDS_PATH.exists(): + print(f"Build IDs file not found: {_BUILD_IDS_PATH}") + answer = input("Run builds-download-upload-update-build-ids.py now to populate it? [y/N] ").strip().lower() + if answer == "y": + update_script = Path(__file__).parent / "builds-download-upload-update-build-ids.py" + subprocess.run([sys.executable, str(update_script)], check=True) + else: + sys.exit("Aborted. Run scripts/test_scripts/builds-download-upload-update-build-ids.py to populate it.") + with _BUILD_IDS_PATH.open() as f: + data: dict = json.load(f) + namespace = data.get("namespace", "") + if not namespace: + sys.exit(f"error: {_BUILD_IDS_PATH} is missing 'namespace'") + builds = [] + for name, entry in data.get("builds", {}).items(): + bucket = entry.get("bucket", "") + build_id = entry.get("buildId", "") + if not bucket or not build_id: + sys.exit(f"error: entry '{name}' in {_BUILD_IDS_PATH} is missing 'bucket' or 'buildId'") + builds.append(Build(name, bucket, build_id)) + if not builds: + sys.exit(f"error: {_BUILD_IDS_PATH} contains no builds") + return namespace, builds + ZEN_EXE: Path = Path(f"./build/{_PLATFORM}/{_ARCH}/release/zen{_EXE_SUFFIX}") ZEN_METADATA_DIR: Path = Path(__file__).resolve().parent / "metadatas" @@ -99,12 +130,12 @@ def wipe_or_create(label: str, path: Path, extra_zen_args: list[str] | None = No print() -def check_prerequisites() -> None: +def check_prerequisites(builds: list[Build]) -> None: if not ZEN_EXE.is_file(): sys.exit(f"error: zen executable not found: {ZEN_EXE}") if not ZEN_METADATA_DIR.is_dir(): sys.exit(f"error: metadata directory not found: {ZEN_METADATA_DIR}") - for build in BUILDS: + for build in builds: metadata = ZEN_METADATA_DIR / f"{build.name}.json" if not metadata.is_file(): sys.exit(f"error: metadata file not found: {metadata}") @@ -145,10 +176,10 @@ def main() -> None: ) parser.add_argument( "--data-path", - default=Path(Path(__file__).stem + "_datadir"), + default=None, type=Path, metavar="PATH", - help=f"root path for all data directories (default: {Path(__file__).stem}_datadir)", + help="root path for all data directories", ) parser.add_argument( "--zen-exe-path", @@ -162,17 +193,24 @@ def main() -> None: data_path = args.positional_path if data_path is None: data_path = args.data_path + if data_path is None: + print("WARNING: This script may require up to 1TB of free disk space.") + raw = input("Enter root path for all data directories: ").strip() + if not raw: + sys.exit("error: data path is required") + data_path = Path(raw) ZEN_EXE = args.zen_exe_positional if ZEN_EXE is None: ZEN_EXE = args.zen_exe_path + namespace, builds = load_builds() zen_system_dir = data_path / "system" zen_download_dir = data_path / "Download" zen_cache_data_dir = data_path / "ZenBuildsCache" zen_upload_dir = data_path / "Upload" zen_chunk_cache_path = data_path / "ChunkCache" - check_prerequisites() + check_prerequisites(builds) start_server("cache zenserver", zen_cache_data_dir, ZEN_CACHE_PORT, extra_zen_args=extra_zen_args, extra_server_args=["--buildstore-enabled"]) @@ -180,12 +218,12 @@ def main() -> None: wipe_or_create("download folder", zen_download_dir, extra_zen_args) wipe_or_create("system folder", zen_system_dir, extra_zen_args) - for build in BUILDS: + for build in builds: print(f"--------- importing {build.name} build") run(zen_cmd( "builds", "download", "--host", "https://jupiter.devtools.epicgames.com", - "--namespace", "fortnite.oplog", + "--namespace", namespace, "--bucket", build.bucket, "--build-id", build.id, "--local-path", zen_download_dir / build.name, @@ -199,7 +237,7 @@ def main() -> None: wipe_or_create("upload folder", zen_upload_dir, extra_zen_args) - for build in BUILDS: + for build in builds: print(f"--------- exporting {build.name} build") run(zen_cmd( "builds", "upload", diff --git a/scripts/test_scripts/builds-download-upload-update-build-ids.py b/scripts/test_scripts/builds-download-upload-update-build-ids.py new file mode 100644 index 000000000..2a63aa44d --- /dev/null +++ b/scripts/test_scripts/builds-download-upload-update-build-ids.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +"""Update builds-download-upload-build-ids.json with build IDs at the highest common changelist across all buckets.""" + +from __future__ import annotations + +import argparse +import json +import os +import platform +import subprocess +import sys +import tempfile +from pathlib import Path + +_PLATFORM = "windows" if sys.platform == "win32" else "macosx" if sys.platform == "darwin" else "linux" +_ARCH = "x64" if sys.platform == "win32" else platform.machine().lower() +_EXE_SUFFIX = ".exe" if sys.platform == "win32" else "" +_DEFAULT_ZEN = Path(f"build/{_PLATFORM}/{_ARCH}/release/zen{_EXE_SUFFIX}") + + +def _cache_dir() -> Path: + if sys.platform == "win32": + base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local")) + return base / "Temp" / "zen" + elif sys.platform == "darwin": + return Path.home() / "Library" / "Caches" / "zen" + else: + base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) + return base / "zen" + + +_OUTPUT_PATH = _cache_dir() / "builds-download-upload-build-ids.json" + +# Maps build name -> Jupiter bucket +_BUILDS: list[tuple[str, str]] = [ + ("XB1Client", "fortnitegame.staged-build.fortnite-main.xb1-client"), + ("WindowsClient", "fortnitegame.staged-build.fortnite-main.windows-client"), + ("SwitchClient", "fortnitegame.staged-build.fortnite-main.switch-client"), + ("LinuxServer", "fortnitegame.staged-build.fortnite-main.linux-server"), + ("Switch2Client", "fortnitegame.staged-build.fortnite-main.switch2-client"), + ("PS4Client", "fortnitegame.staged-build.fortnite-main.ps4-client"), + ("PS5Client", "fortnitegame.staged-build.fortnite-main.ps5-client"), + ("IOSClient", "fortnitegame.staged-build.fortnite-main.ios-client"), + ("AndroidClient", "fortnitegame.staged-build.fortnite-main.android-client"), +] + + +def list_builds_for_bucket(zen: str, host: str, namespace: str, bucket: str) -> list[dict]: + """Run zen builds list for a single bucket and return the results array.""" + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: + result_path = Path(tmp.name) + + cmd = [ + zen, "builds", "list", + "--namespace", namespace, + "--bucket", bucket, + "--host", host, + "--result-path", str(result_path), + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + except FileNotFoundError: + sys.exit(f"error: zen binary not found: {zen}") + except subprocess.CalledProcessError as e: + sys.exit( + f"error: zen builds list failed for bucket '{bucket}' with exit code {e.returncode}\n" + f"stderr: {e.stderr.decode(errors='replace')}" + ) + + with result_path.open() as f: + data = json.load(f) + result_path.unlink(missing_ok=True) + + return data.get("results", []) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Refresh builds-download-upload-build-ids.json with build IDs at the highest changelist present in all buckets." + ) + parser.add_argument("--host", default="https://jupiter.devtools.epicgames.com", help="Jupiter host URL") + parser.add_argument("--zen", default=str(_DEFAULT_ZEN), help="Path to the zen binary") + parser.add_argument("--namespace", default="fortnite.oplog", help="Builds storage namespace") + args = parser.parse_args() + + # For each bucket, fetch results and build a changelist -> buildId map. + # bucket_cl_map[bucket] = { changelist_int: buildId_str, ... } + bucket_cl_map: dict[str, dict[int, str]] = {} + + for name, bucket in _BUILDS: + print(f"Querying {name} ({bucket}) ...") + results = list_builds_for_bucket(args.zen, args.host, args.namespace, bucket) + if not results: + sys.exit(f"error: no results for bucket '{bucket}' (build '{name}')") + + cl_map: dict[int, str] = {} + for entry in results: + build_id = entry.get("buildId", "") + metadata = entry.get("metadata") or {} + cl = metadata.get("commit") + if build_id and cl is not None: + # Keep first occurrence (most recent) per changelist + if cl not in cl_map: + cl_map[int(cl)] = build_id + + if not cl_map: + sys.exit( + f"error: bucket '{bucket}' (build '{name}') returned {len(results)} entries " + "but none had both buildId and changelist in metadata" + ) + + print(f" {len(cl_map)} distinct changelists, latest CL {max(cl_map)}") + bucket_cl_map[bucket] = cl_map + + # Find the highest changelist present in every bucket's result set. + common_cls = set(next(iter(bucket_cl_map.values())).keys()) + for bucket, cl_map in bucket_cl_map.items(): + common_cls &= set(cl_map.keys()) + + if not common_cls: + sys.exit( + "error: no changelist is present in all buckets.\n" + "Per-bucket CL ranges:\n" + + "\n".join( + f" {name} ({bucket}): {min(bucket_cl_map[bucket])} – {max(bucket_cl_map[bucket])}" + for name, bucket in _BUILDS + ) + ) + + best_cl = max(common_cls) + print(f"\nHighest common changelist: {best_cl}") + + build_ids: dict[str, dict[str, str]] = {} + for name, bucket in _BUILDS: + build_id = bucket_cl_map[bucket][best_cl] + build_ids[name] = {"bucket": bucket, "buildId": build_id} + print(f" {name}: {build_id}") + + output = {"namespace": args.namespace, "builds": build_ids} + _OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) + with _OUTPUT_PATH.open("w") as f: + json.dump(output, f, indent=2) + f.write("\n") + + print(f"\nWrote {_OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_scripts/oplog-import-export-test.py b/scripts/test_scripts/oplog-import-export-test.py index b2a5ece6c..f913a7351 100644 --- a/scripts/test_scripts/oplog-import-export-test.py +++ b/scripts/test_scripts/oplog-import-export-test.py @@ -4,6 +4,8 @@ from __future__ import annotations import argparse +import json +import os import platform import subprocess import sys @@ -15,23 +17,51 @@ _ARCH = "x64" if sys.platform == "win32" else platform.machine().lower() _EXE_SUFFIX = ".exe" if sys.platform == "win32" else "" +def _cache_dir() -> Path: + if sys.platform == "win32": + base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local")) + return base / "Temp" / "zen" + elif sys.platform == "darwin": + return Path.home() / "Library" / "Caches" / "zen" + else: + base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) + return base / "zen" + + +_BUILD_IDS_PATH = _cache_dir() / "oplog-import-export-build-ids.json" + + class Build(NamedTuple): name: str bucket: str id: str -BUILDS = [ - Build("XB1Client", "fortnitegame.oplog.fortnite-main.xb1client", "09a75f7f3b7517653dcdaaa4"), - Build("WindowsClient", "fortnitegame.oplog.fortnite-main.windowsclient", "09a75d977ef944ecfd0eddfd"), - Build("SwitchClient", "fortnitegame.oplog.fortnite-main.switchclient", "09a74d03b3598ec94cfd2644"), - Build("XSXClient", "fortnitegame.oplog.fortnite-main.xsxclient", "09a76c2bbd6cd78f4d40d9ea"), - Build("Switch2Client", "fortnitegame.oplog.fortnite-main.switch2client", "09a7686b3d9faa78fb24a38f"), - Build("PS4Client", "fortnitegame.oplog.fortnite-main.ps4client", "09a75b72d1c260ed26020140"), - Build("LinuxServer", "fortnitegame.oplog.fortnite-main.linuxserver", "09a747f5e0ee83a04be013e6"), - Build("IOSClient", "fortnitegame.oplog.fortnite-main.iosclient", "09a75f677e883325a209148c"), - Build("Android_ASTCClient", "fortnitegame.oplog.fortnite-main.android_astcclient", "09a7422c08c6f37becc7d37f"), -] +def load_builds() -> tuple[str, list[Build]]: + if not _BUILD_IDS_PATH.exists(): + print(f"Build IDs file not found: {_BUILD_IDS_PATH}") + answer = input("Run oplog-update-build-ids.py now to populate it? [y/N] ").strip().lower() + if answer == "y": + update_script = Path(__file__).parent / "oplog-update-build-ids.py" + subprocess.run([sys.executable, str(update_script)], check=True) + else: + sys.exit("Aborted. Run scripts/test_scripts/oplog-update-build-ids.py to populate it.") + with _BUILD_IDS_PATH.open() as f: + data: dict = json.load(f) + namespace = data.get("namespace", "") + if not namespace: + sys.exit(f"error: {_BUILD_IDS_PATH} is missing 'namespace'") + builds = [] + for name, entry in data.get("builds", {}).items(): + bucket = entry.get("bucket", "") + build_id = entry.get("buildId", "") + if not bucket or not build_id: + sys.exit(f"error: entry '{name}' in {_BUILD_IDS_PATH} is missing 'bucket' or 'buildId'") + builds.append(Build(name, bucket, build_id)) + if not builds: + sys.exit(f"error: {_BUILD_IDS_PATH} contains no builds") + return namespace, builds + ZEN_EXE: Path = Path(f"./build/{_PLATFORM}/{_ARCH}/release/zen{_EXE_SUFFIX}") @@ -50,6 +80,11 @@ SERVER_ARGS: tuple[str, ...] = ( ) +def zen_cmd(*args: str | Path, extra_zen_args: list[str] | None = None) -> list[str | Path]: + """Build a zen CLI command list, inserting extra_zen_args before subcommands.""" + return [ZEN_EXE, *(extra_zen_args or []), *args] + + def run(cmd: list[str | Path]) -> None: try: subprocess.run(cmd, check=True) @@ -59,31 +94,33 @@ def run(cmd: list[str | Path]) -> None: sys.exit(f"error: command failed with exit code {e.returncode}:\n {' '.join(str(x) for x in e.cmd)}") -def stop_server(label: str, port: int) -> None: +def stop_server(label: str, port: int, extra_zen_args: list[str] | None = None) -> None: """Stop a zen server. Tolerates failures so it is safe to call from finally blocks.""" print(f"--------- stopping {label}") try: - subprocess.run([ZEN_EXE, "down", "--port", str(port)]) + subprocess.run(zen_cmd("down", "--port", str(port), extra_zen_args=extra_zen_args)) except OSError as e: print(f"warning: could not stop {label}: {e}", file=sys.stderr) print() -def start_server(label: str, data_dir: Path, port: int, extra_args: list[str] | None = None) -> None: +def start_server(label: str, data_dir: Path, port: int, extra_zen_args: list[str] | None = None, + extra_server_args: list[str] | None = None) -> None: print(f"--------- starting {label} {data_dir}") - run([ - ZEN_EXE, "up", "--port", str(port), "--show-console", "--", + run(zen_cmd( + "up", "--port", str(port), "--show-console", "--", f"--data-dir={data_dir}", *SERVER_ARGS, - *(extra_args or []), - ]) + *(extra_server_args or []), + extra_zen_args=extra_zen_args, + )) print() -def wipe_or_create(label: str, path: Path) -> None: +def wipe_or_create(label: str, path: Path, extra_zen_args: list[str] | None = None) -> None: if path.exists(): print(f"--------- cleaning {label} {path}") - run([ZEN_EXE, "wipe", "-y", path]) + run(zen_cmd("wipe", "-y", path, extra_zen_args=extra_zen_args)) else: print(f"--------- creating {label} {path}") path.mkdir(parents=True, exist_ok=True) @@ -95,24 +132,39 @@ def check_prerequisites() -> None: sys.exit(f"error: zen executable not found: {ZEN_EXE}") -def setup_project(port: int) -> None: +def setup_project(port: int, extra_zen_args: list[str] | None = None) -> None: """Create the FortniteGame project on the server at the given port.""" print("--------- creating FortniteGame project") - run([ZEN_EXE, "project-create", f"--hosturl=127.0.0.1:{port}", "FortniteGame", "--force-update"]) + run(zen_cmd("project-create", f"--hosturl=127.0.0.1:{port}", "FortniteGame", "--force-update", + extra_zen_args=extra_zen_args)) print() -def setup_oplog(port: int, build_name: str) -> None: +def setup_oplog(port: int, build_name: str, extra_zen_args: list[str] | None = None) -> None: """Create the oplog in the FortniteGame project on the server at the given port.""" print(f"--------- creating {build_name} oplog") - run([ZEN_EXE, "oplog-create", f"--hosturl=127.0.0.1:{port}", "FortniteGame", build_name, "--force-update"]) + run(zen_cmd("oplog-create", f"--hosturl=127.0.0.1:{port}", "FortniteGame", build_name, "--force-update", + extra_zen_args=extra_zen_args)) print() def main() -> None: global ZEN_EXE - parser = argparse.ArgumentParser(description=__doc__) + # Split on '--' to separate script args from extra zen CLI args + script_argv: list[str] = [] + extra_zen_args: list[str] = [] + if "--" in sys.argv[1:]: + sep = sys.argv.index("--", 1) + script_argv = sys.argv[1:sep] + extra_zen_args = sys.argv[sep + 1:] + else: + script_argv = sys.argv[1:] + + parser = argparse.ArgumentParser( + description=__doc__, + epilog="Any arguments after '--' are forwarded to every zen CLI invocation.", + ) parser.add_argument( "positional_path", nargs="?", @@ -131,10 +183,10 @@ def main() -> None: ) parser.add_argument( "--data-path", - default=Path(Path(__file__).stem + "_datadir"), + default=None, type=Path, metavar="PATH", - help=f"root path for all data directories (default: {Path(__file__).stem}_datadir)", + help="root path for all data directories", ) parser.add_argument( "--zen-exe-path", @@ -143,15 +195,22 @@ def main() -> None: metavar="PATH", help=f"path to zen executable (default: {ZEN_EXE})", ) - args = parser.parse_args() + args = parser.parse_args(script_argv) data_path = args.positional_path if data_path is None: data_path = args.data_path + if data_path is None: + print("WARNING: This script may require up to 1TB of free disk space.") + raw = input("Enter root path for all data directories: ").strip() + if not raw: + sys.exit("error: data path is required") + data_path = Path(raw) ZEN_EXE = args.zen_exe_positional if ZEN_EXE is None: ZEN_EXE = args.zen_exe_path + namespace, builds = load_builds() zen_data_dir = data_path / "DDC" / "OplogsZen" zen_cache_data_dir = data_path / "DDC" / "ZenBuildsCache" zen_import_data_dir = data_path / "DDC" / "OplogsZenImport" @@ -159,75 +218,81 @@ def main() -> None: check_prerequisites() - start_server("cache zenserver", zen_cache_data_dir, ZEN_CACHE_PORT, ["--buildstore-enabled"]) + start_server("cache zenserver", zen_cache_data_dir, ZEN_CACHE_PORT, + extra_zen_args=extra_zen_args, extra_server_args=["--buildstore-enabled"]) try: - wipe_or_create("zenserver data", zen_data_dir) - start_server("zenserver", zen_data_dir, ZEN_PORT) + wipe_or_create("zenserver data", zen_data_dir, extra_zen_args) + start_server("zenserver", zen_data_dir, ZEN_PORT, extra_zen_args=extra_zen_args) try: - setup_project(ZEN_PORT) + setup_project(ZEN_PORT, extra_zen_args) - for build in BUILDS: - setup_oplog(ZEN_PORT, build.name) + for build in builds: + setup_oplog(ZEN_PORT, build.name, extra_zen_args) print(f"--------- importing {build.name} oplog") - run([ - ZEN_EXE, "oplog-import", + run(zen_cmd( + "oplog-import", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name, "--clean", "--builds", "https://jupiter.devtools.epicgames.com", - "--namespace", "fortnite.oplog", + "--namespace", namespace, "--bucket", build.bucket, "--builds-id", build.id, f"--zen-cache-host={ZEN_CACHE}", f"--zen-cache-upload={ZEN_CACHE_POPULATE}", f"--allow-partial-block-requests={ZEN_PARTIAL_REQUEST_MODE}", - ]) + extra_zen_args=extra_zen_args, + )) print() print(f"--------- validating {build.name} oplog") - run([ZEN_EXE, "oplog-validate", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name]) + run(zen_cmd("oplog-validate", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name, + extra_zen_args=extra_zen_args)) print() - wipe_or_create("export folder", export_dir) + wipe_or_create("export folder", export_dir, extra_zen_args) - for build in BUILDS: + for build in builds: print(f"--------- exporting {build.name} oplog") - run([ - ZEN_EXE, "oplog-export", + run(zen_cmd( + "oplog-export", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name, "--file", export_dir, "--forcetempblocks", - ]) + extra_zen_args=extra_zen_args, + )) print() finally: - stop_server("zenserver", ZEN_PORT) + stop_server("zenserver", ZEN_PORT, extra_zen_args) - wipe_or_create("alternate zenserver data", zen_import_data_dir) - start_server("import zenserver", zen_import_data_dir, ZEN_PORT) + wipe_or_create("alternate zenserver data", zen_import_data_dir, extra_zen_args) + start_server("import zenserver", zen_import_data_dir, ZEN_PORT, extra_zen_args=extra_zen_args) try: - setup_project(ZEN_PORT) + setup_project(ZEN_PORT, extra_zen_args) - for build in BUILDS: - setup_oplog(ZEN_PORT, build.name) + for build in builds: + setup_oplog(ZEN_PORT, build.name, extra_zen_args) print(f"--------- importing {build.name} oplog") - run([ - ZEN_EXE, "oplog-import", + run(zen_cmd( + "oplog-import", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name, "--file", export_dir, - ]) + extra_zen_args=extra_zen_args, + )) print() print(f"--------- validating {build.name} oplog") - run([ZEN_EXE, "oplog-validate", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name]) + run(zen_cmd("oplog-validate", f"--hosturl=127.0.0.1:{ZEN_PORT}", "FortniteGame", build.name, + extra_zen_args=extra_zen_args)) print() finally: - stop_server("alternative zenserver", ZEN_PORT) + stop_server("alternative zenserver", ZEN_PORT, extra_zen_args) finally: - stop_server("cache zenserver", ZEN_CACHE_PORT) + stop_server("cache zenserver", ZEN_CACHE_PORT, extra_zen_args) if __name__ == "__main__": diff --git a/scripts/test_scripts/oplog-update-build-ids.py b/scripts/test_scripts/oplog-update-build-ids.py new file mode 100644 index 000000000..67e128c8e --- /dev/null +++ b/scripts/test_scripts/oplog-update-build-ids.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +"""Update oplog-import-export-build-ids.json with build IDs at the highest common changelist across all buckets.""" + +from __future__ import annotations + +import argparse +import json +import os +import platform +import subprocess +import sys +import tempfile +from pathlib import Path + +_PLATFORM = "windows" if sys.platform == "win32" else "macosx" if sys.platform == "darwin" else "linux" +_ARCH = "x64" if sys.platform == "win32" else platform.machine().lower() +_EXE_SUFFIX = ".exe" if sys.platform == "win32" else "" +_DEFAULT_ZEN = Path(f"build/{_PLATFORM}/{_ARCH}/release/zen{_EXE_SUFFIX}") + + +def _cache_dir() -> Path: + if sys.platform == "win32": + base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local")) + return base / "Temp" / "zen" + elif sys.platform == "darwin": + return Path.home() / "Library" / "Caches" / "zen" + else: + base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) + return base / "zen" + + +_OUTPUT_PATH = _cache_dir() / "oplog-import-export-build-ids.json" + +# Maps build name -> Jupiter bucket +_BUILDS: list[tuple[str, str]] = [ + ("XB1Client", "fortnitegame.oplog.fortnite-main.xb1client"), + ("WindowsClient", "fortnitegame.oplog.fortnite-main.windowsclient"), + ("SwitchClient", "fortnitegame.oplog.fortnite-main.switchclient"), + ("XSXClient", "fortnitegame.oplog.fortnite-main.xsxclient"), + ("Switch2Client", "fortnitegame.oplog.fortnite-main.switch2client"), + ("PS4Client", "fortnitegame.oplog.fortnite-main.ps4client"), + ("PS5Client", "fortnitegame.oplog.fortnite-main.ps5client"), + ("LinuxServer", "fortnitegame.oplog.fortnite-main.linuxserver"), + ("IOSClient", "fortnitegame.oplog.fortnite-main.iosclient"), + ("Android_ASTCClient", "fortnitegame.oplog.fortnite-main.android_astcclient"), +] + + +def list_builds_for_bucket(zen: str, host: str, namespace: str, bucket: str) -> list[dict]: + """Run zen builds list for a single bucket and return the results array.""" + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as tmp: + result_path = Path(tmp.name) + + cmd = [ + zen, "builds", "list", + "--namespace", namespace, + "--bucket", bucket, + "--host", host, + "--result-path", str(result_path), + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + except FileNotFoundError: + sys.exit(f"error: zen binary not found: {zen}") + except subprocess.CalledProcessError as e: + sys.exit( + f"error: zen builds list failed for bucket '{bucket}' with exit code {e.returncode}\n" + f"stderr: {e.stderr.decode(errors='replace')}" + ) + + with result_path.open() as f: + data = json.load(f) + result_path.unlink(missing_ok=True) + + return data.get("results", []) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Refresh oplog-import-export-build-ids.json with build IDs at the highest changelist present in all buckets." + ) + parser.add_argument("--host", default="https://jupiter.devtools.epicgames.com", help="Jupiter host URL") + parser.add_argument("--zen", default=str(_DEFAULT_ZEN), help="Path to the zen binary") + parser.add_argument("--namespace", default="fortnite.oplog", help="Builds storage namespace") + args = parser.parse_args() + + # For each bucket, fetch results and build a changelist -> buildId map. + # bucket_cl_map[bucket] = { changelist_int: buildId_str, ... } + bucket_cl_map: dict[str, dict[int, str]] = {} + + for name, bucket in _BUILDS: + print(f"Querying {name} ({bucket}) ...") + results = list_builds_for_bucket(args.zen, args.host, args.namespace, bucket) + if not results: + sys.exit(f"error: no results for bucket '{bucket}' (build '{name}')") + + cl_map: dict[int, str] = {} + for entry in results: + build_id = entry.get("buildId", "") + metadata = entry.get("metadata") or {} + cl = metadata.get("changelist") + if build_id and cl is not None: + # Keep first occurrence (most recent) per changelist + if cl not in cl_map: + cl_map[int(cl)] = build_id + + if not cl_map: + sys.exit( + f"error: bucket '{bucket}' (build '{name}') returned {len(results)} entries " + "but none had both buildId and changelist in metadata" + ) + + print(f" {len(cl_map)} distinct changelists, latest CL {max(cl_map)}") + bucket_cl_map[bucket] = cl_map + + # Find the highest changelist present in every bucket's result set. + common_cls = set(next(iter(bucket_cl_map.values())).keys()) + for bucket, cl_map in bucket_cl_map.items(): + common_cls &= set(cl_map.keys()) + + if not common_cls: + sys.exit( + "error: no changelist is present in all buckets.\n" + "Per-bucket CL ranges:\n" + + "\n".join( + f" {name} ({bucket}): {min(bucket_cl_map[bucket])} – {max(bucket_cl_map[bucket])}" + for name, bucket in _BUILDS + ) + ) + + best_cl = max(common_cls) + print(f"\nHighest common changelist: {best_cl}") + + build_ids: dict[str, dict[str, str]] = {} + for name, bucket in _BUILDS: + build_id = bucket_cl_map[bucket][best_cl] + build_ids[name] = {"bucket": bucket, "buildId": build_id} + print(f" {name}: {build_id}") + + output = {"namespace": args.namespace, "builds": build_ids} + _OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) + with _OUTPUT_PATH.open("w") as f: + json.dump(output, f, indent=2) + f.write("\n") + + print(f"\nWrote {_OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/scripts/ue_build_linux/verify_linux_toolchains.sh b/scripts/ue_build_linux/verify_linux_toolchains.sh new file mode 100755 index 000000000..10fad8a82 --- /dev/null +++ b/scripts/ue_build_linux/verify_linux_toolchains.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env bash +# +# Verify that the project builds on Linux with gcc, ue-clang, clang-19 and clang-20. +# Each toolchain gets a clean slate (build dirs + xmake caches wiped). +# +# Usage: +# ./scripts/verify_linux_toolchains.sh # build all four +# ./scripts/verify_linux_toolchains.sh gcc clang-19 # build only specific ones +# ./scripts/verify_linux_toolchains.sh --clean # also wipe ~/.xmake package cache +# +# Installing toolchains (Ubuntu 24.04): +# - gcc: sudo apt install build-essential +# - ue-clang: use scripts/ue_build_linux/get_ue_toolchain.sh +# - clang-19: sudo apt install clang-19 +# - clang-20: sudo apt install clang-20 + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +PASSED=() +FAILED=() +declare -A TIMINGS +CLEAN_XMAKE_HOME=false + +clean_build_state() { + echo -e "${YELLOW}Cleaning build state...${NC}" + rm -rf "$PROJECT_DIR/.xmake" "$PROJECT_DIR/build" + if [ "$CLEAN_XMAKE_HOME" = true ]; then + rm -rf ~/.xmake + fi +} + +build_toolchain() { + local NAME="$1" + shift + local CONFIG_ARGS=("$@") + + echo "" + echo "============================================================" + echo -e "${YELLOW}Building with: ${NAME}${NC}" + echo " xmake config args: ${CONFIG_ARGS[*]}" + echo "============================================================" + + clean_build_state + + local START_TIME=$SECONDS + + if ! (cd "$PROJECT_DIR" && xmake config -y -m debug "${CONFIG_ARGS[@]}"); then + TIMINGS[$NAME]=$(( SECONDS - START_TIME )) + echo -e "${RED}FAILED: ${NAME} (config, ${TIMINGS[$NAME]}s)${NC}" + FAILED+=("$NAME") + return 1 + fi + + if ! (cd "$PROJECT_DIR" && xmake -y -j"$(nproc)"); then + TIMINGS[$NAME]=$(( SECONDS - START_TIME )) + echo -e "${RED}FAILED: ${NAME} (build, ${TIMINGS[$NAME]}s)${NC}" + FAILED+=("$NAME") + return 1 + fi + + TIMINGS[$NAME]=$(( SECONDS - START_TIME )) + echo -e "${GREEN}PASSED: ${NAME} (${TIMINGS[$NAME]}s)${NC}" + PASSED+=("$NAME") +} + +# Available toolchain configurations +declare -A TOOLCHAINS +TOOLCHAINS[gcc]="--toolchain=gcc" +TOOLCHAINS[ue-clang]="--toolchain=ue-clang" +TOOLCHAINS[clang-19]="--toolchain=clang-19" +TOOLCHAINS[clang-20]="--toolchain=clang-20" + +# Parse arguments +SELECTED=() +for ARG in "$@"; do + if [ "$ARG" = "--clean" ]; then + CLEAN_XMAKE_HOME=true + else + SELECTED+=("$ARG") + fi +done + +if [ ${#SELECTED[@]} -eq 0 ]; then + SELECTED=(gcc ue-clang clang-19 clang-20) +fi + +TOTAL_START=$SECONDS + +for TC in "${SELECTED[@]}"; do + if [ -z "${TOOLCHAINS[$TC]+x}" ]; then + echo -e "${RED}Unknown toolchain: ${TC}${NC}" + echo "Available: ${!TOOLCHAINS[*]}" + exit 1 + fi + + # shellcheck disable=SC2086 + build_toolchain "$TC" ${TOOLCHAINS[$TC]} || true +done + +TOTAL_ELAPSED=$(( SECONDS - TOTAL_START )) + +echo "" +echo "============================================================" +echo "Results (${TOTAL_ELAPSED}s total):" +echo "============================================================" +for TC in "${PASSED[@]}"; do + echo -e " ${GREEN}PASS${NC} ${TC} (${TIMINGS[$TC]}s)" +done +for TC in "${FAILED[@]}"; do + echo -e " ${RED}FAIL${NC} ${TC} (${TIMINGS[$TC]}s)" +done + +[ ${#FAILED[@]} -eq 0 ] diff --git a/scripts/updatefrontend.lua b/scripts/updatefrontend.lua deleted file mode 100644 index ab37819d7..000000000 --- a/scripts/updatefrontend.lua +++ /dev/null @@ -1,111 +0,0 @@ --- Copyright Epic Games, Inc. All Rights Reserved. - --------------------------------------------------------------------------------- -local function _exec(cmd, ...) - local args = {} - for _, arg in pairs({...}) do - if arg then - table.insert(args, arg) - end - end - - print("--", cmd, table.unpack(args)) - local ret = os.execv(cmd, args) - print() - return ret -end - --------------------------------------------------------------------------------- -local function _zip(store_only, zip_path, ...) - -- Here's the rules; if len(...) is 1 and it is a dir then create a zip with - -- archive paths like this; - -- - -- glob(foo/bar/**) -> foo/bar/abc, foo/bar/dir/123 -> zip(abc, dir/123) - -- - -- Otherwise assume ... is file paths and add without leading directories; - -- - -- foo/abc, bar/123 -> zip(abc, 123) - - zip_path = path.absolute(zip_path) - os.tryrm(zip_path) - - local inputs = {...} - - local source_dir = nil - if #inputs == 1 and os.isdir(inputs[1]) then - source_dir = inputs[1] - end - - import("detect.tools.find_7z") - local cmd_7z = find_7z() - if cmd_7z then - input_paths = {} - if source_dir then - -- Suffixing a directory path with a "/." will have 7z set the path - -- for archived files relative to that directory. - input_paths = { path.join(source_dir, ".") } - else - for _, input_path in pairs(inputs) do - -- If there is a "/./" anywhere in file paths then 7z drops all - -- directory information and just archives the file by name - input_path = path.relative(input_path, ".") - if input_path:sub(2,2) ~= ":" then - input_path = "./"..input_path - end - table.insert(input_paths, input_path) - end - end - - compression_level = "-mx1" - if store_only then - compression_level = "-mx0" - end - - local ret = _exec(cmd_7z, "a", compression_level, zip_path, table.unpack(input_paths)) - if ret > 0 then - raise("Received error from 7z") - end - return - end - - print("7z not found, falling back to zip") - - import("detect.tools.find_zip") - zip_cmd = find_zip() - if zip_cmd then - local input_paths = inputs - local cwd = os.curdir() - if source_dir then - os.cd(source_dir) - input_paths = { "." } - end - - compression_level = "-1" - if store_only then - compression_level = "-0" - end - - local strip_leading_path = nil - if not source_dir then - strip_leading_path = "--junk-paths" - end - - local ret = _exec(zip_cmd, "-r", compression_level, strip_leading_path, zip_path, table.unpack(input_paths)) - if ret > 0 then - raise("Received error from zip") - end - - os.cd(cwd) - return - end - print("zip not found") - - raise("Unable to find a suitable zip tool") -end - --------------------------------------------------------------------------------- -function main() - local zip_path = "src/zenserver/frontend/html.zip" - local content_dir = "src/zenserver/frontend/html/" - _zip(true, zip_path, content_dir) -end diff --git a/scripts/win_cross/get_win_sdk.sh b/scripts/win_cross/get_win_sdk.sh new file mode 100755 index 000000000..b22d1bf3a --- /dev/null +++ b/scripts/win_cross/get_win_sdk.sh @@ -0,0 +1,305 @@ +#!/bin/bash +# +# Downloads xwin and uses it to fetch the Windows SDK and MSVC CRT headers/libs +# needed for cross-compiling Windows binaries from Linux using clang-cl. +# +# Usage: +# ./get_win_sdk.sh [output_dir] +# +# Output defaults to ~/.xwin-sdk (override via $XWIN_SDK_DIR or first argument). + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +die() { echo "ERROR: $1" >&2; exit 1; } + +sdk_dir="${1:-${XWIN_SDK_DIR:-${HOME}/.xwin-sdk}}" + +if [[ "${sdk_dir}" == "--help" ]]; then + echo "usage: $(basename "${BASH_SOURCE[0]}") [output_dir]" + echo "" + echo "Downloads the Windows SDK and MSVC CRT via xwin for cross-compilation." + echo "Default output: ~/.xwin-sdk (override via \$XWIN_SDK_DIR or first argument)" + exit 0 +fi + +# If the directory already has SDK content, skip download +if [ -d "${sdk_dir}/sdk/include/um" ] && [ -d "${sdk_dir}/crt/include" ]; then + echo "SDK already present at '${sdk_dir}', skipping download." + echo "Delete the directory to force re-download." + # Still create the compat layout in case it's missing (e.g. script was updated) + CREATE_COMPAT_ONLY=true +else + CREATE_COMPAT_ONLY=false +fi + +if [ -e "${sdk_dir}" ]; then + # Allow re-use of existing empty or partial directory + if [ -d "${sdk_dir}" ]; then + : + else + die "'${sdk_dir}' exists but is not a directory" + fi +fi + +mkdir -p "${sdk_dir}" + +# ------------------------------------------------------------------------- +# Detect LLVM installation +# ------------------------------------------------------------------------- +LLVM_BIN="${LLVM_BIN_DIR:-}" +if [ -z "${LLVM_BIN}" ]; then + # Try common locations + for candidate in /usr/lib/llvm-19/bin /usr/lib/llvm-18/bin /usr/lib/llvm-17/bin; do + if [ -x "${candidate}/clang" ]; then + LLVM_BIN="${candidate}" + break + fi + done +fi +if [ -z "${LLVM_BIN}" ]; then + # Fallback: try to find clang on PATH + CLANG_PATH=$(command -v clang 2>/dev/null || true) + if [ -n "${CLANG_PATH}" ]; then + LLVM_BIN=$(dirname "$(readlink -f "${CLANG_PATH}")") + fi +fi +if [ -z "${LLVM_BIN}" ]; then + die "Could not find LLVM/clang installation. Set LLVM_BIN_DIR to the bin directory." +fi +echo "Using LLVM at: ${LLVM_BIN}" + +# ------------------------------------------------------------------------- +# Download xwin binary and fetch SDK (skip if already present) +# ------------------------------------------------------------------------- +if [ "${CREATE_COMPAT_ONLY}" = false ]; then + XWIN_VERSION="0.6.5" + XWIN_ARCHIVE="xwin-${XWIN_VERSION}-x86_64-unknown-linux-musl.tar.gz" + XWIN_URL="https://github.com/Jake-Shadle/xwin/releases/download/${XWIN_VERSION}/${XWIN_ARCHIVE}" + + TMPDIR=$(mktemp -d) + trap 'rm -rf "${TMPDIR}"' EXIT + + echo "Downloading xwin ${XWIN_VERSION}..." + if command -v wget &>/dev/null; then + wget -q --show-progress -O "${TMPDIR}/${XWIN_ARCHIVE}" "${XWIN_URL}" + elif command -v curl &>/dev/null; then + curl -fSL --progress-bar -o "${TMPDIR}/${XWIN_ARCHIVE}" "${XWIN_URL}" + else + die "Neither wget nor curl found" + fi + + echo "Extracting xwin..." + tar -xzf "${TMPDIR}/${XWIN_ARCHIVE}" -C "${TMPDIR}" + + XWIN_BIN="${TMPDIR}/xwin-${XWIN_VERSION}-x86_64-unknown-linux-musl/xwin" + if [ ! -x "${XWIN_BIN}" ]; then + die "xwin binary not found after extraction" + fi + + echo "Fetching Windows SDK and CRT (this may take a few minutes)..." + "${XWIN_BIN}" --accept-license splat --output "${sdk_dir}" +fi + +# ------------------------------------------------------------------------- +# Create tool wrapper scripts in bin/ +# ------------------------------------------------------------------------- +BIN_DIR="${sdk_dir}/bin" +mkdir -p "${BIN_DIR}" + +# clang-cl wrapper (since the host may not have a clang-cl symlink) +cat > "${BIN_DIR}/clang-cl" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/clang" --driver-mode=cl -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH "\$@" +WRAPPER +chmod +x "${BIN_DIR}/clang-cl" + +# clang wrapper for GNU assembly (.S files) +cat > "${BIN_DIR}/clang" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/clang" "\$@" +WRAPPER +chmod +x "${BIN_DIR}/clang" + +# ------------------------------------------------------------------------- +# Create MSVC-compatible directory layout for xmake package builds. +# +# xmake's built-in msvc toolchain on Linux uses find_build_tools() which +# expects the following structure: +# <sdk>/VC/Tools/MSVC/<version>/include → CRT headers +# <sdk>/VC/Tools/MSVC/<version>/lib/<arch> → CRT libs +# <sdk>/Windows Kits/10/Include/<ver>/{ucrt,um,shared} → SDK headers +# <sdk>/Windows Kits/10/Lib/<ver>/{ucrt,um}/<arch> → SDK libs +# <sdk>/bin/<arch>/ → tool wrappers +# +# We create this layout using symlinks back to the xwin flat layout. +# ------------------------------------------------------------------------- +echo "Creating MSVC-compatible directory layout..." + +FAKE_VC_VER="14.0.0" +FAKE_SDK_VER="10.0.0.0" + +# --- VC Tools (CRT) --- +VC_DIR="${sdk_dir}/VC/Tools/MSVC/${FAKE_VC_VER}" +mkdir -p "${VC_DIR}" +ln -sfn "${sdk_dir}/crt/include" "${VC_DIR}/include" +mkdir -p "${VC_DIR}/lib" +ln -sfn "${sdk_dir}/crt/lib/x86_64" "${VC_DIR}/lib/x64" + +# --- Windows Kits (SDK headers) --- +WINSDK_INC="${sdk_dir}/Windows Kits/10/Include/${FAKE_SDK_VER}" +mkdir -p "${WINSDK_INC}" +ln -sfn "${sdk_dir}/sdk/include/ucrt" "${WINSDK_INC}/ucrt" +ln -sfn "${sdk_dir}/sdk/include/um" "${WINSDK_INC}/um" +ln -sfn "${sdk_dir}/sdk/include/shared" "${WINSDK_INC}/shared" + +# --- Windows Kits (SDK libs) --- +WINSDK_LIB="${sdk_dir}/Windows Kits/10/Lib/${FAKE_SDK_VER}" +mkdir -p "${WINSDK_LIB}/ucrt" "${WINSDK_LIB}/um" +ln -sfn "${sdk_dir}/sdk/lib/ucrt/x86_64" "${WINSDK_LIB}/ucrt/x64" +ln -sfn "${sdk_dir}/sdk/lib/um/x86_64" "${WINSDK_LIB}/um/x64" + +# --- Tool wrappers in bin/<arch>/ (for msvc toolchain PATH setup) --- +ARCH_BIN="${sdk_dir}/bin/x64" +mkdir -p "${ARCH_BIN}" + +# cl → clang-cl wrapper +cat > "${ARCH_BIN}/cl" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/clang" --driver-mode=cl -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/cl" +cp "${ARCH_BIN}/cl" "${ARCH_BIN}/cl.exe" + +# link → lld-link (with /lib mode redirecting to llvm-lib for archiver use) +# xmake sets ar=link.exe for non-LTO MSVC builds and may pass linker-only flags +# like /opt:ref to the archiver. We detect /lib mode, filter those flags, and +# redirect to llvm-lib. Also handles response files (@file) that xmake uses +# when the argument list is too long. +cat > "${ARCH_BIN}/link" << WRAPPER +#!/bin/bash +ALL_ARGS=() +for arg in "\$@"; do + if [[ "\$arg" == @* ]]; then + rspfile="\${arg#@}" + while IFS= read -r line; do + [[ -n "\$line" ]] && ALL_ARGS+=("\$line") + done < "\$rspfile" + else + ALL_ARGS+=("\$arg") + fi +done +LIB_MODE=false +HAS_OUT_LIB=false +HAS_OBJ_ONLY=true +ARGS=() +for arg in "\${ALL_ARGS[@]}"; do + lower="\${arg,,}" + case "\$lower" in + /lib|-lib) LIB_MODE=true ;; + /out:*.lib|-out:*.lib) HAS_OUT_LIB=true; ARGS+=("\$arg") ;; + /opt:*|-opt:*) ;; + /subsystem:*|-subsystem:*) HAS_OBJ_ONLY=false; ARGS+=("\$arg") ;; + *.exe) HAS_OBJ_ONLY=false; ARGS+=("\$arg") ;; + *) ARGS+=("\$arg") ;; + esac +done +if [ "\$LIB_MODE" = true ] || ([ "\$HAS_OUT_LIB" = true ] && [ "\$HAS_OBJ_ONLY" = true ]); then + LIB_ARGS=() + for arg in "\${ARGS[@]}"; do + case "\${arg,,}" in + -nodefaultlib:*|/nodefaultlib:*) ;; + *) LIB_ARGS+=("\$arg") ;; + esac + done + exec "${LLVM_BIN}/llvm-lib" "\${LIB_ARGS[@]}" +else + exec "${LLVM_BIN}/lld-link" "\$@" +fi +WRAPPER +chmod +x "${ARCH_BIN}/link" +cp "${ARCH_BIN}/link" "${ARCH_BIN}/link.exe" + +# lib → llvm-lib +cat > "${ARCH_BIN}/lib" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/llvm-lib" "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/lib" +cp "${ARCH_BIN}/lib" "${ARCH_BIN}/lib.exe" + +# rc → llvm-rc (with SDK include paths for winres.h etc.) +cat > "${ARCH_BIN}/rc" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/llvm-rc" /I "${sdk_dir}/crt/include" /I "${sdk_dir}/sdk/include/ucrt" /I "${sdk_dir}/sdk/include/um" /I "${sdk_dir}/sdk/include/shared" "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/rc" +cp "${ARCH_BIN}/rc" "${ARCH_BIN}/rc.exe" + +# ml64 → llvm-ml (MASM-compatible assembler) +cat > "${ARCH_BIN}/ml64" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/llvm-ml" -m64 "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/ml64" +cp "${ARCH_BIN}/ml64" "${ARCH_BIN}/ml64.exe" + +# clang-cl (for xmake's built-in clang-cl toolchain detection) +cat > "${ARCH_BIN}/clang-cl" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/clang" --driver-mode=cl -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/clang-cl" + +# llvm-ar (cmake's clang-cl driver may use llvm-ar as archiver name but with +# MSVC-style flags like /nologo /out: — redirect to llvm-lib which handles these) +cat > "${ARCH_BIN}/llvm-ar" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/llvm-lib" "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/llvm-ar" + +# lld-link (for LTO builds where clang-cl toolchain uses lld-link) +cat > "${ARCH_BIN}/lld-link" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/lld-link" "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/lld-link" + +# mt → llvm-mt (manifest tool) +cat > "${ARCH_BIN}/mt" << WRAPPER +#!/bin/bash +exec "${LLVM_BIN}/llvm-mt" "\$@" +WRAPPER +chmod +x "${ARCH_BIN}/mt" +cp "${ARCH_BIN}/mt" "${ARCH_BIN}/mt.exe" + +# ------------------------------------------------------------------------- +# Create debug CRT lib symlinks (cmake's try_compile uses Debug config +# by default, which links against msvcrtd.lib etc. -- these don't exist +# in xwin since it only ships release libs. Symlink to release versions +# so cmake compiler tests pass.) +# ------------------------------------------------------------------------- +CRT_LIB="${sdk_dir}/crt/lib/x86_64" +for lib in msvcrt MSVCRT vcruntime msvcprt libcmt LIBCMT libcpmt libcpmt1 libconcrt libconcrt1 libvcruntime; do + release="${CRT_LIB}/${lib}.lib" + debug="${CRT_LIB}/${lib}d.lib" + if [ -f "${release}" ] && [ ! -e "${debug}" ]; then + ln -sfn "${lib}.lib" "${debug}" + fi +done + +echo "" +echo "Windows SDK installed to: ${sdk_dir}" +echo " SDK headers: ${sdk_dir}/sdk/include/um" +echo " SDK libs: ${sdk_dir}/sdk/lib/um/x86_64" +echo " CRT headers: ${sdk_dir}/crt/include" +echo " CRT libs: ${sdk_dir}/crt/lib/x86_64" +echo " Tool wrappers: ${BIN_DIR}/" +echo " MSVC compat: ${sdk_dir}/VC/ and ${sdk_dir}/Windows Kits/" +echo "" +echo "Usage:" +echo " xmake config -p windows -a x64 --toolchain=clang-cl --sdk=\${sdk_dir}" +echo "" +echo "Done" diff --git a/src/transports/winsock/winsock.cpp b/src/transports/winsock/winsock.cpp index f98984726..c1f4f6abe 100644 --- a/src/transports/winsock/winsock.cpp +++ b/src/transports/winsock/winsock.cpp @@ -271,7 +271,7 @@ WinsockTransportPlugin::Initialize(TransportServer* ServerInterface) m_ServerInterface = ServerInterface; m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); - if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) + if (m_ListenSocket == INVALID_SOCKET) { throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()), "socket creation failed in HTTP plugin server init"); @@ -302,7 +302,7 @@ WinsockTransportPlugin::Initialize(TransportServer* ServerInterface) do { - if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) + if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != INVALID_SOCKET) { int Flag = 1; setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); diff --git a/src/transports/winsock/xmake.lua b/src/transports/winsock/xmake.lua index c14283546..cdba75885 100644 --- a/src/transports/winsock/xmake.lua +++ b/src/transports/winsock/xmake.lua @@ -5,7 +5,7 @@ target("winsock") set_group("transports") add_headerfiles("**.h") add_files("**.cpp") - add_links("Ws2_32") + add_links("ws2_32") add_includedirs(".") set_symbols("debug") add_deps("zenbase", "transport-sdk") diff --git a/src/zen/cmds/admin_cmd.cpp b/src/zen/cmds/admin_cmd.cpp index 15e854796..034d430fd 100644 --- a/src/zen/cmds/admin_cmd.cpp +++ b/src/zen/cmds/admin_cmd.cpp @@ -17,7 +17,7 @@ namespace zen { ScrubCommand::ScrubCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "n", "dry", "Dry run (do not delete any data)", cxxopts::value(m_DryRun), "<bool>"); m_Options.add_option("", "", "no-gc", "Do not perform GC after scrub pass", cxxopts::value(m_NoGc), "<bool>"); m_Options.add_option("", "", "no-cas", "Do not scrub CAS stores", cxxopts::value(m_NoCas), "<bool>"); @@ -48,7 +48,7 @@ ScrubCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); HttpClient::KeyValueMap Params{{"skipdelete", ToString(m_DryRun)}, {"skipgc", ToString(m_NoGc)}, @@ -70,7 +70,7 @@ ScrubCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) GcCommand::GcCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "s", "smallobjects", @@ -258,7 +258,7 @@ GcCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) } Params.Entries.insert({"enablevalidation", m_EnableValidation ? "true" : "false"}); - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/admin/gc"sv, HttpClient::Accept(HttpContentType::kJSON), Params)) { ZEN_CONSOLE("OK: {}", Response.ToText()); @@ -272,7 +272,7 @@ GcCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) GcStatusCommand::GcStatusCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "d", "details", "Show detailed GC report", cxxopts::value(m_Details)->default_value("false"), "<details>"); } @@ -297,7 +297,7 @@ GcStatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Get("/admin/gc"sv, HttpClient::Accept(HttpContentType::kJSON))) { ZEN_CONSOLE("OK: {}", Response.ToText()); @@ -311,7 +311,7 @@ GcStatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) GcStopCommand::GcStopCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); } GcStopCommand::~GcStopCommand() @@ -335,7 +335,7 @@ GcStopCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/admin/gc-stop"sv, HttpClient::Accept(HttpContentType::kJSON))) { if (Response.StatusCode == HttpResponseCode::Accepted) @@ -358,7 +358,7 @@ GcStopCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) JobCommand::JobCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "j", "jobid", "Job id", cxxopts::value(m_JobId), "<jobid>"); m_Options.add_option("", "c", "cancel", "Cancel job id", cxxopts::value(m_Cancel), "<cancel>"); } @@ -384,7 +384,7 @@ JobCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (m_Cancel) { @@ -421,7 +421,7 @@ JobCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) LoggingCommand::LoggingCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "", "cache-write-log", "Enable cache write logging", cxxopts::value(m_CacheWriteLog), "<enable/disable>"); m_Options.add_option("", "", "cache-access-log", "Enable cache access logging", cxxopts::value(m_CacheAccessLog), "<enable/disable>"); m_Options @@ -467,7 +467,7 @@ LoggingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); HttpClient::KeyValueMap Parameters; @@ -564,7 +564,7 @@ LoggingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) FlushCommand::FlushCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); } FlushCommand::~FlushCommand() = default; @@ -586,7 +586,7 @@ FlushCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - zen::HttpClient Http(m_HostName); + zen::HttpClient Http = CreateHttpClient(m_HostName); if (zen::HttpClient::Response Response = Http.Post("/admin/flush"sv)) { diff --git a/src/zen/cmds/cache_cmd.cpp b/src/zen/cmds/cache_cmd.cpp index 2aa142973..a8c15f119 100644 --- a/src/zen/cmds/cache_cmd.cpp +++ b/src/zen/cmds/cache_cmd.cpp @@ -59,7 +59,7 @@ namespace { DropCommand::DropCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "n", "namespace", "Namespace name", cxxopts::value(m_NamespaceName), "<namespacename>"); m_Options.add_option("", "b", "bucket", "Bucket name", cxxopts::value(m_BucketName), "<bucketname>"); m_Options.parse_positional({"namespace", "bucket"}); @@ -105,7 +105,7 @@ DropCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) ZEN_CONSOLE("Dropping {}", DropDescription); - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Delete(Url)) { ZEN_CONSOLE("{}", Response.ToText()); @@ -119,7 +119,7 @@ DropCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) CacheInfoCommand::CacheInfoCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "n", "namespace", "Namespace name", cxxopts::value(m_NamespaceName), "<namespacename>"); m_Options.add_option("", "", @@ -196,7 +196,7 @@ CacheInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) Parameters.Entries.insert({"bucketsize", "true"}); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Get(Url, HttpClient::Accept(ZenContentType::kJSON), Parameters)) { ZEN_CONSOLE("{}", Response.ToText()); @@ -210,7 +210,7 @@ CacheInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) CacheStatsCommand::CacheStatsCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); } CacheStatsCommand::~CacheStatsCommand() = default; @@ -232,7 +232,7 @@ CacheStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Get("/stats/z$", HttpClient::Accept(ZenContentType::kJSON))) { ZEN_CONSOLE("{}", Response.ToText()); @@ -246,7 +246,7 @@ CacheStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv CacheDetailsCommand::CacheDetailsCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "c", "csv", "Info on csv format", cxxopts::value(m_CSV), "<csv>"); m_Options.add_option("", "d", "details", "Get detailed information about records", cxxopts::value(m_Details), "<details>"); m_Options.add_option("", @@ -329,7 +329,7 @@ CacheDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar Url = "/z$/details$"; } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Get(Url, Headers, Parameters)) { ZEN_CONSOLE("{}", Response.ToText()); @@ -343,7 +343,7 @@ CacheDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar CacheGenerateCommand::CacheGenerateCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options .add_option("", "n", "namespace", "Namespace to generate cache values/records for", cxxopts::value(m_Namespace), "<namespace>"); m_Options.add_option("", "b", "bucket", "Bucket name to generate cache values/records for", cxxopts::value(m_Bucket), "<bucket>"); @@ -431,7 +431,7 @@ CacheGenerateCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a std::uniform_int_distribution<uint64_t> KeyDistribution; - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); auto GeneratePutCacheValueRequest( [this, &KeyDistribution, &Generator](std::span<std::uint64_t> BatchSizes, uint64_t RequestIndex) -> CbPackage { @@ -586,7 +586,7 @@ CacheGenerateCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a CacheGetCommand::CacheGetCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options .add_option("", "n", "namespace", "Namespace to generate cache values/records for", cxxopts::value(m_Namespace), "<namespace>"); m_Options.add_option("", "b", "bucket", "Bucket name to generate cache values/records for", cxxopts::value(m_Bucket), "<bucket>"); @@ -656,7 +656,7 @@ CacheGetCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) } } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (!m_OutputPath.empty()) { diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index 42c7119e7..cbc153e07 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -44,7 +44,7 @@ namespace zen { ExecCommand::ExecCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>"); m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>"); m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>"); m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>"); diff --git a/src/zen/cmds/info_cmd.cpp b/src/zen/cmds/info_cmd.cpp index 49ad022cf..9faad5691 100644 --- a/src/zen/cmds/info_cmd.cpp +++ b/src/zen/cmds/info_cmd.cpp @@ -14,7 +14,7 @@ namespace zen { InfoCommand::InfoCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); } InfoCommand::~InfoCommand() @@ -38,7 +38,7 @@ InfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get("/admin/info", HttpClient::Accept(ZenContentType::kJSON))) { diff --git a/src/zen/cmds/projectstore_cmd.cpp b/src/zen/cmds/projectstore_cmd.cpp index c0780c7e8..d31c34fd0 100644 --- a/src/zen/cmds/projectstore_cmd.cpp +++ b/src/zen/cmds/projectstore_cmd.cpp @@ -507,7 +507,7 @@ namespace projectstore_impl { DropProjectCommand::DropProjectCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); m_Options.add_option("", "", "dryrun", "Dry run - resolve arguments but do not drop", cxxopts::value(m_DryRun), "<dryrun>"); @@ -537,7 +537,7 @@ DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) @@ -598,7 +598,7 @@ DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg ProjectInfoCommand::ProjectInfoCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); m_Options.parse_positional({"project", "oplog"}); @@ -632,7 +632,7 @@ ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("'--project' is required", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); std::string Url; if (m_ProjectName.empty()) @@ -684,7 +684,7 @@ ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg CreateProjectCommand::CreateProjectCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectId), "<projectid>"); m_Options.add_option("", "", "rootdir", "Absolute path to root directory", cxxopts::value(m_RootDir), "<root>"); m_Options.add_option("", "", "enginedir", "Absolute path to engine root directory", cxxopts::value(m_EngineRootDir), "<engineroot>"); @@ -721,7 +721,7 @@ CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a throw OptionParseException("'--project' is required", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); std::string Url = fmt::format("/prj/{}", m_ProjectId); @@ -756,7 +756,7 @@ CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a CreateOplogCommand::CreateOplogCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectId), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogId), "<oplogid>"); m_Options.add_option("", "", "gcpath", "Absolute path to oplog lifetime marker file", cxxopts::value(m_GcPath), "<path>"); @@ -791,8 +791,8 @@ CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("'--project' is required", m_Options.help()); } - HttpClient Http(m_HostName); - m_ProjectId = ResolveProject(Http, m_ProjectId); + HttpClient Http = CreateHttpClient(m_HostName); + m_ProjectId = ResolveProject(Http, m_ProjectId); if (m_ProjectId.empty()) { throw std::runtime_error("Project can not be found"); @@ -835,7 +835,7 @@ CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg ExportOplogCommand::ExportOplogCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); m_Options.add_option("", "", "maxblocksize", "Max size for bundled attachments", cxxopts::value(m_MaxBlockSize), "<blocksize>"); @@ -1022,8 +1022,8 @@ ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("'--project' is required", m_Options.help()); } - HttpClient Http(m_HostName); - m_ProjectName = ResolveProject(Http, m_ProjectName); + HttpClient Http = CreateHttpClient(m_HostName); + m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) { throw std::runtime_error("Project can not be found"); @@ -1368,7 +1368,7 @@ ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg ImportOplogCommand::ImportOplogCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); m_Options.add_option("", @@ -1541,8 +1541,8 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg m_Options.help()); } - HttpClient Http(m_HostName); - m_ProjectName = ResolveProject(Http, m_ProjectName); + HttpClient Http = CreateHttpClient(m_HostName); + m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) { throw std::runtime_error("Project can not be found"); @@ -1782,7 +1782,7 @@ ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg SnapshotOplogCommand::SnapshotOplogCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); @@ -1813,7 +1813,7 @@ SnapshotOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (m_ProjectName.empty()) { @@ -1851,7 +1851,7 @@ SnapshotOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a ProjectStatsCommand::ProjectStatsCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); } ProjectStatsCommand::~ProjectStatsCommand() @@ -1876,7 +1876,7 @@ ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get("/stats/prj", HttpClient::Accept(ZenContentType::kJSON))) { ZEN_CONSOLE("{}", Result.ToText()); @@ -1892,7 +1892,7 @@ ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** ar ProjectOpDetailsCommand::ProjectOpDetailsCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "c", "csv", "Output in CSV format (default is JSon)", cxxopts::value(m_CSV), "<csv>"); m_Options.add_option("", "d", "details", "Detailed info on oplog", cxxopts::value(m_Details), "<details>"); m_Options.add_option("", "o", "opdetails", "Details info on oplog body", cxxopts::value(m_OpDetails), "<opdetails>"); @@ -1929,7 +1929,7 @@ ProjectOpDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char* throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) @@ -1982,7 +1982,7 @@ ProjectOpDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char* OplogMirrorCommand::OplogMirrorCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name to get info from", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "l", "oplog", "Oplog name to get info from", cxxopts::value(m_OplogName), "<oplogid>"); m_Options.add_option("", "t", "target", "Target directory for mirror", cxxopts::value(m_MirrorRootPath), "<path>"); @@ -2045,7 +2045,7 @@ OplogMirrorCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) @@ -2283,7 +2283,7 @@ OplogMirrorCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** arg OplogValidateCommand::OplogValidateCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name to get info from", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "l", "oplog", "Oplog name to get info from", cxxopts::value(m_OplogName), "<oplogid>"); @@ -2313,7 +2313,7 @@ OplogValidateCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** a throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); m_ProjectName = ResolveProject(Http, m_ProjectName); if (m_ProjectName.empty()) diff --git a/src/zen/cmds/rpcreplay_cmd.cpp b/src/zen/cmds/rpcreplay_cmd.cpp index 70e9e5300..3bf81a9df 100644 --- a/src/zen/cmds/rpcreplay_cmd.cpp +++ b/src/zen/cmds/rpcreplay_cmd.cpp @@ -32,7 +32,7 @@ using namespace std::literals; RpcStartRecordingCommand::RpcStartRecordingCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>"); m_Options.parse_positional("path"); @@ -61,7 +61,7 @@ RpcStartRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char throw OptionParseException("'--path' is required", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/z$/exec$/start-recording"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap({{"path", m_RecordingPath}}))) { @@ -78,7 +78,7 @@ RpcStartRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char RpcStopRecordingCommand::RpcStopRecordingCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); } RpcStopRecordingCommand::~RpcStopRecordingCommand() = default; @@ -100,7 +100,7 @@ RpcStopRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char* throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/z$/exec$/stop-recording"sv)) { ZEN_CONSOLE("{}", Response.ToText()); @@ -116,7 +116,7 @@ RpcStopRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char* RpcReplayCommand::RpcReplayCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>"); m_Options.add_option("", "", "dry", "Do a dry run", cxxopts::value(m_DryRun), "<enable>"); m_Options.add_option("", @@ -223,7 +223,7 @@ RpcReplayCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (m_OnHost) { - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/z$/exec$/replay-recording"sv, HttpClient::KeyValueMap{}, @@ -302,7 +302,7 @@ RpcReplayCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) } }); - HttpClient Http{m_HostName}; + HttpClient Http = CreateHttpClient(m_HostName); uint64_t EntryIndex = EntryOffset.fetch_add(m_Stride); while (EntryIndex < EntryCount) diff --git a/src/zen/cmds/serve_cmd.cpp b/src/zen/cmds/serve_cmd.cpp index 49389bcdf..03007a86c 100644 --- a/src/zen/cmds/serve_cmd.cpp +++ b/src/zen/cmds/serve_cmd.cpp @@ -21,7 +21,7 @@ using namespace std::literals; ServeCommand::ServeCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); m_Options.add_option("", "", "path", "Root path to directory", cxxopts::value(m_RootPath), "<rootpath>"); @@ -183,7 +183,7 @@ ServeCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) const std::string ProjectUri = fmt::format("/prj/{}", m_ProjectName); const std::string ProjectOplogUri = fmt::format("/prj/{}/oplog/{}", m_ProjectName, m_OplogName); - HttpClient Client(m_HostName); + HttpClient Client = CreateHttpClient(m_HostName); // Ensure project exists diff --git a/src/zen/cmds/service_cmd.cpp b/src/zen/cmds/service_cmd.cpp index a781dc340..3347f1afe 100644 --- a/src/zen/cmds/service_cmd.cpp +++ b/src/zen/cmds/service_cmd.cpp @@ -12,8 +12,8 @@ #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> # include <shellapi.h> -# include <Shlwapi.h> -# pragma comment(lib, "Shlwapi.lib") +# include <shlwapi.h> +# pragma comment(lib, "shlwapi.lib") #endif #if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC diff --git a/src/zen/cmds/status_cmd.cpp b/src/zen/cmds/status_cmd.cpp index c43f85052..6ed3c42e1 100644 --- a/src/zen/cmds/status_cmd.cpp +++ b/src/zen/cmds/status_cmd.cpp @@ -4,6 +4,7 @@ #include <zencore/compactbinary.h> #include <zencore/compactbinaryutil.h> +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/string.h> @@ -64,7 +65,7 @@ StatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) return; } - ZEN_CONSOLE("{:>5} {:>6} {:>24}", "port", "pid", "session"); + ZEN_CONSOLE("{:>5} {:>6} {:>24} {}", "port", "pid", "session", "socket"); State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { bool MatchesAnyPort = (m_Port == 0) && (EffectivePort == 0); bool MatchesEffectivePort = (EffectivePort != 0) && (Entry.EffectiveListenPort.load() == EffectivePort); @@ -73,7 +74,22 @@ StatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { StringBuilder<25> SessionStringBuilder; Entry.GetSessionId().ToString(SessionStringBuilder); - ZEN_CONSOLE("{:>5} {:>6} {:>24}", Entry.EffectiveListenPort.load(), Entry.Pid.load(), SessionStringBuilder); + + std::string SocketPath; + if (Entry.HasInstanceInfo()) + { + ZenServerInstanceInfo Info; + if (Info.OpenReadOnly(Entry.GetSessionId())) + { + InstanceInfoData Data = Info.Read(); + if (!Data.UnixSocketPath.empty()) + { + SocketPath = PathToUtf8(Data.UnixSocketPath); + } + } + } + std::string PortStr = Entry.IsNoNetwork() ? std::string("-") : fmt::to_string(Entry.EffectiveListenPort.load()); + ZEN_CONSOLE("{:>5} {:>6} {:>24} {}", PortStr, Entry.Pid.load(), SessionStringBuilder, SocketPath); } }); } diff --git a/src/zen/cmds/top_cmd.cpp b/src/zen/cmds/top_cmd.cpp index f674db6cd..cf2896f0f 100644 --- a/src/zen/cmds/top_cmd.cpp +++ b/src/zen/cmds/top_cmd.cpp @@ -2,6 +2,7 @@ #include "top_cmd.h" +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/system.h> @@ -81,13 +82,29 @@ TopCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { if ((n++ % HeaderPeriod) == 0) { - ZEN_CONSOLE("{:>5} {:>6} {:>24}", "port", "pid", "session"); + ZEN_CONSOLE("{:>5} {:>6} {:>24} {}", "port", "pid", "session", "socket"); } State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { StringBuilder<25> SessionStringBuilder; Entry.GetSessionId().ToString(SessionStringBuilder); - ZEN_CONSOLE("{:>5} {:>6} {:>24}", Entry.EffectiveListenPort.load(), Entry.Pid.load(), SessionStringBuilder); + + std::string SocketPath; + if (Entry.HasInstanceInfo()) + { + ZenServerInstanceInfo Info; + if (Info.OpenReadOnly(Entry.GetSessionId())) + { + InstanceInfoData Data = Info.Read(); + if (!Data.UnixSocketPath.empty()) + { + SocketPath = PathToUtf8(Data.UnixSocketPath); + } + } + } + + std::string PortStr = Entry.IsNoNetwork() ? std::string("-") : fmt::to_string(Entry.EffectiveListenPort.load()); + ZEN_CONSOLE("{:>5} {:>6} {:>24} {}", PortStr, Entry.Pid.load(), SessionStringBuilder, SocketPath); }); zen::Sleep(1000); @@ -121,7 +138,21 @@ PsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) } State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { - ZEN_CONSOLE("Port {} : pid {}", Entry.EffectiveListenPort.load(), Entry.Pid.load()); + std::string Extra; + if (Entry.HasInstanceInfo()) + { + ZenServerInstanceInfo Info; + if (Info.OpenReadOnly(Entry.GetSessionId())) + { + InstanceInfoData Data = Info.Read(); + if (!Data.UnixSocketPath.empty()) + { + Extra = fmt::format(" (unix: {})", Data.UnixSocketPath); + } + } + } + std::string PortStr = Entry.IsNoNetwork() ? std::string("-") : fmt::to_string(Entry.EffectiveListenPort.load()); + ZEN_CONSOLE("Port {} : pid {}{}", PortStr, Entry.Pid.load(), Extra); }); } diff --git a/src/zen/cmds/trace_cmd.cpp b/src/zen/cmds/trace_cmd.cpp index 41a30068c..54c0f080d 100644 --- a/src/zen/cmds/trace_cmd.cpp +++ b/src/zen/cmds/trace_cmd.cpp @@ -12,7 +12,7 @@ namespace zen { TraceCommand::TraceCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "s", "stop", "Stop tracing", cxxopts::value(m_Stop)->default_value("false"), "<stop>"); m_Options.add_option("", "", "host", "Start tracing to host", cxxopts::value(m_TraceHost), "<hostip>"); m_Options.add_option("", "", "file", "Start tracing to file", cxxopts::value(m_TraceFile), "<filepath>"); @@ -37,7 +37,7 @@ TraceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } - zen::HttpClient Http(m_HostName); + zen::HttpClient Http = CreateHttpClient(m_HostName); if (m_Stop) { diff --git a/src/zen/cmds/ui_cmd.cpp b/src/zen/cmds/ui_cmd.cpp index da06ce305..4846b4d18 100644 --- a/src/zen/cmds/ui_cmd.cpp +++ b/src/zen/cmds/ui_cmd.cpp @@ -50,7 +50,7 @@ UiCommand::UiCommand() { m_Options.add_options()("h,help", "Print help"); m_Options.add_options()("a,all", "Open dashboard for all running instances", cxxopts::value(m_All)->default_value("false")); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_option("", "p", "path", @@ -230,6 +230,11 @@ UiCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) throw OptionParseException("Unable to resolve server specification", m_Options.help()); } + if (IsUnixSocketSpec(m_HostName)) + { + throw std::runtime_error("Cannot open browser for a Unix domain socket connection"); + } + OpenBrowser(m_HostName); } diff --git a/src/zen/cmds/version_cmd.cpp b/src/zen/cmds/version_cmd.cpp index ed34d7011..0948de1bb 100644 --- a/src/zen/cmds/version_cmd.cpp +++ b/src/zen/cmds/version_cmd.cpp @@ -20,7 +20,7 @@ using namespace std::literals; VersionCommand::VersionCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), "[hosturl]"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName), "[hosturl]"); m_Options.add_option("", "d", "detailed", "Detailed Version", cxxopts::value(m_DetailedVersion), "[detailedversion]"); m_Options.add_option("", "o", "output-path", "Path for output", cxxopts::value(m_OutputPath), "[outputpath]"); m_Options.parse_positional({"hosturl"}); @@ -57,7 +57,7 @@ VersionCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) ZEN_CONSOLE("Querying host {}", m_HostName); } - HttpClient Client(m_HostName, HttpClientSettings{.Timeout = std::chrono::milliseconds(5000)}); + HttpClient Client = CreateHttpClient(m_HostName, {.Timeout = std::chrono::milliseconds(5000)}); HttpClient::KeyValueMap Parameters; if (m_DetailedVersion) diff --git a/src/zen/cmds/vfs_cmd.cpp b/src/zen/cmds/vfs_cmd.cpp index 79ec20ce2..29ad8dc7c 100644 --- a/src/zen/cmds/vfs_cmd.cpp +++ b/src/zen/cmds/vfs_cmd.cpp @@ -18,7 +18,7 @@ using namespace std::literals; VfsCommand::VfsCommand() { m_Options.add_option("", "", "verb", "VFS management verb (mount, unmount, info)", cxxopts::value(m_Verb), "<verb>"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<url>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<url>"); m_Options.add_option("", "", "vfs-path", "Specify VFS mount point path", cxxopts::value(m_MountPath), "<path>"); m_Options.parse_positional({"verb", "vfs-path"}); @@ -45,7 +45,7 @@ VfsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) if (m_HostName.empty()) throw OptionParseException("Unable to resolve server specification", m_Options.help()); - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (m_Verb == "mount"sv) { diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index af265d898..220ef6a9e 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -86,7 +86,7 @@ namespace { WorkspaceCommand::WorkspaceCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_options()("system-dir", "Specify system root", cxxopts::value(m_SystemRootDir)); m_Options.add_option("", "v", "verb", "Verb for workspace - create, remove, info", cxxopts::value(m_Verb), "<verb>"); m_Options.parse_positional({"verb"}); @@ -182,7 +182,7 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { if (!m_HostName.empty()) { - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get("/ws/refresh"); !Result) { ZEN_CONSOLE_WARN("Failed to refresh workspaces for host {}. Reason: '{}'", m_HostName, Result.ErrorMessage(""sv)); @@ -254,7 +254,7 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { if (!m_HostName.empty()) { - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get("/ws/refresh"); !Result) { ZEN_CONSOLE_WARN("Failed to refresh workspaces for host {}. Reason: '{}'", m_HostName, Result.ErrorMessage(""sv)); @@ -275,7 +275,7 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) WorkspaceShareCommand::WorkspaceShareCommand() { m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); + m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), "<hosturl>"); m_Options.add_options()("system-dir", "Specify system root", cxxopts::value(m_SystemRootDir)); m_Options.add_option("", "v", "verb", "Verb for workspace - create, remove, info", cxxopts::value(m_Verb), "<verb>"); m_Options.parse_positional({"verb"}); @@ -475,7 +475,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** { if (!m_HostName.empty()) { - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get("/ws/refresh"); !Result) { ZEN_CONSOLE_WARN("Failed to refresh workspaces for host {}. Reason: '{}'", m_HostName, Result.ErrorMessage(""sv)); @@ -592,7 +592,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** { if (!m_HostName.empty()) { - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get("/ws/refresh"); !Result) { ZEN_CONSOLE_WARN("Failed to refresh workspaces for host {}. Reason: '{}'", m_HostName, Result.ErrorMessage(""sv)); @@ -645,7 +645,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** throw OptionParseException("Unable to resolve server specification", SubOption->help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get(fmt::format("/ws/{}/files", GetShareIdentityUrl(m_FilesOptions)), {}, Params)) { ZEN_CONSOLE("{}: {}", Result, Result.ToText()); @@ -678,7 +678,7 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** throw OptionParseException("Unable to resolve server specification", SubOption->help()); } - HttpClient Http(m_HostName); + HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Result = Http.Get(fmt::format("/ws/{}/entries", GetShareIdentityUrl(m_EntriesOptions)), {}, Params)) { ZEN_CONSOLE("{}: {}", Result, Result.ToText()); @@ -753,8 +753,8 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** throw OptionParseException("'--chunk' is required", SubOption->help()); } - HttpClient Http(m_HostName); - m_ChunkId = ChunksToOidStrings(Http, m_WorkspaceId, m_ShareId, std::vector<std::string>{m_ChunkId})[0]; + HttpClient Http = CreateHttpClient(m_HostName); + m_ChunkId = ChunksToOidStrings(Http, m_WorkspaceId, m_ShareId, std::vector<std::string>{m_ChunkId})[0]; HttpClient::KeyValueMap Params; if (m_Offset != 0) @@ -794,8 +794,8 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** throw OptionParseException("'--chunks' is required", SubOption->help()); } - HttpClient Http(m_HostName); - m_ChunkIds = ChunksToOidStrings(Http, m_WorkspaceId, m_ShareId, m_ChunkIds); + HttpClient Http = CreateHttpClient(m_HostName); + m_ChunkIds = ChunksToOidStrings(Http, m_WorkspaceId, m_ShareId, m_ChunkIds); std::vector<RequestChunkEntry> ChunkRequests; ChunkRequests.resize(m_ChunkIds.size()); diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua index 4c134404a..df249ade4 100644 --- a/src/zen/xmake.lua +++ b/src/zen/xmake.lua @@ -14,8 +14,10 @@ target("zen") if is_plat("windows") then add_files("zen.rc") - add_ldflags("/subsystem:console,5.02") - add_ldflags("/LTCG") + add_ldflags("/subsystem:console,5.02", {force = true}) + if not (get_config("toolchain") or ""):find("clang") then + add_ldflags("/LTCG") + end end if is_plat("macosx") then diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 7b1b6e7d7..86154c291 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -369,9 +369,31 @@ GetReturnCodeFromHttpResult(const HttpClientError& Ex) } } +bool +ZenCmdBase::IsUnixSocketSpec(std::string_view Spec) +{ + return Spec.starts_with("unix://"); +} + +HttpClient +ZenCmdBase::CreateHttpClient(const std::string& HostSpec, HttpClientSettings Settings) +{ + if (IsUnixSocketSpec(HostSpec)) + { + Settings.UnixSocketPath = HostSpec.substr(7); // strip "unix://" + return HttpClient("http://localhost", Settings); + } + return HttpClient(HostSpec, Settings); +} + std::string ZenCmdBase::ResolveTargetHostSpec(const std::string& InHostSpec, uint16_t& OutEffectivePort) { + if (IsUnixSocketSpec(InHostSpec)) + { + return InHostSpec; // pass through as-is; parsed later in CreateHttpClient + } + if (InHostSpec.empty()) { // If no host is specified then look to see if we have an instance @@ -386,8 +408,30 @@ ZenCmdBase::ResolveTargetHostSpec(const std::string& InHostSpec, uint16_t& OutEf Servers.Snapshot([&](const zen::ZenServerState::ZenServerEntry& Entry) { if (ResolvedSpec.empty()) { - ResolvedSpec = fmt::format("http://localhost:{}", Entry.EffectiveListenPort.load()); OutEffectivePort = Entry.EffectiveListenPort; + + // Check for per-instance info (e.g. UDS path) + if (Entry.HasInstanceInfo()) + { + ZenServerInstanceInfo Info; + if (Info.OpenReadOnly(Entry.GetSessionId())) + { + InstanceInfoData Data = Info.Read(); + if (!Data.UnixSocketPath.empty()) + { + ResolvedSpec = "unix://" + PathToUtf8(Data.UnixSocketPath); + return; + } + } + } + + // Skip servers with --no-network since TCP is not reachable + if (Entry.IsNoNetwork()) + { + return; + } + + ResolvedSpec = fmt::format("http://localhost:{}", Entry.EffectiveListenPort.load()); } }); @@ -685,7 +729,7 @@ main(int argc, char** argv) Options.add_options()("c, command", "Sub command", cxxopts::value<std::string>(SubCommand)); Options.add_options()("httpclient", "Select HTTP client implementation (e.g. 'curl', 'cpr')", - cxxopts::value<std::string>(GlobalOptions.HttpClientBackend)->default_value("cpr")); + cxxopts::value<std::string>(GlobalOptions.HttpClientBackend)->default_value("curl")); int CoreLimit = 0; diff --git a/src/zen/zen.h b/src/zen/zen.h index 3cc06eea6..05ce32d0a 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -5,6 +5,7 @@ #include <zencore/except.h> #include <zencore/timer.h> #include <zencore/zencore.h> +#include <zenhttp/httpclient.h> #include <zenutil/config/commandlineoptions.h> #include <zenutil/config/loggingconfig.h> @@ -68,6 +69,11 @@ public: static std::string ResolveTargetHostSpec(const std::string& InHostSpec); static std::string ResolveTargetHostSpec(const std::string& InHostSpec, uint16_t& OutEffectivePort); + static bool IsUnixSocketSpec(std::string_view Spec); + static HttpClient CreateHttpClient(const std::string& HostSpec, HttpClientSettings Settings = {}); + + static constexpr const char* kHostUrlHelp = "Host URL or unix:///path/to/socket"; + static void LogExecutableVersionAndPid(); }; diff --git a/src/zen/zen.rc b/src/zen/zen.rc index 0617681a7..3adf25b72 100644 --- a/src/zen/zen.rc +++ b/src/zen/zen.rc @@ -7,7 +7,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US #pragma code_page(1252) -101 ICON "..\\zen.ico" +101 ICON "../zen.ico" VS_VERSION_INFO VERSIONINFO FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 diff --git a/src/zenbase/include/zenbase/zenbase.h b/src/zenbase/include/zenbase/zenbase.h index 2aec1f314..1d5051c5b 100644 --- a/src/zenbase/include/zenbase/zenbase.h +++ b/src/zenbase/include/zenbase/zenbase.h @@ -211,7 +211,24 @@ char (&ZenArrayCountHelper(const T (&)[N]))[N + 1]; # define ZEN_EXE_SUFFIX_LITERAL "" #endif -#define ZEN_UNUSED(...) ((void)__VA_ARGS__) +#if ZEN_COMPILER_CLANG +// Clang warns about the comma operator in ((void)a, b) with -Wunused-value. +// Use a fold expression via a helper to suppress each argument individually. +namespace zen::detail { +inline void +unused_impl() +{ +} +template<typename... T> +inline void +unused_impl(T&&...) +{ +} +} // namespace zen::detail +# define ZEN_UNUSED(...) ::zen::detail::unused_impl(__VA_ARGS__) +#else +# define ZEN_UNUSED(...) ((void)__VA_ARGS__) +#endif ////////////////////////////////////////////////////////////////////////// diff --git a/src/zencore-test/targetver.h b/src/zencore-test/targetver.h index d432d6993..4805141de 100644 --- a/src/zencore-test/targetver.h +++ b/src/zencore-test/targetver.h @@ -7,4 +7,4 @@ // If you wish to build your application for a previous Windows platform, include WinSDKVer.h and // set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. -#include <SDKDDKVer.h> +#include <sdkddkver.h> diff --git a/src/zencore/callstack.cpp b/src/zencore/callstack.cpp index ee0b0625a..a16bb3f13 100644 --- a/src/zencore/callstack.cpp +++ b/src/zencore/callstack.cpp @@ -6,7 +6,7 @@ #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> -# include <Dbghelp.h> +# include <DbgHelp.h> #endif #if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index 8ed63565c..0d361801f 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -17,7 +17,7 @@ #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> -# include <ShlObj.h> +# include <shlobj.h> # pragma comment(lib, "shell32.lib") # pragma comment(lib, "ole32.lib") #endif @@ -32,17 +32,27 @@ ZEN_THIRD_PARTY_INCLUDES_END #if ZEN_PLATFORM_LINUX # include <dirent.h> # include <fcntl.h> +# include <linux/fs.h> +# include <linux/magic.h> +# include <sys/ioctl.h> # include <sys/resource.h> # include <sys/mman.h> # include <sys/stat.h> +# include <sys/vfs.h> # include <pwd.h> # include <unistd.h> +// XFS_SUPER_MAGIC is not always defined in linux/magic.h +# ifndef XFS_SUPER_MAGIC +# define XFS_SUPER_MAGIC 0x58465342 +# endif #endif #if ZEN_PLATFORM_MAC # include <dirent.h> # include <fcntl.h> # include <libproc.h> +# include <sys/attr.h> +# include <sys/clonefile.h> # include <sys/resource.h> # include <sys/mman.h> # include <sys/stat.h> @@ -59,6 +69,53 @@ namespace zen { using namespace std::literals; +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +struct ScopedFd +{ + int Fd = -1; + + ScopedFd() = default; + explicit ScopedFd(int InFd) : Fd(InFd) {} + + ~ScopedFd() + { + if (Fd >= 0) + { + close(Fd); + } + } + + ScopedFd(const ScopedFd&) = delete; + ScopedFd& operator=(const ScopedFd&) = delete; + + ScopedFd(ScopedFd&& Other) noexcept : Fd(Other.Fd) { Other.Fd = -1; } + + ScopedFd& operator=(ScopedFd&& Other) noexcept + { + if (this != &Other) + { + if (Fd >= 0) + { + close(Fd); + } + Fd = Other.Fd; + Other.Fd = -1; + } + return *this; + } + + // Release ownership of the file descriptor, returning it without closing + int Release() + { + int Result = Fd; + Fd = -1; + return Result; + } + + explicit operator bool() const { return Fd >= 0; } +}; +#endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + #if ZEN_PLATFORM_WINDOWS static bool @@ -615,6 +672,38 @@ SupportsBlockRefCounting(std::filesystem::path Path) } return true; +#elif ZEN_PLATFORM_LINUX + struct statfs Buf; + if (statfs(Path.c_str(), &Buf) != 0) + { + return false; + } + + // Btrfs and XFS (when formatted with reflink support) support FICLONE + return Buf.f_type == BTRFS_SUPER_MAGIC || Buf.f_type == XFS_SUPER_MAGIC; +#elif ZEN_PLATFORM_MAC + struct attrlist AttrList = {}; + AttrList.bitmapcount = ATTR_BIT_MAP_COUNT; + AttrList.volattr = ATTR_VOL_CAPABILITIES; + + struct + { + uint32_t Length; + vol_capabilities_attr_t Capabilities; + } AttrBuf = {}; + + if (getattrlist(Path.c_str(), &AttrList, &AttrBuf, sizeof(AttrBuf), 0) != 0) + { + return false; + } + + // Check that the VOL_CAP_INT_CLONE bit is both valid and set + if (!(AttrBuf.Capabilities.valid[VOL_CAPABILITIES_INTERFACES] & VOL_CAP_INT_CLONE)) + { + return false; + } + + return !!(AttrBuf.Capabilities.capabilities[VOL_CAPABILITIES_INTERFACES] & VOL_CAP_INT_CLONE); #else ZEN_UNUSED(Path); return false; @@ -768,7 +857,115 @@ private: DWORD m_TargetVolumeSerialNumber; }; -#endif // ZEN_PLATFORM_WINDOWS +#elif ZEN_PLATFORM_LINUX + +class LinuxCloneQueryInterface : public CloneQueryInterface +{ +public: + LinuxCloneQueryInterface(uint64_t AlignmentSize, dev_t TargetDevice) : m_AlignmentSize(AlignmentSize), m_TargetDevice(TargetDevice) {} + + virtual bool CanClone(void* SourceNativeHandle) override + { + int Fd = int(uintptr_t(SourceNativeHandle)); + + struct stat St; + if (fstat(Fd, &St) != 0) + { + return false; + } + + // Source must be on the same filesystem as the target + return St.st_dev == m_TargetDevice; + } + + virtual uint64_t GetClonableRange(uint64_t SourceOffset, + uint64_t TargetOffset, + uint64_t Size, + uint64_t& OutPreBytes, + uint64_t& OutPostBytes) override + { + if (Size < m_AlignmentSize) + { + return 0; + } + + uint64_t PreBytes = (m_AlignmentSize - (SourceOffset % m_AlignmentSize)) % m_AlignmentSize; + uint64_t PostBytes = (SourceOffset + Size) % m_AlignmentSize; + ZEN_ASSERT(Size >= PreBytes + PostBytes); + if (Size - (PreBytes + PostBytes) < m_AlignmentSize) + { + return 0; + } + ZEN_ASSERT((PreBytes < Size && PostBytes < Size && Size >= PreBytes + PostBytes + m_AlignmentSize)); + + const uint64_t DestCloneOffset = TargetOffset + PreBytes; + if (DestCloneOffset % m_AlignmentSize != 0) + { + return 0; + } + + OutPreBytes = PreBytes; + OutPostBytes = PostBytes; + uint64_t CloneSize = Size - (PreBytes + PostBytes); + ZEN_ASSERT(CloneSize % m_AlignmentSize == 0); + return CloneSize; + } + + virtual bool TryClone(void* SourceNativeHandle, + void* TargetNativeHandle, + uint64_t AlignedSourceOffset, + uint64_t AlignedTargetOffset, + uint64_t AlignedSize, + uint64_t TargetFinalSize) override + { + ZEN_ASSERT_SLOW(CanClone(SourceNativeHandle)); + ZEN_ASSERT((AlignedSourceOffset % m_AlignmentSize) == 0); + ZEN_ASSERT((AlignedTargetOffset % m_AlignmentSize) == 0); + ZEN_ASSERT(AlignedSize > 0); + ZEN_ASSERT((AlignedSize % m_AlignmentSize) == 0); + + int SourceFd = int(uintptr_t(SourceNativeHandle)); + int TargetFd = int(uintptr_t(TargetNativeHandle)); + + // Ensure the target file is sized to its final size before cloning + struct stat TargetSt; + if (fstat(TargetFd, &TargetSt) != 0 || uint64_t(TargetSt.st_size) != TargetFinalSize) + { + if (ftruncate(TargetFd, TargetFinalSize) != 0) + { + std::error_code DummyEc; + ZEN_DEBUG("Failed setting final size {} for file {}", TargetFinalSize, PathFromHandle(TargetNativeHandle, DummyEc)); + return false; + } + } + + struct file_clone_range Range = {}; + Range.src_fd = SourceFd; + Range.src_offset = AlignedSourceOffset; + Range.src_length = AlignedSize; + Range.dest_offset = AlignedTargetOffset; + + if (ioctl(TargetFd, FICLONERANGE, &Range) != 0) + { + std::error_code DummyEc; + ZEN_DEBUG("Failed cloning {} bytes from file {} at {} to file {} at {}", + AlignedSize, + PathFromHandle(SourceNativeHandle, DummyEc), + AlignedSourceOffset, + PathFromHandle(TargetNativeHandle, DummyEc), + AlignedTargetOffset); + return false; + } + + return true; + } + +private: + uint64_t m_AlignmentSize; + dev_t m_TargetDevice; +}; + +#endif // ZEN_PLATFORM_WINDOWS / ZEN_PLATFORM_LINUX std::unique_ptr<CloneQueryInterface> GetCloneQueryInterface(const std::filesystem::path& TargetDirectory) @@ -819,7 +1016,30 @@ GetCloneQueryInterface(const std::filesystem::path& TargetDirectory) return std::make_unique<WindowsCloneQueryInterface>(SectorsPerCluster * BytesPerSector, DestVolumeSerialNumber); } } -#else // ZEN_PLATFORM_WINDOWS +#elif ZEN_PLATFORM_LINUX + struct statfs FsBuf; + if (statfs(TargetDirectory.c_str(), &FsBuf) != 0) + { + ZEN_DEBUG("Failed to get filesystem info for path {}", TargetDirectory); + return {}; + } + + // Only Btrfs and XFS support FICLONERANGE + if (FsBuf.f_type != BTRFS_SUPER_MAGIC && FsBuf.f_type != XFS_SUPER_MAGIC) + { + return {}; + } + + struct stat StBuf; + if (stat(TargetDirectory.c_str(), &StBuf) != 0) + { + ZEN_DEBUG("Failed to stat path {}", TargetDirectory); + return {}; + } + + uint64_t AlignmentSize = FsBuf.f_bsize; + return std::make_unique<LinuxCloneQueryInterface>(AlignmentSize, StBuf.st_dev); +#else ZEN_UNUSED(TargetDirectory); #endif // ZEN_PLATFORM_WINDOWS return {}; @@ -1000,40 +1220,44 @@ TryCloneFile(const std::filesystem::path& FromPath, const std::filesystem::path& return TryCloneFile((void*)FromFile.m_Handle, (void*)TargetFile.m_Handle); #elif ZEN_PLATFORM_LINUX -# if 0 - struct ScopedFd - { - ~ScopedFd() { close(Fd); } - int Fd; - }; - // The 'from' file - int FromFd = open(FromPath.c_str(), O_RDONLY|O_CLOEXEC); - if (FromFd < 0) + ScopedFd FromFd(open(FromPath.c_str(), O_RDONLY | O_CLOEXEC)); + if (!FromFd) { return false; } - ScopedFd $From = { FromFd }; + + // Remove any existing target so we can create a fresh clone + unlink(ToPath.c_str()); // The 'to' file - int ToFd = open(ToPath.c_str(), O_WRONLY|O_CREAT|O_EXCL|O_CLOEXEC, 0666); - if (ToFd < 0) + ScopedFd ToFd(open(ToPath.c_str(), O_WRONLY | O_CREAT | O_EXCL | O_CLOEXEC, 0666)); + if (!ToFd) { return false; } - fchmod(ToFd, 0666); - ScopedFd $To = { ToFd }; - ioctl(ToFd, FICLONE, FromFd); + if (ioctl(ToFd.Fd, FICLONE, FromFd.Fd) != 0) + { + // Clone not supported by this filesystem or files are on different volumes. + // Remove the empty target file we created. + ToFd = ScopedFd(); + unlink(ToPath.c_str()); + return false; + } - return false; -# endif // 0 - ZEN_UNUSED(FromPath, ToPath); - return false; + return true; #elif ZEN_PLATFORM_MAC - /* clonefile() syscall if APFS */ - ZEN_UNUSED(FromPath, ToPath); - return false; + // Remove any existing target - clonefile() requires the destination not exist + unlink(ToPath.c_str()); + + if (clonefile(FromPath.c_str(), ToPath.c_str(), CLONE_NOFOLLOW) != 0) + { + // Clone not supported (non-APFS) or files are on different volumes + return false; + } + + return true; #endif // ZEN_PLATFORM_WINDOWS } @@ -1069,9 +1293,7 @@ CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToP if (Options.MustClone) { -#if ZEN_PLATFORM_MAC || ZEN_PLATFORM_LINUX - ZEN_ERROR("CloneFile() is not implemented on this platform"); -#endif // ZEN_PLATFORM_MAC || ZEN_PLATFORM_LINUX + ZEN_ERROR("CloneFile() failed for {} -> {}", FromPath, ToPath); return false; } @@ -1084,35 +1306,27 @@ CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToP &CancelFlag, /* dwCopyFlags */ 0); #else - struct ScopedFd - { - ~ScopedFd() { close(Fd); } - int Fd; - }; - // From file - int FromFd = open(FromPath.c_str(), O_RDONLY | O_CLOEXEC); - if (FromFd < 0) + ScopedFd FromFd(open(FromPath.c_str(), O_RDONLY | O_CLOEXEC)); + if (!FromFd) { ThrowLastError(fmt::format("failed to open file {}", FromPath)); } - ScopedFd $From = {FromFd}; // To file - int ToFd = open(ToPath.c_str(), O_WRONLY | O_CREAT | O_CLOEXEC, 0666); - if (ToFd < 0) + ScopedFd ToFd(open(ToPath.c_str(), O_WRONLY | O_CREAT | O_CLOEXEC, 0666)); + if (!ToFd) { ThrowLastError(fmt::format("failed to create file {}", ToPath)); } - fchmod(ToFd, 0666); - ScopedFd $To = {ToFd}; + fchmod(ToFd.Fd, 0666); struct stat Stat; - fstat(FromFd, &Stat); + fstat(FromFd.Fd, &Stat); size_t FileSizeBytes = Stat.st_size; - int $Ignore = fchown(ToFd, Stat.st_uid, Stat.st_gid); + int $Ignore = fchown(ToFd.Fd, Stat.st_uid, Stat.st_gid); ZEN_UNUSED($Ignore); // What's the appropriate error handling here? // Copy impl @@ -1120,14 +1334,14 @@ CopyFile(const std::filesystem::path& FromPath, const std::filesystem::path& ToP void* Buffer = malloc(BufferSize); while (true) { - int BytesRead = read(FromFd, Buffer, BufferSize); + int BytesRead = read(FromFd.Fd, Buffer, BufferSize); if (BytesRead <= 0) { Success = (BytesRead == 0); break; } - if (write(ToFd, Buffer, BytesRead) != BytesRead) + if (write(ToFd.Fd, Buffer, BytesRead) != BytesRead) { Success = false; break; @@ -1371,20 +1585,20 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer } #else - int OpenFlags = O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC; - int Fd = open(Path.c_str(), OpenFlags, 0666); - if (Fd < 0) + int OpenFlags = O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC; + ScopedFd OutFd(open(Path.c_str(), OpenFlags, 0666)); + if (!OutFd) { zen::CreateDirectories(Path.parent_path()); - Fd = open(Path.c_str(), OpenFlags, 0666); + OutFd = ScopedFd(open(Path.c_str(), OpenFlags, 0666)); } - if (Fd < 0) + if (!OutFd) { ThrowLastError(fmt::format("File open failed for '{}'", Path)); } - fchmod(Fd, 0666); + fchmod(OutFd.Fd, 0666); #endif // TODO: this should be block-enlightened @@ -1408,9 +1622,9 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str()); } #else - if (write(Fd, DataPtr, ChunkSize) != int64_t(ChunkSize)) + if (write(OutFd.Fd, DataPtr, ChunkSize) != int64_t(ChunkSize)) { - close(Fd); + OutFd = ScopedFd(); std::error_code DummyEc; RemoveFile(Path, DummyEc); ThrowLastError(fmt::format("File write failed for '{}'", Path)); @@ -1424,8 +1638,6 @@ WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t Buffer #if ZEN_PLATFORM_WINDOWS Outfile.Close(); -#else - close(Fd); #endif } @@ -1707,8 +1919,8 @@ ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<voi ProcessFunc(ReadBuffer.data(), dwBytesRead); } #else - int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC); - if (Fd < 0) + ScopedFd InFd(open(Path.c_str(), O_RDONLY | O_CLOEXEC)); + if (!InFd) { return false; } @@ -1718,7 +1930,7 @@ ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<voi void* Buffer = malloc(ChunkSize); while (true) { - int BytesRead = read(Fd, Buffer, ChunkSize); + int BytesRead = read(InFd.Fd, Buffer, ChunkSize); if (BytesRead < 0) { Success = false; @@ -1734,7 +1946,6 @@ ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<voi } free(Buffer); - close(Fd); if (!Success) { @@ -3123,28 +3334,26 @@ public: ZEN_UNUSED(SystemGlobal); std::string InstanceMapName = fmt::format("/{}", Name); - int Fd = shm_open(InstanceMapName.c_str(), O_RDWR, 0666); - if (Fd < 0) + ScopedFd FdGuard(shm_open(InstanceMapName.c_str(), O_RDWR, 0666)); + if (!FdGuard) { return {}; } - void* hMap = (void*)intptr_t(Fd); struct stat Stat; - fstat(Fd, &Stat); + fstat(FdGuard.Fd, &Stat); if (size_t(Stat.st_size) < Size) { - close(Fd); return {}; } - void* pBuf = mmap(nullptr, Size, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + void* pBuf = mmap(nullptr, Size, PROT_READ | PROT_WRITE, MAP_SHARED, FdGuard.Fd, 0); if (pBuf == MAP_FAILED) { - close(Fd); return {}; } + void* hMap = (void*)intptr_t(FdGuard.Release()); return Data{.Handle = hMap, .DataPtr = pBuf, .Size = Size, .Name = std::string(Name)}; #endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC } @@ -3199,23 +3408,22 @@ public: ZEN_UNUSED(SystemGlobal); std::string InstanceMapName = fmt::format("/{}", Name); - int Fd = shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); - if (Fd < 0) + ScopedFd FdGuard(shm_open(InstanceMapName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666)); + if (!FdGuard) { return {}; } - fchmod(Fd, 0666); - void* hMap = (void*)intptr_t(Fd); + fchmod(FdGuard.Fd, 0666); - int Result = ftruncate(Fd, Size); + int Result = ftruncate(FdGuard.Fd, Size); ZEN_UNUSED(Result); - void* pBuf = mmap(nullptr, Size, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + void* pBuf = mmap(nullptr, Size, PROT_READ | PROT_WRITE, MAP_SHARED, FdGuard.Fd, 0); if (pBuf == MAP_FAILED) { - close(Fd); return {}; } + void* hMap = (void*)intptr_t(FdGuard.Release()); return Data{.Handle = hMap, .DataPtr = pBuf, .Size = Size, .Name = std::string(Name)}; #endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC } @@ -3590,6 +3798,241 @@ TEST_CASE("RotateDirectories") } } +TEST_CASE("TryCloneFile") +{ + std::filesystem::path TestBaseDir = GetRunningExecutablePath().parent_path() / ".clone_test"; + CleanDirectory(TestBaseDir, true); + + SUBCASE("clone produces identical content") + { + std::filesystem::path SrcPath = TestBaseDir / "src.bin"; + std::filesystem::path DstPath = TestBaseDir / "dst.bin"; + + // Write source file with known content + const char Content[] = "Hello, clone world! This is test data for TryCloneFile."; + WriteFile(SrcPath, IoBuffer(IoBuffer::Wrap, Content, sizeof(Content))); + CHECK(IsFile(SrcPath)); + + bool Cloned = TryCloneFile(SrcPath, DstPath); + + if (Cloned) + { + CHECK(IsFile(DstPath)); + CHECK_EQ(FileSizeFromPath(DstPath), sizeof(Content)); + + FileContents DstContents = ReadFile(DstPath); + CHECK(DstContents); + CHECK_EQ(DstContents.Data[0].GetSize(), sizeof(Content)); + CHECK_EQ(memcmp(DstContents.Data[0].Data(), Content, sizeof(Content)), 0); + } + else + { + // Clone not supported on this filesystem - that's okay, just verify it didn't leave debris + ZEN_INFO("TryCloneFile not supported on this filesystem, skipping content check"); + } + } + + SUBCASE("clone overwrites existing target") + { + std::filesystem::path SrcPath = TestBaseDir / "src_overwrite.bin"; + std::filesystem::path DstPath = TestBaseDir / "dst_overwrite.bin"; + + const char OldContent[] = "old content"; + const char NewContent[] = "new content that is longer than the old one"; + WriteFile(DstPath, IoBuffer(IoBuffer::Wrap, OldContent, sizeof(OldContent))); + WriteFile(SrcPath, IoBuffer(IoBuffer::Wrap, NewContent, sizeof(NewContent))); + + bool Cloned = TryCloneFile(SrcPath, DstPath); + + if (Cloned) + { + CHECK_EQ(FileSizeFromPath(DstPath), sizeof(NewContent)); + + FileContents DstContents = ReadFile(DstPath); + CHECK(DstContents); + CHECK_EQ(memcmp(DstContents.Data[0].Data(), NewContent, sizeof(NewContent)), 0); + } + } + + SUBCASE("clone of nonexistent source fails") + { + std::filesystem::path SrcPath = TestBaseDir / "no_such_file.bin"; + std::filesystem::path DstPath = TestBaseDir / "dst_nosrc.bin"; + + CHECK_FALSE(TryCloneFile(SrcPath, DstPath)); + CHECK_FALSE(IsFile(DstPath)); + } + + DeleteDirectories(TestBaseDir); +} + +TEST_CASE("CopyFile.Clone") +{ + std::filesystem::path TestBaseDir = GetRunningExecutablePath().parent_path() / ".copyfile_clone_test"; + CleanDirectory(TestBaseDir, true); + + const char Content[] = "CopyFile clone test content with some bytes to verify integrity."; + std::filesystem::path SrcPath = TestBaseDir / "src.bin"; + WriteFile(SrcPath, IoBuffer(IoBuffer::Wrap, Content, sizeof(Content))); + + SUBCASE("EnableClone copies file regardless of clone support") + { + std::filesystem::path DstPath = TestBaseDir / "dst_enable.bin"; + + CopyFileOptions Options; + Options.EnableClone = true; + bool Success = CopyFile(SrcPath, DstPath, Options); + CHECK(Success); + CHECK(IsFile(DstPath)); + CHECK_EQ(FileSizeFromPath(DstPath), sizeof(Content)); + + FileContents DstContents = ReadFile(DstPath); + CHECK(DstContents); + CHECK_EQ(memcmp(DstContents.Data[0].Data(), Content, sizeof(Content)), 0); + } + + SUBCASE("DisableClone still copies file") + { + std::filesystem::path DstPath = TestBaseDir / "dst_disable.bin"; + + CopyFileOptions Options; + Options.EnableClone = false; + bool Success = CopyFile(SrcPath, DstPath, Options); + CHECK(Success); + CHECK(IsFile(DstPath)); + CHECK_EQ(FileSizeFromPath(DstPath), sizeof(Content)); + + FileContents DstContents = ReadFile(DstPath); + CHECK(DstContents); + CHECK_EQ(memcmp(DstContents.Data[0].Data(), Content, sizeof(Content)), 0); + } + + DeleteDirectories(TestBaseDir); +} + +TEST_CASE("SupportsBlockRefCounting") +{ + std::filesystem::path BinDir = GetRunningExecutablePath().parent_path(); + + // Should not crash or throw on a valid path + bool Supported = SupportsBlockRefCounting(BinDir); + ZEN_INFO("SupportsBlockRefCounting({}) = {}", BinDir, Supported); + + // Should return false for nonexistent path + CHECK_FALSE(SupportsBlockRefCounting("/no/such/path/anywhere")); +} + +TEST_CASE("CloneQueryInterface") +{ + std::filesystem::path TestBaseDir = GetRunningExecutablePath().parent_path() / ".clonequery_test"; + CleanDirectory(TestBaseDir, true); + + auto CloneQuery = GetCloneQueryInterface(TestBaseDir); + + if (CloneQuery) + { + ZEN_INFO("CloneQueryInterface available for {}", TestBaseDir); + + // Write a source file large enough to exercise alignment + const uint64_t FileSize = 256 * 1024; + IoBuffer SrcBuf(FileSize); + { + uint8_t* Ptr = SrcBuf.MutableData<uint8_t>(); + for (uint64_t i = 0; i < FileSize; i++) + { + Ptr[i] = uint8_t(i * 37 + 7); + } + } + + std::filesystem::path SrcPath = TestBaseDir / "clone_src.bin"; + std::filesystem::path DstPath = TestBaseDir / "clone_dst.bin"; + WriteFile(SrcPath, SrcBuf); + + // Open source and target as native handles +# if ZEN_PLATFORM_WINDOWS + windows::Handle SrcHandle(CreateFileW(SrcPath.c_str(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, + OPEN_EXISTING, + 0, + nullptr)); + CHECK(SrcHandle != INVALID_HANDLE_VALUE); + void* SrcNativeHandle = (void*)SrcHandle.m_Handle; + + windows::Handle DstHandle( + CreateFileW(DstPath.c_str(), GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, nullptr, OPEN_ALWAYS, 0, nullptr)); + CHECK(DstHandle != INVALID_HANDLE_VALUE); + void* DstNativeHandle = (void*)DstHandle.m_Handle; +# else + ScopedFd SrcFd(open(SrcPath.c_str(), O_RDONLY | O_CLOEXEC)); + CHECK(bool(SrcFd)); + void* SrcNativeHandle = (void*)uintptr_t(SrcFd.Fd); + + ScopedFd DstFd(open(DstPath.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666)); + CHECK(bool(DstFd)); + void* DstNativeHandle = (void*)uintptr_t(DstFd.Fd); +# endif + + SUBCASE("CanClone returns true for same volume") { CHECK(CloneQuery->CanClone(SrcNativeHandle)); } + + SUBCASE("GetClonableRange and TryClone") + { + uint64_t PreBytes = 0; + uint64_t PostBytes = 0; + uint64_t Clonable = CloneQuery->GetClonableRange(0, 0, FileSize, PreBytes, PostBytes); + + if (Clonable > 0) + { + CHECK_EQ(PreBytes, 0); // Offset 0 is always aligned + CHECK(Clonable + PostBytes == FileSize); + + bool Cloned = CloneQuery->TryClone(SrcNativeHandle, DstNativeHandle, 0, 0, Clonable, FileSize); + CHECK(Cloned); + + if (Cloned) + { + // Write the post-alignment tail if any + if (PostBytes > 0) + { + const uint8_t* SrcData = SrcBuf.Data<uint8_t>() + Clonable; +# if ZEN_PLATFORM_WINDOWS + DWORD Written = 0; + OVERLAPPED Ov = {}; + Ov.Offset = (DWORD)(Clonable & 0xFFFFFFFF); + Ov.OffsetHigh = (DWORD)(Clonable >> 32); + ::WriteFile(DstHandle, SrcData, (DWORD)PostBytes, &Written, &Ov); +# else + pwrite(DstFd.Fd, SrcData, PostBytes, Clonable); +# endif + } + + // Close handles before reading back the file for verification +# if ZEN_PLATFORM_WINDOWS + SrcHandle.Close(); + DstHandle.Close(); +# else + SrcFd = ScopedFd(); + DstFd = ScopedFd(); +# endif + + FileContents DstContents = ReadFile(DstPath); + CHECK(DstContents); + IoBuffer DstFlat = DstContents.Flatten(); + CHECK_EQ(DstFlat.GetSize(), FileSize); + CHECK_EQ(memcmp(DstFlat.Data(), SrcBuf.Data(), FileSize), 0); + } + } + } + } + else + { + ZEN_INFO("CloneQueryInterface not available for {} (filesystem does not support block cloning)", TestBaseDir); + } + + DeleteDirectories(TestBaseDir); +} + TEST_CASE("SharedMemory") { CHECK(!OpenSharedMemory("SharedMemoryTest0", 482, false)); diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h index 16e2b59f8..6dc159a83 100644 --- a/src/zencore/include/zencore/filesystem.h +++ b/src/zencore/include/zencore/filesystem.h @@ -187,6 +187,14 @@ void ScanFile(void* NativeHandle, void WriteFile(void* NativeHandle, const void* Data, uint64_t Size, uint64_t FileOffset, uint64_t ChunkSize, std::error_code& Ec); void ReadFile(void* NativeHandle, void* Data, uint64_t Size, uint64_t FileOffset, uint64_t ChunkSize, std::error_code& Ec); +// Interface for sub-file range cloning on filesystems that support copy-on-write. +// GetCloneQueryInterface() returns nullptr on platforms without range clone support. +// +// Platform capabilities: +// Windows (ReFS) - True CoW range cloning via FSCTL_DUPLICATE_EXTENTS_TO_FILE. +// Linux (Btrfs/XFS) - True CoW range cloning via FICLONERANGE ioctl. +// macOS (APFS) - Not implemented. No sub-file range clone API exists. +// Whole-file CoW cloning is available via TryCloneFile (clonefile syscall). class CloneQueryInterface { public: diff --git a/src/zencore/include/zencore/logbase.h b/src/zencore/include/zencore/logbase.h index ad2ab218d..046e96db3 100644 --- a/src/zencore/include/zencore/logbase.h +++ b/src/zencore/include/zencore/logbase.h @@ -101,7 +101,7 @@ struct LoggerRef inline logging::Logger* operator->() const; inline logging::Logger& operator*() const; - bool ShouldLog(logging::LogLevel Level) const { return m_Logger->ShouldLog(Level); } + bool ShouldLog(logging::LogLevel Level) const { return m_Logger && m_Logger->ShouldLog(Level); } void SetLogLevel(logging::LogLevel NewLogLevel) { m_Logger->SetLevel(NewLogLevel); } logging::LogLevel GetLogLevel() { return m_Logger->GetLevel(); } diff --git a/src/zencore/include/zencore/process.h b/src/zencore/include/zencore/process.h index 809312c7b..3177f64c1 100644 --- a/src/zencore/include/zencore/process.h +++ b/src/zencore/include/zencore/process.h @@ -6,6 +6,9 @@ #include <zencore/zencore.h> #include <filesystem> +#include <string> +#include <utility> +#include <vector> namespace zen { @@ -68,6 +71,12 @@ struct CreateProcOptions const std::filesystem::path* WorkingDirectory = nullptr; uint32_t Flags = 0; std::filesystem::path StdoutFile; + + /// Additional environment variables for the child process. These are merged + /// with the parent's environment — existing variables are inherited, and + /// entries here override or add to them. + std::vector<std::pair<std::string, std::string>> Environment; + #if ZEN_PLATFORM_WINDOWS JobObject* AssignToJob = nullptr; // When set, the process is created suspended, assigned to the job, then resumed #endif diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index 4deca63ed..60293a313 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -331,9 +331,10 @@ public: return AppendAscii(Str); } -#if defined(__clang__) && !defined(__APPLE__) +#if defined(__clang__) && !defined(__APPLE__) && !defined(_MSC_VER) /* UE Toolchain Clang has different types for int64_t and long long so an override is - needed here. Without it, Clang can't disambiguate integer overloads */ + needed here. Without it, Clang can't disambiguate integer overloads. + On MSVC ABI (including clang-cl), int64_t is long long so no separate overload is needed. */ inline StringBuilderImpl& operator<<(long long n) { IntNum Str(n); @@ -953,6 +954,24 @@ StrCaseCompare(const char* Lhs, const char* Rhs, int64_t Length = -1) #endif } +inline int32_t +StrCaseCompare(std::string_view Lhs, std::string_view Rhs) +{ + int32_t Result = StrCaseCompare(Lhs.data(), Rhs.data(), std::min(Lhs.size(), Rhs.size())); + if (Result == 0) + { + if (Lhs.size() < Rhs.size()) + { + return -1; + } + else if (Lhs.size() > Rhs.size()) + { + return 1; + } + } + return Result; +} + /** * @brief * Helper function to implement case sensitive spaceship operator for strings. diff --git a/src/zencore/include/zencore/system.h b/src/zencore/include/zencore/system.h index a67999e52..2e39cc660 100644 --- a/src/zencore/include/zencore/system.h +++ b/src/zencore/include/zencore/system.h @@ -17,6 +17,7 @@ std::string_view GetOperatingSystemName(); std::string GetOperatingSystemVersion(); std::string_view GetRuntimePlatformName(); // "windows", "wine", "linux", or "macos" std::string_view GetCpuName(); +std::string_view GetCompilerName(); struct SystemMetrics { diff --git a/src/zencore/memtrack/callstacktrace.cpp b/src/zencore/memtrack/callstacktrace.cpp index 4a7068568..ccbea1282 100644 --- a/src/zencore/memtrack/callstacktrace.cpp +++ b/src/zencore/memtrack/callstacktrace.cpp @@ -193,8 +193,12 @@ private: std::atomic_uint32_t CallstackIdCounter{1}; // 0 is reserved for "unknown callstack" }; +} // namespace zen + # if UE_CALLSTACK_TRACE_USE_UNWIND_TABLES +namespace zen { + /* * Windows' x64 binaries contain a ".pdata" section that describes the location * and size of its functions and details on how to unwind them. The unwind @@ -908,98 +912,110 @@ FBacktracer::GetBacktraceId(void* AddressOfReturnAddress) // queue (i.e. the processing thread has caught up processing). return CallstackTracer.AddCallstack(BacktraceEntry); } -} +} // namespace zen # else // UE_CALLSTACK_TRACE_USE_UNWIND_TABLES namespace zen { - //////////////////////////////////////////////////////////////////////////////// - class FBacktracer - { - public: - FBacktracer(FMalloc* InMalloc); - ~FBacktracer(); - static FBacktracer* Get(); - inline uint32_t GetBacktraceId(void* AddressOfReturnAddress); - uint32_t GetBacktraceId(uint64_t ReturnAddress); - void AddModule(uintptr_t Base, const char16_t* Name) {} - void RemoveModule(uintptr_t Base) {} - - private: - static FBacktracer* Instance; - FMalloc* Malloc; - FCallstackTracer CallstackTracer; - }; +//////////////////////////////////////////////////////////////////////////////// +class FBacktracer +{ +public: + FBacktracer(FMalloc* InMalloc); + ~FBacktracer(); + static FBacktracer* Get(); + inline uint32_t GetBacktraceId(void* AddressOfReturnAddress); + uint32_t GetBacktraceId(uint64_t ReturnAddress); + void AddModule(uintptr_t /*Base*/, const char16_t* /*Name*/) {} + void RemoveModule(uintptr_t /*Base*/) {} - //////////////////////////////////////////////////////////////////////////////// - FBacktracer* FBacktracer::Instance = nullptr; +private: + static FBacktracer* Instance; + FMalloc* Malloc; + FCallstackTracer CallstackTracer; +}; - //////////////////////////////////////////////////////////////////////////////// - FBacktracer::FBacktracer(FMalloc* InMalloc) : Malloc(InMalloc), CallstackTracer(InMalloc) { Instance = this; } +//////////////////////////////////////////////////////////////////////////////// +FBacktracer* FBacktracer::Instance = nullptr; - //////////////////////////////////////////////////////////////////////////////// - FBacktracer::~FBacktracer() {} +//////////////////////////////////////////////////////////////////////////////// +FBacktracer::FBacktracer(FMalloc* InMalloc) : Malloc(InMalloc), CallstackTracer(InMalloc) +{ + Instance = this; +} - //////////////////////////////////////////////////////////////////////////////// - FBacktracer* FBacktracer::Get() { return Instance; } +//////////////////////////////////////////////////////////////////////////////// +FBacktracer::~FBacktracer() +{ +} - //////////////////////////////////////////////////////////////////////////////// - uint32_t FBacktracer::GetBacktraceId(void* AddressOfReturnAddress) - { - const uint64_t ReturnAddress = *(uint64_t*)AddressOfReturnAddress; - return GetBacktraceId(ReturnAddress); - } +//////////////////////////////////////////////////////////////////////////////// +FBacktracer* +FBacktracer::Get() +{ + return Instance; +} - //////////////////////////////////////////////////////////////////////////////// - uint32_t FBacktracer::GetBacktraceId(uint64_t ReturnAddress) - { +//////////////////////////////////////////////////////////////////////////////// +uint32_t +FBacktracer::GetBacktraceId(void* AddressOfReturnAddress) +{ + const uint64_t ReturnAddress = *(uint64_t*)AddressOfReturnAddress; + return GetBacktraceId(ReturnAddress); +} + +//////////////////////////////////////////////////////////////////////////////// +uint32_t +FBacktracer::GetBacktraceId(uint64_t ReturnAddress) +{ + ZEN_UNUSED(ReturnAddress); # if !UE_BUILD_SHIPPING - uint64_t StackFrames[256]; - int32_t NumStackFrames = FPlatformStackWalk::CaptureStackBackTrace(StackFrames, UE_ARRAY_COUNT(StackFrames)); - if (NumStackFrames > 0) + uint64_t StackFrames[256]; + int32_t NumStackFrames = FPlatformStackWalk::CaptureStackBackTrace(StackFrames, UE_ARRAY_COUNT(StackFrames)); + if (NumStackFrames > 0) + { + FCallstackTracer::FBacktraceEntry BacktraceEntry; + uint64_t BacktraceId = 0; + uint32_t FrameIdx = 0; + bool bUseAddress = false; + for (int32_t Index = 0; Index < NumStackFrames; Index++) { - FCallstackTracer::FBacktraceEntry BacktraceEntry; - uint64_t BacktraceId = 0; - uint32_t FrameIdx = 0; - bool bUseAddress = false; - for (int32_t Index = 0; Index < NumStackFrames; Index++) + if (!bUseAddress) { - if (!bUseAddress) - { - // start using backtrace only after ReturnAddress - if (StackFrames[Index] == (uint64_t)ReturnAddress) - { - bUseAddress = true; - } - } - if (bUseAddress || NumStackFrames == 1) + // start using backtrace only after ReturnAddress + if (StackFrames[Index] == (uint64_t)ReturnAddress) { - uint64_t RetAddr = StackFrames[Index]; - StackFrames[FrameIdx++] = RetAddr; - - // This is a simple order-dependent LCG. Should be sufficient enough - BacktraceId += RetAddr; - BacktraceId *= 0x30be8efa499c249dull; + bUseAddress = true; } } + if (bUseAddress || NumStackFrames == 1) + { + uint64_t RetAddr = StackFrames[Index]; + StackFrames[FrameIdx++] = RetAddr; - // Save the collected id - BacktraceEntry.Hash = BacktraceId; - BacktraceEntry.FrameCount = FrameIdx; - BacktraceEntry.Frames = StackFrames; - - // Add to queue to be processed. This might block until there is room in the - // queue (i.e. the processing thread has caught up processing). - return CallstackTracer.AddCallstack(BacktraceEntry); + // This is a simple order-dependent LCG. Should be sufficient enough + BacktraceId += RetAddr; + BacktraceId *= 0x30be8efa499c249dull; + } } -# endif - return 0; + // Save the collected id + BacktraceEntry.Hash = BacktraceId; + BacktraceEntry.FrameCount = FrameIdx; + BacktraceEntry.Frames = StackFrames; + + // Add to queue to be processed. This might block until there is room in the + // queue (i.e. the processing thread has caught up processing). + return CallstackTracer.AddCallstack(BacktraceEntry); } +# endif + return 0; } +} // namespace zen + # endif // UE_CALLSTACK_TRACE_USE_UNWIND_TABLES namespace zen { @@ -1047,7 +1063,7 @@ CallstackTrace_GetCurrentId() # if PLATFORM_USE_CALLSTACK_ADDRESS_POINTER return Instance->GetBacktraceId(StackAddress); # else - return Instance->GetBacktraceId((uint64_t)StackAddress); + return Instance->GetBacktraceId((uint64_t)StackAddress); # endif } diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index f657869dc..080607f13 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -11,6 +11,7 @@ #include <zencore/timer.h> #include <zencore/trace.h> +#include <map> #include <thread> ZEN_THIRD_PARTY_INCLUDES_START @@ -20,8 +21,8 @@ ZEN_THIRD_PARTY_INCLUDES_START # include <Psapi.h> # include <shellapi.h> -# include <Shlobj.h> -# include <TlHelp32.h> +# include <shlobj.h> +# include <tlhelp32.h> #else # include <fcntl.h> # include <pthread.h> @@ -487,13 +488,57 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma STARTUPINFO StartupInfo{.cb = sizeof(STARTUPINFO)}; bool InheritHandles = false; - void* Environment = nullptr; LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr; LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; + // Build environment block when custom environment variables are specified + ExtendableWideStringBuilder<512> EnvironmentBlock; + void* Environment = nullptr; + if (!Options.Environment.empty()) + { + // Capture current environment into a map + std::map<std::wstring, std::wstring> EnvMap; + wchar_t* EnvStrings = GetEnvironmentStringsW(); + if (EnvStrings) + { + for (const wchar_t* Ptr = EnvStrings; *Ptr; Ptr += wcslen(Ptr) + 1) + { + std::wstring_view Entry(Ptr); + size_t EqPos = Entry.find(L'='); + if (EqPos != std::wstring_view::npos && EqPos > 0) + { + EnvMap[std::wstring(Entry.substr(0, EqPos))] = std::wstring(Entry.substr(EqPos + 1)); + } + } + FreeEnvironmentStringsW(EnvStrings); + } + + // Apply overrides + for (const auto& [Key, Value] : Options.Environment) + { + EnvMap[Utf8ToWide(Key)] = Utf8ToWide(Value); + } + + // Build double-null-terminated environment block + for (const auto& [Key, Value] : EnvMap) + { + EnvironmentBlock << Key; + EnvironmentBlock.Append(L'='); + EnvironmentBlock << Value; + EnvironmentBlock.Append(L'\0'); + } + EnvironmentBlock.Append(L'\0'); + + Environment = EnvironmentBlock.Data(); + } + const bool AssignToJob = Options.AssignToJob && Options.AssignToJob->IsValid(); DWORD CreationFlags = 0; + if (Environment) + { + CreationFlags |= CREATE_UNICODE_ENVIRONMENT; + } if (Options.Flags & CreateProcOptions::Flag_NewConsole) { CreationFlags |= CREATE_NEW_CONSOLE; @@ -790,6 +835,11 @@ CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine } } + for (const auto& [Key, Value] : Options.Environment) + { + setenv(Key.c_str(), Value.c_str(), 1); + } + if (execv(Executable.c_str(), ArgV.data()) < 0) { ThrowLastError("Failed to exec() a new process image"); diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index 58b76783a..b7d01003b 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -31,6 +31,8 @@ namespace { struct SentryAssertImpl : zen::AssertImpl { + ZEN_DEBUG_SECTION ~SentryAssertImpl() override = default; + virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename, int LineNumber, const char* FunctionName, diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp index ed0ba6f46..358722b0b 100644 --- a/src/zencore/string.cpp +++ b/src/zencore/string.cpp @@ -1181,6 +1181,17 @@ TEST_CASE("string") CHECK(StrCaseCompare("BBr", "Bar", 2) > 0); } + SUBCASE("StrCaseCompare") + { + CHECK(StrCaseCompare("foo"sv, "FoO"sv) == 0); + CHECK(StrCaseCompare("foo"sv, "FoOz"sv) < 0); + CHECK(StrCaseCompare("fooo"sv, "FoO"sv) > 0); + CHECK(StrCaseCompare("Bar"sv, "bAs"sv) < 0); + CHECK(StrCaseCompare("bAr"sv, "Bas"sv) < 0); + CHECK(StrCaseCompare("BBr"sv, "Bar"sv) > 0); + CHECK(StrCaseCompare("Bbr"sv, "BAr"sv) > 0); + } + SUBCASE("ForEachStrTok") { const auto Tokens = "here,is,my,different,tokens"sv; diff --git a/src/zencore/system.cpp b/src/zencore/system.cpp index 141450b84..8985a8a76 100644 --- a/src/zencore/system.cpp +++ b/src/zencore/system.cpp @@ -660,6 +660,24 @@ GetCpuName() #endif } +std::string_view +GetCompilerName() +{ +#define ZEN_STRINGIFY_IMPL(x) #x +#define ZEN_STRINGIFY(x) ZEN_STRINGIFY_IMPL(x) +#if ZEN_COMPILER_CLANG + return "clang " ZEN_STRINGIFY(__clang_major__) "." ZEN_STRINGIFY(__clang_minor__) "." ZEN_STRINGIFY(__clang_patchlevel__); +#elif ZEN_COMPILER_MSC + return "MSVC " ZEN_STRINGIFY(_MSC_VER); +#elif ZEN_COMPILER_GCC + return "GCC " ZEN_STRINGIFY(__GNUC__) "." ZEN_STRINGIFY(__GNUC_MINOR__) "." ZEN_STRINGIFY(__GNUC_PATCHLEVEL__); +#else + return "unknown"; +#endif +#undef ZEN_STRINGIFY +#undef ZEN_STRINGIFY_IMPL +} + void Describe(const SystemMetrics& Metrics, CbWriter& Writer) { diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index d7eb3b17d..f5bc723b1 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -24,6 +24,8 @@ # if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC # include <execinfo.h> # include <unistd.h> +# elif ZEN_PLATFORM_WINDOWS +# include <crtdbg.h> # endif namespace zen::testing { @@ -296,6 +298,17 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink } # endif +# if ZEN_PLATFORM_WINDOWS + // Suppress Windows error dialogs (crash/abort/assert) so tests terminate + // immediately instead of blocking on a modal dialog in CI or headless runs. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX); + _set_abort_behavior(0, _WRITE_ABORT_MSG); + _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE); + _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); + _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE); + _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDERR); +# endif + zen::logging::InitializeLogging(); zen::MaximizeOpenFileCount(); InstallCrashSignalHandlers(); diff --git a/src/zenhorde/hordetransportaes.cpp b/src/zenhorde/hordetransportaes.cpp index 986dd3705..505b6bde7 100644 --- a/src/zenhorde/hordetransportaes.cpp +++ b/src/zenhorde/hordetransportaes.cpp @@ -12,7 +12,7 @@ #if ZEN_PLATFORM_WINDOWS # include <zencore/windows.h> # include <bcrypt.h> -# pragma comment(lib, "Bcrypt.lib") +# pragma comment(lib, "bcrypt.lib") #else ZEN_THIRD_PARTY_INCLUDES_START # include <openssl/evp.h> diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index a0f5cc38f..a52b8f74b 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -7,6 +7,7 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compress.h> +#include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/session.h> @@ -513,7 +514,7 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, if (!ConnectionSettings.UnixSocketPath.empty()) { - CprSession->SetUnixSocket(cpr::UnixSocket(ConnectionSettings.UnixSocketPath)); + CprSession->SetUnixSocket(cpr::UnixSocket(PathToUtf8(ConnectionSettings.UnixSocketPath))); } if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty()) diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp index 341adc5f7..ec9b7bac6 100644 --- a/src/zenhttp/clients/httpclientcurl.cpp +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -7,6 +7,8 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/session.h> @@ -93,15 +95,11 @@ struct HeaderCallbackData std::vector<std::pair<std::string, std::string>>* Headers = nullptr; }; -static size_t -CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +static std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) { - auto* Data = static_cast<HeaderCallbackData*>(UserData); - size_t TotalBytes = Size * Nmemb; - - std::string_view Line(Buffer, TotalBytes); - - // Trim trailing \r\n while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) { Line.remove_suffix(1); @@ -109,25 +107,39 @@ CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) if (Line.empty()) { - return TotalBytes; + return std::nullopt; } size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) + if (ColonPos == std::string_view::npos) { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); + return std::nullopt; + } - // Trim whitespace - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; Data->Headers->emplace_back(std::string(Key), std::string(Value)); } @@ -285,57 +297,102 @@ BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, for (const auto& [Key, Value] : *AdditionalHeader) { - std::string HeaderLine = fmt::format("{}: {}", Key, Value); - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); } if (!SessionId.empty()) { - std::string SessionHeader = fmt::format("UE-Session: {}", SessionId); - Headers = curl_slist_append(Headers, SessionHeader.c_str()); + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); } if (AccessToken) { - std::string AuthHeader = fmt::format("Authorization: {}", AccessToken->Value); - Headers = curl_slist_append(Headers, AuthHeader.c_str()); + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken->Value; + Headers = curl_slist_append(Headers, AuthHeader.c_str()); } for (const auto& [Key, Value] : ExtraHeaders) { - std::string HeaderLine = fmt::format("{}: {}", Key, Value); - Headers = curl_slist_append(Headers, HeaderLine.c_str()); + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); } return Headers; } -static std::string -BuildUrlWithParameters(std::string_view BaseUrl, std::string_view ResourcePath, const HttpClient::KeyValueMap& Parameters) +static HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +static void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +static void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +static void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) { - std::string Url; - Url.reserve(BaseUrl.size() + ResourcePath.size() + 64); - Url.append(BaseUrl); - Url.append(ResourcePath); + Url.Append(BaseUrl); + Url.Append(ResourcePath); if (!Parameters->empty()) { char Separator = '?'; for (const auto& [Key, Value] : *Parameters) { - char* EncodedKey = curl_easy_escape(nullptr, Key.c_str(), static_cast<int>(Key.size())); - char* EncodedValue = curl_easy_escape(nullptr, Value.c_str(), static_cast<int>(Value.size())); - Url += Separator; - Url += EncodedKey; - Url += '='; - Url += EncodedValue; - curl_free(EncodedKey); - curl_free(EncodedValue); + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); Separator = '&'; } } - - return Url; } ////////////////////////////////////////////////////////////////////////// @@ -359,6 +416,48 @@ CurlHttpClient::~CurlHttpClient() }); } +CurlHttpClient::Session::~Session() +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + Outer->ReleaseSession(Handle); +} + +void +CurlHttpClient::Session::SetHeaders(curl_slist* Headers) +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + HeaderList = Headers; + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, HeaderList); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::PerformWithResponseCallbacks() +{ + std::string Body; + WriteCallbackData WriteData{.Body = &Body, + .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + return Result; +} + CurlHttpClient::CurlResult CurlHttpClient::Session::Perform() { @@ -411,15 +510,7 @@ CurlHttpClient::ResponseWithPayload(std::string_view SessionId, { IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); - for (const auto& [Key, Value] : Result.Headers) - { - if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) - { - const HttpContentType ContentType = ParseContentType(Value); - ResponseBuffer.SetContentType(ContentType); - break; - } - } + ApplyContentTypeFromHeaders(ResponseBuffer, Result.Headers); if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { @@ -438,15 +529,9 @@ CurlHttpClient::ResponseWithPayload(std::string_view SessionId, return Lhs.RangeOffset < Rhs.RangeOffset; }); - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Result.Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), - .Header = std::move(HeaderMap), + .Header = BuildHeaderMap(Result.Headers), .UploadedBytes = Result.UploadedBytes, .DownloadedBytes = Result.DownloadedBytes, .ElapsedSeconds = Result.ElapsedSeconds, @@ -475,16 +560,10 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, } } - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Result.Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HttpClient::Response{ .StatusCode = WorkResponseCode, .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), - .Header = std::move(HeaderMap), + .Header = BuildHeaderMap(Result.Headers), .UploadedBytes = Result.UploadedBytes, .DownloadedBytes = Result.DownloadedBytes, .ElapsedSeconds = Result.ElapsedSeconds, @@ -493,14 +572,8 @@ CurlHttpClient::CommonResponse(std::string_view SessionId, if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) { - HttpClient::KeyValueMap HeaderMap; - for (const auto& [Key, Value] : Result.Headers) - { - HeaderMap->insert_or_assign(Key, Value); - } - return HttpClient::Response{.StatusCode = WorkResponseCode, - .Header = std::move(HeaderMap), + .Header = BuildHeaderMap(Result.Headers), .UploadedBytes = Result.UploadedBytes, .DownloadedBytes = Result.DownloadedBytes, .ElapsedSeconds = Result.ElapsedSeconds}; @@ -519,25 +592,43 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); - // Find Content-Length in headers + // Collect relevant headers in a single pass + std::string_view ContentLengthValue; + std::string_view IoHashValue; + std::string_view ContentTypeValue; + for (const auto& [Key, Value] : Result.Headers) { - if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) + if (ContentLengthValue.empty() && StrCaseCompare(Key, "Content-Length") == 0) { - std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(Value); - if (!ExpectedContentSize.has_value()) - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", Value); - return false; - } - if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), Value); - return false; - } - break; + ContentLengthValue = Value; + } + else if (IoHashValue.empty() && StrCaseCompare(Key, "X-Jupiter-IoHash") == 0) + { + IoHashValue = Value; + } + else if (ContentTypeValue.empty() && StrCaseCompare(Key, "Content-Type") == 0) + { + ContentTypeValue = Value; + } + } + + // Validate Content-Length + if (!ContentLengthValue.empty()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLengthValue); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLengthValue); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLengthValue); + return false; } } @@ -546,66 +637,55 @@ CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::Temp return true; } - // Check X-Jupiter-IoHash - for (const auto& [Key, Value] : Result.Headers) + // Validate X-Jupiter-IoHash + if (!IoHashValue.empty()) { - if (StrCaseCompare(Key.c_str(), "X-Jupiter-IoHash") == 0) + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(IoHashValue, ExpectedPayloadHash)) { - IoHash ExpectedPayloadHash; - if (IoHash::TryParse(Value, ExpectedPayloadHash)) + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) { - IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); - if (PayloadHash != ExpectedPayloadHash) - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", - PayloadHash.ToHexString(), - ExpectedPayloadHash.ToHexString()); - return false; - } + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; } - break; } } // Validate content-type specific payload - for (const auto& [Key, Value] : Result.Headers) + if (ContentTypeValue == "application/x-ue-comp") { - if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) { - if (Value == "application/x-ue-comp") - { - IoHash RawHash; - uint64_t RawSize; - if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, - RawHash, - RawSize, - /*OutOptionalTotalCompressedSize*/ nullptr)) - { - return true; - } - else - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = "Compressed binary failed validation"; - return false; - } - } - if (Value == "application/x-ue-cb") - { - if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); - Error == CbValidateError::None) - { - return true; - } - else - { - Result.ErrorCode = CURLE_RECV_ERROR; - Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); - return false; - } - } - break; + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (ContentTypeValue == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; } } @@ -666,10 +746,24 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult Attempt++; if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) { - ZEN_INFO("{} Attempt {}/{}", - CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), - Attempt, - m_ConnectionSettings.RetryCount + 1); + if (Result.ErrorCode != CURLE_OK) + { + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + SessionId, + static_cast<int>(MapCurlError(Result.ErrorCode)), + Result.ErrorMessage, + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + else + { + ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}", + SessionId, + Result.StatusCode, + zen::ToString(HttpResponseCode(Result.StatusCode)), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } } Result = Func(); } @@ -681,51 +775,14 @@ CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) { - uint8_t Attempt = 0; - CurlResult Result = Func(); - while (Attempt < m_ConnectionSettings.RetryCount) - { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return Result; - } - if (!ShouldRetry(Result)) - { - if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) - { - break; - } - if (ValidatePayload(Result, PayloadFile)) - { - break; - } - } - Sleep(100 * (Attempt + 1)); - Attempt++; - if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) - { - ZEN_INFO("{} Attempt {}/{}", - CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), - Attempt, - m_ConnectionSettings.RetryCount + 1); - } - Result = Func(); - } - return Result; + return DoWithRetry(SessionId, std::move(Func), [&](CurlResult& Result) { return ValidatePayload(Result, PayloadFile); }); } ////////////////////////////////////////////////////////////////////////// CurlHttpClient::Session -CurlHttpClient::AllocSession(std::string_view BaseUrl, - std::string_view ResourcePath, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken) +CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters) { - ZEN_UNUSED(AccessToken, SessionId, AdditionalHeader); ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); CURL* Handle = nullptr; m_SessionLock.WithExclusiveLock([&] { @@ -739,6 +796,10 @@ CurlHttpClient::AllocSession(std::string_view BaseUrl, if (Handle == nullptr) { Handle = curl_easy_init(); + if (Handle == nullptr) + { + ThrowOutOfMemory("curl_easy_init"); + } } else { @@ -746,33 +807,35 @@ CurlHttpClient::AllocSession(std::string_view BaseUrl, } // Unix domain socket - if (!ConnectionSettings.UnixSocketPath.empty()) + if (!m_ConnectionSettings.UnixSocketPath.empty()) { - curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, ConnectionSettings.UnixSocketPath.c_str()); + std::string SocketPathUtf8 = PathToUtf8(m_ConnectionSettings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, SocketPathUtf8.c_str()); } // Build URL with parameters - std::string Url = BuildUrlWithParameters(BaseUrl, ResourcePath, Parameters); + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); // Timeouts - if (ConnectionSettings.ConnectTimeout.count() > 0) + if (m_ConnectionSettings.ConnectTimeout.count() > 0) { - curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(ConnectionSettings.ConnectTimeout.count())); + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_ConnectionSettings.ConnectTimeout.count())); } - if (ConnectionSettings.Timeout.count() > 0) + if (m_ConnectionSettings.Timeout.count() > 0) { - curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(ConnectionSettings.Timeout.count())); + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_ConnectionSettings.Timeout.count())); } // HTTP/2 - if (ConnectionSettings.AssumeHttp2) + if (m_ConnectionSettings.AssumeHttp2) { curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); } // Verbose/debug - if (ConnectionSettings.Verbose) + if (m_ConnectionSettings.Verbose) { curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); @@ -780,27 +843,27 @@ CurlHttpClient::AllocSession(std::string_view BaseUrl, } // SSL options - if (ConnectionSettings.InsecureSsl) + if (m_ConnectionSettings.InsecureSsl) { curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); } - if (!ConnectionSettings.CaBundlePath.empty()) + if (!m_ConnectionSettings.CaBundlePath.empty()) { - curl_easy_setopt(Handle, CURLOPT_CAINFO, ConnectionSettings.CaBundlePath.c_str()); + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_ConnectionSettings.CaBundlePath.c_str()); } // Disable signal handling for thread safety curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); - if (ConnectionSettings.ForbidReuseConnection) + if (m_ConnectionSettings.ForbidReuseConnection) { curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); } // Note: Headers are NOT set here. Each method builds its own header list - // (potentially adding method-specific headers like Content-Type) and is - // responsible for freeing it with curl_slist_free_all. + // (potentially adding method-specific headers like Content-Type) and passes + // ownership to the Session via SetHeaders(). return Session(this, Handle); } @@ -809,15 +872,13 @@ void CurlHttpClient::ReleaseSession(CURL* Handle) { ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); - - // Free any header list that was set - // curl_easy_reset will be called on next AllocSession, which cleans up the handle state. - // We just push the handle back to the pool. m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); } ////////////////////////////////////////////////////////////////////////// +// TransactPackage is a two-phase protocol (offer + send) with server-side state +// between phases, so retrying individual phases would be incorrect. CurlHttpClient::Response CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) { @@ -831,7 +892,7 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; auto RequestIdString = fmt::to_string(RequestId); - if (Attachments.empty() == false) + if (!Attachments.empty()) { CbObjectWriter Writer; Writer.BeginArray("offer"); @@ -850,27 +911,19 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders)); curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); - std::string FilterBody; - WriteCallbackData WriteData{.Body = &FilterBody}; - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - - CurlResult Result = Sess.Perform(); - - curl_slist_free_all(HeaderList); + CurlResult Result = Sess.PerformWithResponseCallbacks(); - if (Result.ErrorCode == CURLE_OK && Result.StatusCode == 200) + if (Result.ErrorCode == CURLE_OK && IsHttpSuccessCode(Result.StatusCode)) { - IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterBody.data(), FilterBody.size()); + IoBuffer ResponseBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); CbValidateError ValidationError = CbValidateError::None; if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); ValidationError == CbValidateError::None) @@ -908,41 +961,17 @@ CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const K PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* HeaderList = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, HeaderList); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders)); curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); - std::string PkgBody; - WriteCallbackData WriteData{.Body = &PkgBody}; - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - - CurlResult Result = Sess.Perform(); - - curl_slist_free_all(HeaderList); + CurlResult Result = Sess.PerformWithResponseCallbacks(); - if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) - { - return {.StatusCode = HttpResponseCode(Result.StatusCode)}; - } - - IoBuffer ResponseBuffer(IoBuffer::Clone, PkgBody.data(), PkgBody.size()); - - for (const auto& [Key, Value] : Result.Headers) - { - if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) - { - ResponseBuffer.SetContentType(ParseContentType(Value)); - break; - } - } - - return {.StatusCode = HttpResponseCode(Result.StatusCode), .ResponsePayload = std::move(ResponseBuffer)}; + return CommonResponse(m_SessionId, std::move(Result), {}, {}); } ////////////////////////////////////////////////////////////////////////// @@ -957,44 +986,26 @@ CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValu return CommonResponse( m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - - curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); - curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - - ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), - .DataSize = Payload.GetSize(), - .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; - curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); - curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - curl_slist_free_all(Headers); + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - return Result; - }), + return Sess.PerformWithResponseCallbacks(); + }), {}); } @@ -1005,39 +1016,19 @@ CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) return CommonResponse( m_SessionId, - DoWithRetry( - m_SessionId, - [&]() -> CurlResult { - KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeaderWithContentLength, Parameters, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - - curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); - curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = AllocSession(Url, Parameters); + CURL* H = Sess.Get(); - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); + Sess.SetHeaders(BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken())); - curl_slist_free_all(Headers); + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); - return Result; - }), + return Sess.PerformWithResponseCallbacks(); + }), {}); } @@ -1045,43 +1036,20 @@ CurlHttpClient::Response CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) { ZEN_TRACE_CPU("CurlHttpClient::Get"); - return CommonResponse( - m_SessionId, - DoWithRetry( - m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }, - [this](CurlResult& Result) { - std::unique_ptr<detail::TempPayloadFile> NoTempFile; - return ValidatePayload(Result, NoTempFile); - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_HTTPGET, 1L); + return Sess.PerformWithResponseCallbacks(); + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); } CurlHttpClient::Response @@ -1089,33 +1057,15 @@ CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) { ZEN_TRACE_CPU("CurlHttpClient::Head"); - return CommonResponse( - m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_NOBODY, 1L); - - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_NOBODY, 1L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); } CurlHttpClient::Response @@ -1123,38 +1073,15 @@ CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader { ZEN_TRACE_CPU("CurlHttpClient::Delete"); - return CommonResponse( - m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_CUSTOMREQUEST, "DELETE"); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_CUSTOMREQUEST, "DELETE"); + return Sess.PerformWithResponseCallbacks(); + }), + {}); } CurlHttpClient::Response @@ -1162,39 +1089,16 @@ CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, { ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); - return CommonResponse( - m_SessionId, - DoWithRetry(m_SessionId, - [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, Parameters, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); - - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); - curl_easy_setopt(H, CURLOPT_POST, 1L); - curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE, 0L); - - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - - return Result; - }), - {}); + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_POST, 1L); + curl_easy_setopt(Sess.Get(), CURLOPT_POSTFIELDSIZE, 0L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); } CurlHttpClient::Response @@ -1213,12 +1117,10 @@ CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTy DoWithRetry( m_SessionId, [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - // Rebuild headers with content type - curl_slist* Headers = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); IoBufferFileReference FileRef = {nullptr, 0, 0}; if (Payload.GetFileReference(FileRef)) @@ -1234,46 +1136,14 @@ CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTy curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); } curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1295,12 +1165,11 @@ CurlHttpClient::Post(std::string_view Url, PayloadString.clear(); PayloadFile.reset(); - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)})); curl_easy_setopt(H, CURLOPT_POST, 1L); curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); @@ -1329,33 +1198,11 @@ CurlHttpClient::Post(std::string_view Url, auto* Data = static_cast<PostHeaderCallbackData*>(UserData); size_t TotalBytes = Size * Nmemb; - std::string_view Line(Buffer, TotalBytes); - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) { - Line.remove_suffix(1); - } - - if (Line.empty()) - { - return TotalBytes; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) - { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!Key.empty() && Key.back() == ' ') - { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } + auto& [Key, Value] = *Header; - if (StrCaseCompare(std::string(Key).c_str(), "Content-Length") == 0) + if (StrCaseCompare(Key, "Content-Length") == 0) { std::optional<size_t> ContentLength = ParseInt<size_t>(Value); if (ContentLength.has_value()) @@ -1444,7 +1291,6 @@ CurlHttpClient::Post(std::string_view Url, Res.Body = std::move(PayloadString); } - curl_slist_free_all(Headers); return Res; }, PayloadFile); @@ -1467,13 +1313,10 @@ CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenCo m_SessionId, DoWithRetry(m_SessionId, [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); @@ -1485,23 +1328,7 @@ CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenCo curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1516,12 +1343,11 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV DoWithRetry( m_SessionId, [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); @@ -1538,23 +1364,7 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); } ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), @@ -1563,23 +1373,7 @@ CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyV curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1596,13 +1390,10 @@ CurlHttpClient::Upload(std::string_view Url, m_SessionId, DoWithRetry(m_SessionId, [&]() -> CurlResult { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - CURL* H = Sess.Get(); + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); - curl_slist* Headers = - BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)}); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, Headers); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); @@ -1615,23 +1406,7 @@ CurlHttpClient::Upload(std::string_view Url, curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); - std::string Body; - WriteCallbackData WriteData{.Body = &Body}; - HeaderCallbackData HdrData{}; - std::vector<std::pair<std::string, std::string>> ResponseHeaders; - HdrData.Headers = &ResponseHeaders; - - curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, CurlWriteCallback); - curl_easy_setopt(H, CURLOPT_WRITEDATA, &WriteData); - curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); - curl_easy_setopt(H, CURLOPT_HEADERDATA, &HdrData); - - CurlResult Result = Sess.Perform(); - Result.Body = std::move(Body); - Result.Headers = std::move(ResponseHeaders); - - curl_slist_free_all(Headers); - return Result; + return Sess.PerformWithResponseCallbacks(); }), {}); } @@ -1651,11 +1426,10 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp CurlResult Result = DoWithRetry( m_SessionId, [&]() -> CurlResult { - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Session Sess = AllocSession(Url, {}); CURL* H = Sess.Get(); - curl_slist* DlHeaders = BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken()); - curl_easy_setopt(H, CURLOPT_HTTPHEADER, DlHeaders); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); // Reset state from any previous attempt @@ -1673,7 +1447,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp { std::string_view RangeValue(RangeIt->second); size_t RangeStartPos = RangeValue.find('=', 5); - if (RangeStartPos != std::string::npos) + if (RangeStartPos != std::string_view::npos) { RangeStartPos++; while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') @@ -1685,14 +1459,14 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp while (RangeStartPos < RangeValue.length()) { size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); - if (RangeEnd == std::string::npos) + if (RangeEnd == std::string_view::npos) { RangeEnd = RangeValue.length(); } std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); size_t RangeSplitPos = RangeString.find('-'); - if (RangeSplitPos != std::string::npos) + if (RangeSplitPos != std::string_view::npos) { std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); @@ -1742,36 +1516,12 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); size_t TotalBytes = Size * Nmemb; - std::string_view Line(Buffer, TotalBytes); - - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) { - return TotalBytes; - } - - size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) - { - std::string_view KeyView = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - - while (!KeyView.empty() && KeyView.back() == ' ') - { - KeyView.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - + auto& [KeyView, Value] = *Header; const std::string Key(KeyView); - if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) + if (StrCaseCompare(Key, "Content-Length") == 0) { std::optional<size_t> ContentLength = ParseInt<size_t>(Value); if (ContentLength.has_value()) @@ -1795,7 +1545,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } } } - else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + else if (StrCaseCompare(Key, "Content-Type") == 0) { *Data->IsMultiRange = Data->BoundaryParser->Init(Value); if (!*Data->IsMultiRange) @@ -1803,7 +1553,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp *Data->ContentTypeOut = ParseContentType(Value); } } - else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0) + else if (StrCaseCompare(Key, "Content-Range") == 0) { if (!*Data->IsMultiRange) { @@ -1819,7 +1569,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } } - Data->Headers->emplace_back(std::string(Key), std::string(Value)); + Data->Headers->emplace_back(Key, std::string(Value)); } return TotalBytes; @@ -1894,11 +1644,11 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto SupportsRanges = [](const CurlResult& R) -> bool { for (const auto& [K, V] : R.Headers) { - if (StrCaseCompare(K.c_str(), "Content-Range") == 0) + if (StrCaseCompare(K, "Content-Range") == 0) { return true; } - if (StrCaseCompare(K.c_str(), "Accept-Ranges") == 0) + if (StrCaseCompare(K, "Accept-Ranges") == 0) { return V == "bytes"sv; } @@ -1924,7 +1674,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp std::string ContentLengthValue; for (const auto& [K, V] : Res.Headers) { - if (StrCaseCompare(K.c_str(), "Content-Length") == 0) + if (StrCaseCompare(K, "Content-Length") == 0) { ContentLengthValue = V; break; @@ -1943,6 +1693,7 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } KeyValueMap HeadersWithRange(AdditionalHeader); + uint8_t ResumeAttempt = 0; do { uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); @@ -1957,12 +1708,10 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp } HeadersWithRange.Entries.insert_or_assign("Range", Range); - Session ResumeSess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - CURL* ResumeH = ResumeSess.Get(); + Session ResumeSess = AllocSession(Url, {}); + CURL* ResumeH = ResumeSess.Get(); - curl_slist* ResumeHdrList = BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken()); - curl_easy_setopt(ResumeH, CURLOPT_HTTPHEADER, ResumeHdrList); + ResumeSess.SetHeaders(BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken())); curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); std::vector<std::pair<std::string, std::string>> ResumeHeaders; @@ -1983,72 +1732,51 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp auto* Data = static_cast<ResumeHeaderCbData*>(UserData); size_t TotalBytes = Size * Nmemb; - std::string_view Line(Buffer, TotalBytes); - while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) - { - Line.remove_suffix(1); - } - - if (Line.empty()) + auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)); + if (!Header) { return TotalBytes; } + auto& [Key, Value] = *Header; - size_t ColonPos = Line.find(':'); - if (ColonPos != std::string_view::npos) + if (StrCaseCompare(Key, "Content-Range") == 0) { - std::string_view Key = Line.substr(0, ColonPos); - std::string_view Value = Line.substr(ColonPos + 1); - while (!Key.empty() && Key.back() == ' ') + if (Value.starts_with("bytes "sv)) { - Key.remove_suffix(1); - } - while (!Value.empty() && Value.front() == ' ') - { - Value.remove_prefix(1); - } - - if (StrCaseCompare(std::string(Key).c_str(), "Content-Range") == 0) - { - if (Value.starts_with("bytes "sv)) + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) { - size_t RangeStartEnd = Value.find('-', 6); - if (RangeStartEnd != std::string_view::npos) + const std::optional<uint64_t> Start = ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) { - const std::optional<uint64_t> Start = - ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); - if (Start) + uint64_t DownloadedSize = + *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) { - uint64_t DownloadedSize = *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() - : Data->PayloadString->length(); - if (Start.value() == DownloadedSize) - { - Data->Headers->emplace_back(std::string(Key), std::string(Value)); - return TotalBytes; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (*Data->PayloadFile) - { - (*Data->PayloadFile)->ResetWritePos(Start.value()); - } - else - { - *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); - } Data->Headers->emplace_back(std::string(Key), std::string(Value)); return TotalBytes; } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; } } - return 0; } - - Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return 0; } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); return TotalBytes; }; @@ -2064,8 +1792,8 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp Res = ResumeSess.Perform(); Res.Headers = std::move(ResumeHeaders); - curl_slist_free_all(ResumeHdrList); - } while (ShouldResumeCheck(Res)); + ResumeAttempt++; + } while (ResumeAttempt < m_ConnectionSettings.RetryCount && ShouldResumeCheck(Res)); } } } @@ -2075,8 +1803,6 @@ CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& Temp Res.Body = std::move(PayloadString); } - curl_slist_free_all(DlHeaders); - return Res; }, PayloadFile); diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h index 871877863..b7fa52e6c 100644 --- a/src/zenhttp/clients/httpclientcurl.h +++ b/src/zenhttp/clients/httpclientcurl.h @@ -75,40 +75,39 @@ private: struct Session { Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {} - ~Session() { Outer->ReleaseSession(Handle); } + ~Session(); CURL* Get() const { return Handle; } + // Takes ownership of the curl_slist and sets it on the handle. + // The list is freed automatically when the Session is destroyed. + void SetHeaders(curl_slist* Headers); + + // Low-level perform: executes the request and collects status/timing. CurlResult Perform(); + // Sets up standard write+header callbacks, performs the request, and + // moves the collected body and headers into the returned CurlResult. + CurlResult PerformWithResponseCallbacks(); + LoggerRef Log() { return Outer->Log(); } private: CurlHttpClient* Outer; CURL* Handle; + curl_slist* HeaderList = nullptr; Session(Session&&) = delete; Session& operator=(Session&&) = delete; }; - Session AllocSession(std::string_view BaseUrl, - std::string_view Url, - const HttpClientSettings& ConnectionSettings, - const KeyValueMap& AdditionalHeader, - const KeyValueMap& Parameters, - std::string_view SessionId, - std::optional<HttpClientAccessToken> AccessToken); + Session AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters); RwLock m_SessionLock; std::vector<CURL*> m_Sessions; void ReleaseSession(CURL* Handle); - struct RetryResult - { - CurlResult Result; - }; - CurlResult DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 2d566ae86..fbae9f5fe 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -5,6 +5,8 @@ #include "../servers/wsframecodec.h" #include <zencore/base64.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/string.h> @@ -155,7 +157,7 @@ struct HttpWsClient::Impl } }); - asio::local::stream_protocol::endpoint Endpoint(m_Settings.UnixSocketPath); + asio::local::stream_protocol::endpoint Endpoint(PathToUtf8(m_Settings.UnixSocketPath)); m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) { if (Ec) { diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index deeeb6c85..9f49802a0 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -36,15 +36,17 @@ namespace zen { +#if ZEN_WITH_CPR extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); +#endif extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); -static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCpr; +static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCurl; void SetDefaultHttpClientBackend(HttpClientBackend Backend) @@ -55,11 +57,14 @@ SetDefaultHttpClientBackend(HttpClientBackend Backend) void SetDefaultHttpClientBackend(std::string_view Backend) { +#if ZEN_WITH_CPR if (Backend == "cpr") { g_DefaultHttpClientBackend = HttpClientBackend::kCpr; } - else if (Backend == "curl") + else +#endif + if (Backend == "curl") { g_DefaultHttpClientBackend = HttpClientBackend::kCurl; } @@ -363,13 +368,15 @@ HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& Conne switch (EffectiveBackend) { - case HttpClientBackend::kCurl: - m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); - break; +#if ZEN_WITH_CPR case HttpClientBackend::kCpr: - default: m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); break; +#endif + case HttpClientBackend::kCurl: + default: + m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; } } diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp index 5f3ad2455..3ca586f87 100644 --- a/src/zenhttp/httpclient_test.cpp +++ b/src/zenhttp/httpclient_test.cpp @@ -154,6 +154,42 @@ public: }, HttpVerb::kGet); + m_Router.AddMatcher("anypath", [](std::string_view Str) -> bool { return !Str.empty(); }); + + m_Router.RegisterRoute( + "echo/uri", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string Body = std::string(HttpReq.RelativeUri()); + + auto Params = HttpReq.GetQueryParams(); + for (const auto& [Key, Value] : Params.KvPairs) + { + Body += fmt::format("\n{}={}", Key, Value); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Body); + }, + HttpVerb::kGet | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/uri/{anypath}", + [](HttpRouterRequest& Req) { + // Echo both the RelativeUri and the captured path segment + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Captured = Req.GetCapture(1); + std::string Body = fmt::format("uri={}\ncapture={}", HttpReq.RelativeUri(), Captured); + + auto Params = HttpReq.GetQueryParams(); + for (const auto& [Key, Value] : Params.KvPairs) + { + Body += fmt::format("\n{}={}", Key, Value); + } + + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Body); + }, + HttpVerb::kGet | HttpVerb::kPut); + m_Router.RegisterRoute( "slow", [](HttpRouterRequest& Req) { @@ -1689,6 +1725,77 @@ TEST_CASE("httpclient.https") # endif // ZEN_USE_OPENSSL +TEST_CASE("httpclient.uri_decoding") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + // URI without encoding — should pass through unchanged + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello/world.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/hello/world.txt\ncapture=hello/world.txt"); + } + + // Percent-encoded space — server should see decoded path + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt"); + } + + // Percent-encoded slash (%2F) — should be decoded to / + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/a%2Fb.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/a/b.txt\ncapture=a/b.txt"); + } + + // Multiple encodings in one path + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/file%20%26%20name.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/file & name.txt\ncapture=file & name.txt"); + } + + // No capture — echo/uri route returns just RelativeUri + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "echo/uri"); + } + + // Literal percent that is not an escape (%ZZ) — should be kept as-is + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/100%25done.txt"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "uri=echo/uri/100%done.txt\ncapture=100%done.txt"); + } + + // Query params — raw values are returned as-is from GetQueryParams + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri?key=value&name=test"); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == "echo/uri\nkey=value\nname=test"); + } + + // Query params with percent-encoded values + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri?prefix=listing%2F&mode=s3"); + REQUIRE(Resp.IsSuccess()); + // GetQueryParams returns raw (still-encoded) values — callers must Decode() explicitly + CHECK(Resp.AsText() == "echo/uri\nprefix=listing%2F\nmode=s3"); + } + + // Query params with path capture and encoding + { + HttpClient::Response Resp = Client.Get("/api/test/echo/uri/hello%20world.txt?tag=a%26b"); + REQUIRE(Resp.IsSuccess()); + // Path is decoded, query values are raw + CHECK(Resp.AsText() == "uri=echo/uri/hello world.txt\ncapture=hello world.txt\ntag=a%26b"); + } +} + TEST_SUITE_END(); void diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 6ba0ca563..4d98e9650 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -2,6 +2,8 @@ #include <zenhttp/httpserver.h> +#include <zencore/filesystem.h> + #include "servers/httpasio.h" #include "servers/httpmulti.h" #include "servers/httpnull.h" @@ -698,15 +700,6 @@ HttpServerRequest::ReadPayloadPackage() ////////////////////////////////////////////////////////////////////////// void -HttpRequestRouter::AddPattern(const char* Id, const char* Regex) -{ - ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); - ZEN_ASSERT(!m_IsFinalized); - - m_PatternMap.insert({Id, Regex}); -} - -void HttpRequestRouter::AddMatcher(const char* Id, std::function<bool(std::string_view)>&& Matcher) { ZEN_ASSERT(m_MatcherNameMap.find(Id) == m_MatcherNameMap.end()); @@ -722,170 +715,77 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand { ZEN_ASSERT(!m_IsFinalized); - if (ExtendableStringBuilder<128> ExpandedRegex; ProcessRegexSubstitutions(UriPattern, ExpandedRegex)) - { - // Regex route - m_RegexHandlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), UriPattern); - } - else - { - // New-style regex-free route. More efficient and should be used for everything eventually - - int RegexLen = gsl::narrow_cast<int>(strlen(UriPattern)); + int RegexLen = gsl::narrow_cast<int>(strlen(UriPattern)); - int i = 0; + int i = 0; - std::vector<int> MatcherIndices; + std::vector<int> MatcherIndices; - while (i < RegexLen) + while (i < RegexLen) + { + if (UriPattern[i] == '{') { - if (UriPattern[i] == '{') + bool IsComplete = false; + int PatternStart = i + 1; + while (++i < RegexLen) { - bool IsComplete = false; - int PatternStart = i + 1; - while (++i < RegexLen) + if (UriPattern[i] == '}') { - if (UriPattern[i] == '}') + if (i == PatternStart) { - if (i == PatternStart) - { - throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); - } - std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); - if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) - { - // It's a match - MatcherIndices.push_back(it->second); - IsComplete = true; - ++i; - break; - } - else - { - throw std::runtime_error(fmt::format("unknown matcher pattern '{}' in URI pattern '{}'", Pattern, UriPattern)); - } + throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); } - } - if (!IsComplete) - { - throw std::runtime_error(fmt::format("unterminated matcher pattern in URI pattern '{}'", UriPattern)); - } - } - else - { - if (UriPattern[i] == '/') - { - throw std::runtime_error(fmt::format("unexpected '/' in literal segment of URI pattern '{}'", UriPattern)); - } - - int SegmentStart = i; - while (++i < RegexLen && UriPattern[i] != '/') - ; - - std::string_view Segment(&UriPattern[SegmentStart], (i - SegmentStart)); - int LiteralIndex = gsl::narrow_cast<int>(m_Literals.size()); - m_Literals.push_back(std::string(Segment)); - MatcherIndices.push_back(-1 - LiteralIndex); - } - - if (i < RegexLen && UriPattern[i] == '/') - { - ++i; // skip slash - } - } - - m_MatcherEndpoints.emplace_back(std::move(MatcherIndices), SupportedVerbs, std::move(HandlerFunc), UriPattern); - } -} - -std::string_view -HttpRouterRequest::GetCapture(uint32_t Index) const -{ - if (!m_CapturedSegments.empty()) - { - ZEN_ASSERT(Index < m_CapturedSegments.size()); - return m_CapturedSegments[Index]; - } - - ZEN_ASSERT(Index < m_Match.size()); - - const auto& Match = m_Match[Index]; - - return std::string_view(&*Match.first, Match.second - Match.first); -} - -bool -HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) -{ - size_t RegexLen = strlen(Regex); - - bool HasRegex = false; - - std::vector<std::string> UnknownPatterns; - - for (size_t i = 0; i < RegexLen;) - { - bool matched = false; - - if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) - { - // Might have a pattern reference - find closing brace - - for (size_t j = i + 1; j < RegexLen; ++j) - { - if (Regex[j] == '}') - { - std::string Pattern(&Regex[i + 1], j - i - 1); - - if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) + std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); + if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { - OutExpandedRegex.Append(it->second.c_str()); - HasRegex = true; + // It's a match + MatcherIndices.push_back(it->second); + IsComplete = true; + ++i; + break; } else { - UnknownPatterns.push_back(Pattern); + throw std::runtime_error(fmt::format("unknown matcher pattern '{}' in URI pattern '{}'", Pattern, UriPattern)); } - - // skip ahead - i = j + 1; - - matched = true; - - break; } } + if (!IsComplete) + { + throw std::runtime_error(fmt::format("unterminated matcher pattern in URI pattern '{}'", UriPattern)); + } } - - if (!matched) - { - OutExpandedRegex.Append(Regex[i++]); - } - } - - if (HasRegex) - { - if (UnknownPatterns.size() > 0) + else { - std::string UnknownList; - for (const auto& Pattern : UnknownPatterns) + if (UriPattern[i] == '/') { - if (!UnknownList.empty()) - { - UnknownList += ", "; - } - UnknownList += "'"; - UnknownList += Pattern; - UnknownList += "'"; + throw std::runtime_error(fmt::format("unexpected '/' in literal segment of URI pattern '{}'", UriPattern)); } - throw std::runtime_error(fmt::format("unknown pattern(s) {} in regex route '{}'", UnknownList, Regex)); + int SegmentStart = i; + while (++i < RegexLen && UriPattern[i] != '/') + ; + + std::string_view Segment(&UriPattern[SegmentStart], (i - SegmentStart)); + int LiteralIndex = gsl::narrow_cast<int>(m_Literals.size()); + m_Literals.push_back(std::string(Segment)); + MatcherIndices.push_back(-1 - LiteralIndex); } - return true; + if (i < RegexLen && UriPattern[i] == '/') + { + ++i; // skip slash + } } - return false; + m_MatcherEndpoints.emplace_back(std::move(MatcherIndices), SupportedVerbs, std::move(HandlerFunc), UriPattern); +} + +std::string_view +HttpRouterRequest::GetCapture(uint32_t Index) const +{ + ZEN_ASSERT(Index < m_CapturedSegments.size()); + return m_CapturedSegments[Index]; } bool @@ -901,8 +801,6 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) std::string_view Uri = Request.RelativeUri(); HttpRouterRequest RouterRequest(Request); - // First try new-style matcher routes - for (const MatcherEndpoint& Handler : m_MatcherEndpoints) { if ((Handler.Verbs & Verb) == Verb) @@ -1000,28 +898,6 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) } } - // Old-style regex routes - - for (const auto& Handler : m_RegexHandlers) - { - if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) - { -#if ZEN_WITH_OTEL - if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) - { - ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.Service().BaseUri()); - RoutePath.Append(Handler.Pattern); - ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); - } -#endif - - Handler.Handler(RouterRequest); - - return true; // Route matched - } - } - return false; // No route matched } @@ -1157,7 +1033,7 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig ZEN_INFO("using asio HTTP server implementation") return CreateHttpAsioServer(AsioConfig { .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer, - .NoNetwork = Config.NoNetwork, .UnixSocketPath = Config.UnixSocketPath, + .NoNetwork = Config.NoNetwork, .UnixSocketPath = PathToUtf8(Config.UnixSocketPath), #if ZEN_USE_OPENSSL .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile, #endif @@ -1420,72 +1296,6 @@ TEST_CASE("http.common") virtual uint32_t ParseRequestId() const override { return 0; } }; - SUBCASE("router-regex") - { - bool HandledA = false; - bool HandledAA = false; - std::vector<std::string> Captures; - auto Reset = [&] { - Captures.clear(); - HandledA = HandledAA = false; - }; - - TestHttpService Service; - - HttpRequestRouter r; - r.AddPattern("a", "([[:alpha:]]+)"); - r.RegisterRoute( - "{a}", - [&](auto& Req) { - HandledA = true; - Captures = {std::string(Req.GetCapture(0))}; - }, - HttpVerb::kGet); - - r.RegisterRoute( - "{a}/{a}", - [&](auto& Req) { - HandledAA = true; - Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; - }, - HttpVerb::kGet); - - { - Reset(); - TestHttpServerRequest req(Service, "abc"sv); - r.HandleRequest(req); - CHECK(HandledA); - CHECK(!HandledAA); - REQUIRE_EQ(Captures.size(), 1); - CHECK_EQ(Captures[0], "abc"sv); - } - - { - Reset(); - TestHttpServerRequest req{Service, "abc/def"sv}; - r.HandleRequest(req); - CHECK(!HandledA); - CHECK(HandledAA); - REQUIRE_EQ(Captures.size(), 2); - CHECK_EQ(Captures[0], "abc"sv); - CHECK_EQ(Captures[1], "def"sv); - } - - { - Reset(); - TestHttpServerRequest req{Service, "123"sv}; - r.HandleRequest(req); - CHECK(!HandledA); - } - - { - Reset(); - TestHttpServerRequest req{Service, "a123"sv}; - r.HandleRequest(req); - CHECK(!HandledA); - } - } - SUBCASE("router-matcher") { bool HandledA = false; diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index c252a5d99..3cfe652c5 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -2,17 +2,19 @@ #pragma once -#include <zencore/compactbinary.h> -#include <zencore/compactbinaryvalidation.h> -#include <zencore/iobuffer.h> -#include <zencore/string.h> -#include <zenhttp/formatters.h> -#include <zenhttp/httpclient.h> -#include <zenhttp/httpcommon.h> +#if ZEN_WITH_CPR + +# include <zencore/compactbinary.h> +# include <zencore/compactbinaryvalidation.h> +# include <zencore/iobuffer.h> +# include <zencore/string.h> +# include <zenhttp/formatters.h> +# include <zenhttp/httpclient.h> +# include <zenhttp/httpcommon.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/response.h> -#include <fmt/format.h> +# include <cpr/response.h> +# include <fmt/format.h> ZEN_THIRD_PARTY_INCLUDES_END template<> @@ -92,3 +94,5 @@ struct fmt::formatter<cpr::Response> } } }; + +#endif // ZEN_WITH_CPR diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 03c98af7e..e878c900f 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -10,6 +10,7 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <filesystem> #include <functional> #include <optional> #include <unordered_map> @@ -51,7 +52,9 @@ enum class HttpClientErrorCode : int enum class HttpClientBackend : uint8_t { kDefault, +#if ZEN_WITH_CPR kCpr, +#endif kCurl, }; @@ -91,7 +94,7 @@ struct HttpClientSettings /// Unix domain socket path. When non-empty, the client connects via this /// socket instead of TCP. BaseUri is still used for the Host header and URL. - std::string UnixSocketPath; + std::filesystem::path UnixSocketPath; /// Disable HTTP keep-alive by closing the connection after each request. /// Useful for testing per-connection overhead. @@ -174,11 +177,14 @@ class HttpClientBase; class HttpClient { public: - HttpClient(std::string_view BaseUri, - const HttpClientSettings& Connectionsettings = {}, - std::function<bool()>&& CheckIfAbortFunction = {}); + explicit HttpClient(std::string_view BaseUri, + const HttpClientSettings& Connectionsettings = {}, + std::function<bool()>&& CheckIfAbortFunction = {}); ~HttpClient(); + HttpClient(const HttpClient&) = delete; + HttpClient& operator=(const HttpClient&) = delete; + struct ErrorContext { HttpClientErrorCode ErrorCode; diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 627e7921f..2a8b2ca94 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -15,11 +15,11 @@ #include <zentelemetry/stats.h> +#include <filesystem> #include <functional> #include <gsl/gsl-lite.hpp> #include <list> #include <map> -#include <regex> #include <span> #include <unordered_map> @@ -329,7 +329,7 @@ struct HttpServerConfig std::vector<HttpServerPluginConfig> PluginConfigs; bool ForceLoopback = false; unsigned int ThreadCount = 0; - std::string UnixSocketPath; // Unix domain socket path (empty = disabled, non-Windows only) + std::filesystem::path UnixSocketPath; // Unix domain socket path (empty = disabled) bool NoNetwork = false; // Disable TCP/HTTPS listeners; only accept connections via UnixSocketPath int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend) std::string CertFile; // PEM certificate chain file path @@ -356,9 +356,8 @@ class HttpRouterRequest public: /** Get captured segment from matched URL * - * @param Index Index of captured segment to retrieve. Note that due to - * backwards compatibility with regex-based routes, this index is 1-based - * and index=0 is the full matched URL + * @param Index Index of captured segment to retrieve. Index 0 is the full + * matched URL, subsequent indices are the matched segments in order. * @return Returns string view of captured segment */ std::string_view GetCapture(uint32_t Index) const; @@ -371,11 +370,8 @@ private: HttpRouterRequest(const HttpRouterRequest&) = delete; HttpRouterRequest& operator=(const HttpRouterRequest&) = delete; - using MatchResults_t = std::match_results<std::string_view::const_iterator>; - HttpServerRequest& m_HttpRequest; - MatchResults_t m_Match; - std::vector<std::string_view> m_CapturedSegments; // for matcher-based routes + std::vector<std::string_view> m_CapturedSegments; friend class HttpRequestRouter; }; @@ -383,9 +379,7 @@ private: /** HTTP request router helper * * This helper class allows a service implementer to register one or more - * endpoints using pattern matching. We currently support a legacy regex-based - * matching system, but also a new matcher-function based system which is more - * efficient and should be used whenever possible. + * endpoints using pattern matching with matcher functions. * * This is intended to be initialized once only, there is no thread * safety so you can absolutely not add or remove endpoints once the handler @@ -404,13 +398,6 @@ public: typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; /** - * @brief Add pattern which can be referenced by name, commonly used for URL components - * @param Id String used to identify patterns for replacement - * @param Regex String which will replace the Id string in any registered URL paths - */ - void AddPattern(const char* Id, const char* Regex); - - /** * @brief Add matcher function which can be referenced by name, used for URL components * @param Id String used to identify matchers in endpoint specifications * @param Matcher Function which will be called to match the component @@ -420,8 +407,8 @@ public: /** * @brief Register an endpoint handler for the given route * @param Pattern Pattern used to match the handler to a request. This should - * only contain literal URI segments and pattern aliases registered - via AddPattern() or AddMatcher() + * only contain literal URI segments and matcher aliases registered + via AddMatcher() * @param HandlerFunc Handler function to call for any matching request * @param SupportedVerbs Supported HTTP verbs for this handler */ @@ -436,36 +423,6 @@ public: bool HandleRequest(zen::HttpServerRequest& Request); private: - bool ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); - - struct RegexEndpoint - { - RegexEndpoint(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) - : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) - , Verbs(SupportedVerbs) - , Handler(std::move(Handler)) - , Pattern(Pattern) - { - } - - ~RegexEndpoint() = default; - - std::regex RegEx; - HttpVerb Verbs; - HandlerFunc_t Handler; - const char* Pattern; - - private: - RegexEndpoint& operator=(const RegexEndpoint&) = delete; - RegexEndpoint(const RegexEndpoint&) = delete; - }; - - std::list<RegexEndpoint> m_RegexHandlers; - std::unordered_map<std::string, std::string> m_PatternMap; - - // New-style matcher endpoints. Should be preferred over regex endpoints where possible - // as it is considerably more efficient - struct MatcherEndpoint { MatcherEndpoint(std::vector<int>&& ComponentIndices, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 34d338b1d..2ca9b7ab1 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -46,7 +46,7 @@ struct HttpWsClientSettings /// Unix domain socket path. When non-empty, connects via this socket /// instead of TCP. The URL host is still used for the Host header. - std::string UnixSocketPath; + std::filesystem::path UnixSocketPath; }; /** diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 643f33618..9f4875eaf 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -601,6 +601,7 @@ public: bool m_IsLocalMachineRequest; bool m_AllowZeroCopyFileSend = true; std::string m_RemoteAddress; + std::string m_DecodedUri; // Percent-decoded URI; m_Uri/m_UriWithExtension point into this std::unique_ptr<HttpResponse> m_Response; }; @@ -623,6 +624,7 @@ public: ~HttpResponse() = default; void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } + void SetKeepAlive(bool KeepAlive) { m_IsKeepAlive = KeepAlive; } /** * Initialize the response for sending a payload made up of multiple blobs @@ -780,8 +782,8 @@ public: return m_Headers; } - template<typename SocketType> - void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) + template<typename SocketType, typename Executor> + void SendResponse(SocketType& Socket, Executor& Strand, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) { ZEN_ASSERT(m_State == State::kInitialized); @@ -791,11 +793,11 @@ public: m_SendCb = std::move(Token); m_State = State::kSending; - SendNextChunk(Socket); + SendNextChunk(Socket, Strand); } - template<typename SocketType> - void SendNextChunk(SocketType& Socket) + template<typename SocketType, typename Executor> + void SendNextChunk(SocketType& Socket, Executor& Strand) { ZEN_ASSERT(m_State == State::kSending); @@ -812,12 +814,12 @@ public: auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); }; - asio::defer(Socket.get_executor(), std::move(CompletionToken)); + asio::defer(Strand, std::move(CompletionToken)); return; } - auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) { + auto OnCompletion = asio::bind_executor(Strand, [this, &Socket, &Strand](const asio::error_code& Ec, std::size_t ByteCount) { ZEN_ASSERT(m_State == State::kSending); m_TotalBytesSent += ByteCount; @@ -828,9 +830,9 @@ public: } else { - SendNextChunk(Socket); + SendNextChunk(Socket, Strand); } - }; + }); const IoVec& Io = m_IoVecs[m_IoVecCursor++]; @@ -982,16 +984,14 @@ private: void CloseConnection(); void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {}); - HttpAsioServerImpl& m_Server; - asio::streambuf m_RequestBuffer; - std::atomic<uint32_t> m_RequestCounter{0}; - uint32_t m_ConnectionId = 0; - Ref<IHttpPackageHandler> m_PackageHandler; - - RwLock m_ActiveResponsesLock; + HttpAsioServerImpl& m_Server; + std::unique_ptr<SocketType> m_Socket; + asio::strand<asio::any_io_executor> m_Strand; + asio::streambuf m_RequestBuffer; + uint32_t m_RequestCounter = 0; + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses; - - std::unique_ptr<SocketType> m_Socket; }; std::atomic<uint32_t> g_ConnectionIdCounter{0}; @@ -999,8 +999,9 @@ std::atomic<uint32_t> g_ConnectionIdCounter{0}; template<typename SocketType> HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket) : m_Server(Server) -, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) , m_Socket(std::move(Socket)) +, m_Strand(asio::make_strand(m_Socket->get_executor())) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) { ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); } @@ -1008,8 +1009,6 @@ HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Ser template<typename SocketType> HttpServerConnectionT<SocketType>::~HttpServerConnectionT() { - RwLock::ExclusiveLockScope _(m_ActiveResponsesLock); - ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); } @@ -1017,7 +1016,7 @@ template<typename SocketType> void HttpServerConnectionT<SocketType>::HandleNewRequest() { - EnqueueRead(); + asio::dispatch(m_Strand, [Conn = AsSharedPtr()] { Conn->EnqueueRead(); }); } template<typename SocketType> @@ -1058,7 +1057,9 @@ HttpServerConnectionT<SocketType>::EnqueueRead() asio::async_read(*m_Socket.get(), m_RequestBuffer, asio::transfer_at_least(1), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); + asio::bind_executor(m_Strand, [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnDataReceived(Ec, ByteCount); + })); } template<typename SocketType> @@ -1091,7 +1092,7 @@ HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[ ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, - m_RequestCounter.load(std::memory_order_relaxed), + m_RequestCounter, zen::GetCurrentThreadId(), NiceBytes(ByteCount)); @@ -1153,25 +1154,23 @@ HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& if (ResponseToPop) { - m_ActiveResponsesLock.WithExclusiveLock([&] { - // Once a response is sent we can release any referenced resources - // - // completion callbacks may be issued out-of-order so we need to - // remove the relevant entry from our active response list, it may - // not be the first - - if (auto It = find_if(begin(m_ActiveResponses), - end(m_ActiveResponses), - [ResponseToPop](const auto& Item) { return Item.get() == ResponseToPop; }); - It != end(m_ActiveResponses)) - { - m_ActiveResponses.erase(It); - } - else - { - ZEN_WARN("response not found"); - } - }); + // Once a response is sent we can release any referenced resources + // + // completion callbacks may be issued out-of-order so we need to + // remove the relevant entry from our active response list, it may + // not be the first + + if (auto It = find_if(begin(m_ActiveResponses), + end(m_ActiveResponses), + [ResponseToPop](const auto& Item) { return Item.get() == ResponseToPop; }); + It != end(m_ActiveResponses)) + { + m_ActiveResponses.erase(It); + } + else + { + ZEN_WARN("response not found"); + } } if (!m_RequestData.IsKeepAlive()) @@ -1234,9 +1233,11 @@ HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber asio::async_write( *m_Socket, Buffer, - [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + asio::bind_executor( + m_Strand, + [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); + })); } template<typename SocketType> @@ -1272,21 +1273,23 @@ HttpServerConnectionT<SocketType>::HandleRequest() asio::async_write( *m_Socket, asio::buffer(ResponseStr->data(), ResponseStr->size()), - [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); - return; - } - - Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); - using WsConnType = WsAsioConnectionT<SocketType>; - Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); - Ref<WebSocketConnection> WsConnRef(WsConn.Get()); - - WsHandler->OnWebSocketOpen(std::move(WsConnRef)); - WsConn->Start(); - }); + asio::bind_executor( + m_Strand, + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + using WsConnType = WsAsioConnectionT<SocketType>; + Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + })); m_RequestState = RequestState::kDone; return; @@ -1312,7 +1315,7 @@ HttpServerConnectionT<SocketType>::HandleRequest() m_RequestState = RequestState::kWriting; } - const uint32_t RequestNumber = m_RequestCounter.fetch_add(1); + const uint32_t RequestNumber = m_RequestCounter++; if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) { @@ -1444,31 +1447,34 @@ HttpServerConnectionT<SocketType>::HandleRequest() { ZEN_TRACE_CPU("asio::async_write"); - std::string_view Headers = Response->GetHeaders(); + HttpResponse* ResponseRaw = Response.get(); + m_ActiveResponses.push_back(std::move(Response)); + + std::string_view Headers = ResponseRaw->GetHeaders(); std::vector<asio::const_buffer> AsioBuffers; AsioBuffers.push_back(asio::const_buffer(Headers.data(), Headers.size())); - asio::async_write(*m_Socket.get(), - AsioBuffers, - asio::transfer_all(), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + asio::async_write( + *m_Socket.get(), + AsioBuffers, + asio::transfer_all(), + asio::bind_executor( + m_Strand, + [Conn = AsSharedPtr(), ResponseRaw, RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ ResponseRaw); + })); } else { ZEN_TRACE_CPU("asio::async_write"); HttpResponse* ResponseRaw = Response.get(); - - m_ActiveResponsesLock.WithExclusiveLock([&] { - // Keep referenced resources alive - m_ActiveResponses.push_back(std::move(Response)); - }); + m_ActiveResponses.push_back(std::move(Response)); ResponseRaw->SendResponse( *m_Socket, + m_Strand, [Conn = AsSharedPtr(), ResponseRaw, RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ ResponseRaw); }); @@ -1982,11 +1988,24 @@ HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, { const int PrefixLength = Service.UriPrefixLength(); - std::string_view Uri = Request.Url(); - Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); - m_Uri = Uri; - m_UriWithExtension = Uri; - m_QueryString = Request.QueryString(); + std::string_view RawUri = Request.Url(); + RawUri.remove_prefix(std::min(PrefixLength, static_cast<int>(RawUri.size()))); + + // Percent-decode the URI path so handlers see the same decoded paths regardless + // of whether the ASIO or http.sys backend is used (http.sys pre-decodes via CookedUrl). + // Skip the allocation when there is nothing to decode (common case). + if (RawUri.find('%') != std::string_view::npos) + { + m_DecodedUri = Decode(RawUri); + m_Uri = m_DecodedUri; + m_UriWithExtension = m_DecodedUri; + } + else + { + m_Uri = RawUri; + m_UriWithExtension = RawUri; + } + m_QueryString = Request.QueryString(); m_Verb = Request.RequestVerb(); m_ContentLength = Request.Body().Size(); @@ -2083,6 +2102,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); + m_Response->SetKeepAlive(m_Request.IsKeepAlive()); std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -2097,6 +2117,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); + m_Response->SetKeepAlive(m_Request.IsKeepAlive()); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -2108,6 +2129,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); + m_Response->SetKeepAlive(m_Request.IsKeepAlive()); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 4bf8c61bb..31b0315d4 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -147,7 +147,7 @@ public: HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection - virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual bool IsLocalMachineRequest() const override { return false; } virtual std::string_view GetAuthorizationHeader() const override; virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 4d6a53696..f8fb1c9be 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -1173,7 +1173,7 @@ HttpSysServer::RegisterHttpUrls(int BasePort) { Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - if ((Result == ERROR_SHARING_VIOLATION)) + if (Result == ERROR_SHARING_VIOLATION) { ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 9b461662e..b4c65ea96 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -8,7 +8,12 @@ target('zenhttp') add_files("servers/httpsys.cpp", {unity_ignored=true}) add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) - add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") + add_deps("zencore", "zentelemetry", "transport-sdk", "asio") + if has_config("zencpr") then + add_deps("cpr") + else + remove_files("clients/httpclientcpr.cpp") + end add_packages("http_parser", "json11") add_options("httpsys") diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index a02ca7be3..0d8550c5b 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -865,7 +865,7 @@ ZenComputeServer::Run() ExtendableStringBuilder<256> BuildOptions; GetBuildOptions(BuildOptions, '\n'); - ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + ZEN_INFO("Build options ({}/{}, {}):\n{}", GetOperatingSystemName(), GetCpuName(), GetCompilerName(), BuildOptions); } ZEN_INFO(ZEN_APP_NAME " now running as COMPUTE (pid: {})", GetCurrentProcessId()); diff --git a/src/zenserver/config/config.cpp b/src/zenserver/config/config.cpp index c550b174c..60ae93853 100644 --- a/src/zenserver/config/config.cpp +++ b/src/zenserver/config/config.cpp @@ -201,6 +201,7 @@ struct ZenServerCmdLineOptions std::string DataDir; std::string BaseSnapshotDir; std::string SecurityConfigPath; + std::string UnixSocketPath; std::string PortStr; ZenLoggingCmdLineOptions LoggingOptions; @@ -320,7 +321,7 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi "", "unix-socket", "Unix domain socket path to listen on (in addition to TCP)", - cxxopts::value<std::string>(ServerOptions.HttpConfig.UnixSocketPath), + cxxopts::value<std::string>(UnixSocketPath), "<path>"); options.add_option("network", @@ -413,7 +414,7 @@ ZenServerCmdLineOptions::AddCliOptions(cxxopts::Options& options, ZenServerConfi "", "httpclient", "Select HTTP client implementation (e.g. 'curl', 'cpr')", - cxxopts::value<std::string>(ServerOptions.HttpClient.Backend)->default_value("cpr"), + cxxopts::value<std::string>(ServerOptions.HttpClient.Backend)->default_value("curl"), "<http client>"); options.add_option("network", @@ -480,6 +481,11 @@ ZenServerCmdLineOptions::ApplyOptions(cxxopts::Options& options, ZenServerConfig ServerOptions.BaseSnapshotDir = MakeSafeAbsolutePath(BaseSnapshotDir); ServerOptions.SecurityConfigPath = MakeSafeAbsolutePath(SecurityConfigPath); + if (!UnixSocketPath.empty()) + { + ServerOptions.HttpConfig.UnixSocketPath = MakeSafeAbsolutePath(UnixSocketPath); + } + if (PortStr != "auto") { int Port = 0; diff --git a/src/zenserver/frontend/html.zip b/src/zenserver/frontend/html.zip Binary files differdeleted file mode 100644 index 58778a592..000000000 --- a/src/zenserver/frontend/html.zip +++ /dev/null diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 696991403..b0ae0a8b1 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -337,7 +337,7 @@ ZenHubServer::Run() ExtendableStringBuilder<256> BuildOptions; GetBuildOptions(BuildOptions, '\n'); - ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + ZEN_INFO("Build options ({}/{}, {}):\n{}", GetOperatingSystemName(), GetCpuName(), GetCompilerName(), BuildOptions); } ZEN_INFO(ZEN_APP_NAME " now running as HUB (pid: {})", GetCurrentProcessId()); diff --git a/src/zenserver/proxy/zenproxyserver.cpp b/src/zenserver/proxy/zenproxyserver.cpp index acfdad45f..c768e940a 100644 --- a/src/zenserver/proxy/zenproxyserver.cpp +++ b/src/zenserver/proxy/zenproxyserver.cpp @@ -359,7 +359,7 @@ ZenProxyServer::Run() ExtendableStringBuilder<256> BuildOptions; GetBuildOptions(BuildOptions, '\n'); - ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + ZEN_INFO("Build options ({}/{}, {}):\n{}", GetOperatingSystemName(), GetCpuName(), GetCompilerName(), BuildOptions); } ZEN_INFO(ZEN_APP_NAME " now running as PROXY (pid: {})", GetCurrentProcessId()); diff --git a/src/zenserver/storage/objectstore/objectstore.cpp b/src/zenserver/storage/objectstore/objectstore.cpp index 052c3d630..e347e2dfe 100644 --- a/src/zenserver/storage/objectstore/objectstore.cpp +++ b/src/zenserver/storage/objectstore/objectstore.cpp @@ -271,7 +271,7 @@ HttpObjectStoreService::Inititalize() CreateDirectories(BucketsPath); } - static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]() ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; static constexpr AsciiSet ValidBucketCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; m_Router.AddMatcher("path", @@ -292,10 +292,9 @@ HttpObjectStoreService::Inititalize() m_Router.RegisterRoute( "bucket/{path}", [this](zen::HttpRouterRequest& Request) { - const std::string_view EncodedPath = Request.GetCapture(1); - const std::string Path = Request.ServerRequest().Decode(EncodedPath); - const auto Sep = Path.find_last_of('.'); - const bool IsObject = Sep != std::string::npos && Path.size() - Sep > 0; + const std::string_view Path = Request.GetCapture(1); + const auto Sep = Path.find_last_of('.'); + const bool IsObject = Sep != std::string_view::npos && Path.size() - Sep > 0; if (IsObject) { @@ -378,7 +377,7 @@ HttpObjectStoreService::ListBucket(zen::HttpRouterRequest& Request, const std::s const auto QueryParms = Request.ServerRequest().GetQueryParams(); if (auto PrefixParam = QueryParms.GetValue("prefix"); PrefixParam.empty() == false) { - BucketPrefix = PrefixParam; + BucketPrefix = HttpServerRequest::Decode(PrefixParam); } } BucketPrefix.erase(0, BucketPrefix.find_first_not_of('/')); diff --git a/src/zenserver/storage/vfs/vfsservice.cpp b/src/zenserver/storage/vfs/vfsservice.cpp index 863ec348a..f418c4131 100644 --- a/src/zenserver/storage/vfs/vfsservice.cpp +++ b/src/zenserver/storage/vfs/vfsservice.cpp @@ -62,7 +62,7 @@ GetContentAsCbObject(HttpServerRequest& HttpReq, CbObject& Cb) // echo {"method": "mount", "params": {"path": "d:\\VFS_ROOT"}} | curl.exe http://localhost:8558/vfs --data-binary @- // echo {"method": "unmount"} | curl.exe http://localhost:8558/vfs --data-binary @- -VfsService::VfsService(HttpStatusService& StatusService, VfsServiceImpl* ServiceImpl) : m_StatusService(StatusService), m_Impl(ServiceImpl) +VfsService::VfsService(HttpStatusService& StatusService, VfsServiceImpl* ServiceImpl) : m_Impl(ServiceImpl), m_StatusService(StatusService) { m_Router.RegisterRoute( "info", diff --git a/src/zenserver/storage/zenstorageserver.cpp b/src/zenserver/storage/zenstorageserver.cpp index 77588bd6c..bba5e0a61 100644 --- a/src/zenserver/storage/zenstorageserver.cpp +++ b/src/zenserver/storage/zenstorageserver.cpp @@ -720,7 +720,7 @@ ZenStorageServer::Run() ExtendableStringBuilder<256> BuildOptions; GetBuildOptions(BuildOptions, '\n'); - ZEN_INFO("Build options ({}/{}):\n{}", GetOperatingSystemName(), GetCpuName(), BuildOptions); + ZEN_INFO("Build options ({}/{}, {}):\n{}", GetOperatingSystemName(), GetCpuName(), GetCompilerName(), BuildOptions); } ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", GetCurrentProcessId()); diff --git a/src/zenserver/targetver.h b/src/zenserver/targetver.h index d432d6993..4805141de 100644 --- a/src/zenserver/targetver.h +++ b/src/zenserver/targetver.h @@ -7,4 +7,4 @@ // If you wish to build your application for a previous Windows platform, include WinSDKVer.h and // set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. -#include <SDKDDKVer.h> +#include <sdkddkver.h> diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index 3cfaa956d..6b29dadfb 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -12,12 +12,14 @@ target("zenserver") "zenremotestore", "zenstore", "zentelemetry", - "zenutil", - "zenvfs") + "zenutil") + if is_plat("windows") then + add_deps("zenvfs") + end add_headerfiles("**.h") add_rules("utils.bin2c", {extensions = {".zip"}}) add_files("**.cpp") - add_files("frontend/*.zip") + add_files("frontend/html.zip") add_files("zenserver.cpp", {unity_ignored = true }) if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then @@ -58,13 +60,15 @@ target("zenserver") end if is_plat("windows") then - add_ldflags("/subsystem:console,5.02") - add_ldflags("/MANIFEST:EMBED") - add_ldflags("/LTCG") + add_ldflags("/subsystem:console,5.02", {force = true}) + add_ldflags("/MANIFEST:EMBED", {force = true}) + if not (get_config("toolchain") or ""):find("clang") then + add_ldflags("/LTCG") + end add_files("zenserver.rc") add_cxxflags("/bigobj") add_links("delayimp", "projectedfslib") - add_ldflags("/delayload:ProjectedFSLib.dll") + add_ldflags("/delayload:ProjectedFSLib.dll", {force = true}) else remove_files("windows/**") end @@ -78,7 +82,45 @@ target("zenserver") add_ldflags("-framework SystemConfiguration") end - -- to work around some unfortunate Ctrl-C behaviour on Linux/Mac due to + on_load(function(target) + local html_dir = path.join(os.projectdir(), "src/zenserver/frontend/html") + local zip_path = path.join(os.projectdir(), "src/zenserver/frontend/html.zip") + + -- Check if zip needs regeneration + local need_update = not os.isfile(zip_path) + if not need_update then + local zip_mtime = os.mtime(zip_path) + for _, file in ipairs(os.files(path.join(html_dir, "**"))) do + if os.mtime(file) > zip_mtime then + need_update = true + break + end + end + end + + if need_update then + print("Regenerating frontend zip...") + os.tryrm(zip_path) + + import("detect.tools.find_7z") + local cmd_7z = find_7z() + if cmd_7z then + os.execv(cmd_7z, {"a", "-mx0", zip_path, path.join(html_dir, ".")}) + else + import("detect.tools.find_zip") + local zip_cmd = find_zip() + if zip_cmd then + local oldir = os.cd(html_dir) + os.execv(zip_cmd, {"-r", "-0", zip_path, "."}) + os.cd(oldir) + else + raise("Unable to find a suitable zip tool (need 7z or zip)") + end + end + end + end) + + -- to work around some unfortunate Ctrl-C behaviour on Linux/Mac due to -- our use of setsid() at startup we pass in `--no-detach` to zenserver -- ensure that it recieves signals when the user requests termination on_run(function(target) diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 8283f0cbe..6760e0372 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -201,6 +201,9 @@ ZenServerBase::Initialize(const ZenServerConfig& ServerOptions, ZenServerState:: std::chrono::system_clock::now().time_since_epoch()).count(), .BuildOptions = { {"ZEN_ADDRESS_SANITIZER", ZEN_ADDRESS_SANITIZER != 0}, + {"ZEN_THREAD_SANITIZER", ZEN_THREAD_SANITIZER != 0}, + {"ZEN_MEMORY_SANITIZER", ZEN_MEMORY_SANITIZER != 0}, + {"ZEN_LEAK_SANITIZER", ZEN_LEAK_SANITIZER != 0}, {"ZEN_USE_SENTRY", ZEN_USE_SENTRY != 0}, {"ZEN_WITH_TESTS", ZEN_WITH_TESTS != 0}, {"ZEN_USE_MIMALLOC", ZEN_USE_MIMALLOC != 0}, @@ -251,6 +254,12 @@ ZenServerBase::GetBuildOptions(StringBuilderBase& OutOptions, char Separator) co OutOptions << "ZEN_ADDRESS_SANITIZER=" << (ZEN_ADDRESS_SANITIZER ? "1" : "0"); OutOptions << Separator; + OutOptions << "ZEN_THREAD_SANITIZER=" << (ZEN_THREAD_SANITIZER ? "1" : "0"); + OutOptions << Separator; + OutOptions << "ZEN_MEMORY_SANITIZER=" << (ZEN_MEMORY_SANITIZER ? "1" : "0"); + OutOptions << Separator; + OutOptions << "ZEN_LEAK_SANITIZER=" << (ZEN_LEAK_SANITIZER ? "1" : "0"); + OutOptions << Separator; OutOptions << "ZEN_USE_SENTRY=" << (ZEN_USE_SENTRY ? "1" : "0"); OutOptions << Separator; OutOptions << "ZEN_WITH_TESTS=" << (ZEN_WITH_TESTS ? "1" : "0"); @@ -726,6 +735,20 @@ ZenServerMain::Run() Entry = ServerState.Register(m_ServerOptions.BasePort); + // Publish per-instance extended info (e.g. UDS path) via a small shared memory + // section keyed by SessionId so clients can discover it during Snapshot() enumeration. + { + InstanceInfoData InstanceData; + InstanceData.UnixSocketPath = m_ServerOptions.HttpConfig.UnixSocketPath; + m_InstanceInfo.Create(GetSessionId(), InstanceData); + Entry->SignalHasInstanceInfo(); + } + + if (m_ServerOptions.HttpConfig.NoNetwork) + { + Entry->SignalNoNetwork(); + } + if (m_ServerOptions.OwnerPid) { // We are adding a sponsor process to our own entry, can't wait for pick since the code is not run until later @@ -786,7 +809,8 @@ ZenServerMain::MakeLockData(bool IsReady) .EffectiveListenPort = gsl::narrow<uint16_t>(m_ServerOptions.BasePort), .Ready = IsReady, .DataDir = m_ServerOptions.DataDir, - .ExecutablePath = GetRunningExecutablePath()}); + .ExecutablePath = GetRunningExecutablePath(), + .UnixSocketPath = m_ServerOptions.HttpConfig.UnixSocketPath}); }; } // namespace zen diff --git a/src/zenserver/zenserver.h b/src/zenserver/zenserver.h index 374184aa9..830f36e54 100644 --- a/src/zenserver/zenserver.h +++ b/src/zenserver/zenserver.h @@ -148,8 +148,9 @@ public: ZenServerMain& operator=(const ZenServerMain&) = delete; protected: - ZenServerConfig& m_ServerOptions; - LockFile m_LockFile; + ZenServerConfig& m_ServerOptions; + LockFile m_LockFile; + ZenServerInstanceInfo m_InstanceInfo; virtual void InitializeLogging(); virtual void DoRun(ZenServerState::ZenServerEntry* Entry) = 0; diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc index f353bd9cc..abe1acf71 100644 --- a/src/zenserver/zenserver.rc +++ b/src/zenserver/zenserver.rc @@ -28,7 +28,7 @@ LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US // Icon with lowest ID value placed first to ensure application icon // remains consistent on all systems. -IDI_ICON1 ICON "..\\zen.ico" +IDI_ICON1 ICON "../zen.ico" #endif // English (United States) resources ///////////////////////////////////////////////////////////////////////////// diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp index b3450b805..f3edf804d 100644 --- a/src/zenstore/gc.cpp +++ b/src/zenstore/gc.cpp @@ -1776,11 +1776,13 @@ GcScheduler::Initialize(const GcSchedulerConfig& Config) m_LastGcTime = GcClock::TimePoint(GcClock::Duration(SchedulerState["LastGcTime"sv].AsInt64())); m_LastGcExpireTime = GcClock::TimePoint(GcClock::Duration(SchedulerState["LastGcExpireTime"].AsInt64(GcClock::Duration::min().count()))); - if (m_LastGcTime + m_Config.Interval < GcClock::Now()) + if (m_LastGcTime > GcClock::Now() || m_LastGcTime + m_Config.Interval < GcClock::Now()) { - // TODO: Trigger GC? + // Reset if the stored timestamp is in the future (e.g. clock resolution mismatch + // between the build that wrote gc_state and this build) or too far in the past. m_LastGcTime = GcClock::Now(); m_LastLightweightGcTime = m_LastGcTime; + m_LastGcExpireTime = GcClock::TimePoint::min(); } m_AttachmentPassIndex = SchedulerState["AttachmentPassIndex"sv].AsUInt8(); } @@ -2084,6 +2086,10 @@ GcScheduler::GetState() const { Result.RemainingTimeUntilFullGc = std::chrono::seconds::zero(); } + else if (Result.RemainingTimeUntilFullGc > Result.Config.Interval) + { + Result.RemainingTimeUntilFullGc = Result.Config.Interval; + } Result.RemainingTimeUntilLightweightGc = Result.Config.LightweightInterval.count() == 0 @@ -2094,6 +2100,10 @@ GcScheduler::GetState() const { Result.RemainingTimeUntilLightweightGc = std::chrono::seconds::zero(); } + else if (Result.RemainingTimeUntilLightweightGc > Result.Config.LightweightInterval) + { + Result.RemainingTimeUntilLightweightGc = Result.Config.LightweightInterval; + } } return Result; @@ -2418,6 +2428,10 @@ GcScheduler::SchedulerThread() { RemainingTimeUntilGc = std::chrono::seconds::zero(); } + else if (RemainingTimeUntilGc > GcInterval) + { + RemainingTimeUntilGc = GcInterval; + } std::chrono::seconds RemainingTimeUntilLightweightGc = LightweightGcInterval.count() == 0 ? std::chrono::seconds::max() @@ -2428,6 +2442,10 @@ GcScheduler::SchedulerThread() { RemainingTimeUntilLightweightGc = std::chrono::seconds::zero(); } + else if (RemainingTimeUntilLightweightGc > LightweightGcInterval) + { + RemainingTimeUntilLightweightGc = LightweightGcInterval; + } // Don't schedule a lightweight GC if a full GC is // due quite soon anyway diff --git a/src/zenstore/xmake.lua b/src/zenstore/xmake.lua index ea8155e94..94c2b51ca 100644 --- a/src/zenstore/xmake.lua +++ b/src/zenstore/xmake.lua @@ -6,6 +6,11 @@ target('zenstore') add_headerfiles("**.h") add_files("**.cpp") add_includedirs("include", {public=true}) - add_deps("zencore", "zentelemetry", "zenutil", "zenvfs") + add_deps("zencore", "zentelemetry", "zenutil") + if is_plat("windows") then + add_deps("zenvfs") + else + add_includedirs("$(projectdir)/src/zenvfs/include", {public=true}) + end add_deps("robin-map") add_packages("eastl", {public=true}); diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h index 1b8750628..2f76f0d6c 100644 --- a/src/zenutil/include/zenutil/zenserverprocess.h +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -224,8 +224,10 @@ public: enum class FlagsEnum : uint16_t { - kShutdownPlease = 1 << 0, - kIsReady = 1 << 1, + kShutdownPlease = 1 << 0, + kIsReady = 1 << 1, + kHasInstanceInfo = 1 << 2, + kNoNetwork = 1 << 3, }; FRIEND_ENUM_CLASS_FLAGS(FlagsEnum); @@ -236,6 +238,10 @@ public: bool IsShutdownRequested() const; void SignalReady(); bool IsReady() const; + void SignalHasInstanceInfo(); + bool HasInstanceInfo() const; + void SignalNoNetwork(); + bool IsNoNetwork() const; bool AddSponsorProcess(uint32_t Pid, uint64_t Timeout = 0); }; @@ -258,6 +264,51 @@ private: bool m_IsReadOnly = true; }; +/** Per-instance extended data published via a small shared memory section keyed by SessionId. + + Servers create a writable section; clients open it read-only during Snapshot() + enumeration to discover fields that don't fit in the fixed-size ZenServerEntry + (e.g. Unix domain socket path). + + SessionId is preferred over PID for naming because it is unique per server + instance lifetime, avoiding issues with PID reuse on crash/restart. + */ + +struct InstanceInfoData +{ + std::filesystem::path UnixSocketPath; + // Extensible: add more per-instance fields here in the future +}; + +class ZenServerInstanceInfo +{ +public: + ZenServerInstanceInfo(); + ~ZenServerInstanceInfo(); + + ZenServerInstanceInfo(const ZenServerInstanceInfo&) = delete; + ZenServerInstanceInfo& operator=(const ZenServerInstanceInfo&) = delete; + + /// Server-side: create read-write, populate with data + void Create(const Oid& SessionId, const InstanceInfoData& Data); + + /// Client-side: open read-only by SessionId, returns false if not found + [[nodiscard]] bool OpenReadOnly(const Oid& SessionId); + + /// Read the data (valid after Create or successful OpenReadOnly) + [[nodiscard]] InstanceInfoData Read() const; + + bool IsValid() const { return m_Data != nullptr; } + +private: + static std::string MakeName(const Oid& SessionId); + + void* m_hMapFile = nullptr; + uint8_t* m_Data = nullptr; + bool m_IsOwner = false; + Oid m_SessionId; // for POSIX cleanup (shm_unlink) +}; + struct LockFileInfo { int32_t Pid; @@ -266,6 +317,7 @@ struct LockFileInfo bool Ready; std::filesystem::path DataDir; std::filesystem::path ExecutablePath; + std::filesystem::path UnixSocketPath; }; CbObject MakeLockFilePayload(const LockFileInfo& Info); diff --git a/src/zenutil/workerpools.cpp b/src/zenutil/workerpools.cpp index 1bab39b2a..25f961f77 100644 --- a/src/zenutil/workerpools.cpp +++ b/src/zenutil/workerpools.cpp @@ -25,9 +25,9 @@ namespace { struct WorkerPool { - std::unique_ptr<WorkerThreadPool> Pool; const int TreadCount; const std::string_view Name; + std::unique_ptr<WorkerThreadPool> Pool; }; WorkerPool BurstLargeWorkerPool = {.TreadCount = LargeWorkerThreadPoolTreadCount, .Name = "large"}; diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index b9c50be4f..ac614f779 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -449,6 +449,30 @@ ZenServerState::ZenServerEntry::IsReady() const return (Flags.load() & static_cast<uint16_t>(FlagsEnum::kIsReady)) != 0; } +void +ZenServerState::ZenServerEntry::SignalHasInstanceInfo() +{ + Flags |= uint16_t(FlagsEnum::kHasInstanceInfo); +} + +bool +ZenServerState::ZenServerEntry::HasInstanceInfo() const +{ + return (Flags.load() & static_cast<uint16_t>(FlagsEnum::kHasInstanceInfo)) != 0; +} + +void +ZenServerState::ZenServerEntry::SignalNoNetwork() +{ + Flags |= uint16_t(FlagsEnum::kNoNetwork); +} + +bool +ZenServerState::ZenServerEntry::IsNoNetwork() const +{ + return (Flags.load() & static_cast<uint16_t>(FlagsEnum::kNoNetwork)) != 0; +} + bool ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd, uint64_t Timeout) { @@ -492,6 +516,222 @@ ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd, uint64_t Ti } ////////////////////////////////////////////////////////////////////////// +// ZenServerInstanceInfo +////////////////////////////////////////////////////////////////////////// + +static constexpr size_t kInstanceInfoSize = 4096; + +ZenServerInstanceInfo::ZenServerInstanceInfo() = default; + +ZenServerInstanceInfo::~ZenServerInstanceInfo() +{ +#if ZEN_PLATFORM_WINDOWS + if (m_Data) + { + UnmapViewOfFile(m_Data); + } + if (m_hMapFile) + { + CloseHandle(m_hMapFile); + } +#else + if (m_Data != nullptr) + { + munmap(m_Data, kInstanceInfoSize); + } + if (m_hMapFile != nullptr) + { + int Fd = int(intptr_t(m_hMapFile)); + close(Fd); + } + if (m_IsOwner) + { + std::string Name = MakeName(m_SessionId); + shm_unlink(Name.c_str()); + } +#endif + m_Data = nullptr; +} + +std::string +ZenServerInstanceInfo::MakeName(const Oid& SessionId) +{ +#if ZEN_PLATFORM_WINDOWS + return fmt::format("Global\\ZenInstance_{}", SessionId); +#else + // macOS limits shm_open names to ~31 chars (PSHMNAMLEN), so keep this short. + // "/ZenI_" (6) + 24 hex = 30 chars, within the limit. + return fmt::format("/ZenI_{}", SessionId); +#endif +} + +void +ZenServerInstanceInfo::Create(const Oid& SessionId, const InstanceInfoData& Data) +{ + m_SessionId = SessionId; + m_IsOwner = true; + + // Serialize the data to compact binary + CbObjectWriter Cbo; + if (!Data.UnixSocketPath.empty()) + { + Cbo << "unix_socket" << PathToUtf8(Data.UnixSocketPath); + } + CbObject Payload = Cbo.Save(); + + MemoryView PayloadView = Payload.GetView(); + uint32_t PayloadSize = gsl::narrow<uint32_t>(PayloadView.GetSize()); + + std::string Name = MakeName(SessionId); + +#if ZEN_PLATFORM_WINDOWS + zenutil::AnyUserSecurityAttributes Attrs; + + std::wstring WideName(Name.begin(), Name.end()); + + HANDLE hMap = + CreateFileMappingW(INVALID_HANDLE_VALUE, Attrs.Attributes(), PAGE_READWRITE, 0, DWORD(kInstanceInfoSize), WideName.c_str()); + + if (hMap == NULL) + { + // Fall back to Local namespace + std::string LocalName = fmt::format("Local\\ZenInstance_{}", SessionId); + std::wstring WideLocalName(LocalName.begin(), LocalName.end()); + hMap = CreateFileMappingW(INVALID_HANDLE_VALUE, + Attrs.Attributes(), + PAGE_READWRITE, + 0, + DWORD(kInstanceInfoSize), + WideLocalName.c_str()); + } + + if (hMap == NULL) + { + ThrowLastError("Could not create instance info shared memory"); + } + + void* pBuf = MapViewOfFile(hMap, FILE_MAP_ALL_ACCESS, 0, 0, DWORD(kInstanceInfoSize)); + if (pBuf == NULL) + { + CloseHandle(hMap); + ThrowLastError("Could not map instance info shared memory"); + } +#else + int Fd = shm_open(Name.c_str(), O_RDWR | O_CREAT | O_TRUNC | O_CLOEXEC, 0666); + if (Fd < 0) + { + ThrowLastError("Could not create instance info shared memory"); + } + fchmod(Fd, 0666); + + if (ftruncate(Fd, kInstanceInfoSize) < 0) + { + close(Fd); + shm_unlink(Name.c_str()); + ThrowLastError("Could not resize instance info shared memory"); + } + + void* pBuf = mmap(nullptr, kInstanceInfoSize, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + close(Fd); + shm_unlink(Name.c_str()); + ThrowLastError("Could not map instance info shared memory"); + } + + void* hMap = reinterpret_cast<void*>(intptr_t(Fd)); +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<uint8_t*>(pBuf); + + // Write payload: [uint32_t size][compact binary bytes] + memcpy(m_Data, &PayloadSize, sizeof PayloadSize); + if (PayloadSize > 0) + { + memcpy(m_Data + sizeof(uint32_t), PayloadView.GetData(), PayloadSize); + } +} + +bool +ZenServerInstanceInfo::OpenReadOnly(const Oid& SessionId) +{ + m_SessionId = SessionId; + + std::string Name = MakeName(SessionId); + +#if ZEN_PLATFORM_WINDOWS + std::wstring WideName(Name.begin(), Name.end()); + + HANDLE hMap = OpenFileMappingW(FILE_MAP_READ, FALSE, WideName.c_str()); + if (hMap == NULL) + { + // Fall back to Local namespace + std::string LocalName = fmt::format("Local\\ZenInstance_{}", SessionId); + std::wstring WideLocalName(LocalName.begin(), LocalName.end()); + hMap = OpenFileMappingW(FILE_MAP_READ, FALSE, WideLocalName.c_str()); + } + + if (hMap == NULL) + { + return false; + } + + void* pBuf = MapViewOfFile(hMap, FILE_MAP_READ, 0, 0, DWORD(kInstanceInfoSize)); + if (pBuf == NULL) + { + CloseHandle(hMap); + return false; + } +#else + int Fd = shm_open(Name.c_str(), O_RDONLY | O_CLOEXEC, 0666); + if (Fd < 0) + { + return false; + } + + void* pBuf = mmap(nullptr, kInstanceInfoSize, PROT_READ, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + close(Fd); + return false; + } + + void* hMap = reinterpret_cast<void*>(intptr_t(Fd)); +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<uint8_t*>(pBuf); + m_IsOwner = false; + + return true; +} + +InstanceInfoData +ZenServerInstanceInfo::Read() const +{ + InstanceInfoData Result; + + if (m_Data == nullptr) + { + return Result; + } + + uint32_t PayloadSize = 0; + memcpy(&PayloadSize, m_Data, sizeof PayloadSize); + + if (PayloadSize == 0 || PayloadSize > kInstanceInfoSize - sizeof(uint32_t)) + { + return Result; + } + + CbObject Payload = CbObject::Clone(m_Data + sizeof(uint32_t)); + Result.UnixSocketPath = Payload["unix_socket"].AsU8String(); + + return Result; +} + +////////////////////////////////////////////////////////////////////////// std::atomic<int> ZenServerTestCounter{0}; @@ -1234,6 +1474,10 @@ MakeLockFilePayload(const LockFileInfo& Info) CbObjectWriter Cbo; Cbo << "pid" << Info.Pid << "data" << PathToUtf8(Info.DataDir) << "port" << Info.EffectiveListenPort << "session_id" << Info.SessionId << "ready" << Info.Ready << "executable" << PathToUtf8(Info.ExecutablePath); + if (!Info.UnixSocketPath.empty()) + { + Cbo << "unix_socket" << PathToUtf8(Info.UnixSocketPath); + } return Cbo.Save(); } LockFileInfo @@ -1246,6 +1490,7 @@ ReadLockFilePayload(const CbObject& Payload) Info.Ready = Payload["ready"].AsBool(); Info.DataDir = Payload["data"].AsU8String(); Info.ExecutablePath = Payload["executable"].AsU8String(); + Info.UnixSocketPath = Payload["unix_socket"].AsU8String(); return Info; } @@ -1275,7 +1520,7 @@ ValidateLockFileInfo(const LockFileInfo& Info, std::string& OutReason) OutReason = fmt::format("session id ({}) is not valid", Info.SessionId); return false; } - if (Info.EffectiveListenPort == 0) + if (Info.EffectiveListenPort == 0 && Info.UnixSocketPath.empty()) { OutReason = fmt::format("listen port ({}) is not valid", Info.EffectiveListenPort); return false; diff --git a/thirdparty/asio/asio/include/asio/detail/io_uring_service.hpp b/thirdparty/asio/asio/include/asio/detail/io_uring_service.hpp index 7cc6cc51b..f76ac953a 100644 --- a/thirdparty/asio/asio/include/asio/detail/io_uring_service.hpp +++ b/thirdparty/asio/asio/include/asio/detail/io_uring_service.hpp @@ -50,7 +50,7 @@ private: public: enum op_types { read_op = 0, write_op = 1, except_op = 2, max_ops = 3 }; - class io_object; + struct io_object; // An I/O queue stores operations that must run serially. class io_queue : operation diff --git a/thirdparty/xmake.lua b/thirdparty/xmake.lua index 1fb5acad7..e8832d50a 100644 --- a/thirdparty/xmake.lua +++ b/thirdparty/xmake.lua @@ -36,8 +36,8 @@ target('rpmalloc') set_kind("static") set_group('thirdparty') set_languages("c17", "cxx20") - if is_os("windows") then - add_cflags("/experimental:c11atomics", {force=true}) + if is_os("windows") and not (get_config("toolchain") or ""):find("clang") then + add_cflags("/experimental:c11atomics", {force=true, tools="cl"}) end add_defines("RPMALLOC_FIRST_CLASS_HEAPS=1", "ENABLE_STATISTICS=1", "ENABLE_OVERRIDE=0") add_files("rpmalloc/rpmalloc.c") @@ -67,7 +67,11 @@ target('cpr') target('asio') set_kind('headeronly') set_group('thirdparty') - add_defines("ASIO_STANDLONE", "ASIO_HEADER_ONLY", {public=true}) + add_defines("ASIO_STANDALONE", "ASIO_HEADER_ONLY", {public=true}) + if is_plat("linux") and not (get_config("toolchain") == "ue-clang") then + add_defines("ASIO_HAS_IO_URING", {public=true}) + add_packages("liburing", {public=true}) + end add_headerfiles("asio/asio/include/**.hpp") add_includedirs("asio/asio/include", {public=true}) @@ -78,8 +82,8 @@ target("blake3") add_headerfiles("blake3/c/blake3.h") add_includedirs("blake3/c", {public=true}) - if is_os("windows") then - add_cflags("/experimental:c11atomics") + if is_os("windows") and not (get_config("toolchain") or ""):find("clang") then + add_cflags("/experimental:c11atomics", {tools="cl"}) add_cflags("/wd4245", {force = true}) -- conversion from 'type1' to 'type2', possible loss of data elseif is_os("macosx") then add_cflags("-Wno-unused-function") @@ -135,6 +139,9 @@ target("fmt") set_kind("static") set_group("thirdparty") set_warnings("allextra") + if is_plat("windows") then + add_cxxflags("/wd4834", {force=true}) -- C4834: discarding return value of [[nodiscard]] function + end add_files("fmt/src/format.cc", "fmt/src/os.cc") add_headerfiles("fmt/include/**.h") add_includedirs("fmt/include", {public=true}) @@ -20,3 +20,15 @@ # TSAN reports as a race. This is benign: the slot is always NULL and writing NULL # to it has no observable effect. race:eastl::hashtable*DoFreeNodes* + +# UE::Trace's GetUid() uses a racy static uint32 cache (Uid = Uid ? Uid : Initialize()) +# as a fast path to avoid re-entering Initialize(). The actual initialization is done via +# a thread-safe static (Uid_ThreadSafeInit) inside Initialize(), so the worst case is +# redundant calls to Initialize() which always returns the same value. +race:*Fields::GetUid* + +# TRACE_CPU_SCOPE generates a function-local `static int32 scope_id` that is lazily +# initialized without synchronization (if (0 == scope_id) scope_id = ScopeNew(...)). +# Same benign pattern as GetUid: the worst case is redundant calls to ScopeNew() which +# always returns the same value for a given scope name. +race:*$trace_scope_id* @@ -6,6 +6,12 @@ set_configvar("ZEN_DATA_FORCE_SCRUB_VERSION", 0) set_allowedplats("windows", "linux", "macosx") set_allowedarchs("windows|x64", "linux|x86_64", "macosx|x86_64", "macosx|arm64") +-- Returns true when building for Windows with native MSVC (not clang-cl cross-compilation) +function is_native_msvc() + local tc = get_config("toolchain") or "" + return is_plat("windows") and tc ~= "clang-cl" +end + -------------------------------------------------------------------------- -- We support debug and release modes. On Windows we use static CRT to -- minimize dependencies. @@ -48,12 +54,27 @@ set_policy("build.sanitizer.address", use_asan) -- ThreadSanitizer, MemorySanitizer, LeakSanitizer, and UndefinedBehaviorSanitizer -- are supported on Linux and MacOS only. +-- +-- You can enable these by editing the xmake.lua directly, or by passing the +-- appropriate flags on the command line: +-- +-- `xmake --policies=build.sanitizer.thread:y` for ThreadSanitizer, +-- `xmake --policies=build.sanitizer.memory:y` for MemorySanitizer, etc. ---set_policy("build.sanitizer.thread", true) ---set_policy("build.sanitizer.memory", true) +-- When using TSAN you will want to also use the suppression tile to silence +-- known benign races. You do this by ensuring the the TSAN_OPTIONS environment +-- vriable is set to something like `TSAN_OPTIONS="suppressions=$(projectdir)/tsan.supp"` +-- +-- `prompt> TSAN_OPTIONS="detect_deadlocks=0 suppressions=$(projectdir)/tsan.supp" xmake run zenserver` + +--set_policy("build.sanitizer.thread", true) --set_policy("build.sanitizer.leak", true) --set_policy("build.sanitizer.undefined", true) +-- In practice, this does not work because of the difficulty of compiling +-- dependencies with MemorySanitizer. +--set_policy("build.sanitizer.memory", true) + -------------------------------------------------------------------------- -- Dependencies @@ -131,7 +152,9 @@ enable_unity = false if is_mode("release") then -- LTO does not appear to work with the current Linux UE toolchain -- Also, disabled LTO on Mac to reduce time spent building openssl tests - if not is_plat("linux", "macosx") then + -- Disabled for cross-compilation (clang-cl on Linux) due to cmake package compat issues + local is_cross_win = is_plat("windows") and is_host("linux") + if not is_plat("linux", "macosx") and not is_cross_win then set_policy("build.optimization.lto", true) end set_optimize("fastest") @@ -154,6 +177,12 @@ else set_encodings("source:utf-8", "target:utf-8") end +-- When cross-compiling with clang-cl, the xwin SDK may ship a newer MSVC STL +-- than the host clang version supports. Bypass the version gate. +if is_plat("windows") and not is_native_msvc() then + add_defines("_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH") +end + if is_os("windows") then add_defines( "_CRT_SECURE_NO_WARNINGS", @@ -166,33 +195,49 @@ if is_os("windows") then "_WIN32_WINNT=0x0A00", "_WINSOCK_DEPRECATED_NO_WARNINGS" -- let us use the ANSI functions ) - -- Make builds more deterministic and portable - add_cxxflags("/d1trimfile:$(curdir)\\") -- eliminates the base path from __FILE__ paths - add_cxxflags("/experimental:deterministic") -- (more) deterministic compiler output - add_ldflags("/PDBALTPATH:%_PDB%") -- deterministic pdb reference in exe - - add_cxxflags("/Zc:preprocessor") -- Enable preprocessor conformance mode - add_cxxflags("/Zc:u8EscapeEncoding") -- Enable UTF-8 encoding for u8 string literals - add_cxxflags("/Zc:inline") -- Enforce inline semantics + -- Make builds more deterministic and portable (MSVC-only flags) + if is_native_msvc() then + add_cxxflags("/d1trimfile:$(curdir)\\") -- eliminates the base path from __FILE__ paths + add_cxxflags("/experimental:deterministic") -- (more) deterministic compiler output + add_ldflags("/PDBALTPATH:%_PDB%") -- deterministic pdb reference in exe + add_cxxflags("/Zc:u8EscapeEncoding") -- Enable UTF-8 encoding for u8 string literals (clang does this by default) + add_cxxflags("/Zc:preprocessor") -- Enable preprocessor conformance mode + add_cxxflags("/Zc:inline") -- Enforce inline semantics + end -- add_ldflags("/MAP") end -if is_os("linux") or is_os("macosx") then - add_cxxflags("-Wno-implicit-fallthrough") - add_cxxflags("-Wno-missing-field-initializers") - add_cxxflags("-Wno-strict-aliasing") - add_cxxflags("-Wno-switch") - add_cxxflags("-Wno-unused-lambda-capture") - add_cxxflags("-Wno-unused-private-field") - add_cxxflags("-Wno-unused-value") - add_cxxflags("-Wno-unused-variable") - add_cxxflags("-Wno-vla-cxx-extension") +-- Clang warning suppressions (native clang on Linux/Mac, or clang-cl cross-compile) +if is_os("linux") or is_os("macosx") or not is_native_msvc() then + -- Silence warnings about unrecognized -Wno-* flags on older clang versions + add_cxxflags("-Wno-unknown-warning-option", {force = true}) + add_cxxflags("-Wno-delete-non-abstract-non-virtual-dtor", {force = true}) + add_cxxflags("-Wno-format", {force = true}) + add_cxxflags("-Wno-implicit-fallthrough", {force = true}) + add_cxxflags("-Wno-inconsistent-missing-override", {force = true}) + add_cxxflags("-Wno-missing-field-initializers", {force = true}) + add_cxxflags("-Wno-nonportable-include-path", {force = true}) + add_cxxflags("-Wno-sign-compare", {force = true}) + add_cxxflags("-Wno-strict-aliasing", {force = true}) + add_cxxflags("-Wno-switch", {force = true}) + add_cxxflags("-Wno-unused-lambda-capture", {force = true}) + add_cxxflags("-Wno-unused-private-field", {force = true}) + add_cxxflags("-Wno-unused-value", {force = true}) + add_cxxflags("-Wno-unused-variable", {force = true}) + add_cxxflags("-Wno-vla-cxx-extension", {force = true}) + -- GCC false positive: constinit static locals used by reference are reported as unused-but-set + add_cxxflags("-Wno-unused-but-set-variable", {tools = "gcc"}) end -if is_os("macosx") then - -- silence warnings about -Wno-vla-cxx-extension since to my knowledge we can't - -- detect the clang version used in Xcode and only recent versions contain this flag - add_cxxflags("-Wno-unknown-warning-option") +-- Additional suppressions specific to clang-cl cross-compilation +if get_config("toolchain") == "clang-cl" then + add_cxxflags("-Wno-cast-function-type-mismatch", {force = true}) + add_cxxflags("-Wno-parentheses-equality", {force = true}) + add_cxxflags("-Wno-reorder-ctor", {force = true}) + add_cxxflags("-Wno-unused-but-set-variable", {force = true}) + add_cxxflags("-Wno-unused-parameter", {force = true}) + add_cflags("-Wno-unknown-warning-option", {force = true}) + add_cflags("-Wno-unused-command-line-argument", {force = true}) end if is_os("linux") then @@ -286,6 +331,13 @@ option("zentrace") option_end() add_define_by_config("ZEN_WITH_TRACE", "zentrace") +option("zencpr") + set_default(true) + set_showmenu(true) + set_description("Enable CPR HTTP client backend") +option_end() +add_define_by_config("ZEN_WITH_CPR", "zencpr") + set_warnings("allextra", "error") set_languages("cxx20") @@ -328,7 +380,9 @@ end includes("src/zenstore", "src/zenstore-test") includes("src/zentelemetry", "src/zentelemetry-test") includes("src/zenutil", "src/zenutil-test") -includes("src/zenvfs") +if is_plat("windows") then + includes("src/zenvfs") +end includes("src/zenserver", "src/zenserver-test") includes("src/zen") includes("src/zentest-appstub") @@ -349,6 +403,23 @@ task("bundle") bundle() end) +task("docker") + set_menu { + usage = "xmake docker [--push] [--no-wine] [--win-binary PATH] [--tag TAG] [--registry REGISTRY]", + description = "Build Docker image for zenserver compute workers", + options = { + {nil, "push", "k", nil, "Push the image after building"}, + {nil, "no-wine", "k", nil, "Build without Wine (smaller image, Linux-only workers)"}, + {nil, "win-binary", "v", nil, "Path to Windows zenserver.exe to include in image"}, + {nil, "tag", "v", nil, "Override image tag (default: version from VERSION.txt)"}, + {nil, "registry", "v", nil, "Registry prefix (e.g. ghcr.io/epicgames)"}, + } + } + on_run(function () + import("scripts.docker") + docker() + end) + task("kill") set_menu { usage = "xmake kill", @@ -366,16 +437,6 @@ task("kill") end end) -task("updatefrontend") - set_menu { - usage = "xmake updatefrontend", - description = "Create Zip of the frontend/html folder for bundling with zenserver executable", - } - on_run(function() - import("scripts.updatefrontend") - updatefrontend() - end) - task("precommit") set_menu { usage = "xmake precommit", @@ -430,401 +491,6 @@ task("test") } } on_run(function() - import("core.base.option") - import("core.project.config") - import("core.project.project") - - config.load() - - -- Override table: target name -> short name (for targets that don't follow convention) - local short_name_overrides = { - ["zenserver-test"] = "integration", - } - - -- Build test list from targets in the "tests" group - local available_tests = {} - for name, target in pairs(project.targets()) do - if target:get("group") == "tests" and name:endswith("-test") then - local short = short_name_overrides[name] - if not short then - -- Derive short name: "zencore-test" -> "core" - short = name - if short:startswith("zen") then short = short:sub(4) end - if short:endswith("-test") then short = short:sub(1, -6) end - end - table.insert(available_tests, {short, name}) - end - end - - -- Add non-test-group entries that have a test subcommand - table.insert(available_tests, {"server", "zenserver"}) - - table.sort(available_tests, function(a, b) return a[1] < b[1] end) - - -- Handle --list: print discovered test names and exit - if option.get("list") then - printf("Available tests:\n") - for _, entry in ipairs(available_tests) do - printf(" %-16s -> %s\n", entry[1], entry[2]) - end - return - end - - local testname = option.get("run") - - -- Parse comma-separated test names into a set - local requested = {} - for token in testname:gmatch("[^,]+") do - requested[token:match("^%s*(.-)%s*$")] = true - end - - -- Filter to requested test(s) - local tests = {} - local matched = {} - - for _, entry in ipairs(available_tests) do - local name, target = entry[1], entry[2] - if requested["all"] or requested[name] then - table.insert(tests, {name = name, target = target}) - matched[name] = true - end - end - - -- Check for unknown test names - if not requested["all"] then - for name, _ in pairs(requested) do - if not matched[name] then - raise("no tests match specification: '%s'", name) - end - end - end - - if #tests == 0 then - raise("no tests match specification: '%s'", testname) - end - - local plat, arch - if is_host("windows") then - plat = "windows" - arch = "x64" - elseif is_host("macosx") then - plat = "macosx" - arch = is_arch("arm64") and "arm64" or "x86_64" - else - plat = "linux" - arch = "x86_64" - end - - -- Only reconfigure if current config doesn't already match - if config.get("mode") ~= "debug" or config.get("plat") ~= plat or config.get("arch") ~= arch then - local toolchain_flag = config.get("toolchain") and ("--toolchain=" .. config.get("toolchain")) or "" - local sdk_flag = config.get("sdk") and ("--sdk=" .. config.get("sdk")) or "" - os.exec("xmake config -c -m debug -p %s -a %s %s %s", plat, arch, toolchain_flag, sdk_flag) - end - - -- Build targets we're going to run - if requested["all"] then - os.exec("xmake build -y") - else - for _, entry in ipairs(tests) do - os.exec("xmake build -y %s", entry.target) - end - end - - local use_junit_reporting = option.get("junit") - local use_noskip = option.get("noskip") - local use_verbose = option.get("verbose") - local repeat_count = tonumber(option.get("repeat")) or 1 - local extra_args = option.get("arguments") or {} - local junit_report_files = {} - - local junit_report_dir - if use_junit_reporting then - junit_report_dir = path.join(os.projectdir(), config.get("buildir"), "reports") - os.mkdir(junit_report_dir) - end - - -- Results collection for summary table - local results = {} - local any_failed = false - - -- Format a number with thousands separators (e.g. 31103 -> "31,103") - local function format_number(n) - local s = tostring(n) - local pos = #s % 3 - if pos == 0 then pos = 3 end - local result = s:sub(1, pos) - for i = pos + 1, #s, 3 do - result = result .. "," .. s:sub(i, i + 2) - end - return result - end - - -- Center a string within a given width - local function center_str(s, width) - local pad = width - #s - local lpad = math.floor(pad / 2) - local rpad = pad - lpad - return string.rep(" ", lpad) .. s .. string.rep(" ", rpad) - end - - -- Left-align a string within a given width (with 1-space left margin) - local function left_align_str(s, width) - return " " .. s .. string.rep(" ", width - #s - 1) - end - - -- Right-align a string within a given width (with 1-space right margin) - local function right_align_str(s, width) - return string.rep(" ", width - #s - 1) .. s .. " " - end - - -- Format elapsed seconds as a human-readable string - local function format_time(seconds) - if seconds >= 60 then - local mins = math.floor(seconds / 60) - local secs = seconds - mins * 60 - return string.format("%dm %04.1fs", mins, secs) - else - return string.format("%.1fs", seconds) - end - end - - -- Parse test summary file written by TestListener - local function parse_summary_file(filepath) - if not os.isfile(filepath) then return nil end - local content = io.readfile(filepath) - if not content then return nil end - local ct = content:match("cases_total=(%d+)") - local cp = content:match("cases_passed=(%d+)") - local at = content:match("assertions_total=(%d+)") - local ap = content:match("assertions_passed=(%d+)") - if ct then - local failures = {} - for name, file, line in content:gmatch("failed=([^|\n]+)|([^|\n]+)|(%d+)") do - table.insert(failures, {name = name, file = file, line = tonumber(line)}) - end - local es = content:match("elapsed_seconds=([%d%.]+)") - return { - cases_total = tonumber(ct), - cases_passed = tonumber(cp) or 0, - asserts_total = tonumber(at) or 0, - asserts_passed = tonumber(ap) or 0, - elapsed_seconds = tonumber(es) or 0, - failures = failures - } - end - return nil - end - - -- Temp directory for summary files - local summary_dir = path.join(os.tmpdir(), "zen-test-summary") - os.mkdir(summary_dir) - - -- Run each test suite and collect results - for iteration = 1, repeat_count do - if repeat_count > 1 then - printf("\n*** Iteration %d/%d ***\n", iteration, repeat_count) - end - - for _, entry in ipairs(tests) do - local name, target = entry.name, entry.target - printf("=== %s ===\n", target) - - local suite_name = target - if name == "server" then - suite_name = "zenserver (test)" - end - - local cmd = string.format("xmake run %s", target) - if name == "server" then - cmd = string.format("xmake run %s test", target) - end - cmd = string.format("%s --duration=true", cmd) - - if use_junit_reporting then - local junit_report_file = path.join(junit_report_dir, string.format("junit-%s-%s-%s.xml", config.plat(), arch, target)) - junit_report_files[target] = junit_report_file - cmd = string.format("%s --reporters=junit --out=%s", cmd, junit_report_file) - end - if use_noskip then - cmd = string.format("%s --no-skip", cmd) - end - if use_verbose and name == "integration" then - cmd = string.format("%s --verbose", cmd) - end - for _, arg in ipairs(extra_args) do - cmd = string.format("%s %s", cmd, arg) - end - - -- Tell TestListener where to write the summary - local summary_file = path.join(summary_dir, target .. ".txt") - os.setenv("ZEN_TEST_SUMMARY_FILE", summary_file) - - -- Run test with real-time streaming output - local test_ok = true - try { - function() - os.exec(cmd) - end, - catch { - function(errors) - test_ok = false - end - } - } - - -- Read summary written by TestListener - local summary = parse_summary_file(summary_file) - os.tryrm(summary_file) - - if not test_ok then - any_failed = true - end - - table.insert(results, { - suite = suite_name, - cases_passed = summary and summary.cases_passed or 0, - cases_total = summary and summary.cases_total or 0, - asserts_passed = summary and summary.asserts_passed or 0, - asserts_total = summary and summary.asserts_total or 0, - elapsed_seconds = summary and summary.elapsed_seconds or 0, - failures = summary and summary.failures or {}, - passed = test_ok - }) - end - - if any_failed then - if repeat_count > 1 then - printf("\n*** Failure detected on iteration %d, stopping ***\n", iteration) - end - break - end - end - - -- Clean up - os.setenv("ZEN_TEST_SUMMARY_FILE", "") - os.tryrm(summary_dir) - - -- Print JUnit reports if requested - for test, junit_report_file in pairs(junit_report_files) do - printf("=== report - %s ===\n", test) - if os.isfile(junit_report_file) then - local data = io.readfile(junit_report_file) - if data then - print(data) - end - end - end - - -- Print summary table - if #results > 0 then - -- Calculate column widths based on content - local col_suite = #("Suite") - local col_cases = #("Cases") - local col_asserts = #("Assertions") - local col_time = #("Time") - local col_status = #("Status") - - -- Compute totals - local total_cases_passed = 0 - local total_cases_total = 0 - local total_asserts_passed = 0 - local total_asserts_total = 0 - local total_elapsed = 0 - - for _, r in ipairs(results) do - col_suite = math.max(col_suite, #r.suite) - local cases_str = format_number(r.cases_passed) .. "/" .. format_number(r.cases_total) - col_cases = math.max(col_cases, #cases_str) - local asserts_str = format_number(r.asserts_passed) .. "/" .. format_number(r.asserts_total) - col_asserts = math.max(col_asserts, #asserts_str) - col_time = math.max(col_time, #format_time(r.elapsed_seconds)) - local status_str = r.passed and "SUCCESS" or "FAILED" - col_status = math.max(col_status, #status_str) - - total_cases_passed = total_cases_passed + r.cases_passed - total_cases_total = total_cases_total + r.cases_total - total_asserts_passed = total_asserts_passed + r.asserts_passed - total_asserts_total = total_asserts_total + r.asserts_total - total_elapsed = total_elapsed + r.elapsed_seconds - end - - -- Account for totals row in column widths - col_suite = math.max(col_suite, #("Total")) - col_cases = math.max(col_cases, #(format_number(total_cases_passed) .. "/" .. format_number(total_cases_total))) - col_asserts = math.max(col_asserts, #(format_number(total_asserts_passed) .. "/" .. format_number(total_asserts_total))) - col_time = math.max(col_time, #format_time(total_elapsed)) - - -- Add padding (1 space each side) - col_suite = col_suite + 2 - col_cases = col_cases + 2 - col_asserts = col_asserts + 2 - col_time = col_time + 2 - col_status = col_status + 2 - - -- Build horizontal border segments - local h_suite = string.rep("-", col_suite) - local h_cases = string.rep("-", col_cases) - local h_asserts = string.rep("-", col_asserts) - local h_time = string.rep("-", col_time) - local h_status = string.rep("-", col_status) - - local top = "+" .. h_suite .. "+" .. h_cases .. "+" .. h_asserts .. "+" .. h_time .. "+" .. h_status .. "+" - local mid = "+" .. h_suite .. "+" .. h_cases .. "+" .. h_asserts .. "+" .. h_time .. "+" .. h_status .. "+" - local bottom = "+" .. h_suite .. "+" .. h_cases .. "+" .. h_asserts .. "+" .. h_time .. "+" .. h_status .. "+" - local vbar = "|" - - local header_msg = any_failed and "Some tests failed:" or "All tests passed:" - printf("\n* %s\n", header_msg) - printf(" %s\n", top) - printf(" %s%s%s%s%s%s%s%s%s%s%s\n", vbar, center_str("Suite", col_suite), vbar, center_str("Cases", col_cases), vbar, center_str("Assertions", col_asserts), vbar, center_str("Time", col_time), vbar, center_str("Status", col_status), vbar) - - for _, r in ipairs(results) do - printf(" %s\n", mid) - local cases_str = format_number(r.cases_passed) .. "/" .. format_number(r.cases_total) - local asserts_str = format_number(r.asserts_passed) .. "/" .. format_number(r.asserts_total) - local time_str = format_time(r.elapsed_seconds) - local status_str = r.passed and "SUCCESS" or "FAILED" - printf(" %s%s%s%s%s%s%s%s%s%s%s\n", vbar, left_align_str(r.suite, col_suite), vbar, right_align_str(cases_str, col_cases), vbar, right_align_str(asserts_str, col_asserts), vbar, right_align_str(time_str, col_time), vbar, right_align_str(status_str, col_status), vbar) - end - - -- Totals row - if #results > 1 then - local h_suite_eq = string.rep("=", col_suite) - local h_cases_eq = string.rep("=", col_cases) - local h_asserts_eq = string.rep("=", col_asserts) - local h_time_eq = string.rep("=", col_time) - local h_status_eq = string.rep("=", col_status) - local totals_sep = "+" .. h_suite_eq .. "+" .. h_cases_eq .. "+" .. h_asserts_eq .. "+" .. h_time_eq .. "+" .. h_status_eq .. "+" - printf(" %s\n", totals_sep) - - local total_cases_str = format_number(total_cases_passed) .. "/" .. format_number(total_cases_total) - local total_asserts_str = format_number(total_asserts_passed) .. "/" .. format_number(total_asserts_total) - local total_time_str = format_time(total_elapsed) - local total_status_str = any_failed and "FAILED" or "SUCCESS" - printf(" %s%s%s%s%s%s%s%s%s%s%s\n", vbar, left_align_str("Total", col_suite), vbar, right_align_str(total_cases_str, col_cases), vbar, right_align_str(total_asserts_str, col_asserts), vbar, right_align_str(total_time_str, col_time), vbar, right_align_str(total_status_str, col_status), vbar) - end - - printf(" %s\n", bottom) - end - - -- Print list of individual failing tests - if any_failed then - printf("\n Failures:\n") - for _, r in ipairs(results) do - if #r.failures > 0 then - printf(" -- %s --\n", r.suite) - for _, f in ipairs(r.failures) do - printf(" FAILED: %s (%s:%d)\n", f.name, f.file, f.line) - end - elseif not r.passed then - printf(" -- %s --\n", r.suite) - printf(" (test binary exited with error, no failure details available)\n") - end - end - end - - if any_failed then - raise("one or more test suites failed") - end + import("scripts.test") + test() end) |