aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2026-05-04 16:46:03 +0200
committerGitHub Enterprise <[email protected]>2026-05-04 16:46:03 +0200
commit10d2a61fe1c848f44033e8450ff3a5ffa7f4322a (patch)
treeaa66c6a068b50d2390bdae5f857c7151f15e5a86 /src
parentTui picker fixes (#1027) (diff)
downloadarchived-zen-10d2a61fe1c848f44033e8450ff3a5ffa7f4322a.tar.xz
archived-zen-10d2a61fe1c848f44033e8450ff3a5ffa7f4322a.zip
zenhttp improvements (robustness / correctness) (#968)
A collection of security, correctness, and robustness fixes in `zenhttp` and `zencore` surfaced by security review. Most items are small, independent commits grouped here because they all tighten trust boundaries or fix UB along the same code paths. ## WebSocket protocol hardening (RFC 6455) - **Enforce the client-side mask bit**. Server-side frame loops now reject unmasked frames with close code 1002 per §5.1. Prevents HTTP intermediary smuggling. - **Validate control frames and RSV bits**. Fragmented control frames, oversized (>125 B) control payloads, and any non-zero RSV bit now fail the connection before allocation. - **Lower per-frame payload cap** from 256 MB → 4 MB. Bounds per-connection accumulator memory. - **Implement message fragmentation**. Continuation frames are coalesced and delivered as a single message; interleaved non-control frames close with 1002; assembled messages are capped at 4 MB (1009 on overflow). Previously partial fragments were delivered to handlers, bypassing payload validation. - **Parse the 101 handshake response properly** in `HttpWsClient`. Status-line, `Upgrade`, `Connection`, and `Sec-WebSocket-Accept` are now matched exactly rather than via substring searches against the full body. ## Auth / OIDC hardening - **Constant-time password compare** in `PasswordSecurity::IsAllowed` (closes a remote length/content timing oracle). Adds a shared `ConstantTimeEquals` helper. - **Harden Basic-auth header parsing**: trim trailing LWS, reject control bytes and DEL in the credential. - **OIDC discovery pinning**: require HTTPS (loopback exempt), verify `issuer` matches `BaseUrl`, require `token_endpoint` / `userinfo_endpoint` / `jwks_uri` to share origin with `BaseUrl`, reject empty `token_endpoint`. - **Restrict `POST /auth/oidc/refreshtoken`** to local-machine requests. Previously unauthenticated in default deployments — remote callers could evict or replace cached tokens. - **Stop logging OIDC provider response bodies** on refresh failure (IdPs echo `refresh_token` back in error bodies). - **Drop the unused `IdentityToken` field** from `OidcClient` / `OpenIdToken` so nothing in the tree accidentally trusts an unverified JWT. ## Auth state encryption migration - Add `AesGcm` AEAD primitive (BCrypt / OpenSSL backends, mbedTLS stubbed) and `CryptoRandom::Fill` CSPRNG helper in `zencore/crypto.h`. - Migrate authstate file from AES-256-CBC with a fixed IV to AES-GCM with a fresh 12-byte random nonce per write and the 4-byte `ZEN1` magic bound as AAD. Legacy-CBC files are transparently read once and rewritten in the new format. ## Filesystem / IO robustness - `IoBufferExtendedCore::Materialize` now checks `MAP_FAILED` on POSIX (was comparing to `nullptr`, which let the failure sentinel propagate into later reads and `munmap(MAP_FAILED, ...)`). - `IoBufferBuilder::MakeFromFile / MakeFromTemporaryFile`: close the FD/HANDLE on exception via a dismissable `ScopeGuard`; actually check the `fstat()` return value (previously used an uninitialized `FileSize`). - `ReadFromFileMaybe`: loop short reads, retry `EINTR`, chunk Windows `ReadFile` at `0xFFFFFFFF` bytes (fixes silent truncation of multi-GiB reads). - `WipeDirectory`: compare `FindFirstFileW` handle against `INVALID_HANDLE_VALUE` rather than `nullptr`. - `RemoveFileNative` (Linux/macOS): report non-`ENOENT` stat failures via the `std::error_code` out-param and stop reading `st_mode` after a failed stat. ## Buffer / compression correctness - Avoid per-copy `IoBufferCore` heap allocations in `CompositeBuffer::CopyTo / ViewOrCopyRange` iterators; add fast path for `BufferHeader::Read` when the 64-byte header fits in the first plain-memory segment. - `BufferHeader`: add `IsHeaderValid()` gate covering `BlockSizeExponent` range, `BlockCount * BlockSize` overflow, and `TotalRawSize` bounds before any arithmetic uses them. Defends against attacker-controlled headers that can pass the CRC and trigger OOB writes in `DecompressBlock`.
Diffstat (limited to 'src')
-rw-r--r--src/zencore/basicfile.cpp65
-rw-r--r--src/zencore/compositebuffer.cpp52
-rw-r--r--src/zencore/compress.cpp81
-rw-r--r--src/zencore/crypto.cpp550
-rw-r--r--src/zencore/filesystem.cpp81
-rw-r--r--src/zencore/include/zencore/crypto.h75
-rw-r--r--src/zencore/include/zencore/filesystem.h13
-rw-r--r--src/zencore/include/zencore/sharedbuffer.h6
-rw-r--r--src/zencore/include/zencore/string.h28
-rw-r--r--src/zencore/iobuffer.cpp109
-rw-r--r--src/zenhttp/auth/authmgr.cpp242
-rw-r--r--src/zenhttp/auth/authservice.cpp14
-rw-r--r--src/zenhttp/auth/oidc.cpp168
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp132
-rw-r--r--src/zenhttp/include/zenhttp/auth/authmgr.h8
-rw-r--r--src/zenhttp/include/zenhttp/auth/oidc.h1
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h11
-rw-r--r--src/zenhttp/security/passwordsecurity.cpp18
-rw-r--r--src/zenhttp/security/passwordsecurityfilter.cpp38
-rw-r--r--src/zenhttp/servers/wsasio.cpp73
-rw-r--r--src/zenhttp/servers/wsasio.h6
-rw-r--r--src/zenhttp/servers/wsframecodec.cpp63
-rw-r--r--src/zenhttp/servers/wsframecodec.h23
-rw-r--r--src/zenhttp/servers/wshttpsys.cpp70
-rw-r--r--src/zenhttp/servers/wshttpsys.h6
-rw-r--r--src/zenhttp/servers/wstest.cpp367
-rw-r--r--src/zenhttp/zenhttp.cpp2
-rw-r--r--src/zenhttp/zipfs.cpp111
-rw-r--r--src/zenhttp/zipfs_test.cpp144
-rw-r--r--src/zenserver/frontend/frontend.cpp36
-rw-r--r--src/zenserver/sessions/httpsessions.cpp26
-rw-r--r--src/zenserver/storage/admin/admin.cpp106
-rw-r--r--src/zenserver/storage/buildstore/httpbuildstore.cpp3
-rw-r--r--src/zenserver/storage/cache/httpstructuredcache.cpp31
-rw-r--r--src/zenserver/storage/objectstore/objectstore.cpp61
-rw-r--r--src/zenserver/storage/projectstore/httpprojectstore.cpp31
-rw-r--r--src/zenserver/storage/workspaces/httpworkspaces.cpp62
-rw-r--r--src/zenstore/include/zenstore/cache/structuredcachestore.h1
-rw-r--r--src/zenutil/include/zenutil/process/subprocessmanager.h2
-rw-r--r--src/zenutil/process/asyncpipereader.cpp48
-rw-r--r--src/zenutil/process/asyncpipereader.h4
-rw-r--r--src/zenutil/process/exitwatcher.cpp61
-rw-r--r--src/zenutil/process/exitwatcher.h4
-rw-r--r--src/zenutil/process/subprocessmanager.cpp64
44 files changed, 2785 insertions, 312 deletions
diff --git a/src/zencore/basicfile.cpp b/src/zencore/basicfile.cpp
index fdf742261..01d550957 100644
--- a/src/zencore/basicfile.cpp
+++ b/src/zencore/basicfile.cpp
@@ -798,11 +798,12 @@ BasicFileWriter::Write(const void* Data, uint64_t Size, uint64_t FileOffset)
{
if (m_Buffer == nullptr || (Size >= m_BufferSize))
{
- if (FileOffset == m_BufferEnd)
- {
- Flush();
- m_BufferStart = m_BufferEnd = FileOffset + Size;
- }
+ // Always flush pending buffered data first. Otherwise a later
+ // Flush() would replay stale bytes at m_BufferStart, clobbering
+ // any range of this direct write that overlaps
+ // [m_BufferStart, m_BufferEnd).
+ Flush();
+ m_BufferStart = m_BufferEnd = FileOffset + Size;
m_Base.Write(Data, Size, FileOffset);
return;
@@ -1200,6 +1201,60 @@ TEST_CASE("BasicFileBuffer")
}
}
+TEST_CASE("BasicFileWriter.LargeDiscontinuousWriteFlushesBuffer")
+{
+ // Regression: BasicFileWriter::Write's large-write branch used to skip
+ // Flush() whenever FileOffset != m_BufferEnd. Any subsequent Flush
+ // (including the one in ~BasicFileWriter) then replayed the stale
+ // buffered bytes at their original offset, clobbering whatever the
+ // caller had just written directly.
+ ScopedCurrentDirectoryChange _;
+
+ constexpr uint64_t BufferSize = 64;
+ constexpr uint64_t SmallSize = 10;
+ constexpr uint64_t LargeSize = 1024; // >= BufferSize, forces the direct-write path
+ constexpr uint64_t LargeOffset = 5; // overlaps the pending small-write region
+ constexpr uint64_t OverlapStart = LargeOffset;
+ constexpr uint64_t OverlapEnd = SmallSize;
+
+ {
+ BasicFile File;
+ File.Open("discontig_write", BasicFile::Mode::kTruncate);
+ BasicFileWriter Writer(File, BufferSize);
+
+ // First: small write buffered at [0, SmallSize) - still pending flush.
+ std::vector<uint8_t> Small(SmallSize, 'A');
+ Writer.Write(Small.data(), Small.size(), 0);
+
+ // Second: large write that overlaps the pending buffered region.
+ // Last-writer-wins means the overlap bytes must end up as 'B'.
+ std::vector<uint8_t> Large(LargeSize, 'B');
+ Writer.Write(Large.data(), Large.size(), LargeOffset);
+ }
+
+ BasicFile Reader;
+ Reader.Open("discontig_write", BasicFile::Mode::kRead);
+ IoBuffer Contents = Reader.ReadAll();
+ REQUIRE_EQ(Contents.Size(), LargeOffset + LargeSize);
+
+ const uint8_t* Bytes = reinterpret_cast<const uint8_t*>(Contents.Data());
+ for (uint64_t I = 0; I < OverlapStart; ++I)
+ {
+ CHECK_EQ(Bytes[I], 'A');
+ }
+ // The bytes in the overlap range [OverlapStart, OverlapEnd) are the
+ // critical check: with the bug, the late Flush() replayed the small
+ // 'A' write and clobbered these with 'A'.
+ for (uint64_t I = OverlapStart; I < OverlapEnd; ++I)
+ {
+ CHECK_EQ(Bytes[I], 'B');
+ }
+ for (uint64_t I = OverlapEnd; I < LargeOffset + LargeSize; ++I)
+ {
+ CHECK_EQ(Bytes[I], 'B');
+ }
+}
+
TEST_SUITE_END();
void
diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp
index ed2b16384..1dee8477f 100644
--- a/src/zencore/compositebuffer.cpp
+++ b/src/zencore/compositebuffer.cpp
@@ -179,12 +179,11 @@ CompositeBuffer::GetIterator(uint64_t Offset) const
MemoryView
CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const
{
- // We use a sub range IoBuffer when we want to copy data from a segment.
- // This means we will only materialize that range of the segment when doing
- // GetView() rather than the full segment.
- // A hot path for this code is when we call CompressedBuffer::FromCompressed which
- // is only interested in reading the header (first 64 bytes or so) and then throws
- // away the materialized data.
+ // A hot path for this code is CompressedBuffer::FromCompressed, which only reads the header
+ // (first 64 bytes or so). For plain memory segments we take a direct view (no allocation);
+ // for extended (file-backed) segments we materialize only the requested slice via a
+ // sub-range IoBuffer whose lifetime must extend across the CopyFrom below — otherwise its
+ // view would dangle into freed memory the moment the IoBuffer goes out of scope.
if (CopyBuffer.GetSize() < Size)
{
CopyBuffer = UniqueBuffer::Alloc(Size);
@@ -198,9 +197,20 @@ CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& Copy
const SharedBuffer& Segment = m_Segments[It.SegmentIndex];
size_t SegmentSize = Segment.GetSize();
size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft);
- IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
- MemoryView ReadView = SubSegment.GetView();
- WriteView = WriteView.CopyFrom(ReadView);
+
+ IoBuffer SubSegment; // lifetime holder for the extended-segment view
+ MemoryView ReadView;
+ if (Segment.IsExtended())
+ {
+ SubSegment = IoBuffer(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
+ ReadView = SubSegment.GetView();
+ }
+ else
+ {
+ ReadView = Segment.GetView().Mid(It.OffsetInSegment, CopySize);
+ }
+ WriteView = WriteView.CopyFrom(ReadView);
+
It.OffsetInSegment += CopySize;
ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
if (It.OffsetInSegment == SegmentSize)
@@ -216,12 +226,7 @@ CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& Copy
void
CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const
{
- // We use a sub range IoBuffer when we want to copy data from a segment.
- // This means we will only materialize that range of the segment when doing
- // GetView() rather than the full segment.
- // A hot path for this code is when we call CompressedBuffer::FromCompressed which
- // is only interested in reading the header (first 64 bytes or so) and then throws
- // away the materialized data.
+ // See ViewOrCopyRange above for rationale on the extended vs. plain segment split.
size_t SizeLeft = WriteView.GetSize();
size_t SegmentCount = m_Segments.size();
@@ -231,9 +236,20 @@ CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const
const SharedBuffer& Segment = m_Segments[It.SegmentIndex];
size_t SegmentSize = Segment.GetSize();
size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft);
- IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
- MemoryView ReadView = SubSegment.GetView();
- WriteView = WriteView.CopyFrom(ReadView);
+
+ IoBuffer SubSegment; // lifetime holder for the extended-segment view
+ MemoryView ReadView;
+ if (Segment.IsExtended())
+ {
+ SubSegment = IoBuffer(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
+ ReadView = SubSegment.GetView();
+ }
+ else
+ {
+ ReadView = Segment.GetView().Mid(It.OffsetInSegment, CopySize);
+ }
+ WriteView = WriteView.CopyFrom(ReadView);
+
It.OffsetInSegment += CopySize;
ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
if (It.OffsetInSegment == SegmentSize)
diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp
index 6aa0adce0..a0e91f908 100644
--- a/src/zencore/compress.cpp
+++ b/src/zencore/compress.cpp
@@ -78,12 +78,20 @@ struct BufferHeader
BufferHeader Header;
if (sizeof(BufferHeader) <= CompressedData.GetSize())
{
- // if (CompressedData.GetSegments()[0].AsIoBuffer().IsWholeFile())
- // {
- // ZEN_ASSERT(true);
- // }
- CompositeBuffer::Iterator It;
- CompressedData.CopyTo(MakeMutableMemoryView(&Header, &Header + 1), It);
+ // Fast path: the overwhelmingly common case is that the 64-byte header sits entirely
+ // within the first segment and that segment is plain memory. Skip the iterator and
+ // the sub-range IoBuffer wrapper (which would otherwise heap-allocate an IoBufferCore).
+ const std::span<const SharedBuffer> Segments = CompressedData.GetSegments();
+ const SharedBuffer& First = Segments.front();
+ if (sizeof(BufferHeader) <= First.GetSize() && !First.IsExtended())
+ {
+ MakeMutableMemoryView(&Header, &Header + 1).CopyFrom(First.GetView().Left(sizeof(BufferHeader)));
+ }
+ else
+ {
+ CompositeBuffer::Iterator It;
+ CompressedData.CopyTo(MakeMutableMemoryView(&Header, &Header + 1), It);
+ }
Header.ByteSwap();
}
return Header;
@@ -837,6 +845,15 @@ BlockDecoder::DecompressToStream(
{
return false;
}
+ // RawOffset+RawSize-1 below underflows when RawSize is 0, and the
+ // BlockCount-0 / BlockSize-0 arithmetic is only defined when the header
+ // has already been validated (see IsHeaderValid). Guard both here as
+ // defence in depth.
+ if (RawSize == 0 || Header.BlockCount == 0 || Header.BlockSizeExponent >= 32 || RawOffset > Header.TotalRawSize ||
+ RawSize > Header.TotalRawSize - RawOffset)
+ {
+ return false;
+ }
const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent;
@@ -1386,6 +1403,50 @@ GetDecoder(CompressionMethod Method)
}
}
+// Sanity-check a header that was just read from an untrusted buffer before
+// any of the decode arithmetic (1 << BlockSizeExponent, BlockCount*BlockSize,
+// divides by BlockSize, etc.) is performed. Must be called after the magic,
+// decoder and CRC checks pass.
+static bool
+IsHeaderValid(const BufferHeader& Header)
+{
+ // 1 << BlockSizeExponent is UB for Exponent >= 64 and wildly impractical
+ // below that. Real producers use <= 24 (16 MiB blocks); cap at 32 for
+ // headroom while staying well below the UB boundary.
+ if (Header.BlockSizeExponent >= 32)
+ {
+ return false;
+ }
+
+ // Only the block-based methods use BlockCount / BlockSizeExponent. The
+ // None method keeps them zero.
+ if (Header.Method != CompressionMethod::None)
+ {
+ // A non-empty buffer needs at least one block.
+ if (Header.BlockCount == 0)
+ {
+ return Header.TotalRawSize == 0;
+ }
+
+ const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent;
+
+ // BlockCount * BlockSize must not overflow and must fit TotalRawSize
+ // in the half-open range ((BlockCount - 1) * BlockSize, BlockCount * BlockSize].
+ if (Header.BlockCount > (UINT64_MAX / BlockSize))
+ {
+ return false;
+ }
+ const uint64_t MaxRawSize = uint64_t(Header.BlockCount) * BlockSize;
+ const uint64_t MinRawSize = MaxRawSize - BlockSize;
+ if (Header.TotalRawSize > MaxRawSize || Header.TotalRawSize <= MinRawSize)
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
//////////////////////////////////////////////////////////////////////////
bool
@@ -1426,6 +1487,10 @@ ReadHeader(const CompositeBuffer& CompressedData, BufferHeader& OutHeader, Uniqu
{
return false;
}
+ if (!IsHeaderValid(OutHeader))
+ {
+ return false;
+ }
uint64_t FullHeaderSize = Decoder->GetHeaderSize(OutHeader);
if (FullHeaderSize > CompressedDataSize)
{
@@ -1520,7 +1585,7 @@ TryReadHeader(DecoderContext& Context, Archive& Ar, FHeader& OutHeader, MemoryVi
FHeader* const HeaderCopy = static_cast<FHeader*>(HeaderView.GetData());
HeaderCopy->ByteSwap();
- if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView))
+ if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView) && IsHeaderValid(Header))
{
Context.HeaderOffset = uint64_t(Offset);
Context.HeaderSize = HeaderSize;
@@ -1560,7 +1625,7 @@ TryReadHeader(DecoderContext& Context, const CompositeBuffer& Buffer, FHeader& O
const MemoryView HeaderView = Buffer.ViewOrCopyRange(0, HeaderSize, Context.Header, [](uint64_t Size) {
return UniqueBuffer::Alloc(zen::Max(NextPow2(Size), DefaultHeaderSize));
});
- if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView))
+ if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView) && IsHeaderValid(Header))
{
Context.HeaderOffset = 0;
Context.HeaderSize = HeaderSize;
diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp
index 9984f35ac..8a172de3a 100644
--- a/src/zencore/crypto.cpp
+++ b/src/zencore/crypto.cpp
@@ -432,6 +432,285 @@ namespace crypto {
return true;
}
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // AES-256-GCM backends
+ //
+
+#if ZEN_USE_MBEDTLS
+
+ MemoryView GcmTransform(TransformMode Mode,
+ const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView In,
+ MutableMemoryView Out,
+ MutableMemoryView TagOut, // only used for Encrypt
+ MemoryView TagIn, // only used for Decrypt
+ std::optional<std::string>& Reason)
+ {
+ Reason = "AES-GCM is not implemented on the mbedTLS backend"sv;
+ (void)Mode;
+ (void)Key;
+ (void)Nonce;
+ (void)Aad;
+ (void)In;
+ (void)Out;
+ (void)TagOut;
+ (void)TagIn;
+ return MemoryView();
+ }
+
+#elif ZEN_USE_OPENSSL
+
+ MemoryView GcmTransform(TransformMode Mode,
+ const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView In,
+ MutableMemoryView Out,
+ MutableMemoryView TagOut,
+ MemoryView TagIn,
+ std::optional<std::string>& Reason)
+ {
+ EvpContext Ctx;
+
+ const EVP_CIPHER* Cipher = EVP_aes_256_gcm();
+ ZEN_ASSERT(Cipher != nullptr);
+
+ const bool Encrypting = (Mode == TransformMode::Encrypt);
+
+ if (EVP_CipherInit_ex(Ctx, Cipher, nullptr, nullptr, nullptr, Encrypting ? 1 : 0) != 1)
+ {
+ Reason = "EVP_CipherInit_ex (algo) failed"sv;
+ return {};
+ }
+
+ // Explicitly set IV length to 12; the default is 12 for GCM but pinning
+ // it prevents any surprise from an OpenSSL build that defaults
+ // differently.
+ if (EVP_CIPHER_CTX_ctrl(Ctx, EVP_CTRL_AEAD_SET_IVLEN, (int)AesGcm::NonceSize, nullptr) != 1)
+ {
+ Reason = "EVP_CTRL_AEAD_SET_IVLEN failed"sv;
+ return {};
+ }
+
+ if (EVP_CipherInit_ex(Ctx,
+ nullptr,
+ nullptr,
+ reinterpret_cast<const unsigned char*>(Key.GetView().GetData()),
+ reinterpret_cast<const unsigned char*>(Nonce.GetData()),
+ Encrypting ? 1 : 0) != 1)
+ {
+ Reason = "EVP_CipherInit_ex (key+nonce) failed"sv;
+ return {};
+ }
+
+ // Feed AAD (if any) before the ciphertext/plaintext.
+ if (!Aad.IsEmpty())
+ {
+ int AadOutLen = 0;
+ if (EVP_CipherUpdate(Ctx,
+ nullptr,
+ &AadOutLen,
+ reinterpret_cast<const unsigned char*>(Aad.GetData()),
+ static_cast<int>(Aad.GetSize())) != 1)
+ {
+ Reason = "EVP_CipherUpdate (AAD) failed"sv;
+ return {};
+ }
+ }
+
+ // For decrypt, set the expected tag before calling Final so the tag
+ // check can fail cleanly.
+ if (!Encrypting)
+ {
+ if (EVP_CIPHER_CTX_ctrl(Ctx, EVP_CTRL_AEAD_SET_TAG, (int)TagIn.GetSize(), (void*)TagIn.GetData()) != 1)
+ {
+ Reason = "EVP_CTRL_AEAD_SET_TAG failed"sv;
+ return {};
+ }
+ }
+
+ int BodyLen = 0;
+ if (EVP_CipherUpdate(Ctx,
+ reinterpret_cast<unsigned char*>(Out.GetData()),
+ &BodyLen,
+ reinterpret_cast<const unsigned char*>(In.GetData()),
+ static_cast<int>(In.GetSize())) != 1)
+ {
+ Reason = "EVP_CipherUpdate (body) failed"sv;
+ return {};
+ }
+
+ int FinalLen = 0;
+ if (EVP_CipherFinal_ex(Ctx, reinterpret_cast<unsigned char*>(Out.GetData()) + BodyLen, &FinalLen) != 1)
+ {
+ // For decrypt, this is the authentication-tag mismatch path.
+ Reason = Encrypting ? std::string("EVP_CipherFinal_ex (encrypt) failed") : std::string("AES-GCM authentication tag mismatch");
+ return {};
+ }
+
+ if (Encrypting)
+ {
+ if (EVP_CIPHER_CTX_ctrl(Ctx, EVP_CTRL_AEAD_GET_TAG, (int)TagOut.GetSize(), TagOut.GetData()) != 1)
+ {
+ Reason = "EVP_CTRL_AEAD_GET_TAG failed"sv;
+ return {};
+ }
+ }
+
+ return Out.Left(static_cast<size_t>(BodyLen + FinalLen));
+ }
+
+#else // ZEN_USE_BCRYPT
+
+ MemoryView GcmTransform(TransformMode Mode,
+ const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView In,
+ MutableMemoryView Out,
+ MutableMemoryView TagOut,
+ MemoryView TagIn,
+ std::optional<std::string>& Reason)
+ {
+ BCRYPT_ALG_HANDLE hAlg = nullptr;
+ NTSTATUS Status = STATUS_UNSUCCESSFUL;
+
+ if (!NT_SUCCESS(Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0)))
+ {
+ Reason = fmt::format("BCryptOpenAlgorithmProvider failed, 0x{:08x}"sv, Status);
+ return {};
+ }
+ auto CloseAlg = MakeGuard([hAlg] { BCryptCloseAlgorithmProvider(hAlg, 0); });
+
+ if (!NT_SUCCESS(Status =
+ BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0)))
+ {
+ Reason = fmt::format("BCryptSetProperty(CHAIN_MODE_GCM) failed, 0x{:08x}"sv, Status);
+ return {};
+ }
+
+ BCRYPT_KEY_HANDLE hKey = nullptr;
+ if (!NT_SUCCESS(Status = BCryptGenerateSymmetricKey(hAlg,
+ &hKey,
+ nullptr,
+ 0,
+ (PUCHAR)Key.GetView().GetData(),
+ (ULONG)Key.GetView().GetSize(),
+ 0)))
+ {
+ Reason = fmt::format("BCryptGenerateSymmetricKey failed, 0x{:08x}"sv, Status);
+ return {};
+ }
+ auto CloseKey = MakeGuard([hKey] { BCryptDestroyKey(hKey); });
+
+ BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo;
+ BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo);
+ AuthInfo.pbNonce = (PUCHAR)Nonce.GetData();
+ AuthInfo.cbNonce = (ULONG)Nonce.GetSize();
+ AuthInfo.pbAuthData = Aad.IsEmpty() ? nullptr : (PUCHAR)Aad.GetData();
+ AuthInfo.cbAuthData = (ULONG)Aad.GetSize();
+
+ ULONG ResultLen = 0;
+
+ if (Mode == TransformMode::Encrypt)
+ {
+ AuthInfo.pbTag = (PUCHAR)TagOut.GetData();
+ AuthInfo.cbTag = (ULONG)TagOut.GetSize();
+
+ Status = BCryptEncrypt(hKey,
+ (PUCHAR)In.GetData(),
+ (ULONG)In.GetSize(),
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out.GetData(),
+ (ULONG)Out.GetSize(),
+ &ResultLen,
+ /*flags=*/0);
+
+ if (!NT_SUCCESS(Status))
+ {
+ Reason = fmt::format("BCryptEncrypt (GCM) failed, 0x{:08x}"sv, Status);
+ return {};
+ }
+ }
+ else
+ {
+ AuthInfo.pbTag = (PUCHAR)TagIn.GetData();
+ AuthInfo.cbTag = (ULONG)TagIn.GetSize();
+
+ Status = BCryptDecrypt(hKey,
+ (PUCHAR)In.GetData(),
+ (ULONG)In.GetSize(),
+ &AuthInfo,
+ nullptr,
+ 0,
+ (PUCHAR)Out.GetData(),
+ (ULONG)Out.GetSize(),
+ &ResultLen,
+ /*flags=*/0);
+
+ if (!NT_SUCCESS(Status))
+ {
+ // STATUS_AUTH_TAG_MISMATCH (0xC000A002) is the tag-failure path.
+ Reason = fmt::format("BCryptDecrypt (GCM) failed, 0x{:08x}"sv, Status);
+ return {};
+ }
+ }
+
+ return Out.Left(ResultLen);
+ }
+
+#endif // backend selection
+
+ MemoryView Gcm(TransformMode Mode,
+ const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView In,
+ MutableMemoryView Out,
+ MutableMemoryView TagOut,
+ MemoryView TagIn,
+ std::optional<std::string>& Reason)
+ {
+ if (!Key.IsValid())
+ {
+ Reason = "invalid key"sv;
+ return {};
+ }
+ if (Nonce.GetSize() != AesGcm::NonceSize)
+ {
+ Reason = fmt::format("AES-GCM nonce must be exactly {} bytes"sv, AesGcm::NonceSize);
+ return {};
+ }
+ if (Out.GetSize() < In.GetSize())
+ {
+ Reason = "AES-GCM output buffer is too small"sv;
+ return {};
+ }
+ if (Mode == TransformMode::Encrypt)
+ {
+ if (TagOut.GetSize() != AesGcm::TagSize)
+ {
+ Reason = fmt::format("AES-GCM tag output must be exactly {} bytes"sv, AesGcm::TagSize);
+ return {};
+ }
+ }
+ else
+ {
+ if (TagIn.GetSize() != AesGcm::TagSize)
+ {
+ Reason = fmt::format("AES-GCM tag input must be exactly {} bytes"sv, AesGcm::TagSize);
+ return {};
+ }
+ }
+
+ return GcmTransform(Mode, Key, Nonce, Aad, In, Out, TagOut, TagIn, Reason);
+ }
+
} // namespace crypto
bool
@@ -741,6 +1020,78 @@ Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, Muta
return crypto::Transform(crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason);
}
+MemoryView
+AesGcm::Encrypt(const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView Plaintext,
+ MutableMemoryView Out,
+ MutableMemoryView OutTag,
+ std::optional<std::string>& Reason)
+{
+ return crypto::Gcm(crypto::TransformMode::Encrypt, Key, Nonce, Aad, Plaintext, Out, OutTag, /*TagIn*/ {}, Reason);
+}
+
+MemoryView
+AesGcm::Decrypt(const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView Ciphertext,
+ MemoryView Tag,
+ MutableMemoryView Out,
+ std::optional<std::string>& Reason)
+{
+ return crypto::Gcm(crypto::TransformMode::Decrypt, Key, Nonce, Aad, Ciphertext, Out, /*TagOut*/ {}, Tag, Reason);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// CryptoRandom
+//
+
+bool
+CryptoRandom::Fill(MutableMemoryView Buffer, std::optional<std::string>* Reason)
+{
+ if (Buffer.GetSize() == 0)
+ {
+ return true;
+ }
+
+ auto SetReason = [&](std::string Msg) {
+ if (Reason)
+ {
+ *Reason = std::move(Msg);
+ }
+ };
+
+#if ZEN_USE_BCRYPT
+ // BCRYPT_USE_SYSTEM_PREFERRED_RNG draws from the OS CSPRNG without
+ // requiring the caller to manage an algorithm handle.
+ const NTSTATUS Status = BCryptGenRandom(nullptr, (PUCHAR)Buffer.GetData(), (ULONG)Buffer.GetSize(), BCRYPT_USE_SYSTEM_PREFERRED_RNG);
+ if (!NT_SUCCESS(Status))
+ {
+ SetReason(fmt::format("BCryptGenRandom failed, 0x{:08x}", static_cast<uint32_t>(Status)));
+ return false;
+ }
+ return true;
+#elif ZEN_USE_OPENSSL
+ // RAND_bytes returns 1 on success, 0 on failure, -1 if not supported.
+ const int Rc = RAND_bytes(reinterpret_cast<unsigned char*>(Buffer.GetData()), static_cast<int>(Buffer.GetSize()));
+ if (Rc != 1)
+ {
+ SetReason(fmt::format("RAND_bytes failed (rc={})", Rc));
+ return false;
+ }
+ return true;
+#else
+ // mbedTLS: no CSPRNG wired up here yet. Callers on this backend must
+ // provide their own random source until a proper wiring is added.
+ SetReason("CryptoRandom::Fill is not implemented on the mbedTLS backend");
+ (void)Buffer;
+ return false;
+#endif
+}
+
#if ZEN_WITH_TESTS
void
@@ -801,6 +1152,205 @@ TEST_CASE("crypto.aes")
}
}
+TEST_CASE("crypto.aesgcm")
+{
+ const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ const uint8_t NonceBytes[AesGcm::NonceSize] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
+ const MemoryView Nonce = MakeMemoryView(NonceBytes);
+
+ SUBCASE("round trip without AAD")
+ {
+ std::string_view Plain = "The quick brown fox jumps over the lazy dog"sv;
+
+ std::vector<uint8_t> Cipher(Plain.size());
+ std::vector<uint8_t> Tag(AesGcm::TagSize);
+ std::optional<std::string> Reason;
+
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ Nonce,
+ /*Aad*/ {},
+ MakeMemoryView(Plain),
+ MakeMutableMemoryView(Cipher),
+ MakeMutableMemoryView(Tag),
+ Reason);
+ REQUIRE(!Reason.has_value());
+ CHECK_EQ(CipherView.GetSize(), Plain.size());
+
+ std::vector<uint8_t> Decoded(Plain.size());
+ MemoryView DecodedView =
+ AesGcm::Decrypt(Key, Nonce, /*Aad*/ {}, CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason);
+ REQUIRE(!Reason.has_value());
+ CHECK_EQ(DecodedView.GetSize(), Plain.size());
+
+ std::string_view DecodedText(reinterpret_cast<const char*>(DecodedView.GetData()), DecodedView.GetSize());
+ CHECK_EQ(DecodedText, Plain);
+ }
+
+ SUBCASE("round trip with AAD")
+ {
+ std::string_view Plain = "payload"sv;
+ std::string_view Aad = "header bits that are authenticated but not encrypted"sv;
+
+ std::vector<uint8_t> Cipher(Plain.size());
+ std::vector<uint8_t> Tag(AesGcm::TagSize);
+ std::optional<std::string> Reason;
+
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ Nonce,
+ MakeMemoryView(Aad),
+ MakeMemoryView(Plain),
+ MakeMutableMemoryView(Cipher),
+ MakeMutableMemoryView(Tag),
+ Reason);
+ REQUIRE(!Reason.has_value());
+
+ std::vector<uint8_t> Decoded(Plain.size());
+ MemoryView DecodedView =
+ AesGcm::Decrypt(Key, Nonce, MakeMemoryView(Aad), CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason);
+ REQUIRE(!Reason.has_value());
+ CHECK_EQ(DecodedView.GetSize(), Plain.size());
+ }
+
+ SUBCASE("tampered ciphertext fails authentication")
+ {
+ std::string_view Plain = "important"sv;
+
+ std::vector<uint8_t> Cipher(Plain.size());
+ std::vector<uint8_t> Tag(AesGcm::TagSize);
+ std::optional<std::string> Reason;
+
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ Nonce,
+ /*Aad*/ {},
+ MakeMemoryView(Plain),
+ MakeMutableMemoryView(Cipher),
+ MakeMutableMemoryView(Tag),
+ Reason);
+ REQUIRE(!Reason.has_value());
+
+ // Flip a bit in the ciphertext.
+ Cipher[0] ^= 0x01;
+
+ std::vector<uint8_t> Decoded(Plain.size());
+ MemoryView DecodedView =
+ AesGcm::Decrypt(Key, Nonce, /*Aad*/ {}, CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason);
+ CHECK(Reason.has_value());
+ CHECK(DecodedView.IsEmpty());
+ }
+
+ SUBCASE("tampered tag fails authentication")
+ {
+ std::string_view Plain = "important"sv;
+
+ std::vector<uint8_t> Cipher(Plain.size());
+ std::vector<uint8_t> Tag(AesGcm::TagSize);
+ std::optional<std::string> Reason;
+
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ Nonce,
+ /*Aad*/ {},
+ MakeMemoryView(Plain),
+ MakeMutableMemoryView(Cipher),
+ MakeMutableMemoryView(Tag),
+ Reason);
+ REQUIRE(!Reason.has_value());
+
+ Tag[0] ^= 0x80;
+
+ std::vector<uint8_t> Decoded(Plain.size());
+ MemoryView DecodedView =
+ AesGcm::Decrypt(Key, Nonce, /*Aad*/ {}, CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason);
+ CHECK(Reason.has_value());
+ CHECK(DecodedView.IsEmpty());
+ }
+
+ SUBCASE("AAD mismatch fails authentication")
+ {
+ std::string_view Plain = "payload"sv;
+ std::string_view AadOk = "expected header"sv;
+ std::string_view AadNo = "different header"sv;
+
+ std::vector<uint8_t> Cipher(Plain.size());
+ std::vector<uint8_t> Tag(AesGcm::TagSize);
+ std::optional<std::string> Reason;
+
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ Nonce,
+ MakeMemoryView(AadOk),
+ MakeMemoryView(Plain),
+ MakeMutableMemoryView(Cipher),
+ MakeMutableMemoryView(Tag),
+ Reason);
+ REQUIRE(!Reason.has_value());
+
+ std::vector<uint8_t> Decoded(Plain.size());
+ MemoryView DecodedView =
+ AesGcm::Decrypt(Key, Nonce, MakeMemoryView(AadNo), CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason);
+ CHECK(Reason.has_value());
+ CHECK(DecodedView.IsEmpty());
+ }
+
+ SUBCASE("wrong nonce size is rejected")
+ {
+ std::string_view Plain = "x"sv;
+
+ const uint8_t TooShort[8] = {0};
+ std::vector<uint8_t> Cipher(Plain.size());
+ std::vector<uint8_t> Tag(AesGcm::TagSize);
+ std::optional<std::string> Reason;
+
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ MakeMemoryView(TooShort),
+ /*Aad*/ {},
+ MakeMemoryView(Plain),
+ MakeMutableMemoryView(Cipher),
+ MakeMutableMemoryView(Tag),
+ Reason);
+ CHECK(Reason.has_value());
+ CHECK(CipherView.IsEmpty());
+ }
+}
+
+TEST_CASE("crypto.random")
+{
+ SUBCASE("fills buffer with non-zero bytes")
+ {
+ uint8_t Buffer[32] = {};
+ MutableMemoryView View = MakeMutableMemoryView(Buffer);
+ const bool Ok = CryptoRandom::Fill(View);
+ REQUIRE(Ok);
+
+ // Probability of 32 all-zero bytes from a CSPRNG is 2^-256 — we
+ // accept it as "effectively never".
+ bool AnyNonZero = false;
+ for (uint8_t B : Buffer)
+ {
+ if (B != 0)
+ {
+ AnyNonZero = true;
+ break;
+ }
+ }
+ CHECK(AnyNonZero);
+ }
+
+ SUBCASE("two calls produce different output")
+ {
+ uint8_t A[32] = {};
+ uint8_t B[32] = {};
+ CHECK(CryptoRandom::Fill(MakeMutableMemoryView(A)));
+ CHECK(CryptoRandom::Fill(MakeMutableMemoryView(B)));
+ CHECK(memcmp(A, B, 32) != 0);
+ }
+
+ SUBCASE("zero-size buffer is a no-op success")
+ {
+ uint8_t Dummy = 0xAB;
+ CHECK(CryptoRandom::Fill(MutableMemoryView(&Dummy, size_t{0})));
+ CHECK_EQ(Dummy, 0xAB);
+ }
+}
+
TEST_CASE("crypto.securerandom")
{
std::array<uint8_t, 64> A{};
diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp
index e8ceac5c0..1e18ef8be 100644
--- a/src/zencore/filesystem.cpp
+++ b/src/zencore/filesystem.cpp
@@ -181,7 +181,7 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles)
bool Success = true;
- if (hFind != nullptr)
+ if (hFind != INVALID_HANDLE_VALUE)
{
do
{
@@ -436,15 +436,18 @@ RemoveFileNative(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFile
if (!ForceRemoveReadOnlyFiles)
{
struct stat Stat;
- int err = stat(Path.native().c_str(), &Stat);
- if (err != 0)
+ if (stat(Path.native().c_str(), &Stat) != 0)
{
- int32_t err = errno;
- if (err == ENOENT)
+ const int StatErrno = errno;
+ if (StatErrno == ENOENT)
{
Ec.clear();
- return false;
}
+ else
+ {
+ Ec = MakeErrorCode(StatErrno);
+ }
+ return false;
}
const uint32_t Mode = (uint32_t)Stat.st_mode;
if (IsFileModeReadOnly(Mode))
@@ -3437,6 +3440,37 @@ MakeSafeAbsolutePath(const std::filesystem::path& Path)
return Tmp;
}
+std::optional<std::filesystem::path>
+ResolveSafeRelativePath(const std::filesystem::path& TrustedRoot, std::string_view RelativePath)
+{
+ if (RelativePath.empty())
+ {
+ return std::nullopt;
+ }
+
+ std::filesystem::path Requested(RelativePath);
+ if (Requested.is_absolute() || Requested.has_root_name() || Requested.has_root_directory())
+ {
+ return std::nullopt;
+ }
+ for (const std::filesystem::path& Component : Requested)
+ {
+ if (Component == "..")
+ {
+ return std::nullopt;
+ }
+ }
+
+ const std::filesystem::path NormalizedRoot = TrustedRoot.lexically_normal();
+ const std::filesystem::path Joined = (NormalizedRoot / Requested).lexically_normal();
+ if (std::mismatch(NormalizedRoot.begin(), NormalizedRoot.end(), Joined.begin(), Joined.end()).first != NormalizedRoot.end())
+ {
+ return std::nullopt;
+ }
+
+ return Joined;
+}
+
class SharedMemoryImpl : public SharedMemory
{
public:
@@ -4238,6 +4272,41 @@ TEST_CASE("filesystem.MakeSafeAbsolutePath")
# endif // ZEN_PLATFORM_WINDOWS
}
+TEST_CASE("filesystem.ResolveSafeRelativePath")
+{
+ const std::filesystem::path Root = std::filesystem::path("root") / "traces";
+
+ // Empty input is rejected.
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "").has_value());
+
+ // A plain relative path resolves under the root.
+ {
+ auto Resolved = ResolveSafeRelativePath(Root, "session.utrace");
+ REQUIRE(Resolved.has_value());
+ CHECK_EQ(*Resolved, (Root.lexically_normal() / "session.utrace"));
+ }
+
+ // Nested relative segments are allowed as long as they stay inside the root.
+ {
+ auto Resolved = ResolveSafeRelativePath(Root, "2026-04/session.utrace");
+ REQUIRE(Resolved.has_value());
+ CHECK_EQ(*Resolved, (Root.lexically_normal() / "2026-04" / "session.utrace"));
+ }
+
+ // ".." components are rejected before normalisation can collapse them.
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "..").has_value());
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "../etc/passwd").has_value());
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "foo/../../bar").has_value());
+
+ // Absolute paths are rejected on both platforms.
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "/etc/passwd").has_value());
+# if ZEN_PLATFORM_WINDOWS
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "C:/Windows/win.ini").has_value());
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "C:\\Windows\\win.ini").has_value());
+ CHECK_FALSE(ResolveSafeRelativePath(Root, "\\\\server\\share\\evil").has_value());
+# endif
+}
+
TEST_CASE("ExpandEnvironmentVariables")
{
// No variables - pass-through
diff --git a/src/zencore/include/zencore/crypto.h b/src/zencore/include/zencore/crypto.h
index a5e23135f..36bf80527 100644
--- a/src/zencore/include/zencore/crypto.h
+++ b/src/zencore/include/zencore/crypto.h
@@ -73,6 +73,81 @@ public:
std::optional<std::string>& Reason);
};
+/**
+ * AES-256-GCM authenticated encryption.
+ *
+ * GCM is an AEAD construction: the authentication tag proves both the
+ * ciphertext and any additional associated data (AAD) were not modified.
+ * Use this instead of plain Aes (CBC) whenever a persistent artefact
+ * could be tampered with by a local attacker — CBC has no integrity
+ * check and will happily decrypt bit-flipped ciphertext into corrupted
+ * plaintext.
+ *
+ * Nonce must be exactly 12 bytes (NIST SP 800-38D recommended size).
+ * Tag is always exactly 16 bytes.
+ *
+ * CRITICAL: never reuse a (key, nonce) pair. GCM is catastrophically
+ * broken under nonce reuse — both confidentiality and authenticity fail.
+ * For every encryption, the caller must supply a fresh nonce drawn from
+ * a CSPRNG (a 96-bit random nonce has negligible collision probability
+ * for any reasonable message volume) or from a strictly monotonic
+ * counter under a single key.
+ */
+class AesGcm
+{
+public:
+ static constexpr size_t NonceSize = 12; // bytes
+ static constexpr size_t TagSize = 16; // bytes
+
+ /**
+ * Encrypt Plaintext. Writes ciphertext of the same length as Plaintext
+ * into Out, and the 16-byte authentication tag into OutTag. Aad may be
+ * empty. Returns the ciphertext view on success, or an empty view with
+ * Reason set on failure.
+ */
+ static MemoryView Encrypt(const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView Plaintext,
+ MutableMemoryView Out,
+ MutableMemoryView OutTag,
+ std::optional<std::string>& Reason);
+
+ /**
+ * Decrypt Ciphertext, verifying against the supplied 16-byte Tag.
+ * Returns the plaintext view on success. On authentication failure or
+ * any other error, returns an empty view with Reason set; the Out
+ * buffer contents are undefined in that case.
+ */
+ static MemoryView Decrypt(const AesKey256Bit& Key,
+ MemoryView Nonce,
+ MemoryView Aad,
+ MemoryView Ciphertext,
+ MemoryView Tag,
+ MutableMemoryView Out,
+ std::optional<std::string>& Reason);
+};
+
+/**
+ * Cryptographically secure random-byte source.
+ *
+ * Uses the OS CSPRNG where available (BCryptGenRandom on Windows,
+ * RAND_bytes on OpenSSL). Suitable for generating keys, nonces, and
+ * any other value whose predictability matters — do NOT use
+ * std::mt19937 or std::random_device alone for these purposes.
+ */
+class CryptoRandom
+{
+public:
+ /**
+ * Fill Buffer with cryptographically random bytes. Returns true on
+ * success; on failure returns false and writes a human-readable cause
+ * to Reason if provided. On failure the buffer contents are
+ * unspecified.
+ */
+ static bool Fill(MutableMemoryView Buffer, std::optional<std::string>* Reason = nullptr);
+};
+
// Fill Out with cryptographically secure random bytes from the platform RNG.
// Returns false if the platform RNG failed; on success Out is filled entirely.
bool SecureRandomBytes(MutableMemoryView Out);
diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h
index 6218c3421..8f6209f79 100644
--- a/src/zencore/include/zencore/filesystem.h
+++ b/src/zencore/include/zencore/filesystem.h
@@ -451,6 +451,19 @@ bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly);
void MakeSafeAbsolutePathInPlace(std::filesystem::path& Path);
[[nodiscard]] std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path);
+/** Resolve an attacker-supplied relative path underneath a trusted root directory.
+ *
+ * Rejects absolute paths, drive-qualified paths, root-directory paths, and any
+ * component equal to "..". Lexically-normalises the joined path and verifies the
+ * result still lives under the normalised root. Intended for validating HTTP
+ * query-parameter paths before passing them to filesystem APIs.
+ *
+ * Returns nullopt if RelativePath is empty, contains rejected syntax, or escapes
+ * TrustedRoot after normalisation.
+ */
+[[nodiscard]] std::optional<std::filesystem::path> ResolveSafeRelativePath(const std::filesystem::path& TrustedRoot,
+ std::string_view RelativePath);
+
class SharedMemory
{
public:
diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h
index 3183c7c0c..0fed11458 100644
--- a/src/zencore/include/zencore/sharedbuffer.h
+++ b/src/zencore/include/zencore/sharedbuffer.h
@@ -145,6 +145,12 @@ public:
[[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer.Get()); }
+ /** Returns true if the segment is backed by an extended core (e.g. a file range) whose data
+ * is materialized on demand. For extended segments, sub-ranges should be taken via
+ * IoBuffer(outer, offset, size) to avoid materializing the whole segment. Plain memory
+ * segments (IsExtended() == false) can be sub-viewed directly with GetView().Mid(...). */
+ [[nodiscard]] inline bool IsExtended() const { return m_Buffer && m_Buffer->ExtendedCore() != nullptr; }
+
SharedBuffer& operator=(UniqueBuffer&& Rhs)
{
m_Buffer = std::move(Rhs.m_Buffer);
diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h
index fded960f3..53bfd196d 100644
--- a/src/zencore/include/zencore/string.h
+++ b/src/zencore/include/zencore/string.h
@@ -1058,6 +1058,34 @@ StrCaseEndsWith(std::string_view Str, std::string_view Suffix)
}
/**
+ * Constant-time string equality. Unlike std::string_view::operator==, which
+ * short-circuits on the first byte mismatch (and on size mismatch before
+ * inspecting any bytes), this helper always iterates over max(Lhs.size(),
+ * Rhs.size()) bytes and ORs every byte difference into an accumulator. The
+ * length-compare is folded into the final check so that neither length
+ * mismatches nor content differences produce a timing signal an attacker
+ * can observe remotely.
+ *
+ * Use for comparisons of secret-bearing data (passwords, bearer tokens,
+ * HMAC outputs) against an attacker-supplied value.
+ */
+inline bool
+ConstantTimeEquals(std::string_view Lhs, std::string_view Rhs)
+{
+ const size_t N = std::max(Lhs.size(), Rhs.size());
+ uint32_t Accum = static_cast<uint32_t>(Lhs.size() ^ Rhs.size());
+ const size_t LhsLen = Lhs.size();
+ const size_t RhsLen = Rhs.size();
+ for (size_t I = 0; I < N; ++I)
+ {
+ const uint8_t LByte = (I < LhsLen) ? static_cast<uint8_t>(Lhs[I]) : 0;
+ const uint8_t RByte = (I < RhsLen) ? static_cast<uint8_t>(Rhs[I]) : 0;
+ Accum |= uint32_t(LByte ^ RByte);
+ }
+ return Accum == 0;
+}
+
+/**
* @brief
* Helper function to implement case sensitive spaceship operator for strings.
* MacOS clang version we use does not implement <=> for std::string
diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp
index 529afe341..d1aa0bb9e 100644
--- a/src/zencore/iobuffer.cpp
+++ b/src/zencore/iobuffer.cpp
@@ -10,6 +10,7 @@
#include <zencore/memory/llm.h>
#include <zencore/memory/memory.h>
#include <zencore/memoryview.h>
+#include <zencore/scopeguard.h>
#include <zencore/testing.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
@@ -375,7 +376,14 @@ IoBufferExtendedCore::Materialize() const
/* offset */ MapOffset);
#endif // ZEN_PLATFORM_WINDOWS
- if (MappedBase == nullptr)
+#if ZEN_PLATFORM_WINDOWS
+ const bool MapFailed = (MappedBase == nullptr);
+#else
+ // mmap returns MAP_FAILED (not nullptr) on failure
+ const bool MapFailed = (MappedBase == MAP_FAILED);
+#endif
+
+ if (MapFailed)
{
int32_t Error = zen::GetLastError();
#if ZEN_PLATFORM_WINDOWS
@@ -540,32 +548,54 @@ IoBufferBuilder::ReadFromFileMaybe(const IoBuffer& InBuffer)
const uint64_t NumberOfBytesToRead = FileRef.FileChunkSize;
const uint64_t FileOffset = FileRef.FileChunkOffset;
+ uint8_t* const Dst = reinterpret_cast<uint8_t*>(OutBuffer.MutableData());
+
#if ZEN_PLATFORM_WINDOWS
- OVERLAPPED Ovl{};
+ while (BytesRead < NumberOfBytesToRead)
+ {
+ const uint64_t Remaining = NumberOfBytesToRead - BytesRead;
+ const uint64_t ChunkStart = FileOffset + BytesRead;
+ const DWORD ChunkSize = DWORD(Min<uint64_t>(Remaining, 0xffff'ffffu));
- Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu);
- Ovl.OffsetHigh = DWORD(FileOffset >> 32);
+ OVERLAPPED Ovl{};
+ Ovl.Offset = DWORD(ChunkStart & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(ChunkStart >> 32);
- DWORD dwNumberOfBytesRead = 0;
- BOOL Success = ::ReadFile(FileRef.FileHandle, OutBuffer.MutableData(), DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl);
- if (Success)
- {
- BytesRead = size_t(dwNumberOfBytesRead);
- }
- else
- {
- Error = zen::GetLastError();
+ DWORD dwBytesRead = 0;
+ if (!::ReadFile(FileRef.FileHandle, Dst + BytesRead, ChunkSize, &dwBytesRead, &Ovl))
+ {
+ Error = zen::GetLastError();
+ break;
+ }
+ if (dwBytesRead == 0)
+ {
+ // Hit EOF before we got everything we asked for
+ break;
+ }
+ BytesRead += size_t(dwBytesRead);
}
#else
int Fd = int(intptr_t(FileRef.FileHandle));
- ssize_t ReadResult = pread(Fd, OutBuffer.MutableData(), size_t(NumberOfBytesToRead), off_t(FileOffset));
- if (ReadResult != -1)
+ while (BytesRead < NumberOfBytesToRead)
{
- BytesRead = size_t(ReadResult);
- }
- else
- {
- Error = zen::GetLastError();
+ const size_t Remaining = size_t(NumberOfBytesToRead - BytesRead);
+ const off_t ChunkStart = off_t(FileOffset + BytesRead);
+ const ssize_t ReadResult = pread(Fd, Dst + BytesRead, Remaining, ChunkStart);
+ if (ReadResult < 0)
+ {
+ if (errno == EINTR)
+ {
+ continue;
+ }
+ Error = zen::GetLastError();
+ break;
+ }
+ if (ReadResult == 0)
+ {
+ // Hit EOF before we got everything we asked for
+ break;
+ }
+ BytesRead += size_t(ReadResult);
}
#endif
@@ -632,10 +662,19 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of
{
return {};
}
+ auto FdGuard = MakeGuard([&Fd] {
+ if (Fd >= 0)
+ {
+ close(Fd);
+ }
+ });
static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
struct stat Stat;
- fstat(Fd, &Stat);
+ if (fstat(Fd, &Stat) != 0)
+ {
+ return {};
+ }
FileSize = Stat.st_size;
#endif // ZEN_PLATFORM_WINDOWS
@@ -655,18 +694,20 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of
if (Size)
{
#if ZEN_PLATFORM_WINDOWS
- void* Fd = DataFile.Detach();
-#endif
+ void* Fd = DataFile.Detach();
+ auto HandleGuard = MakeGuard([Fd] { CloseHandle((HANDLE)Fd); });
+ IoBuffer NewBuffer(IoBuffer::File, Fd, Offset, Size, Offset == 0 && Size == FileSize);
+ HandleGuard.Dismiss();
+#else
IoBuffer NewBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize);
+ FdGuard.Dismiss();
+#endif
NewBuffer.SetContentType(ContentType);
return NewBuffer;
}
-#if !ZEN_PLATFORM_WINDOWS
- close(Fd);
-#endif
-
// For an empty file, we may as well just return an empty memory IoBuffer
+ // (FdGuard on non-Windows closes Fd automatically; DataFile destructor closes the HANDLE on Windows)
return IoBuffer(IoBuffer::Wrap, "", 0);
}
@@ -694,23 +735,33 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName, Ze
DataFile.GetSize((ULONGLONG&)FileSize);
- Handle = DataFile.Detach();
+ Handle = DataFile.Detach();
+ auto HandleGuard = MakeGuard([Handle] { CloseHandle((HANDLE)Handle); });
#else
int Fd = open(FileName.native().c_str(), O_RDONLY);
if (Fd < 0)
{
return {};
}
+ auto FdGuard = MakeGuard([Fd] { close(Fd); });
static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
struct stat Stat;
- fstat(Fd, &Stat);
+ if (fstat(Fd, &Stat) != 0)
+ {
+ return {};
+ }
FileSize = Stat.st_size;
Handle = (void*)uintptr_t(Fd);
#endif // ZEN_PLATFORM_WINDOWS
IoBuffer NewBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true);
+#if ZEN_PLATFORM_WINDOWS
+ HandleGuard.Dismiss();
+#else
+ FdGuard.Dismiss();
+#endif
NewBuffer.SetContentType(ContentType);
return NewBuffer;
}
diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp
index 2fa22f2c2..9d10b1ba6 100644
--- a/src/zenhttp/auth/authmgr.cpp
+++ b/src/zenhttp/auth/authmgr.cpp
@@ -12,6 +12,11 @@
#include <zencore/trace.h>
#include <zenhttp/auth/oidc.h>
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif
+
#include <condition_variable>
#include <memory>
#include <shared_mutex>
@@ -25,13 +30,37 @@ namespace zen {
using namespace std::literals;
namespace details {
+ // On-disk format for the GCM-encrypted authstate:
+ //
+ // [ 4 bytes: magic "ZEN1" ]
+ // [12 bytes: GCM nonce ] — fresh random value per write
+ // [ N bytes: ciphertext ] — same length as the plaintext
+ // [16 bytes: GCM tag ]
+ //
+ // The magic is bound into the authentication by passing it as AAD to
+ // AesGcm; that ties the tag to the exact header bytes, so an attacker
+ // cannot strip the format indicator or swap to a different format without
+ // failing authentication.
+ //
+ // The magic also doubles as the format version. If the framing ever
+ // needs to change (e.g. adding a key-identifier for rotation), bump to
+ // "ZEN2" and teach the reader to accept both.
+ constexpr std::string_view kAuthStateMagic = "ZEN1";
+ constexpr size_t kAuthStateHeaderSz = 4 + AesGcm::NonceSize;
+
IoBuffer ReadEncryptedFile(std::filesystem::path Path,
const AesKey256Bit& Key,
- const AesIV128Bit& IV,
- std::optional<std::string>& Reason)
+ const AesIV128Bit& LegacyIV,
+ std::optional<std::string>& Reason,
+ bool* OutWasLegacy = nullptr)
{
ZEN_TRACE_CPU("AuthMgr::ReadEncryptedFile");
+ if (OutWasLegacy)
+ {
+ *OutWasLegacy = false;
+ }
+
FileContents Result = ReadFile(Path);
if (Result.ErrorCode)
@@ -46,24 +75,52 @@ namespace details {
return IoBuffer();
}
+ // Current format: magic-prefixed AES-256-GCM
+ if (EncryptedBuffer.GetSize() >= kAuthStateHeaderSz + AesGcm::TagSize &&
+ memcmp(EncryptedBuffer.GetData(), kAuthStateMagic.data(), kAuthStateMagic.size()) == 0)
+ {
+ const uint8_t* Bytes = static_cast<const uint8_t*>(EncryptedBuffer.GetData());
+ const size_t Total = EncryptedBuffer.GetSize();
+ const size_t CipherSize = Total - kAuthStateHeaderSz - AesGcm::TagSize;
+
+ MemoryView Magic(Bytes, kAuthStateMagic.size());
+ MemoryView Nonce(Bytes + kAuthStateMagic.size(), AesGcm::NonceSize);
+ MemoryView Cipher(Bytes + kAuthStateHeaderSz, CipherSize);
+ MemoryView Tag(Bytes + kAuthStateHeaderSz + CipherSize, AesGcm::TagSize);
+
+ std::vector<uint8_t> PlainBuffer(CipherSize);
+ MemoryView PlainView = AesGcm::Decrypt(Key, Nonce, /*Aad=*/Magic, Cipher, Tag, MakeMutableMemoryView(PlainBuffer), Reason);
+
+ if (PlainView.IsEmpty())
+ {
+ return IoBuffer();
+ }
+
+ return IoBufferBuilder::MakeCloneFromMemory(PlainView);
+ }
+
+ // Legacy format: raw AES-256-CBC with the caller-configured IV.
+ // Decrypt it once so we don't lose the user's cached state across
+ // the upgrade; the next SaveState will rewrite in the GCM format.
std::vector<uint8_t> DecryptionBuffer;
DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize);
- MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason);
+ MemoryView DecryptedView = Aes::Decrypt(Key, LegacyIV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason);
if (DecryptedView.IsEmpty())
{
return IoBuffer();
}
+ if (OutWasLegacy)
+ {
+ *OutWasLegacy = true;
+ }
+
return IoBufferBuilder::MakeCloneFromMemory(DecryptedView);
}
- void WriteEncryptedFile(std::filesystem::path Path,
- IoBuffer FileData,
- const AesKey256Bit& Key,
- const AesIV128Bit& IV,
- std::optional<std::string>& Reason)
+ void WriteEncryptedFile(std::filesystem::path Path, IoBuffer FileData, const AesKey256Bit& Key, std::optional<std::string>& Reason)
{
ZEN_TRACE_CPU("AuthMgr::WriteEncryptedFile");
@@ -72,17 +129,44 @@ namespace details {
return;
}
- std::vector<uint8_t> EncryptionBuffer;
- EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize);
+ if (!Key.IsValid())
+ {
+ Reason = "invalid key";
+ return;
+ }
+
+ // Fresh nonce per write. Never reuse a (Key, Nonce) pair — GCM is
+ // catastrophically broken under nonce reuse.
+ uint8_t Nonce[AesGcm::NonceSize];
+ if (!CryptoRandom::Fill(MakeMutableMemoryView(Nonce), &Reason))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> FileBuffer;
+ FileBuffer.resize(kAuthStateHeaderSz + FileData.GetSize() + AesGcm::TagSize);
+
+ // [magic][nonce][cipher][tag]
+ memcpy(FileBuffer.data(), kAuthStateMagic.data(), kAuthStateMagic.size());
+ memcpy(FileBuffer.data() + kAuthStateMagic.size(), Nonce, AesGcm::NonceSize);
- MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason);
+ MutableMemoryView CipherOut(FileBuffer.data() + kAuthStateHeaderSz, FileData.GetSize());
+ MutableMemoryView TagOut(FileBuffer.data() + kAuthStateHeaderSz + FileData.GetSize(), AesGcm::TagSize);
- if (EncryptedView.IsEmpty())
+ MemoryView CipherView = AesGcm::Encrypt(Key,
+ MakeMemoryView(Nonce),
+ /*Aad=*/MakeMemoryView(kAuthStateMagic),
+ FileData.GetView(),
+ CipherOut,
+ TagOut,
+ Reason);
+
+ if (CipherView.IsEmpty())
{
return;
}
- TemporaryFile::SafeWriteFile(Path, EncryptedView);
+ TemporaryFile::SafeWriteFile(Path, MakeMemoryView(FileBuffer));
}
} // namespace details
@@ -191,10 +275,9 @@ public:
bool IsNew = false;
{
- auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
- .RefreshToken = RefreshResult.RefreshToken,
- .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
- .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+ auto Token = OpenIdToken{.RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
std::unique_lock _(m_TokenMutex);
@@ -240,7 +323,6 @@ private:
struct OpenIdToken
{
- std::string IdentityToken;
std::string RefreshToken;
std::string AccessToken;
TimePoint ExpireTime{};
@@ -283,9 +365,18 @@ private:
try
{
std::optional<std::string> Reason;
+ bool WasLegacy = false;
+
+ IoBuffer Buffer = details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv,
+ m_Config.EncryptionKey,
+ m_Config.EncryptionIV,
+ Reason,
+ &WasLegacy);
- IoBuffer Buffer =
- details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason);
+ if (Buffer && WasLegacy)
+ {
+ ZEN_INFO("authstate read via legacy AES-CBC fallback; next save will migrate to AES-GCM");
+ }
if (!Buffer)
{
@@ -399,7 +490,6 @@ private:
details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv,
AuthState.Save().GetBuffer().AsIoBuffer(),
m_Config.EncryptionKey,
- m_Config.EncryptionIV,
Reason);
if (Reason)
@@ -466,10 +556,9 @@ private:
{
ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first);
- auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
- .RefreshToken = RefreshResult.RefreshToken,
- .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
- .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+ auto Token = OpenIdToken{.RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
{
std::unique_lock _(m_TokenMutex);
@@ -534,4 +623,107 @@ AuthMgr::Create(const AuthConfig& Config)
return std::make_unique<AuthMgrImpl>(Config);
}
+#if ZEN_WITH_TESTS
+
+TEST_SUITE_BEGIN("http.authmgr");
+
+TEST_CASE("authmgr.authstate_gcm_roundtrip")
+{
+ ScopedTemporaryDirectory TmpDir;
+ std::filesystem::path Path = TmpDir.Path() / "authstate";
+
+ const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ const AesIV128Bit UnusedIv; // ignored on write, unused on GCM read
+
+ const std::string_view Plain = "sensitive compact-binary payload representing cached auth state"sv;
+ IoBuffer InBuf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Plain));
+
+ std::optional<std::string> WriteReason;
+ details::WriteEncryptedFile(Path, InBuf, Key, WriteReason);
+ REQUIRE_FALSE(WriteReason.has_value());
+
+ std::optional<std::string> ReadReason;
+ bool WasLegacy = true;
+ IoBuffer Out = details::ReadEncryptedFile(Path, Key, UnusedIv, ReadReason, &WasLegacy);
+ REQUIRE_FALSE(ReadReason.has_value());
+ REQUIRE(Out.GetSize() == Plain.size());
+ CHECK(memcmp(Out.GetData(), Plain.data(), Plain.size()) == 0);
+ CHECK_FALSE(WasLegacy);
+}
+
+TEST_CASE("authmgr.authstate_legacy_cbc_fallback")
+{
+ // Hand-craft a legacy-format authstate file (raw AES-CBC, no magic) and
+ // verify the reader falls back, returns the plaintext, and flags the file
+ // as legacy so the caller can rewrite it in the new format.
+ ScopedTemporaryDirectory TmpDir;
+ std::filesystem::path Path = TmpDir.Path() / "authstate";
+
+ const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ const uint8_t IvBytes[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+ const AesIV128Bit LegacyIv = AesIV128Bit::FromMemoryView(MakeMemoryView(IvBytes));
+
+ const std::string_view Plain = "legacy authstate payload"sv;
+
+ std::vector<uint8_t> CbcBuf(Plain.size() + Aes::BlockSize);
+ std::optional<std::string> CbcReason;
+ MemoryView CbcView = Aes::Encrypt(Key, LegacyIv, MakeMemoryView(Plain), MakeMutableMemoryView(CbcBuf), CbcReason);
+ REQUIRE_FALSE(CbcReason.has_value());
+
+ TemporaryFile::SafeWriteFile(Path, CbcView);
+
+ std::optional<std::string> ReadReason;
+ bool WasLegacy = false;
+ IoBuffer Out = details::ReadEncryptedFile(Path, Key, LegacyIv, ReadReason, &WasLegacy);
+ REQUIRE_FALSE(ReadReason.has_value());
+ REQUIRE(Out.GetSize() == Plain.size());
+ CHECK(memcmp(Out.GetData(), Plain.data(), Plain.size()) == 0);
+ CHECK(WasLegacy);
+}
+
+TEST_CASE("authmgr.authstate_gcm_tamper_detection")
+{
+ ScopedTemporaryDirectory TmpDir;
+ std::filesystem::path Path = TmpDir.Path() / "authstate";
+
+ const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ const AesIV128Bit UnusedIv;
+
+ const std::string_view Plain = "payload that should not decrypt after tampering"sv;
+ IoBuffer InBuf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Plain));
+
+ std::optional<std::string> WriteReason;
+ details::WriteEncryptedFile(Path, InBuf, Key, WriteReason);
+ REQUIRE_FALSE(WriteReason.has_value());
+
+ // Flip a single byte in the middle of the ciphertext region (past magic +
+ // nonce) and verify the GCM tag check rejects the tampered file. Copy
+ // the bytes out first and drop the FileContents / IoBuffer handles before
+ // rewriting so the underlying file isn't still open when we overwrite it.
+ std::vector<uint8_t> Mutated;
+ {
+ FileContents FC = ReadFile(Path);
+ IoBuffer Whole = FC.Flatten();
+ REQUIRE(Whole.GetSize() > 4 + AesGcm::NonceSize);
+ Mutated.assign(static_cast<const uint8_t*>(Whole.GetData()), static_cast<const uint8_t*>(Whole.GetData()) + Whole.GetSize());
+ }
+ Mutated[4 + AesGcm::NonceSize] ^= 0x40;
+ TemporaryFile::SafeWriteFile(Path, MakeMemoryView(Mutated));
+
+ std::optional<std::string> ReadReason;
+ bool WasLegacy = false;
+ IoBuffer Out = details::ReadEncryptedFile(Path, Key, UnusedIv, ReadReason, &WasLegacy);
+ CHECK(ReadReason.has_value());
+ CHECK(Out.GetSize() == 0);
+}
+
+TEST_SUITE_END();
+
+void
+authmgr_forcelink()
+{
+}
+
+#endif // ZEN_WITH_TESTS
+
} // namespace zen
diff --git a/src/zenhttp/auth/authservice.cpp b/src/zenhttp/auth/authservice.cpp
index f89ca91da..d8a1588c6 100644
--- a/src/zenhttp/auth/authservice.cpp
+++ b/src/zenhttp/auth/authservice.cpp
@@ -3,6 +3,7 @@
#include "zenhttp/auth/authservice.h"
#include <zencore/compactbinarybuilder.h>
+#include <zencore/logging.h>
#include <zencore/string.h>
#include <zenhttp/auth/authmgr.h>
@@ -21,6 +22,19 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
[this](HttpRouterRequest& RouterRequest) {
HttpServerRequest& ServerRequest = RouterRequest.ServerRequest();
+ // The refresh-token endpoint is how the local Unreal Editor / zen CLI
+ // hands off a refresh token obtained via an interactive flow. There
+ // is no legitimate reason for it to be reachable from the network: a
+ // remote caller that can reach it can evict the cached token (DoS)
+ // or — with a valid same-provider token — substitute credentials used
+ // by GetOpenIdAccessToken. Require the origin to be the local
+ // machine; any configured password filter still runs in front.
+ if (!ServerRequest.IsLocalMachineRequest())
+ {
+ ZEN_WARN("rejecting non-local POST to /auth/oidc/refreshtoken");
+ return ServerRequest.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
const HttpContentType ContentType = ServerRequest.RequestContentType();
if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false)
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp
index 23bbc17e8..63c8d9030 100644
--- a/src/zenhttp/auth/oidc.cpp
+++ b/src/zenhttp/auth/oidc.cpp
@@ -32,6 +32,68 @@ namespace details {
using namespace std::literals;
+// Return the "scheme://host[:port]" prefix of an absolute URL, or empty on
+// malformed input. RFC 8414 §3 requires every endpoint advertised by an
+// OP to live under the issuer's origin, so exact-prefix comparison on this
+// value is sufficient to pin an endpoint to the configured IdP.
+static std::string_view
+OriginOf(std::string_view Url)
+{
+ // scheme://...
+ auto SchemeEnd = Url.find("://");
+ if (SchemeEnd == std::string_view::npos)
+ {
+ return {};
+ }
+ // First path char (or query / fragment) after the authority.
+ const size_t AuthorityStart = SchemeEnd + 3;
+ size_t OriginEnd = Url.size();
+ for (size_t I = AuthorityStart; I < Url.size(); ++I)
+ {
+ const char C = Url[I];
+ if (C == '/' || C == '?' || C == '#')
+ {
+ OriginEnd = I;
+ break;
+ }
+ }
+ // Require at least one character in the authority.
+ if (OriginEnd == AuthorityStart)
+ {
+ return {};
+ }
+ return Url.substr(0, OriginEnd);
+}
+
+// True if the URL uses a scheme / host that we trust for discovery even
+// without HTTPS — narrowly, only loopback for dev / test setups.
+static bool
+IsLoopbackHttp(std::string_view Url)
+{
+ constexpr std::string_view Prefixes[] = {
+ "http://localhost"sv,
+ "http://127.0.0.1"sv,
+ "http://[::1]"sv,
+ };
+ for (std::string_view P : Prefixes)
+ {
+ if (Url.size() >= P.size() && Url.substr(0, P.size()) == P)
+ {
+ // Ensure the next char (if any) ends the authority cleanly.
+ if (Url.size() == P.size())
+ {
+ return true;
+ }
+ const char Next = Url[P.size()];
+ if (Next == ':' || Next == '/' || Next == '?' || Next == '#')
+ {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
static std::string
FormUrlEncode(std::string_view Input)
{
@@ -60,6 +122,19 @@ OidcClient::OidcClient(const OidcClient::Options& Options)
OidcClient::InitResult
OidcClient::Initialize()
{
+ // The OIDC discovery document determines where we send refresh tokens, so
+ // the transport to the discovery endpoint has to be trustworthy. Require
+ // HTTPS on the configured BaseUrl. Loopback is permitted over plain HTTP
+ // for developer setups that run a local IdP mock — no meaningful attack
+ // surface on the loopback interface.
+ if (m_BaseUrl.size() < 8 || m_BaseUrl.substr(0, 8) != "https://"sv)
+ {
+ if (!IsLoopbackHttp(m_BaseUrl))
+ {
+ return {.Reason = "BaseUrl must use https:// (or a http://localhost / 127.0.0.1 / [::1] loopback)"};
+ }
+ }
+
HttpClient Http{m_BaseUrl};
HttpClient::Response Response = Http.Get("/.well-known/openid-configuration"sv);
@@ -81,14 +156,86 @@ OidcClient::Initialize()
return {.Reason = std::move(JsonError)};
}
- m_Config = {.Issuer = Json["issuer"].string_value(),
+ // RFC 8414 §3: the discovery document's `issuer` value MUST identify the
+ // OP and MUST be the origin used to fetch the document. Without this
+ // check, an attacker who can intercept discovery (or a misconfigured
+ // intermediate) can swap the issuer identity without detection. Accept
+ // a trailing '/' divergence since OPs vary.
+ const std::string Issuer = Json["issuer"].string_value();
+ {
+ std::string_view ExpectedBase = m_BaseUrl;
+ while (!ExpectedBase.empty() && ExpectedBase.back() == '/')
+ {
+ ExpectedBase.remove_suffix(1);
+ }
+ std::string_view ActualIssuer = Issuer;
+ while (!ActualIssuer.empty() && ActualIssuer.back() == '/')
+ {
+ ActualIssuer.remove_suffix(1);
+ }
+ if (ActualIssuer.empty() || ActualIssuer != ExpectedBase)
+ {
+ return {.Reason = fmt::format("discovery issuer mismatch (expected '{}')", ExpectedBase)};
+ }
+ }
+
+ // Pin every endpoint we actually use to the same origin as BaseUrl. This
+ // is the last defense against a tampered discovery document redirecting
+ // token submissions to an attacker-controlled host. We check the
+ // endpoints we may call later; endpoints this client never dispatches to
+ // are left alone so a discovery document with unrelated auxiliary URLs
+ // isn't rejected for no reason.
+ const std::string_view BaseOrigin = OriginOf(m_BaseUrl);
+ if (BaseOrigin.empty())
+ {
+ return {.Reason = "BaseUrl is malformed"};
+ }
+
+ const std::string TokenEndpoint = Json["token_endpoint"].string_value();
+ const std::string UserInfoEndpoint = Json["userinfo_endpoint"].string_value();
+ const std::string JwksUri = Json["jwks_uri"].string_value();
+
+ auto CheckOrigin = [&](std::string_view Name, std::string_view Url) -> std::optional<std::string> {
+ if (Url.empty())
+ {
+ return std::nullopt;
+ }
+ const std::string_view Origin = OriginOf(Url);
+ if (Origin != BaseOrigin)
+ {
+ return fmt::format("discovery endpoint '{}' is off-origin (expected origin '{}')", Name, BaseOrigin);
+ }
+ return std::nullopt;
+ };
+
+ if (auto Err = CheckOrigin("token_endpoint"sv, TokenEndpoint); Err.has_value())
+ {
+ return {.Reason = std::move(*Err)};
+ }
+ if (auto Err = CheckOrigin("userinfo_endpoint"sv, UserInfoEndpoint); Err.has_value())
+ {
+ return {.Reason = std::move(*Err)};
+ }
+ if (auto Err = CheckOrigin("jwks_uri"sv, JwksUri); Err.has_value())
+ {
+ return {.Reason = std::move(*Err)};
+ }
+
+ // token_endpoint is required for the refresh flow we implement; fail early
+ // rather than at RefreshToken time if the OP omitted it.
+ if (TokenEndpoint.empty())
+ {
+ return {.Reason = "discovery document is missing token_endpoint"};
+ }
+
+ m_Config = {.Issuer = Issuer,
.AuthorizationEndpoint = Json["authorization_endpoint"].string_value(),
- .TokenEndpoint = Json["token_endpoint"].string_value(),
- .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
+ .TokenEndpoint = TokenEndpoint,
+ .UserInfoEndpoint = UserInfoEndpoint,
.RegistrationEndpoint = Json["registration_endpoint"].string_value(),
.EndSessionEndpoint = Json["end_session_endpoint"].string_value(),
.DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(),
- .JwksUri = Json["jwks_uri"].string_value(),
+ .JwksUri = JwksUri,
.SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
.SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
.SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]),
@@ -118,7 +265,12 @@ OidcClient::RefreshToken(std::string_view RefreshToken)
if (Response.StatusCode != HttpResponseCode::OK)
{
- return {.Reason = fmt::format("{} ({})", ToString(Response.StatusCode), Response.AsText())};
+ // Do NOT include Response.AsText() in the reason string. Some IdPs
+ // echo the submitted refresh_token (or a prefix of it) in their error
+ // body — plumbing that into the Reason string causes AuthMgrImpl's
+ // ZEN_WARN in the refresh paths to write the token into the log.
+ // Only the status code is safe to surface up to the log sites.
+ return {.Reason = fmt::format("{} (provider returned {} bytes)", ToString(Response.StatusCode), Response.AsText().size())};
}
std::string JsonError;
@@ -129,10 +281,14 @@ OidcClient::RefreshToken(std::string_view RefreshToken)
return {.Reason = std::move(JsonError)};
}
+ // Note: id_token is intentionally not parsed. It is a JWT whose contents
+ // are meaningful only after signature / issuer / audience / expiry
+ // verification against the provider's JWKS, and nothing downstream
+ // currently consumes it. Leaving it unparsed avoids planting an
+ // unauthenticated identity claim in the OpenIdToken cache.
return {.TokenType = Json["token_type"].string_value(),
.AccessToken = Json["access_token"].string_value(),
.RefreshToken = Json["refresh_token"].string_value(),
- .IdentityToken = Json["id_token"].string_value(),
.Scope = Json["scope"].string_value(),
.ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()),
.Ok = true};
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
index 4337fcb79..842bf9d49 100644
--- a/src/zenhttp/clients/httpwsclient.cpp
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -23,6 +23,8 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
+using namespace std::literals;
+
//////////////////////////////////////////////////////////////////////////
struct HttpWsClient::Impl
@@ -271,6 +273,43 @@ struct HttpWsClient::Impl
});
}
+ // Trim ASCII LWS (space / tab) from both ends of a header value, along with
+ // a trailing CR if the caller didn't strip it.
+ static std::string_view TrimHeaderValue(std::string_view V)
+ {
+ while (!V.empty() && (V.front() == ' ' || V.front() == '\t'))
+ {
+ V.remove_prefix(1);
+ }
+ while (!V.empty() && (V.back() == ' ' || V.back() == '\t' || V.back() == '\r'))
+ {
+ V.remove_suffix(1);
+ }
+ return V;
+ }
+
+ // Return true if a comma-separated header value contains the given token,
+ // case-insensitively. Used for Connection header parsing where the value
+ // may legitimately be "Upgrade, keep-alive" etc.
+ static bool HeaderContainsToken(std::string_view Value, std::string_view Token)
+ {
+ while (!Value.empty())
+ {
+ auto CommaPos = Value.find(',');
+ std::string_view Part = TrimHeaderValue(Value.substr(0, CommaPos));
+ if (Part.size() == Token.size() && StrCaseCompare(Part, Token) == 0)
+ {
+ return true;
+ }
+ if (CommaPos == std::string_view::npos)
+ {
+ break;
+ }
+ Value.remove_prefix(CommaPos + 1);
+ }
+ return false;
+ }
+
void DoReadHandshakeResponse()
{
WithSocket([this](auto& Socket) {
@@ -284,30 +323,105 @@ struct HttpWsClient::Impl
return;
}
- // Parse the response
const auto& Data = m_ReadBuffer.data();
std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
// Consume the headers from the read buffer (any extra data stays for frame parsing)
auto HeaderEnd = Response.find("\r\n\r\n");
- if (HeaderEnd != std::string::npos)
+ if (HeaderEnd == std::string::npos)
{
- m_ReadBuffer.consume(HeaderEnd + 4);
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: incomplete headers");
+ m_Handler.OnWsClose(1006, "handshake incomplete");
+ return;
}
+ m_ReadBuffer.consume(HeaderEnd + 4);
+
+ // Parse the status line. Substring matching on "101" anywhere
+ // in the response is unsafe — a server returning
+ // "HTTP/1.1 404 Not Found\r\nX-Retry-After: 101\r\n" would have
+ // satisfied it. We require the first line to start with
+ // "HTTP/1.x 101" followed by end-of-line or space.
+ //
+ // ResponseView spans up through the first "\r\n" of the
+ // terminating "\r\n\r\n" so that every header line — including
+ // the last one — is terminated by "\r\n" in the view.
+ std::string_view ResponseView(Response.data(), HeaderEnd + 2);
+ auto StatusLineEnd = ResponseView.find("\r\n");
+ if (StatusLineEnd == std::string_view::npos)
+ {
+ m_Handler.OnWsClose(1006, "handshake malformed");
+ return;
+ }
+ std::string_view StatusLine = ResponseView.substr(0, StatusLineEnd);
- // Validate 101 response
- if (Response.find("101") == std::string::npos)
+ // Expect: "HTTP/1.x 101" (12 chars min), with 'x' being '0' or '1'.
+ bool StatusOk = StatusLine.size() >= 12 && StatusLine.substr(0, 7) == "HTTP/1." &&
+ (StatusLine[7] == '0' || StatusLine[7] == '1') && StatusLine[8] == ' ' &&
+ StatusLine.substr(9, 3) == "101" && (StatusLine.size() == 12 || StatusLine[12] == ' ');
+ if (!StatusOk)
{
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected; status line: {}", StatusLine.substr(0, 80));
m_Handler.OnWsClose(1006, "handshake rejected");
return;
}
- // Validate Sec-WebSocket-Accept
+ // Parse headers and extract the three fields RFC 6455 §4.1
+ // requires a client to validate: Upgrade, Connection, and
+ // Sec-WebSocket-Accept. Case-insensitive on header names and
+ // on the Upgrade / Connection token values; exact-match on the
+ // Sec-WebSocket-Accept base64 value.
+ bool UpgradeOk = false;
+ bool ConnectionOk = false;
+ std::string_view AcceptValue;
+
+ std::string_view HeaderBlock = ResponseView.substr(StatusLineEnd + 2);
+ while (!HeaderBlock.empty())
+ {
+ auto NextLineEnd = HeaderBlock.find("\r\n");
+ if (NextLineEnd == std::string_view::npos)
+ {
+ break;
+ }
+ std::string_view Line = HeaderBlock.substr(0, NextLineEnd);
+ HeaderBlock = HeaderBlock.substr(NextLineEnd + 2);
+ if (Line.empty())
+ {
+ break;
+ }
+
+ auto ColonPos = Line.find(':');
+ if (ColonPos == std::string_view::npos)
+ {
+ continue;
+ }
+ std::string_view Name = Line.substr(0, ColonPos);
+ std::string_view Value = TrimHeaderValue(Line.substr(ColonPos + 1));
+
+ if (Name.size() == 7 && StrCaseCompare(Name, "Upgrade"sv) == 0)
+ {
+ UpgradeOk = (Value.size() == 9 && StrCaseCompare(Value, "websocket"sv) == 0);
+ }
+ else if (Name.size() == 10 && StrCaseCompare(Name, "Connection"sv) == 0)
+ {
+ ConnectionOk = HeaderContainsToken(Value, "upgrade"sv);
+ }
+ else if (Name.size() == 20 && StrCaseCompare(Name, "Sec-WebSocket-Accept"sv) == 0)
+ {
+ AcceptValue = Value;
+ }
+ }
+
+ if (!UpgradeOk || !ConnectionOk)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake missing required Upgrade/Connection headers");
+ m_Handler.OnWsClose(1006, "handshake missing headers");
+ return;
+ }
+
std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
- if (Response.find(ExpectedAccept) == std::string::npos)
+ if (AcceptValue.size() != ExpectedAccept.size() || AcceptValue != ExpectedAccept)
{
- ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid or missing Sec-WebSocket-Accept");
m_Handler.OnWsClose(1006, "invalid accept key");
return;
}
diff --git a/src/zenhttp/include/zenhttp/auth/authmgr.h b/src/zenhttp/include/zenhttp/auth/authmgr.h
index 054588ab9..b61b221f6 100644
--- a/src/zenhttp/include/zenhttp/auth/authmgr.h
+++ b/src/zenhttp/include/zenhttp/auth/authmgr.h
@@ -17,7 +17,13 @@ struct AuthConfig
std::filesystem::path RootDirectory;
std::chrono::seconds UpdateInterval{30};
AesKey256Bit EncryptionKey;
- AesIV128Bit EncryptionIV;
+
+ // LEGACY: consulted only when reading a pre-AES-GCM authstate file written
+ // before the format migration. New writes use AES-GCM with a fresh random
+ // nonce per write and do not consult this field. Kept so existing
+ // deployments roll forward transparently; remove once the next format
+ // version bump lands.
+ AesIV128Bit EncryptionIV;
};
class AuthMgr
diff --git a/src/zenhttp/include/zenhttp/auth/oidc.h b/src/zenhttp/include/zenhttp/auth/oidc.h
index 6f9c3198e..1008367df 100644
--- a/src/zenhttp/include/zenhttp/auth/oidc.h
+++ b/src/zenhttp/include/zenhttp/auth/oidc.h
@@ -39,7 +39,6 @@ public:
std::string TokenType;
std::string AccessToken;
std::string RefreshToken;
- std::string IdentityToken;
std::string Scope;
std::string Reason;
int64_t ExpiresInSeconds{};
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
index 2d25515d3..22fb419c1 100644
--- a/src/zenhttp/include/zenhttp/websocket.h
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -13,11 +13,12 @@ namespace zen {
enum class WebSocketOpcode : uint8_t
{
- kText = 0x1,
- kBinary = 0x2,
- kClose = 0x8,
- kPing = 0x9,
- kPong = 0xA
+ kContinuation = 0x0, // RFC 6455 §5.4 - only seen inside the codec; handlers always receive the coalesced opcode
+ kText = 0x1,
+ kBinary = 0x2,
+ kClose = 0x8,
+ kPing = 0x9,
+ kPong = 0xA
};
struct WebSocketMessage
diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp
index 0e3a743c3..70315fa03 100644
--- a/src/zenhttp/security/passwordsecurity.cpp
+++ b/src/zenhttp/security/passwordsecurity.cpp
@@ -67,7 +67,10 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUr
{
return true;
}
- if (Password() == InPassword)
+ // Constant-time compare: a plain == short-circuits on the first byte mismatch
+ // (and on length mismatch before any bytes are inspected), exposing the
+ // configured password to a remote timing oracle across the network.
+ if (ConstantTimeEquals(Password(), InPassword))
{
return true;
}
@@ -148,6 +151,19 @@ TEST_CASE("passwordsecurity.allowsomelocaluris")
CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false));
}
+TEST_CASE("passwordsecurity.constanttimeequals")
+{
+ // Documents the contract of the helper used to compare passwords. We
+ // cannot observe timing from here, but we can at least verify equality
+ // and inequality of the three interesting shapes: same length equal,
+ // same length unequal, and length mismatch.
+ CHECK(ConstantTimeEquals("abcdef"sv, "abcdef"sv));
+ CHECK_FALSE(ConstantTimeEquals("abcdef"sv, "abcdeg"sv));
+ CHECK_FALSE(ConstantTimeEquals("abcdef"sv, "abcdefg"sv));
+ CHECK_FALSE(ConstantTimeEquals(""sv, "x"sv));
+ CHECK(ConstantTimeEquals(""sv, ""sv));
+}
+
TEST_CASE("passwordsecurity.conflictingunprotecteduris")
{
try
diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp
index 87d8cc275..0428b5437 100644
--- a/src/zenhttp/security/passwordsecurityfilter.cpp
+++ b/src/zenhttp/security/passwordsecurityfilter.cpp
@@ -34,13 +34,41 @@ IHttpRequestFilter::Result
PasswordHttpFilter::FilterRequest(HttpServerRequest& Request)
{
std::string_view Password;
- std::string_view AuthorizationHeader = Request.GetAuthorizationHeader();
- size_t AuthorizationHeaderLength = AuthorizationHeader.length();
- if (AuthorizationHeaderLength > m_AuthenticationTypeString.length())
+ std::string_view AuthorizationHeader = Request.GetAuthorizationHeader();
+
+ // Look for the configured scheme prefix (e.g. "Basic ") case-insensitively.
+ // Only extract a candidate credential when the scheme prefix matches and the
+ // remainder passes two defense-in-depth checks:
+ // * trim any trailing LWS (SP / HTAB) from the credential, so a
+ // well-behaved but "Authorization: Basic <b64> " style header still
+ // works instead of producing a silent 403;
+ // * reject the header outright if the credential contains any control
+ // byte (< 0x20) or DEL (0x7F) — valid Base64 never does, and passing
+ // such bytes downstream is a log-injection / smuggling vector.
+ if (AuthorizationHeader.length() > m_AuthenticationTypeString.length() &&
+ StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0)
{
- if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0)
+ std::string_view Candidate = AuthorizationHeader.substr(m_AuthenticationTypeString.length());
+
+ while (!Candidate.empty() && (Candidate.back() == ' ' || Candidate.back() == '\t'))
+ {
+ Candidate.remove_suffix(1);
+ }
+
+ bool HasControlByte = false;
+ for (char C : Candidate)
+ {
+ const uint8_t B = static_cast<uint8_t>(C);
+ if (B < 0x20 || B == 0x7F)
+ {
+ HasControlByte = true;
+ break;
+ }
+ }
+
+ if (!HasControlByte)
{
- Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length());
+ Password = Candidate;
}
}
diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp
index 078c21ea1..d223a50c0 100644
--- a/src/zenhttp/servers/wsasio.cpp
+++ b/src/zenhttp/servers/wsasio.cpp
@@ -107,11 +107,17 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData()
const auto* Data = static_cast<const uint8_t*>(InputBuffer.data());
const auto Size = InputBuffer.size();
- WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size);
- if (!Frame.IsValid)
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size, /*RequireMask=*/true);
+ if (Frame.Status == WsFrameParseStatus::kNeedMoreData)
{
break; // not enough data yet
}
+ if (Frame.Status == WsFrameParseStatus::kProtocolError)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "WebSocket protocol violation; closing with 1002");
+ DoClose(1002, "protocol error");
+ return;
+ }
m_ReadBuffer.consume(Frame.BytesConsumed);
@@ -120,15 +126,70 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData()
m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
}
+ // Cap for a coalesced multi-frame message. Matches the single-frame cap
+ // in the codec; any legitimate zenhttp use case fits in either case.
+ static constexpr size_t kMaxAssembledSize = 4 * 1024 * 1024;
+
switch (Frame.Opcode)
{
case WebSocketOpcode::kText:
case WebSocketOpcode::kBinary:
{
- WebSocketMessage Msg;
- Msg.Opcode = Frame.Opcode;
- Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
- m_Handler.OnWebSocketMessage(*this, Msg);
+ // Per RFC 6455 §5.4, a new data frame is a protocol violation when a
+ // fragmented message is already in progress (non-control frames must
+ // not be interleaved).
+ if (m_FragmentInProgress)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "New data frame arrived mid-fragment; closing 1002");
+ DoClose(1002, "interleaved data frame");
+ return;
+ }
+
+ if (Frame.Fin)
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ }
+ else
+ {
+ if (Frame.Payload.size() > kMaxAssembledSize)
+ {
+ DoClose(1009, "message too big");
+ return;
+ }
+ m_FragmentOpcode = Frame.Opcode;
+ m_FragmentBuffer = std::move(Frame.Payload);
+ m_FragmentInProgress = true;
+ }
+ break;
+ }
+
+ case WebSocketOpcode::kContinuation:
+ {
+ if (!m_FragmentInProgress)
+ {
+ ZEN_LOG_DEBUG(WsLog(), "Continuation frame with no message in progress; closing 1002");
+ DoClose(1002, "unexpected continuation");
+ return;
+ }
+ if (m_FragmentBuffer.size() + Frame.Payload.size() > kMaxAssembledSize)
+ {
+ DoClose(1009, "message too big");
+ return;
+ }
+ m_FragmentBuffer.insert(m_FragmentBuffer.end(), Frame.Payload.begin(), Frame.Payload.end());
+ if (Frame.Fin)
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = m_FragmentOpcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, m_FragmentBuffer.data(), m_FragmentBuffer.size());
+ m_FragmentBuffer.clear();
+ m_FragmentBuffer.shrink_to_fit();
+ m_FragmentInProgress = false;
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ }
break;
}
diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h
index 64602ee46..5db38db92 100644
--- a/src/zenhttp/servers/wsasio.h
+++ b/src/zenhttp/servers/wsasio.h
@@ -77,6 +77,12 @@ private:
std::deque<std::vector<uint8_t>> m_WriteQueue;
bool m_IsWriting = false;
+ // Fragmented-message reassembly (RFC 6455 §5.4). Only the reader touches
+ // these; no synchronization required.
+ std::vector<uint8_t> m_FragmentBuffer;
+ WebSocketOpcode m_FragmentOpcode{WebSocketOpcode::kContinuation};
+ bool m_FragmentInProgress = false;
+
std::atomic<bool> m_IsOpen{true};
std::atomic<bool> m_CloseSent{false};
};
diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp
index e452141fe..781f04c5e 100644
--- a/src/zenhttp/servers/wsframecodec.cpp
+++ b/src/zenhttp/servers/wsframecodec.cpp
@@ -16,7 +16,7 @@ namespace zen {
//
WsFrameParseResult
-WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
+WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size, bool RequireMask)
{
// Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames)
if (Size < 2)
@@ -24,10 +24,48 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
return {};
}
- const bool Fin = (Data[0] & 0x80) != 0;
- const uint8_t OpcodeRaw = Data[0] & 0x0F;
- const bool Masked = (Data[1] & 0x80) != 0;
- uint64_t PayloadLen = Data[1] & 0x7F;
+ const bool Fin = (Data[0] & 0x80) != 0;
+ const uint8_t RsvBits = Data[0] & 0x70;
+ const uint8_t OpcodeRaw = Data[0] & 0x0F;
+ const bool Masked = (Data[1] & 0x80) != 0;
+ const uint8_t ShortLength = Data[1] & 0x7F;
+ uint64_t PayloadLen = ShortLength;
+
+ const bool IsControlFrame = (OpcodeRaw & 0x08) != 0;
+
+ // RFC 6455 section 5.2: RSV1/2/3 must be zero unless a negotiated extension
+ // defines them. We do not negotiate any extensions, so any non-zero RSV bit
+ // is a protocol violation.
+ if (RsvBits != 0)
+ {
+ WsFrameParseResult Error;
+ Error.Status = WsFrameParseStatus::kProtocolError;
+ return Error;
+ }
+
+ // RFC 6455 section 5.5: control frames (Close / Ping / Pong and any opcode
+ // in 0x8..0xF) MUST NOT be fragmented and MUST have a payload of 125 bytes
+ // or less. Rejecting fragmented or oversized control frames prevents a
+ // peer from tying up unbounded memory inside an auto-pong, and closes off
+ // a class of smuggling tricks where handlers might observe partial control
+ // payloads.
+ if (IsControlFrame && (!Fin || ShortLength > 125))
+ {
+ WsFrameParseResult Error;
+ Error.Status = WsFrameParseStatus::kProtocolError;
+ return Error;
+ }
+
+ // RFC 6455 section 5.1: a server MUST close the connection upon receiving an
+ // unmasked client frame. Signal this distinctly from "need more data" so the
+ // server close path can trigger a 1002 close rather than stalling for bytes
+ // that will never satisfy the parse.
+ if (RequireMask && !Masked)
+ {
+ WsFrameParseResult Error;
+ Error.Status = WsFrameParseStatus::kProtocolError;
+ return Error;
+ }
size_t HeaderSize = 2;
@@ -51,11 +89,19 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
HeaderSize = 10;
}
- // Reject frames with unreasonable payload sizes to prevent OOM
- static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB
+ // Reject frames with unreasonable payload sizes to bound per-connection
+ // memory. Parsers accumulate the whole frame before dispatch (see the
+ // read loops in wsasio.cpp / wshttpsys.cpp), so this cap also bounds the
+ // accumulator: a peer that advertises a large frame and streams bytes
+ // slowly cannot grow buffers past this limit. 4 MB is well above anything
+ // the monitoring / stats endpoints produce; raise it if a legitimate use
+ // case emerges.
+ static constexpr uint64_t kMaxPayloadSize = 4 * 1024 * 1024; // 4 MB
if (PayloadLen > kMaxPayloadSize)
{
- return {};
+ WsFrameParseResult Error;
+ Error.Status = WsFrameParseStatus::kProtocolError;
+ return Error;
}
const size_t MaskSize = Masked ? 4 : 0;
@@ -70,6 +116,7 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size)
const uint8_t* PayloadData = Data + HeaderSize + MaskSize;
WsFrameParseResult Result;
+ Result.Status = WsFrameParseStatus::kValid;
Result.IsValid = true;
Result.BytesConsumed = TotalFrame;
Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw);
diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h
index 2d90b6fa1..0a62e6c80 100644
--- a/src/zenhttp/servers/wsframecodec.h
+++ b/src/zenhttp/servers/wsframecodec.h
@@ -15,11 +15,22 @@
namespace zen {
/**
+ * Outcome of a frame parse attempt
+ */
+enum class WsFrameParseStatus : uint8_t
+{
+ kNeedMoreData, // the buffer does not yet contain a full frame
+ kValid, // a complete, well-formed frame was parsed
+ kProtocolError // the frame violates RFC 6455; caller must close the connection (typically with 1002)
+};
+
+/**
* Result of attempting to parse a single WebSocket frame from a byte buffer
*/
struct WsFrameParseResult
{
- bool IsValid = false; // true if a complete frame was successfully parsed
+ WsFrameParseStatus Status = WsFrameParseStatus::kNeedMoreData;
+ bool IsValid = false; // true iff Status == kValid (convenience mirror)
size_t BytesConsumed = 0; // number of bytes consumed from the input buffer
WebSocketOpcode Opcode = WebSocketOpcode::kText;
bool Fin = false;
@@ -37,11 +48,17 @@ struct WsFrameCodec
/**
* Try to parse one complete frame from the front of the buffer.
*
- * Returns a result with IsValid == false and BytesConsumed == 0 when
+ * Returns a result with Status == kNeedMoreData and BytesConsumed == 0 when
* there is not enough data yet. The caller should accumulate more data
* and retry.
+ *
+ * When RequireMask is true (server parsing client-to-server frames), a frame
+ * that is not masked is rejected with Status == kProtocolError per RFC 6455
+ * section 5.1 — the caller must close the connection with status 1002.
+ * When RequireMask is false (client parsing server-to-client frames), the
+ * mask bit is not validated.
*/
- static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size);
+ static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size, bool RequireMask = false);
/**
* Build a server-to-client frame (no masking)
diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp
index 8520e9f60..1f2b11bf1 100644
--- a/src/zenhttp/servers/wshttpsys.cpp
+++ b/src/zenhttp/servers/wshttpsys.cpp
@@ -176,11 +176,17 @@ WsHttpSysConnection::ProcessReceivedData()
{
while (!m_Accumulated.empty())
{
- WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size());
- if (!Frame.IsValid)
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size(), /*RequireMask=*/true);
+ if (Frame.Status == WsFrameParseStatus::kNeedMoreData)
{
break; // not enough data yet
}
+ if (Frame.Status == WsFrameParseStatus::kProtocolError)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket protocol violation; closing with 1002");
+ DoClose(1002, "protocol error");
+ return;
+ }
// Remove consumed bytes
m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed);
@@ -190,15 +196,67 @@ WsHttpSysConnection::ProcessReceivedData()
m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed);
}
+ // Cap for a coalesced multi-frame message. Matches the single-frame cap
+ // in the codec; any legitimate zenhttp use case fits in either case.
+ static constexpr size_t kMaxAssembledSize = 4 * 1024 * 1024;
+
switch (Frame.Opcode)
{
case WebSocketOpcode::kText:
case WebSocketOpcode::kBinary:
{
- WebSocketMessage Msg;
- Msg.Opcode = Frame.Opcode;
- Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
- m_Handler.OnWebSocketMessage(*this, Msg);
+ if (m_FragmentInProgress)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "New data frame arrived mid-fragment; closing 1002");
+ DoClose(1002, "interleaved data frame");
+ return;
+ }
+
+ if (Frame.Fin)
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ }
+ else
+ {
+ if (Frame.Payload.size() > kMaxAssembledSize)
+ {
+ DoClose(1009, "message too big");
+ return;
+ }
+ m_FragmentOpcode = Frame.Opcode;
+ m_FragmentBuffer = std::move(Frame.Payload);
+ m_FragmentInProgress = true;
+ }
+ break;
+ }
+
+ case WebSocketOpcode::kContinuation:
+ {
+ if (!m_FragmentInProgress)
+ {
+ ZEN_LOG_DEBUG(WsHttpSysLog(), "Continuation frame with no message in progress; closing 1002");
+ DoClose(1002, "unexpected continuation");
+ return;
+ }
+ if (m_FragmentBuffer.size() + Frame.Payload.size() > kMaxAssembledSize)
+ {
+ DoClose(1009, "message too big");
+ return;
+ }
+ m_FragmentBuffer.insert(m_FragmentBuffer.end(), Frame.Payload.begin(), Frame.Payload.end());
+ if (Frame.Fin)
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = m_FragmentOpcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, m_FragmentBuffer.data(), m_FragmentBuffer.size());
+ m_FragmentBuffer.clear();
+ m_FragmentBuffer.shrink_to_fit();
+ m_FragmentInProgress = false;
+ m_Handler.OnWebSocketMessage(*this, Msg);
+ }
break;
}
diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h
index 6015e3873..49b0cdf67 100644
--- a/src/zenhttp/servers/wshttpsys.h
+++ b/src/zenhttp/servers/wshttpsys.h
@@ -87,6 +87,12 @@ private:
std::vector<uint8_t> m_ReadBuffer;
std::vector<uint8_t> m_Accumulated;
+ // Fragmented-message reassembly (RFC 6455 §5.4). Only the read-completion
+ // callback touches these; no synchronization required.
+ std::vector<uint8_t> m_FragmentBuffer;
+ WebSocketOpcode m_FragmentOpcode{WebSocketOpcode::kContinuation};
+ bool m_FragmentInProgress = false;
+
// Write state
RwLock m_WriteLock;
std::deque<std::vector<uint8_t>> m_WriteQueue;
diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp
index a58037fec..d375aeb12 100644
--- a/src/zenhttp/servers/wstest.cpp
+++ b/src/zenhttp/servers/wstest.cpp
@@ -239,6 +239,97 @@ TEST_CASE("websocket.framecodec")
CHECK_EQ(Result.Payload.size(), 70000u);
}
+ SUBCASE("TryParseFrame - declared payload exceeding codec cap is rejected")
+ {
+ // Hand-build a masked Binary frame header that declares a 128 MB payload
+ // via the 64-bit extended length. The codec should reject this as a
+ // protocol error without trying to allocate or wait for the bytes.
+ std::vector<uint8_t> Frame;
+ Frame.push_back(0x82); // FIN + binary
+ Frame.push_back(0x80 | 127); // MASK + 64-bit extended length
+ uint64_t DeclaredLen = 128ull * 1024 * 1024; // 128 MB, well above the 4 MB cap
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((DeclaredLen >> (i * 8)) & 0xFF));
+ }
+ uint8_t MaskKey[4] = {0x11, 0x22, 0x33, 0x44};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ // Do not include any payload bytes — the cap should trip before any are needed
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true);
+ CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError);
+ }
+
+ SUBCASE("TryParseFrame - fragmented control frame rejected")
+ {
+ // Build a Close frame by hand with FIN=0 (fragmented — illegal for control frames)
+ std::vector<uint8_t> Frame;
+ Frame.push_back(static_cast<uint8_t>(WebSocketOpcode::kClose)); // no FIN bit
+ Frame.push_back(0x80 | 2); // MASK + 2-byte payload
+ uint8_t MaskKey[4] = {0x01, 0x02, 0x03, 0x04};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+ Frame.push_back(0x03 ^ MaskKey[0]); // close code high byte (1000 = 0x03E8)
+ Frame.push_back(0xE8 ^ MaskKey[1]); // close code low byte
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true);
+ CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError);
+ }
+
+ SUBCASE("TryParseFrame - oversized control payload rejected")
+ {
+ // Build a Ping with 126-byte payload (exceeds the 125-byte control-frame limit).
+ // Payload length = 126 forces the extended length encoding, which is itself
+ // illegal for a control frame — both conditions fail the spec.
+ std::vector<uint8_t> Payload(126, 0xAA);
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPing, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true);
+ CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError);
+ }
+
+ SUBCASE("TryParseFrame - non-zero RSV bits rejected")
+ {
+ // Start from a valid masked text frame and flip RSV1 on byte 0
+ std::string_view Text = "rsv";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ Frame[0] |= 0x40; // set RSV1
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true);
+ CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError);
+ }
+
+ SUBCASE("TryParseFrame - unmasked client frame rejected when RequireMask=true")
+ {
+ // Build a server-style (unmasked) text frame and feed it to the codec
+ // with RequireMask=true. RFC 6455 section 5.1 requires the server to
+ // reject unmasked client frames.
+ std::string_view Text = "unmasked";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload);
+
+ WsFrameParseResult Strict = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true);
+ CHECK_EQ(Strict.Status, WsFrameParseStatus::kProtocolError);
+ CHECK_FALSE(Strict.IsValid);
+
+ // Same bytes with RequireMask=false should still parse successfully
+ WsFrameParseResult Lenient = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/false);
+ CHECK_EQ(Lenient.Status, WsFrameParseStatus::kValid);
+ CHECK(Lenient.IsValid);
+ }
+
+ SUBCASE("TryParseFrame - masked client frame accepted when RequireMask=true")
+ {
+ std::string_view Text = "masked";
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+
+ WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true);
+ CHECK_EQ(Result.Status, WsFrameParseStatus::kValid);
+ CHECK(Result.IsValid);
+ CHECK_EQ(Result.Payload.size(), Text.size());
+ }
+
SUBCASE("BuildMaskedCloseFrame roundtrip")
{
std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure");
@@ -309,6 +400,46 @@ namespace {
return Frame;
}
+ /**
+ * Build a masked frame with an explicit FIN bit (for fragmentation tests)
+ */
+ std::vector<uint8_t> BuildMaskedFrameEx(WebSocketOpcode Opcode, std::span<const uint8_t> Payload, bool Fin)
+ {
+ std::vector<uint8_t> Frame;
+
+ // FIN-or-not + opcode
+ Frame.push_back((Fin ? 0x80 : 0x00) | static_cast<uint8_t>(Opcode));
+
+ if (Payload.size() < 126)
+ {
+ Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size()));
+ }
+ else if (Payload.size() <= 0xFFFF)
+ {
+ Frame.push_back(0x80 | 126);
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF));
+ Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF));
+ }
+ else
+ {
+ Frame.push_back(0x80 | 127);
+ for (int i = 7; i >= 0; --i)
+ {
+ Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF));
+ }
+ }
+
+ uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78};
+ Frame.insert(Frame.end(), MaskKey, MaskKey + 4);
+
+ for (size_t i = 0; i < Payload.size(); ++i)
+ {
+ Frame.push_back(Payload[i] ^ MaskKey[i & 3]);
+ }
+
+ return Frame;
+ }
+
std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text)
{
std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
@@ -744,6 +875,191 @@ TEST_CASE("websocket.integration")
Sock.close();
}
+ SUBCASE("fragmented message coalesced into single OnMessage")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Split "hello world" into three fragments: first is Text/FIN=0,
+ // two continuations (FIN=0, then FIN=1).
+ std::string_view Part1 = "hello ";
+ std::string_view Part2 = "wor";
+ std::string_view Part3 = "ld";
+
+ std::vector<uint8_t> F1 = BuildMaskedFrameEx(WebSocketOpcode::kText,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part1.data()), Part1.size()),
+ /*Fin=*/false);
+ std::vector<uint8_t> F2 = BuildMaskedFrameEx(WebSocketOpcode::kContinuation,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part2.data()), Part2.size()),
+ /*Fin=*/false);
+ std::vector<uint8_t> F3 = BuildMaskedFrameEx(WebSocketOpcode::kContinuation,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part3.data()), Part3.size()),
+ /*Fin=*/true);
+
+ asio::write(Sock, asio::buffer(F1));
+ asio::write(Sock, asio::buffer(F2));
+ asio::write(Sock, asio::buffer(F3));
+
+ // Server should dispatch a single coalesced message and echo it back.
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText);
+ std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size());
+ CHECK_EQ(ReplyText, "hello world"sv);
+ CHECK_EQ(TestService.m_MessageCount.load(), 1);
+ CHECK_EQ(TestService.m_LastMessage, "hello world");
+
+ Sock.close();
+ }
+
+ SUBCASE("unexpected continuation frame closes connection with 1002")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Continuation frame with no in-progress message → protocol error
+ std::string_view Part = "orphan";
+ std::vector<uint8_t> F = BuildMaskedFrameEx(WebSocketOpcode::kContinuation,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part.data()), Part.size()),
+ /*Fin=*/true);
+ asio::write(Sock, asio::buffer(F));
+
+ // Server should send a close frame with 1002
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Reply.Payload.size() >= 2);
+ uint16_t Code = (uint16_t(Reply.Payload[0]) << 8) | uint16_t(Reply.Payload[1]);
+ CHECK_EQ(Code, 1002);
+
+ Sock.close();
+ }
+
+ SUBCASE("interleaved text frame mid-fragment closes connection with 1002")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Start a fragmented Text message, then send another Text frame without
+ // finishing the first — RFC 6455 §5.4 says this is illegal.
+ std::string_view Part1 = "open";
+ std::vector<uint8_t> F1 = BuildMaskedFrameEx(WebSocketOpcode::kText,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part1.data()), Part1.size()),
+ /*Fin=*/false);
+ std::string_view Part2 = "interrupt";
+ std::vector<uint8_t> F2 = BuildMaskedFrameEx(WebSocketOpcode::kText,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part2.data()), Part2.size()),
+ /*Fin=*/true);
+
+ asio::write(Sock, asio::buffer(F1));
+ asio::write(Sock, asio::buffer(F2));
+
+ WsFrameParseResult Reply = ReadOneFrame(Sock);
+ REQUIRE(Reply.IsValid);
+ CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose);
+ REQUIRE(Reply.Payload.size() >= 2);
+ uint16_t Code = (uint16_t(Reply.Payload[0]) << 8) | uint16_t(Reply.Payload[1]);
+ CHECK_EQ(Code, 1002);
+
+ Sock.close();
+ }
+
+ SUBCASE("ping interleaved with fragments is allowed")
+ {
+ asio::io_context IoCtx;
+ asio::ip::tcp::socket Sock(IoCtx);
+ Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port)));
+
+ bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port);
+ REQUIRE(Ok);
+ Sleep(50);
+
+ // Send: Text frame FIN=0, then Ping (control), then Continuation FIN=1.
+ // The control frame must be dispatched immediately and not affect
+ // fragment state; the final continuation should coalesce and echo.
+ std::string_view Part1 = "ab";
+ std::vector<uint8_t> F1 = BuildMaskedFrameEx(WebSocketOpcode::kText,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part1.data()), Part1.size()),
+ /*Fin=*/false);
+ std::string_view PingData = "p";
+ std::vector<uint8_t> PingFrame =
+ BuildMaskedFrame(WebSocketOpcode::kPing,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(PingData.data()), PingData.size()));
+ std::string_view Part2 = "cd";
+ std::vector<uint8_t> F2 = BuildMaskedFrameEx(WebSocketOpcode::kContinuation,
+ std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part2.data()), Part2.size()),
+ /*Fin=*/true);
+
+ asio::write(Sock, asio::buffer(F1));
+ asio::write(Sock, asio::buffer(PingFrame));
+ asio::write(Sock, asio::buffer(F2));
+
+ // We expect two server-to-client frames in sequence: Pong (for the
+ // interleaved Ping) then Text (the coalesced echo). Read bytes into
+ // a growing buffer and parse both frames from the same buffer so
+ // ReadOneFrame's one-shot semantics don't lose the trailing bytes
+ // when both frames arrive in a single TCP segment.
+ std::vector<uint8_t> Buffer;
+ WsFrameParseResult Pong;
+ WsFrameParseResult Echo;
+ auto Deadline = std::chrono::steady_clock::now() + 5s;
+ while (std::chrono::steady_clock::now() < Deadline)
+ {
+ uint8_t Tmp[4096];
+ asio::error_code Ec;
+ size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec);
+ if (Ec || BytesRead == 0)
+ {
+ break;
+ }
+ Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead);
+
+ if (!Pong.IsValid)
+ {
+ WsFrameParseResult F = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (F.IsValid)
+ {
+ Pong = std::move(F);
+ Buffer.erase(Buffer.begin(), Buffer.begin() + Pong.BytesConsumed);
+ }
+ }
+ if (Pong.IsValid)
+ {
+ WsFrameParseResult F = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size());
+ if (F.IsValid)
+ {
+ Echo = std::move(F);
+ break;
+ }
+ }
+ }
+
+ REQUIRE(Pong.IsValid);
+ CHECK_EQ(Pong.Opcode, WebSocketOpcode::kPong);
+
+ REQUIRE(Echo.IsValid);
+ CHECK_EQ(Echo.Opcode, WebSocketOpcode::kText);
+ std::string_view EchoText(reinterpret_cast<const char*>(Echo.Payload.data()), Echo.Payload.size());
+ CHECK_EQ(EchoText, "abcd"sv);
+
+ Sock.close();
+ }
+
SUBCASE("multiple messages in sequence")
{
asio::io_context IoCtx;
@@ -877,6 +1193,57 @@ TEST_CASE("websocket.client")
CHECK_FALSE(Client.IsOpen());
}
+ SUBCASE("rejects non-101 response that contains '101' in a header value")
+ {
+ // Stand up a tiny TCP listener that replies to any connection with a
+ // 404 response whose headers happen to contain "101". An earlier,
+ // permissive client would accept this as a successful WebSocket
+ // upgrade; the hardened client must reject it.
+ asio::io_context FakeCtx;
+ asio::ip::tcp::acceptor Acceptor(FakeCtx, asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), 0));
+ uint16_t FakePort = Acceptor.local_endpoint().port();
+
+ std::thread FakeServer([&] {
+ asio::ip::tcp::socket S(FakeCtx);
+ asio::error_code Ec;
+ Acceptor.accept(S, Ec);
+ if (Ec)
+ {
+ return;
+ }
+ asio::streambuf Req;
+ asio::read_until(S, Req, "\r\n\r\n", Ec);
+ std::string_view Resp =
+ "HTTP/1.1 404 Not Found\r\n"
+ "X-Retry-After: 101\r\n"
+ "Content-Length: 0\r\n"
+ "\r\n";
+ asio::write(S, asio::buffer(Resp.data(), Resp.size()), Ec);
+ S.close(Ec);
+ });
+ auto FakeGuard = MakeGuard([&] {
+ Acceptor.close();
+ if (FakeServer.joinable())
+ {
+ FakeServer.join();
+ }
+ });
+
+ TestWsClientHandler Handler;
+ std::string FakeUrl = fmt::format("ws://127.0.0.1:{}/any", FakePort);
+ HttpWsClient FakeClient(FakeUrl, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)});
+ FakeClient.Connect();
+
+ auto FakeDeadline = std::chrono::steady_clock::now() + 5s;
+ while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < FakeDeadline)
+ {
+ Sleep(10);
+ }
+
+ CHECK_EQ(Handler.m_OpenCount.load(), 0);
+ CHECK_EQ(Handler.m_CloseCount.load(), 1);
+ }
+
SUBCASE("connect to bad port")
{
TestWsClientHandler Handler;
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
index 1317f0159..30a50de12 100644
--- a/src/zenhttp/zenhttp.cpp
+++ b/src/zenhttp/zenhttp.cpp
@@ -12,6 +12,7 @@
namespace zen {
+void authmgr_forcelink();
void zipfs_test_forcelink();
void
@@ -25,6 +26,7 @@ zenhttp_forcelinktests()
forcelink_packageformat();
passwordsecurity_forcelink();
websocket_forcelink();
+ authmgr_forcelink();
zipfs_test_forcelink();
}
diff --git a/src/zenhttp/zipfs.cpp b/src/zenhttp/zipfs.cpp
index c0ffa2052..3bd09091a 100644
--- a/src/zenhttp/zipfs.cpp
+++ b/src/zenhttp/zipfs.cpp
@@ -100,57 +100,105 @@ namespace {
//////////////////////////////////////////////////////////////////////////
ZipFs::ZipFs(IoBuffer&& Buffer)
{
- MemoryView View = Buffer.GetView();
-
- uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize();
- if (View.GetSize() < sizeof(EocdRecord))
+ // Treat the input buffer as attacker-controlled. Every offset, size, and
+ // trailer length is validated against View.GetSize() before it is used to
+ // form a pointer. All additions are performed in uint64_t to prevent 32-bit
+ // wrap.
+ const MemoryView View = Buffer.GetView();
+ const uint8_t* Base = static_cast<const uint8_t*>(View.GetData());
+ const size_t Size = View.GetSize();
+
+ if (Size < sizeof(EocdRecord))
{
return;
}
- const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord));
+ const size_t EocdOffset = Size - sizeof(EocdRecord);
+ const EocdRecord* Eocd = reinterpret_cast<const EocdRecord*>(Base + EocdOffset);
- // It is more correct to search backwards for EocdRecord::Magic as the
- // comment can be of a variable length. But here we're not going to support
- // zip files with comments.
- if (EocdCursor->Signature != EocdRecord::Magic)
+ // We only support a zip whose EOCD sits at the very end of the buffer — no
+ // trailing comment, no Zip64.
+ if (Eocd->Signature != EocdRecord::Magic)
+ {
+ return;
+ }
+ if (Eocd->ThisDiskIndex == 0xffff)
{
return;
}
- // Zip64 isn't supported either
- if (EocdCursor->ThisDiskIndex == 0xffff)
+ const uint32_t CdOffsetRel = Eocd->CdOffset;
+ const uint32_t CdSize = Eocd->CdSize;
+ const uint16_t CdRecordCount = Eocd->CdRecordCount;
+
+ // Central directory must fit strictly before the EOCD. Derive the archive
+ // origin from the EOCD's declared layout so any pre-zip padding in the
+ // buffer is accounted for; LFH offsets are relative to this origin.
+ if (uint64_t(CdOffsetRel) + uint64_t(CdSize) > EocdOffset)
{
return;
}
- Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize);
+ const uint8_t* ArchiveStart = Base + (EocdOffset - CdOffsetRel - CdSize);
+ const uint8_t* CdCursor = ArchiveStart + CdOffsetRel;
+ const uint8_t* CdEnd = CdCursor + CdSize;
- const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset);
- for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i)
+ for (uint32_t Record = 0; Record < CdRecordCount; ++Record)
{
- const CentralDirectoryRecord& Cd = *CdCursor;
+ if (size_t(CdEnd - CdCursor) < sizeof(CentralDirectoryRecord))
+ {
+ return;
+ }
+ const CentralDirectoryRecord& Cd = *reinterpret_cast<const CentralDirectoryRecord*>(CdCursor);
+ if (Cd.Signature != CentralDirectoryRecord::Magic)
+ {
+ return;
+ }
+
+ const uint16_t NameLen = Cd.FileNameLength;
+ const uint16_t ExtraLen = Cd.ExtraFieldLength;
+ const uint16_t CommentLen = Cd.CommentLength;
+ const uint32_t Trailer = uint32_t(NameLen) + uint32_t(ExtraLen) + uint32_t(CommentLen);
+ if (size_t(CdEnd - CdCursor) - sizeof(CentralDirectoryRecord) < Trailer)
+ {
+ return;
+ }
+
+ const uint16_t Compression = Cd.CompressionMethod;
+ const uint32_t Compressed = Cd.CompressedSize;
+ const uint32_t Original = Cd.OriginalSize;
+ const uint32_t LfhOffset = Cd.Offset;
- bool Acceptable = true;
- Acceptable &= (Cd.OriginalSize > 0); // has some content
- Acceptable &= (Cd.CompressionMethod == 0 || Cd.CompressionMethod == 8); // stored or deflate
- if (Acceptable)
+ const bool AcceptableCompression = (Compression == 0) || (Compression == 8);
+ const bool HasContent = Original > 0;
+ if (AcceptableCompression && HasContent)
{
- const uint8_t* Lfh = Cursor + Cd.Offset;
- if (uintptr_t(Lfh - Cursor) < View.GetSize())
+ // LFH header must fit inside the [ArchiveStart, CdOffsetRel) region
+ // (i.e. the pre-CD body). The LFH's own name + extra + compressed
+ // payload must also fit in that region.
+ const uint64_t LfhEndRel = uint64_t(LfhOffset) + sizeof(LocalFileHeader);
+ if (LfhEndRel <= CdOffsetRel)
{
- std::string_view FileName(Cd.FileName, Cd.FileNameLength);
- FileItem Item;
- Item.View = MemoryView{Lfh, size_t(0)};
- Item.CompressionMethod = Cd.CompressionMethod;
- Item.CompressedSize = Cd.CompressedSize;
- Item.UncompressedSize = Cd.OriginalSize;
- m_Files.insert(std::make_pair(FileName, std::move(Item)));
+ const LocalFileHeader* Lfh = reinterpret_cast<const LocalFileHeader*>(ArchiveStart + LfhOffset);
+ const uint64_t DataStartRel =
+ LfhEndRel + uint64_t(uint16_t(Lfh->FileNameLength)) + uint64_t(uint16_t(Lfh->ExtraFieldLength));
+ const uint64_t DataEndRel = DataStartRel + uint64_t(Compressed);
+ if (DataEndRel <= CdOffsetRel)
+ {
+ const uint8_t* FileData = ArchiveStart + DataStartRel;
+ std::string_view FileName(Cd.FileName, NameLen);
+
+ FileItem Item;
+ Item.View = MemoryView{FileData, size_t(0)};
+ Item.CompressionMethod = Compression;
+ Item.CompressedSize = Compressed;
+ Item.UncompressedSize = Original;
+ m_Files.insert({FileName, std::move(Item)});
+ }
}
}
- uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength;
- CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes);
+ CdCursor += sizeof(CentralDirectoryRecord) + Trailer;
}
m_Buffer = std::move(Buffer);
@@ -184,8 +232,7 @@ ZipFs::GetFile(const std::string_view& FileName) const
return IoBuffer(IoBuffer::Wrap, Item.View.GetData(), Item.View.GetSize());
}
- const auto* Lfh = (LocalFileHeader*)(Item.View.GetData());
- const uint8_t* FileData = (const uint8_t*)(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength);
+ const uint8_t* FileData = static_cast<const uint8_t*>(Item.View.GetData());
if (Item.CompressionMethod == 0)
{
diff --git a/src/zenhttp/zipfs_test.cpp b/src/zenhttp/zipfs_test.cpp
index b3a45c408..0ccca7ce2 100644
--- a/src/zenhttp/zipfs_test.cpp
+++ b/src/zenhttp/zipfs_test.cpp
@@ -216,6 +216,150 @@ TEST_CASE("zipfs.not_found")
CHECK(!Result);
}
+//////////////////////////////////////////////////////////////////////////
+// Malformed / attacker-shaped inputs — the parser must refuse these rather
+// than read out of bounds. Field offsets below mirror the in-memory layout:
+//
+// EocdRecord: CentralDirectoryRecord:
+// +0 Signature (4) +0 Signature (4)
+// +4 ThisDiskIndex (2) +20 CompressedSize (4)
+// +6 CdStartDiskIndex (2) +28 FileNameLength (2)
+// +8 CdRecordsThis (2) +42 Offset (4)
+// +10 CdRecords (2)
+// +12 CdSize (4)
+// +16 CdOffset (4)
+// +20 CommentSize (2) = 46 bytes header
+// = 22 bytes
+
+namespace {
+constexpr size_t kEocdSize = 22;
+constexpr size_t kEocdOffCdSz = 12;
+constexpr size_t kEocdOffCdOff = 16;
+constexpr size_t kCdrOffCompSz = 20;
+constexpr size_t kCdrOffNameLn = 28;
+constexpr size_t kCdrOffLfhOff = 42;
+
+template<typename T>
+void
+WriteAt(zen::IoBuffer& Buf, size_t Offset, T Value)
+{
+ std::memcpy(static_cast<uint8_t*>(Buf.GetMutableView().GetData()) + Offset, &Value, sizeof(Value));
+}
+
+template<typename T>
+T
+ReadAt(const zen::IoBuffer& Buf, size_t Offset)
+{
+ T Out;
+ std::memcpy(&Out, static_cast<const uint8_t*>(Buf.GetView().GetData()) + Offset, sizeof(Out));
+ return Out;
+}
+
+size_t
+EocdOffset(const zen::IoBuffer& Buf)
+{
+ return Buf.GetView().GetSize() - kEocdSize;
+}
+
+size_t
+CdOffset(const zen::IoBuffer& Buf)
+{
+ return ReadAt<uint32_t>(Buf, EocdOffset(Buf) + kEocdOffCdOff);
+}
+} // namespace
+
+TEST_CASE("zipfs.buffer_smaller_than_eocd")
+{
+ zen::IoBuffer Tiny(10);
+ std::memset(Tiny.GetMutableView().GetData(), 0, Tiny.GetView().GetSize());
+ zen::ZipFs Fs(std::move(Tiny));
+ CHECK(!Fs.GetFile("anything"));
+}
+
+TEST_CASE("zipfs.bad_eocd_magic")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint32_t>(Buf, EocdOffset(Buf), 0xdeadbeef);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
+TEST_CASE("zipfs.cd_offset_past_eocd")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint32_t>(Buf, EocdOffset(Buf) + kEocdOffCdOff, 0xFFFF0000u);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
+TEST_CASE("zipfs.cd_size_past_eocd")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint32_t>(Buf, EocdOffset(Buf) + kEocdOffCdSz, 0xFFFF0000u);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
+TEST_CASE("zipfs.bad_cd_record_signature")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint32_t>(Buf, CdOffset(Buf), 0xdeadbeef);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
+TEST_CASE("zipfs.oversize_filename_length")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint16_t>(Buf, CdOffset(Buf) + kCdrOffNameLn, 0xFFFF);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
+TEST_CASE("zipfs.lfh_offset_past_buffer")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint32_t>(Buf, CdOffset(Buf) + kCdrOffLfhOff, 0xFFFFFFFF);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
+TEST_CASE("zipfs.compressed_size_past_buffer")
+{
+ ZipBuilder Zip;
+ Zip.AddFile("test.txt", "x", 1, false);
+ zen::IoBuffer Buf = Zip.Build();
+
+ WriteAt<uint32_t>(Buf, CdOffset(Buf) + kCdrOffCompSz, 0xFFFF0000u);
+
+ zen::ZipFs Fs(std::move(Buf));
+ CHECK(!Fs.GetFile("test.txt"));
+}
+
TEST_SUITE_END();
#endif // ZEN_WITH_TESTS
diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp
index 812536074..c7c2b0023 100644
--- a/src/zenserver/frontend/frontend.cpp
+++ b/src/zenserver/frontend/frontend.cpp
@@ -143,12 +143,6 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
Uri = UriBuilder;
}
- // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads
- if (Uri.find("..") != Uri.npos)
- {
- return Request.WriteResponse(HttpResponseCode::Forbidden);
- }
-
// Map the file extension to a MIME type. To keep things constrained, only a
// small subset of file extensions is allowed
@@ -184,28 +178,32 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
constexpr std::string_view DataPrefix = "data/";
if (!m_DocsDirectory.empty() && InUri.starts_with(DataPrefix))
{
- std::string_view DocsRelative = InUri.substr(DataPrefix.size());
- auto FullPath = m_DocsDirectory / std::filesystem::path(DocsRelative).make_preferred();
- FileContents File = ReadFile(FullPath);
-
- if (!File.ErrorCode)
+ const std::string_view DocsRelative = InUri.substr(DataPrefix.size());
+ if (std::optional<std::filesystem::path> FullPath = ResolveSafeRelativePath(m_DocsDirectory, DocsRelative))
{
- Request.WriteResponse(ResponseCode, ContentType, File.Data[0]);
- return true;
+ FileContents File = ReadFile(*FullPath);
+
+ if (!File.ErrorCode)
+ {
+ Request.WriteResponse(ResponseCode, ContentType, File.Data[0]);
+ return true;
+ }
}
}
// The given content directory overrides any zip-fs discovered in the binary
if (!m_Directory.empty())
{
- auto FullPath = m_Directory / std::filesystem::path(InUri).make_preferred();
- FileContents File = ReadFile(FullPath);
-
- if (!File.ErrorCode)
+ if (std::optional<std::filesystem::path> FullPath = ResolveSafeRelativePath(m_Directory, InUri))
{
- Request.WriteResponse(ResponseCode, ContentType, File.Data[0]);
+ FileContents File = ReadFile(*FullPath);
- return true;
+ if (!File.ErrorCode)
+ {
+ Request.WriteResponse(ResponseCode, ContentType, File.Data[0]);
+
+ return true;
+ }
}
}
diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp
index 2276cb81a..88db36828 100644
--- a/src/zenserver/sessions/httpsessions.cpp
+++ b/src/zenserver/sessions/httpsessions.cpp
@@ -5,6 +5,7 @@
#include <zencore/compactbinarybuilder.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
+#include <zencore/string.h>
#include <zencore/trace.h>
#include "sessions.h"
@@ -470,9 +471,14 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req)
std::string_view CursorStr = Params.GetValue("cursor"sv);
if (!CursorStr.empty())
{
- uint64_t AfterCursor = std::strtoull(std::string(CursorStr).c_str(), nullptr, 10);
+ const std::optional<uint64_t> AfterCursor = ParseInt<uint64_t>(CursorStr);
+ if (!AfterCursor)
+ {
+ m_SessionsStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'cursor' parameter"sv);
+ }
- SessionsService::Session::CursorResult Result = Session->GetLogEntriesAfter(AfterCursor);
+ SessionsService::Session::CursorResult Result = Session->GetLogEntriesAfter(*AfterCursor);
CbObjectWriter Response;
Response << "cursor" << Result.Cursor;
@@ -495,11 +501,23 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req)
if (std::string_view LimitStr = Params.GetValue("limit"sv); !LimitStr.empty())
{
- Limit = uint32_t(std::strtoul(std::string(LimitStr).c_str(), nullptr, 10));
+ const std::optional<uint32_t> Parsed = ParseInt<uint32_t>(LimitStr);
+ if (!Parsed)
+ {
+ m_SessionsStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'limit' parameter"sv);
+ }
+ Limit = *Parsed;
}
if (std::string_view OffsetStr = Params.GetValue("offset"sv); !OffsetStr.empty())
{
- Offset = uint32_t(std::strtoul(std::string(OffsetStr).c_str(), nullptr, 10));
+ const std::optional<uint32_t> Parsed = ParseInt<uint32_t>(OffsetStr);
+ if (!Parsed)
+ {
+ m_SessionsStats.BadRequestCount++;
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'offset' parameter"sv);
+ }
+ Offset = *Parsed;
}
std::vector<SessionsService::LogEntry> Entries = Session->GetLogEntries(Limit, Offset);
diff --git a/src/zenserver/storage/admin/admin.cpp b/src/zenserver/storage/admin/admin.cpp
index 34d9e570e..1de5f74fe 100644
--- a/src/zenserver/storage/admin/admin.cpp
+++ b/src/zenserver/storage/admin/admin.cpp
@@ -26,6 +26,60 @@
namespace zen {
+#if ZEN_WITH_TRACE
+namespace {
+ // Accept only loopback destinations for admin-triggered trace streams. Handles
+ // "localhost", "127.0.0.1", "::1", and bracketed IPv6 ("[::1]"), each optionally
+ // followed by ":<digits>". Rejects any control characters so the value is also
+ // safe to log.
+ bool IsLoopbackTraceHost(std::string_view Host)
+ {
+ if (Host.empty())
+ {
+ return false;
+ }
+ for (char C : Host)
+ {
+ if (static_cast<unsigned char>(C) < 0x20 || C == 0x7F)
+ {
+ return false;
+ }
+ }
+
+ std::string_view HostOnly = Host;
+ if (HostOnly.front() == '[')
+ {
+ const size_t Close = HostOnly.find(']');
+ if (Close == std::string_view::npos)
+ {
+ return false;
+ }
+ const std::string_view Tail = HostOnly.substr(Close + 1);
+ if (!Tail.empty())
+ {
+ if (Tail.front() != ':' || Tail.size() < 2 || Tail.find_first_not_of("0123456789", 1) != std::string_view::npos)
+ {
+ return false;
+ }
+ }
+ HostOnly = HostOnly.substr(1, Close - 1);
+ }
+ else if (const size_t Colon = HostOnly.find(':');
+ Colon != std::string_view::npos && HostOnly.find(':', Colon + 1) == std::string_view::npos)
+ {
+ const std::string_view Port = HostOnly.substr(Colon + 1);
+ if (Port.empty() || Port.find_first_not_of("0123456789") != std::string_view::npos)
+ {
+ return false;
+ }
+ HostOnly = HostOnly.substr(0, Colon);
+ }
+
+ return HostOnly == "localhost" || HostOnly == "127.0.0.1" || HostOnly == "::1";
+ }
+} // namespace
+#endif // ZEN_WITH_TRACE
+
struct DirStats
{
uint64_t FileCount = 0;
@@ -149,17 +203,13 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler,
[&](HttpRouterRequest& Req) {
const auto& JobIdString = Req.GetCapture(1);
std::optional<uint64_t> JobIdArg = ParseInt<uint64_t>(JobIdString);
- if (!JobIdArg)
- {
- Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
- }
- JobId Id{.Id = JobIdArg.value_or(0)};
- if (Id.Id == 0)
+ if (!JobIdArg || JobIdArg.value() == 0)
{
return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest,
ZenContentType::kText,
- fmt::format("Invalid Job Id: {}", Id.Id));
+ fmt::format("Invalid Job Id: '{}'", JobIdString));
}
+ const JobId Id{.Id = JobIdArg.value()};
std::optional<JobQueue::JobDetails> CurrentState = m_BackgroundJobQueue.Get(Id);
if (!CurrentState)
@@ -271,11 +321,13 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler,
[&](HttpRouterRequest& Req) {
const auto& JobIdString = Req.GetCapture(1);
std::optional<uint64_t> JobIdArg = ParseInt<uint64_t>(JobIdString);
- if (!JobIdArg)
+ if (!JobIdArg || JobIdArg.value() == 0)
{
- Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest,
+ ZenContentType::kText,
+ fmt::format("Invalid Job Id: '{}'", JobIdString));
}
- JobId Id{.Id = JobIdArg.value_or(0)};
+ const JobId Id{.Id = JobIdArg.value()};
if (m_BackgroundJobQueue.CancelJob(Id))
{
Req.ServerRequest().WriteResponse(HttpResponseCode::OK);
@@ -610,11 +662,6 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler,
const HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
TraceOptions TraceOptions;
- if (!IsTracing())
- {
- TraceInit("zenserver");
- }
-
if (auto Channels = Params.GetValue("channels"); Channels.empty() == false)
{
TraceOptions.Channels = Channels;
@@ -622,22 +669,41 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler,
if (auto File = Params.GetValue("file"); File.empty() == false)
{
- TraceOptions.File = File;
+ const std::filesystem::path TracesRoot = m_ServerOptions.DataDir / "traces";
+ std::optional<std::filesystem::path> Resolved = ResolveSafeRelativePath(TracesRoot, File);
+ if (!Resolved)
+ {
+ ZEN_WARN("admin trace/start rejected unsafe 'file' parameter '{}'", File);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'file' parameter"sv);
+ }
+ TraceOptions.File = Resolved->string();
}
else if (auto Host = Params.GetValue("host"); Host.empty() == false)
{
+ if (!IsLoopbackTraceHost(Host))
+ {
+ ZEN_WARN("admin trace/start rejected non-loopback 'host' parameter '{}'", Host);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Invalid 'host' parameter (must be a loopback address)"sv);
+ }
TraceOptions.Host = Host;
}
else
{
- return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest,
- HttpContentType::kText,
- "Invalid trace type, use `file` or `host`"sv);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Invalid trace type, use `file` or `host`"sv);
+ }
+
+ if (!IsTracing())
+ {
+ TraceInit("zenserver");
}
TraceConfigure(TraceOptions);
- return Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "Tracing started");
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "Tracing started");
},
HttpVerb::kPost);
diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp
index f935e2c6b..bdaaef327 100644
--- a/src/zenserver/storage/buildstore/httpbuildstore.cpp
+++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp
@@ -138,8 +138,7 @@ HttpBuildStoreService::PutBlobRequest(HttpRouterRequest& Req)
HttpContentType::kText,
fmt::format("Payload blob {} content type {} is invalid", Hash, ToString(Payload.GetContentType())));
}
- m_BuildStore.PutBlob(BlobHash, ServerRequest.ReadPayload());
- // ZEN_INFO("Stored blob {}. Size: {}", BlobHash, ServerRequest.ReadPayload().GetSize());
+ m_BuildStore.PutBlob(BlobHash, std::move(Payload));
return ServerRequest.WriteResponse(HttpResponseCode::OK);
}
diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp
index 4d3673e70..7ca496dec 100644
--- a/src/zenserver/storage/cache/httpstructuredcache.cpp
+++ b/src/zenserver/storage/cache/httpstructuredcache.cpp
@@ -70,6 +70,7 @@ namespace {
static constinit std::string_view HttpZCacheUtilStopRecording = "exec$/stop-recording"sv;
static constinit std::string_view HttpZCacheUtilReplayRecording = "exec$/replay-recording"sv;
static constinit std::string_view HttpZCacheDetailsPrefix = "details$"sv;
+ static constinit std::string_view HttpZCacheRecordingsDirName = "recordings"sv;
} // namespace
//////////////////////////////////////////////////////////////////////////
@@ -396,17 +397,26 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
{
HttpServerRequest::QueryParams Params = Request.GetQueryParams();
- std::string RecordPath = UrlDecode(Params.GetValue("path"));
+ const std::string RecordPath = UrlDecode(Params.GetValue("path"));
+
+ const std::filesystem::path RecordingsRoot = m_CacheStore.GetBasePath() / HttpZCacheRecordingsDirName;
+ std::optional<std::filesystem::path> ResolvedPath = ResolveSafeRelativePath(RecordingsRoot, RecordPath);
+ if (!ResolvedPath)
+ {
+ m_CacheStats.BadRequestCount++;
+ ZEN_WARN("cache RPC start-recording rejected unsafe path '{}'", RecordPath);
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
{
RwLock::ExclusiveLockScope _(m_RequestRecordingLock);
m_RequestRecordingEnabled.store(false);
m_RequestRecorder.reset();
- m_RequestRecorder = cache::MakeDiskRequestRecorder(RecordPath);
+ m_RequestRecorder = cache::MakeDiskRequestRecorder(*ResolvedPath);
m_RequestRecordingEnabled.store(true);
}
- ZEN_INFO("cache RPC recording STARTED -> '{}'", RecordPath);
+ ZEN_INFO("cache RPC recording STARTED -> '{}'", *ResolvedPath);
Request.WriteResponse(HttpResponseCode::OK);
return;
}
@@ -435,7 +445,16 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
HttpServerRequest::QueryParams Params = Request.GetQueryParams();
- std::string RecordPath = UrlDecode(Params.GetValue("path"));
+ const std::string RecordPath = UrlDecode(Params.GetValue("path"));
+
+ const std::filesystem::path RecordingsRoot = m_CacheStore.GetBasePath() / HttpZCacheRecordingsDirName;
+ std::optional<std::filesystem::path> ResolvedPath = ResolveSafeRelativePath(RecordingsRoot, RecordPath);
+ if (!ResolvedPath)
+ {
+ m_CacheStats.BadRequestCount++;
+ ZEN_WARN("cache RPC replay-recording rejected unsafe path '{}'", RecordPath);
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
const uint32_t HardwareConcurrency = GetHardwareConcurrency();
const uint32_t MaxThreadCount = std::max<uint32_t>(HardwareConcurrency, 16u);
@@ -449,9 +468,9 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
}
ThreadCount = std::clamp<uint32_t>(ThreadCount, 1u, MaxThreadCount);
- ZEN_INFO("initiating cache RPC replay using {} threads, from '{}'", ThreadCount, RecordPath);
+ ZEN_INFO("initiating cache RPC replay using {} threads, from '{}'", ThreadCount, *ResolvedPath);
- std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(RecordPath, false));
+ std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(*ResolvedPath, false));
ReplayRequestRecorder(RequestContext, *Replayer, ThreadCount);
ZEN_INFO("cache RPC replay COMPLETED");
diff --git a/src/zenserver/storage/objectstore/objectstore.cpp b/src/zenserver/storage/objectstore/objectstore.cpp
index 1115c1cd6..252a381ae 100644
--- a/src/zenserver/storage/objectstore/objectstore.cpp
+++ b/src/zenserver/storage/objectstore/objectstore.cpp
@@ -29,6 +29,15 @@ using namespace std::literals;
ZEN_DEFINE_LOG_CATEGORY_STATIC(LogObj, "obj"sv);
+namespace {
+ // Permitted bucket-name characters. Must stay in sync with the "bucket" URL matcher
+ // registered in HttpObjectStoreService::Initialize so POST / DELETE / PUT with an
+ // explicit bucketname payload uses the same rule as the route matcher.
+ constexpr AsciiSet ValidBucketCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+
+ bool IsValidBucketName(std::string_view Name) { return !Name.empty() && AsciiSet::HasOnly(Name, ValidBucketCharactersSet); }
+} // namespace
+
class CbXmlWriter
{
public:
@@ -302,12 +311,10 @@ HttpObjectStoreService::Initialize()
}
static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]() ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
- static constexpr AsciiSet ValidBucketCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
m_Router.AddMatcher("path",
[](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); });
- m_Router.AddMatcher("bucket",
- [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidBucketCharactersSet); });
+ m_Router.AddMatcher("bucket", [](std::string_view Str) -> bool { return IsValidBucketName(Str); });
m_Router.RegisterRoute(
"",
@@ -471,8 +478,9 @@ HttpObjectStoreService::CreateBucket(HttpRouterRequest& Request)
const CbObject Params = Request.ServerRequest().ReadPayloadObject();
const std::string_view BucketName = Params["bucketname"].AsString();
- if (BucketName.empty())
+ if (!IsValidBucketName(BucketName))
{
+ ZEN_LOG_WARN(LogObj, "CREATE - rejected invalid bucket name '{}'", BucketName);
return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
}
@@ -515,9 +523,21 @@ HttpObjectStoreService::ListBucket(HttpRouterRequest& Request, const std::string
BucketPrefix.erase(0, BucketPrefix.find_first_not_of('/'));
BucketPrefix.erase(0, BucketPrefix.find_first_not_of('\\'));
- const fs::path BucketRoot = GetBucketDirectory(BucketName);
- const fs::path RelativeBucketPath = fs::path(BucketPrefix).make_preferred();
- const fs::path FullPath = BucketRoot / RelativeBucketPath;
+ const fs::path BucketRoot = GetBucketDirectory(BucketName);
+
+ fs::path RelativeBucketPath;
+ fs::path FullPath = BucketRoot;
+ if (!BucketPrefix.empty())
+ {
+ std::optional<fs::path> Resolved = ResolveSafeRelativePath(BucketRoot, BucketPrefix);
+ if (!Resolved)
+ {
+ ZEN_LOG_WARN(LogObj, "LIST - bucket '{}' rejected unsafe prefix '{}'", BucketName, BucketPrefix);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
+ }
+ FullPath = std::move(*Resolved);
+ RelativeBucketPath = fs::relative(FullPath, BucketRoot);
+ }
struct Visitor : FileSystemTraversal::TreeVisitor
{
@@ -589,8 +609,9 @@ HttpObjectStoreService::DeleteBucket(HttpRouterRequest& Request)
const CbObject Params = Request.ServerRequest().ReadPayloadObject();
const std::string_view BucketName = Params["bucketname"].AsString();
- if (BucketName.empty())
+ if (!IsValidBucketName(BucketName))
{
+ ZEN_LOG_WARN(LogObj, "DELETE - rejected invalid bucket name '{}'", BucketName);
return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
}
@@ -621,15 +642,14 @@ HttpObjectStoreService::GetObject(HttpRouterRequest& Request, const std::string_
return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
}
- const fs::path RelativeBucketPath = fs::path(BucketPrefix).make_preferred();
-
- if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with(".."))
+ std::optional<fs::path> ResolvedFilePath = ResolveSafeRelativePath(BucketDir, BucketPrefix);
+ if (!ResolvedFilePath)
{
- ZEN_LOG_DEBUG(LogObj, "GET - from bucket '{}' [FAILED], invalid file path", BucketName);
+ ZEN_LOG_WARN(LogObj, "GET - from bucket '{}' rejected unsafe path '{}'", BucketName, BucketPrefix);
return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
}
-
- const fs::path FilePath = BucketDir / RelativeBucketPath;
+ const fs::path FilePath = std::move(*ResolvedFilePath);
+ const fs::path RelativeBucketPath = fs::relative(FilePath, BucketDir);
if (!IsFile(FilePath))
{
ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' [FAILED], doesn't exist", BucketName, FilePath);
@@ -720,16 +740,17 @@ HttpObjectStoreService::PutObject(HttpRouterRequest& Request)
return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
}
- const fs::path RelativeBucketPath = fs::path(Request.GetCapture(2)).make_preferred();
+ const std::string_view PathCapture = Request.GetCapture(2);
- if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with(".."))
+ std::optional<fs::path> ResolvedFilePath = ResolveSafeRelativePath(BucketDir, PathCapture);
+ if (!ResolvedFilePath)
{
- ZEN_LOG_DEBUG(LogObj, "PUT - bucket '{}' [FAILED], invalid file path", BucketName);
+ ZEN_LOG_WARN(LogObj, "PUT - bucket '{}' rejected unsafe path '{}'", BucketName, PathCapture);
return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
}
-
- const fs::path FilePath = BucketDir / RelativeBucketPath;
- const fs::path FileDirectory = FilePath.parent_path();
+ const fs::path FilePath = std::move(*ResolvedFilePath);
+ const fs::path RelativeBucketPath = fs::relative(FilePath, BucketDir);
+ const fs::path FileDirectory = FilePath.parent_path();
{
std::lock_guard _(m_BucketsMutex);
diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp
index c40690d5f..9c94a0381 100644
--- a/src/zenserver/storage/projectstore/httpprojectstore.cpp
+++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp
@@ -1000,10 +1000,10 @@ HttpProjectService::HandleChunkBatchRequest(HttpRouterRequest& Req)
uint64_t RequestBytes;
};
- if (Payload.Size() <= sizeof(RequestHeader))
+ if (Payload.Size() < sizeof(RequestHeader))
{
m_ProjectStats.BadRequestCount++;
- HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
}
RequestHeader RequestHdr;
@@ -1012,7 +1012,21 @@ HttpProjectService::HandleChunkBatchRequest(HttpRouterRequest& Req)
if (RequestHdr.Magic != RequestHeader::kMagic)
{
m_ProjectStats.BadRequestCount++;
- HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ constexpr uint32_t kMaxChunkCount = 1u << 16;
+ if (RequestHdr.ChunkCount > kMaxChunkCount)
+ {
+ m_ProjectStats.BadRequestCount++;
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const uint64_t ExpectedChunkBytes = uint64_t(RequestHdr.ChunkCount) * sizeof(RequestChunkEntry);
+ if (ExpectedChunkBytes > Reader.Remaining())
+ {
+ m_ProjectStats.BadRequestCount++;
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
}
std::vector<RequestChunkEntry> RequestedChunks;
@@ -1610,9 +1624,14 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req)
if (std::string_view SaltParam = Params.GetValue("salt"sv); SaltParam.empty() == false)
{
- const uint32_t Salt = std::stoi(std::string(SaltParam));
- SaltHash = IoHash::HashBuffer(&Salt, sizeof Salt);
- IsUsingSalt = true;
+ const std::optional<uint32_t> Salt = ParseInt<uint32_t>(SaltParam);
+ if (!Salt)
+ {
+ m_ProjectStats.BadRequestCount++;
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'salt' parameter"sv);
+ }
+ SaltHash = IoHash::HashBuffer(&Salt.value(), sizeof(uint32_t));
+ IsUsingSalt = true;
}
Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
diff --git a/src/zenserver/storage/workspaces/httpworkspaces.cpp b/src/zenserver/storage/workspaces/httpworkspaces.cpp
index 12e7bae73..ba3bc00dd 100644
--- a/src/zenserver/storage/workspaces/httpworkspaces.cpp
+++ b/src/zenserver/storage/workspaces/httpworkspaces.cpp
@@ -4,6 +4,7 @@
#include <zencore/basicfile.h>
#include <zencore/compactbinarybuilder.h>
+#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/trace.h>
@@ -29,6 +30,48 @@ namespace {
return {};
}
+ // Validate a workspace root_path supplied via HTTP. Rejects empty / non-absolute
+ // paths, Windows UNC (\\server\share) and device-namespace prefixes (\\?\, \\.\),
+ // and strings containing control characters. Canonicalises the result so any later
+ // joins and stored config anchor at a resolved, existing directory — a follow-up
+ // symlink swap on disk can no longer redirect the workspace root.
+ std::optional<std::filesystem::path> ValidateWorkspaceRootPath(std::string_view RawInput)
+ {
+ if (RawInput.empty())
+ {
+ return std::nullopt;
+ }
+ for (char C : RawInput)
+ {
+ if (static_cast<unsigned char>(C) < 0x20 || C == 0x7F)
+ {
+ return std::nullopt;
+ }
+ }
+ if (RawInput.starts_with("\\\\") || RawInput.starts_with("//"))
+ {
+ return std::nullopt;
+ }
+
+ std::filesystem::path Requested(RawInput);
+ if (!Requested.is_absolute())
+ {
+ return std::nullopt;
+ }
+
+ std::error_code Ec;
+ std::filesystem::path Canonical = std::filesystem::canonical(Requested, Ec);
+ if (Ec)
+ {
+ return std::nullopt;
+ }
+ if (!std::filesystem::is_directory(Canonical, Ec) || Ec)
+ {
+ return std::nullopt;
+ }
+ return Canonical;
+ }
+
void WriteWorkspaceConfig(CbWriter& Writer, const Workspaces::WorkspaceConfiguration& Config)
{
Writer << "id" << Config.Id;
@@ -505,14 +548,17 @@ HttpWorkspacesService::WorkspaceRequest(HttpRouterRequest& Req)
{
case HttpVerb::kPut:
{
- std::filesystem::path WorkspacePath = GetPathParameter(ServerRequest, "root_path"sv);
- if (WorkspacePath.empty())
+ const std::string RawRootPath = HttpServerRequest::Decode(ServerRequest.GetQueryParams().GetValue("root_path"sv));
+ std::optional<std::filesystem::path> ValidatedRootPath = ValidateWorkspaceRootPath(RawRootPath);
+ if (!ValidatedRootPath)
{
m_WorkspacesStats.BadRequestCount++;
+ ZEN_WARN("workspace PUT rejected unsafe 'root_path' parameter '{}'", RawRootPath);
return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
HttpContentType::kText,
"Invalid 'root_path' parameter");
}
+ std::filesystem::path WorkspacePath = std::move(*ValidatedRootPath);
if (Req.GetCapture(1) == Oid::Zero.ToString())
{
@@ -1096,6 +1142,16 @@ HttpWorkspacesService::ShareRequest(HttpRouterRequest& Req, const Oid& Workspace
fmt::format("Workspace '{}' does not exist", WorkspaceId));
}
+ std::optional<std::filesystem::path> ResolvedSharePath = ResolveSafeRelativePath(Workspace.RootPath, SharePath.string());
+ if (!ResolvedSharePath)
+ {
+ m_WorkspacesStats.BadRequestCount++;
+ ZEN_WARN("share PUT in workspace '{}' rejected unsafe 'share_path' parameter '{}'", WorkspaceId, SharePath);
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Invalid 'share_path' parameter");
+ }
+
if (!Workspace.AllowShareCreationFromHttp)
{
if (!MayChangeConfiguration(ServerRequest))
@@ -1143,7 +1199,7 @@ HttpWorkspacesService::ShareRequest(HttpRouterRequest& Req, const Oid& Workspace
}
}
- if (!IsDir(Workspace.RootPath / NewConfig.SharePath))
+ if (!IsDir(*ResolvedSharePath))
{
return ServerRequest.WriteResponse(HttpResponseCode::NotFound,
HttpContentType::kText,
diff --git a/src/zenstore/include/zenstore/cache/structuredcachestore.h b/src/zenstore/include/zenstore/cache/structuredcachestore.h
index 9e0561364..425dd4b79 100644
--- a/src/zenstore/include/zenstore/cache/structuredcachestore.h
+++ b/src/zenstore/include/zenstore/cache/structuredcachestore.h
@@ -274,6 +274,7 @@ public:
Configuration GetConfiguration() const { return m_Configuration; }
void SetLoggingConfig(const Configuration::LogConfig& Loggingconfig);
+ const std::filesystem::path& GetBasePath() const { return m_BasePath; }
Info GetInfo() const;
std::optional<ZenCacheNamespace::Info> GetNamespaceInfo(std::string_view Namespace);
std::optional<ZenCacheNamespace::BucketInfo> GetBucketInfo(std::string_view Namespace, std::string_view Bucket);
diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h
index 95d7fa43d..e8b4a700c 100644
--- a/src/zenutil/include/zenutil/process/subprocessmanager.h
+++ b/src/zenutil/include/zenutil/process/subprocessmanager.h
@@ -156,7 +156,7 @@ private:
friend class ProcessGroup;
struct Impl;
- std::unique_ptr<Impl> m_Impl;
+ std::shared_ptr<Impl> m_Impl;
};
/// A process managed by SubprocessManager.
diff --git a/src/zenutil/process/asyncpipereader.cpp b/src/zenutil/process/asyncpipereader.cpp
index 8eac350c6..7d603aa3f 100644
--- a/src/zenutil/process/asyncpipereader.cpp
+++ b/src/zenutil/process/asyncpipereader.cpp
@@ -31,7 +31,7 @@ static constexpr size_t kReadBufferSize = 4096;
#if !ZEN_PLATFORM_WINDOWS
-struct AsyncPipeReader::Impl
+struct AsyncPipeReader::Impl : public TRefCounted<Impl>
{
asio::io_context& m_IoContext;
std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor;
@@ -82,22 +82,25 @@ struct AsyncPipeReader::Impl
return;
}
- m_Descriptor->async_read_some(asio::buffer(m_Buffer), [this](const asio::error_code& Ec, size_t BytesRead) {
+ // Capture Self so the Impl outlives any completion the io_context has
+ // already picked up but not yet dispatched, even if the owning
+ // AsyncPipeReader is destroyed in the meantime.
+ m_Descriptor->async_read_some(asio::buffer(m_Buffer), [Self = Ref<Impl>(this)](const asio::error_code& Ec, size_t BytesRead) {
if (Ec)
{
- if (Ec != asio::error::operation_aborted && m_EofCallback)
+ if (Ec != asio::error::operation_aborted && Self->m_EofCallback)
{
- m_EofCallback();
+ Self->m_EofCallback();
}
return;
}
- if (BytesRead > 0 && m_DataCallback)
+ if (BytesRead > 0 && Self->m_DataCallback)
{
- m_DataCallback(std::string_view(m_Buffer.data(), BytesRead));
+ Self->m_DataCallback(std::string_view(Self->m_Buffer.data(), BytesRead));
}
- EnqueueRead();
+ Self->EnqueueRead();
});
}
};
@@ -183,7 +186,7 @@ CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe)
return true;
}
-struct AsyncPipeReader::Impl
+struct AsyncPipeReader::Impl : public TRefCounted<Impl>
{
asio::io_context& m_IoContext;
std::unique_ptr<asio::windows::stream_handle> m_StreamHandle;
@@ -229,22 +232,25 @@ struct AsyncPipeReader::Impl
return;
}
- m_StreamHandle->async_read_some(asio::buffer(m_Buffer), [this](const asio::error_code& Ec, size_t BytesRead) {
+ // Capture Self so the Impl outlives any completion the io_context has
+ // already picked up but not yet dispatched, even if the owning
+ // AsyncPipeReader is destroyed in the meantime.
+ m_StreamHandle->async_read_some(asio::buffer(m_Buffer), [Self = Ref<Impl>(this)](const asio::error_code& Ec, size_t BytesRead) {
if (Ec)
{
- if (Ec != asio::error::operation_aborted && m_EofCallback)
+ if (Ec != asio::error::operation_aborted && Self->m_EofCallback)
{
- m_EofCallback();
+ Self->m_EofCallback();
}
return;
}
- if (BytesRead > 0 && m_DataCallback)
+ if (BytesRead > 0 && Self->m_DataCallback)
{
- m_DataCallback(std::string_view(m_Buffer.data(), BytesRead));
+ Self->m_DataCallback(std::string_view(Self->m_Buffer.data(), BytesRead));
}
- EnqueueRead();
+ Self->EnqueueRead();
});
}
};
@@ -255,11 +261,21 @@ struct AsyncPipeReader::Impl
// Common wrapper
// ============================================================================
-AsyncPipeReader::AsyncPipeReader(asio::io_context& IoContext) : m_Impl(std::make_unique<Impl>(IoContext))
+AsyncPipeReader::AsyncPipeReader(asio::io_context& IoContext) : m_Impl(new Impl(IoContext))
{
}
-AsyncPipeReader::~AsyncPipeReader() = default;
+AsyncPipeReader::~AsyncPipeReader()
+{
+ // Explicitly stop pending async reads. The Impl may outlive this call if a
+ // completion is still in the io_context queue (the handler holds a strong
+ // Ref back to Impl to keep it alive). Stop() here guarantees reads stop
+ // even if nobody has called Stop() on the wrapper.
+ if (m_Impl)
+ {
+ m_Impl->Stop();
+ }
+}
void
AsyncPipeReader::Start(StdoutPipeHandles&& Pipe, std::function<void(std::string_view)> DataCallback, std::function<void()> EofCallback)
diff --git a/src/zenutil/process/asyncpipereader.h b/src/zenutil/process/asyncpipereader.h
index ad2ff8455..3e4be6906 100644
--- a/src/zenutil/process/asyncpipereader.h
+++ b/src/zenutil/process/asyncpipereader.h
@@ -2,11 +2,11 @@
#pragma once
+#include <zenbase/refcount.h>
#include <zencore/process.h>
#include <zencore/zencore.h>
#include <functional>
-#include <memory>
#include <string_view>
namespace asio {
@@ -56,7 +56,7 @@ public:
private:
struct Impl;
- std::unique_ptr<Impl> m_Impl;
+ Ref<Impl> m_Impl;
};
} // namespace zen
diff --git a/src/zenutil/process/exitwatcher.cpp b/src/zenutil/process/exitwatcher.cpp
index cef31ebca..2e88dfdeb 100644
--- a/src/zenutil/process/exitwatcher.cpp
+++ b/src/zenutil/process/exitwatcher.cpp
@@ -38,7 +38,7 @@ namespace zen {
#if ZEN_PLATFORM_LINUX
-struct ProcessExitWatcher::Impl
+struct ProcessExitWatcher::Impl : public TRefCounted<Impl>
{
asio::io_context& m_IoContext;
std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor;
@@ -64,8 +64,11 @@ struct ProcessExitWatcher::Impl
m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, m_PidFd);
+ // Capture a strong Ref so the Impl outlives the pending async_wait even
+ // if the owning ProcessExitWatcher is destroyed while a completion is
+ // in flight on the io_context.
m_Descriptor->async_wait(asio::posix::stream_descriptor::wait_read,
- [this, Callback = std::move(OnExit)](const asio::error_code& Ec) {
+ [Self = Ref<Impl>(this), Callback = std::move(OnExit)](const asio::error_code& Ec) {
if (Ec)
{
return; // Cancelled or error
@@ -74,7 +77,7 @@ struct ProcessExitWatcher::Impl
int ExitCode = -1;
int Status = 0;
// The pidfd told us the process exited. Reap it with waitpid.
- if (waitpid(m_Pid, &Status, WNOHANG) > 0)
+ if (waitpid(Self->m_Pid, &Status, WNOHANG) > 0)
{
if (WIFEXITED(Status))
{
@@ -115,7 +118,7 @@ struct ProcessExitWatcher::Impl
#elif ZEN_PLATFORM_WINDOWS
-struct ProcessExitWatcher::Impl
+struct ProcessExitWatcher::Impl : public TRefCounted<Impl>
{
asio::io_context& m_IoContext;
std::unique_ptr<asio::windows::object_handle> m_ObjectHandle;
@@ -147,16 +150,21 @@ struct ProcessExitWatcher::Impl
// object_handle takes ownership of the handle
m_ObjectHandle = std::make_unique<asio::windows::object_handle>(m_IoContext, m_DuplicatedHandle);
- m_ObjectHandle->async_wait([this, DupHandle = m_DuplicatedHandle, Callback = std::move(OnExit)](const asio::error_code& Ec) {
- if (Ec)
- {
- return;
- }
-
- DWORD ExitCode = 0;
- GetExitCodeProcess(static_cast<HANDLE>(DupHandle), &ExitCode);
- Callback(static_cast<int>(ExitCode));
- });
+ // Capture a strong Ref so the duplicated handle (owned by m_ObjectHandle,
+ // which is owned by *this) cannot be closed out from under a completion
+ // that the io_context is still about to dispatch.
+ m_ObjectHandle->async_wait(
+ [Self = Ref<Impl>(this), DupHandle = m_DuplicatedHandle, Callback = std::move(OnExit)](const asio::error_code& Ec) {
+ (void)Self;
+ if (Ec)
+ {
+ return;
+ }
+
+ DWORD ExitCode = 0;
+ GetExitCodeProcess(static_cast<HANDLE>(DupHandle), &ExitCode);
+ Callback(static_cast<int>(ExitCode));
+ });
}
void Cancel()
@@ -182,7 +190,7 @@ struct ProcessExitWatcher::Impl
#elif ZEN_PLATFORM_MAC
-struct ProcessExitWatcher::Impl
+struct ProcessExitWatcher::Impl : public TRefCounted<Impl>
{
asio::io_context& m_IoContext;
std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor;
@@ -218,8 +226,11 @@ struct ProcessExitWatcher::Impl
m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, m_KqueueFd);
+ // Capture a strong Ref so the Impl outlives the pending async_wait even
+ // if the owning ProcessExitWatcher is destroyed while a completion is
+ // in flight on the io_context.
m_Descriptor->async_wait(asio::posix::stream_descriptor::wait_read,
- [this, Callback = std::move(OnExit)](const asio::error_code& Ec) {
+ [Self = Ref<Impl>(this), Callback = std::move(OnExit)](const asio::error_code& Ec) {
if (Ec)
{
return;
@@ -228,11 +239,11 @@ struct ProcessExitWatcher::Impl
// Drain the kqueue event
struct kevent Event;
struct timespec Timeout = {0, 0};
- kevent(m_KqueueFd, nullptr, 0, &Event, 1, &Timeout);
+ kevent(Self->m_KqueueFd, nullptr, 0, &Event, 1, &Timeout);
int ExitCode = -1;
int Status = 0;
- if (waitpid(m_Pid, &Status, WNOHANG) > 0)
+ if (waitpid(Self->m_Pid, &Status, WNOHANG) > 0)
{
if (WIFEXITED(Status))
{
@@ -273,11 +284,21 @@ struct ProcessExitWatcher::Impl
// Common wrapper (delegates to Impl)
// ============================================================================
-ProcessExitWatcher::ProcessExitWatcher(asio::io_context& IoContext) : m_Impl(std::make_unique<Impl>(IoContext))
+ProcessExitWatcher::ProcessExitWatcher(asio::io_context& IoContext) : m_Impl(new Impl(IoContext))
{
}
-ProcessExitWatcher::~ProcessExitWatcher() = default;
+ProcessExitWatcher::~ProcessExitWatcher()
+{
+ // Explicitly cancel pending async ops. The Impl may outlive this call if a
+ // completion is still in the io_context queue (the handler holds a strong
+ // Ref back to Impl to keep it alive). Cancel() here guarantees the watch
+ // stops even if nobody has called Cancel() on the wrapper.
+ if (m_Impl)
+ {
+ m_Impl->Cancel();
+ }
+}
void
ProcessExitWatcher::Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit)
diff --git a/src/zenutil/process/exitwatcher.h b/src/zenutil/process/exitwatcher.h
index 24906d7d0..c9b23368a 100644
--- a/src/zenutil/process/exitwatcher.h
+++ b/src/zenutil/process/exitwatcher.h
@@ -2,11 +2,11 @@
#pragma once
+#include <zenbase/refcount.h>
#include <zencore/process.h>
#include <zencore/zencore.h>
#include <functional>
-#include <memory>
namespace asio {
class io_context;
@@ -42,7 +42,7 @@ public:
private:
struct Impl;
- std::unique_ptr<Impl> m_Impl;
+ Ref<Impl> m_Impl;
};
} // namespace zen
diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp
index d0b912a0d..acb518808 100644
--- a/src/zenutil/process/subprocessmanager.cpp
+++ b/src/zenutil/process/subprocessmanager.cpp
@@ -236,7 +236,7 @@ ManagedProcess::GetTag() const
// SubprocessManager::Impl
// ============================================================================
-struct SubprocessManager::Impl
+struct SubprocessManager::Impl : public std::enable_shared_from_this<Impl>
{
asio::io_context& m_IoContext;
SubprocessManagerConfig m_Config;
@@ -308,7 +308,10 @@ SubprocessManager::Impl::Impl(asio::io_context& IoContext, SubprocessManagerConf
if (m_Config.MetricsSampleIntervalMs > 0)
{
m_MetricsTimer = std::make_unique<asio::steady_timer>(IoContext);
- EnqueueMetricsTimer();
+ // Don't start the timer here: EnqueueMetricsTimer captures
+ // weak_from_this(), which requires the enclosing shared_ptr to
+ // already own this. The caller (SubprocessManager ctor) invokes
+ // EnqueueMetricsTimer() after the shared_ptr is established.
}
}
@@ -381,8 +384,17 @@ SubprocessManager::Impl::SetupExitWatcher(ManagedProcess* Proc, ProcessExitCallb
{
int Pid = Proc->Pid();
- Proc->m_Impl->m_ExitWatcher.Watch(Proc->m_Impl->m_Handle, [this, Pid, Callback = std::move(OnExit)](int ExitCode) {
- ManagedProcess* Found = FindProcess(Pid);
+ // Capture a weak_ptr so the handler safely no-ops if the manager is
+ // destroyed (or the process has been Remove()'d) before the exit
+ // completion is dispatched on the io_context.
+ Proc->m_Impl->m_ExitWatcher.Watch(Proc->m_Impl->m_Handle, [Self = weak_from_this(), Pid, Callback = std::move(OnExit)](int ExitCode) {
+ auto Locked = Self.lock();
+ if (!Locked)
+ {
+ return;
+ }
+
+ ManagedProcess* Found = Locked->FindProcess(Pid);
if (Found)
{
@@ -399,17 +411,22 @@ SubprocessManager::Impl::SetupStdoutReader(ManagedProcess* Proc, StdoutPipeHandl
Proc->m_Impl->m_StdoutReader = std::make_unique<AsyncPipeReader>(m_IoContext);
Proc->m_Impl->m_StdoutReader->Start(
std::move(Pipe),
- [this, Pid](std::string_view Data) {
- ManagedProcess* Found = FindProcess(Pid);
+ [Self = weak_from_this(), Pid](std::string_view Data) {
+ auto Locked = Self.lock();
+ if (!Locked)
+ {
+ return;
+ }
+ ManagedProcess* Found = Locked->FindProcess(Pid);
if (Found)
{
if (Found->m_Impl->m_StdoutCallback)
{
Found->m_Impl->m_StdoutCallback(*Found, Data);
}
- else if (m_DefaultStdoutCallback)
+ else if (Locked->m_DefaultStdoutCallback)
{
- m_DefaultStdoutCallback(*Found, Data);
+ Locked->m_DefaultStdoutCallback(*Found, Data);
}
else
{
@@ -427,17 +444,22 @@ SubprocessManager::Impl::SetupStderrReader(ManagedProcess* Proc, StdoutPipeHandl
Proc->m_Impl->m_StderrReader = std::make_unique<AsyncPipeReader>(m_IoContext);
Proc->m_Impl->m_StderrReader->Start(
std::move(Pipe),
- [this, Pid](std::string_view Data) {
- ManagedProcess* Found = FindProcess(Pid);
+ [Self = weak_from_this(), Pid](std::string_view Data) {
+ auto Locked = Self.lock();
+ if (!Locked)
+ {
+ return;
+ }
+ ManagedProcess* Found = Locked->FindProcess(Pid);
if (Found)
{
if (Found->m_Impl->m_StderrCallback)
{
Found->m_Impl->m_StderrCallback(*Found, Data);
}
- else if (m_DefaultStderrCallback)
+ else if (Locked->m_DefaultStderrCallback)
{
- m_DefaultStderrCallback(*Found, Data);
+ Locked->m_DefaultStderrCallback(*Found, Data);
}
else
{
@@ -557,14 +579,19 @@ SubprocessManager::Impl::EnqueueMetricsTimer()
}
m_MetricsTimer->expires_after(std::chrono::milliseconds(m_Config.MetricsSampleIntervalMs));
- m_MetricsTimer->async_wait([this](const asio::error_code& Ec) {
- if (Ec || !m_Running.load())
+ m_MetricsTimer->async_wait([Self = weak_from_this()](const asio::error_code& Ec) {
+ auto Locked = Self.lock();
+ if (!Locked)
+ {
+ return;
+ }
+ if (Ec || !Locked->m_Running.load())
{
return;
}
- SampleBatch();
- EnqueueMetricsTimer();
+ Locked->SampleBatch();
+ Locked->EnqueueMetricsTimer();
});
}
@@ -711,8 +738,11 @@ SubprocessManager::Impl::EnumerateGroups(std::function<void(const ProcessGroup&)
// ============================================================================
SubprocessManager::SubprocessManager(asio::io_context& IoContext, SubprocessManagerConfig Config)
-: m_Impl(std::make_unique<Impl>(IoContext, Config))
+: m_Impl(std::make_shared<Impl>(IoContext, Config))
{
+ // Start the metrics timer now that the shared_ptr owns the Impl - only
+ // then does weak_from_this() produce a valid weak_ptr for the handler.
+ m_Impl->EnqueueMetricsTimer();
}
SubprocessManager::~SubprocessManager() = default;