diff options
| author | Stefan Boberg <[email protected]> | 2026-05-04 16:46:03 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-05-04 16:46:03 +0200 |
| commit | 10d2a61fe1c848f44033e8450ff3a5ffa7f4322a (patch) | |
| tree | aa66c6a068b50d2390bdae5f857c7151f15e5a86 /src | |
| parent | Tui picker fixes (#1027) (diff) | |
| download | archived-zen-10d2a61fe1c848f44033e8450ff3a5ffa7f4322a.tar.xz archived-zen-10d2a61fe1c848f44033e8450ff3a5ffa7f4322a.zip | |
zenhttp improvements (robustness / correctness) (#968)
A collection of security, correctness, and robustness fixes in `zenhttp` and `zencore` surfaced by security review. Most items are small, independent commits grouped here because they all tighten trust boundaries or fix UB along the same code paths.
## WebSocket protocol hardening (RFC 6455)
- **Enforce the client-side mask bit**. Server-side frame loops now reject unmasked frames with close code 1002 per §5.1. Prevents HTTP intermediary smuggling.
- **Validate control frames and RSV bits**. Fragmented control frames, oversized (>125 B) control payloads, and any non-zero RSV bit now fail the connection before allocation.
- **Lower per-frame payload cap** from 256 MB → 4 MB. Bounds per-connection accumulator memory.
- **Implement message fragmentation**. Continuation frames are coalesced and delivered as a single message; interleaved non-control frames close with 1002; assembled messages are capped at 4 MB (1009 on overflow). Previously partial fragments were delivered to handlers, bypassing payload validation.
- **Parse the 101 handshake response properly** in `HttpWsClient`. Status-line, `Upgrade`, `Connection`, and `Sec-WebSocket-Accept` are now matched exactly rather than via substring searches against the full body.
## Auth / OIDC hardening
- **Constant-time password compare** in `PasswordSecurity::IsAllowed` (closes a remote length/content timing oracle). Adds a shared `ConstantTimeEquals` helper.
- **Harden Basic-auth header parsing**: trim trailing LWS, reject control bytes and DEL in the credential.
- **OIDC discovery pinning**: require HTTPS (loopback exempt), verify `issuer` matches `BaseUrl`, require `token_endpoint` / `userinfo_endpoint` / `jwks_uri` to share origin with `BaseUrl`, reject empty `token_endpoint`.
- **Restrict `POST /auth/oidc/refreshtoken`** to local-machine requests. Previously unauthenticated in default deployments — remote callers could evict or replace cached tokens.
- **Stop logging OIDC provider response bodies** on refresh failure (IdPs echo `refresh_token` back in error bodies).
- **Drop the unused `IdentityToken` field** from `OidcClient` / `OpenIdToken` so nothing in the tree accidentally trusts an unverified JWT.
## Auth state encryption migration
- Add `AesGcm` AEAD primitive (BCrypt / OpenSSL backends, mbedTLS stubbed) and `CryptoRandom::Fill` CSPRNG helper in `zencore/crypto.h`.
- Migrate authstate file from AES-256-CBC with a fixed IV to AES-GCM with a fresh 12-byte random nonce per write and the 4-byte `ZEN1` magic bound as AAD. Legacy-CBC files are transparently read once and rewritten in the new format.
## Filesystem / IO robustness
- `IoBufferExtendedCore::Materialize` now checks `MAP_FAILED` on POSIX (was comparing to `nullptr`, which let the failure sentinel propagate into later reads and `munmap(MAP_FAILED, ...)`).
- `IoBufferBuilder::MakeFromFile / MakeFromTemporaryFile`: close the FD/HANDLE on exception via a dismissable `ScopeGuard`; actually check the `fstat()` return value (previously used an uninitialized `FileSize`).
- `ReadFromFileMaybe`: loop short reads, retry `EINTR`, chunk Windows `ReadFile` at `0xFFFFFFFF` bytes (fixes silent truncation of multi-GiB reads).
- `WipeDirectory`: compare `FindFirstFileW` handle against `INVALID_HANDLE_VALUE` rather than `nullptr`.
- `RemoveFileNative` (Linux/macOS): report non-`ENOENT` stat failures via the `std::error_code` out-param and stop reading `st_mode` after a failed stat.
## Buffer / compression correctness
- Avoid per-copy `IoBufferCore` heap allocations in `CompositeBuffer::CopyTo / ViewOrCopyRange` iterators; add fast path for `BufferHeader::Read` when the 64-byte header fits in the first plain-memory segment.
- `BufferHeader`: add `IsHeaderValid()` gate covering `BlockSizeExponent` range, `BlockCount * BlockSize` overflow, and `TotalRawSize` bounds before any arithmetic uses them. Defends against attacker-controlled headers that can pass the CRC and trigger OOB writes in `DecompressBlock`.
Diffstat (limited to 'src')
44 files changed, 2785 insertions, 312 deletions
diff --git a/src/zencore/basicfile.cpp b/src/zencore/basicfile.cpp index fdf742261..01d550957 100644 --- a/src/zencore/basicfile.cpp +++ b/src/zencore/basicfile.cpp @@ -798,11 +798,12 @@ BasicFileWriter::Write(const void* Data, uint64_t Size, uint64_t FileOffset) { if (m_Buffer == nullptr || (Size >= m_BufferSize)) { - if (FileOffset == m_BufferEnd) - { - Flush(); - m_BufferStart = m_BufferEnd = FileOffset + Size; - } + // Always flush pending buffered data first. Otherwise a later + // Flush() would replay stale bytes at m_BufferStart, clobbering + // any range of this direct write that overlaps + // [m_BufferStart, m_BufferEnd). + Flush(); + m_BufferStart = m_BufferEnd = FileOffset + Size; m_Base.Write(Data, Size, FileOffset); return; @@ -1200,6 +1201,60 @@ TEST_CASE("BasicFileBuffer") } } +TEST_CASE("BasicFileWriter.LargeDiscontinuousWriteFlushesBuffer") +{ + // Regression: BasicFileWriter::Write's large-write branch used to skip + // Flush() whenever FileOffset != m_BufferEnd. Any subsequent Flush + // (including the one in ~BasicFileWriter) then replayed the stale + // buffered bytes at their original offset, clobbering whatever the + // caller had just written directly. + ScopedCurrentDirectoryChange _; + + constexpr uint64_t BufferSize = 64; + constexpr uint64_t SmallSize = 10; + constexpr uint64_t LargeSize = 1024; // >= BufferSize, forces the direct-write path + constexpr uint64_t LargeOffset = 5; // overlaps the pending small-write region + constexpr uint64_t OverlapStart = LargeOffset; + constexpr uint64_t OverlapEnd = SmallSize; + + { + BasicFile File; + File.Open("discontig_write", BasicFile::Mode::kTruncate); + BasicFileWriter Writer(File, BufferSize); + + // First: small write buffered at [0, SmallSize) - still pending flush. + std::vector<uint8_t> Small(SmallSize, 'A'); + Writer.Write(Small.data(), Small.size(), 0); + + // Second: large write that overlaps the pending buffered region. + // Last-writer-wins means the overlap bytes must end up as 'B'. + std::vector<uint8_t> Large(LargeSize, 'B'); + Writer.Write(Large.data(), Large.size(), LargeOffset); + } + + BasicFile Reader; + Reader.Open("discontig_write", BasicFile::Mode::kRead); + IoBuffer Contents = Reader.ReadAll(); + REQUIRE_EQ(Contents.Size(), LargeOffset + LargeSize); + + const uint8_t* Bytes = reinterpret_cast<const uint8_t*>(Contents.Data()); + for (uint64_t I = 0; I < OverlapStart; ++I) + { + CHECK_EQ(Bytes[I], 'A'); + } + // The bytes in the overlap range [OverlapStart, OverlapEnd) are the + // critical check: with the bug, the late Flush() replayed the small + // 'A' write and clobbered these with 'A'. + for (uint64_t I = OverlapStart; I < OverlapEnd; ++I) + { + CHECK_EQ(Bytes[I], 'B'); + } + for (uint64_t I = OverlapEnd; I < LargeOffset + LargeSize; ++I) + { + CHECK_EQ(Bytes[I], 'B'); + } +} + TEST_SUITE_END(); void diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp index ed2b16384..1dee8477f 100644 --- a/src/zencore/compositebuffer.cpp +++ b/src/zencore/compositebuffer.cpp @@ -179,12 +179,11 @@ CompositeBuffer::GetIterator(uint64_t Offset) const MemoryView CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const { - // We use a sub range IoBuffer when we want to copy data from a segment. - // This means we will only materialize that range of the segment when doing - // GetView() rather than the full segment. - // A hot path for this code is when we call CompressedBuffer::FromCompressed which - // is only interested in reading the header (first 64 bytes or so) and then throws - // away the materialized data. + // A hot path for this code is CompressedBuffer::FromCompressed, which only reads the header + // (first 64 bytes or so). For plain memory segments we take a direct view (no allocation); + // for extended (file-backed) segments we materialize only the requested slice via a + // sub-range IoBuffer whose lifetime must extend across the CopyFrom below — otherwise its + // view would dangle into freed memory the moment the IoBuffer goes out of scope. if (CopyBuffer.GetSize() < Size) { CopyBuffer = UniqueBuffer::Alloc(Size); @@ -198,9 +197,20 @@ CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& Copy const SharedBuffer& Segment = m_Segments[It.SegmentIndex]; size_t SegmentSize = Segment.GetSize(); size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft); - IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); - MemoryView ReadView = SubSegment.GetView(); - WriteView = WriteView.CopyFrom(ReadView); + + IoBuffer SubSegment; // lifetime holder for the extended-segment view + MemoryView ReadView; + if (Segment.IsExtended()) + { + SubSegment = IoBuffer(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); + ReadView = SubSegment.GetView(); + } + else + { + ReadView = Segment.GetView().Mid(It.OffsetInSegment, CopySize); + } + WriteView = WriteView.CopyFrom(ReadView); + It.OffsetInSegment += CopySize; ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); if (It.OffsetInSegment == SegmentSize) @@ -216,12 +226,7 @@ CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& Copy void CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const { - // We use a sub range IoBuffer when we want to copy data from a segment. - // This means we will only materialize that range of the segment when doing - // GetView() rather than the full segment. - // A hot path for this code is when we call CompressedBuffer::FromCompressed which - // is only interested in reading the header (first 64 bytes or so) and then throws - // away the materialized data. + // See ViewOrCopyRange above for rationale on the extended vs. plain segment split. size_t SizeLeft = WriteView.GetSize(); size_t SegmentCount = m_Segments.size(); @@ -231,9 +236,20 @@ CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const const SharedBuffer& Segment = m_Segments[It.SegmentIndex]; size_t SegmentSize = Segment.GetSize(); size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft); - IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); - MemoryView ReadView = SubSegment.GetView(); - WriteView = WriteView.CopyFrom(ReadView); + + IoBuffer SubSegment; // lifetime holder for the extended-segment view + MemoryView ReadView; + if (Segment.IsExtended()) + { + SubSegment = IoBuffer(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); + ReadView = SubSegment.GetView(); + } + else + { + ReadView = Segment.GetView().Mid(It.OffsetInSegment, CopySize); + } + WriteView = WriteView.CopyFrom(ReadView); + It.OffsetInSegment += CopySize; ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); if (It.OffsetInSegment == SegmentSize) diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp index 6aa0adce0..a0e91f908 100644 --- a/src/zencore/compress.cpp +++ b/src/zencore/compress.cpp @@ -78,12 +78,20 @@ struct BufferHeader BufferHeader Header; if (sizeof(BufferHeader) <= CompressedData.GetSize()) { - // if (CompressedData.GetSegments()[0].AsIoBuffer().IsWholeFile()) - // { - // ZEN_ASSERT(true); - // } - CompositeBuffer::Iterator It; - CompressedData.CopyTo(MakeMutableMemoryView(&Header, &Header + 1), It); + // Fast path: the overwhelmingly common case is that the 64-byte header sits entirely + // within the first segment and that segment is plain memory. Skip the iterator and + // the sub-range IoBuffer wrapper (which would otherwise heap-allocate an IoBufferCore). + const std::span<const SharedBuffer> Segments = CompressedData.GetSegments(); + const SharedBuffer& First = Segments.front(); + if (sizeof(BufferHeader) <= First.GetSize() && !First.IsExtended()) + { + MakeMutableMemoryView(&Header, &Header + 1).CopyFrom(First.GetView().Left(sizeof(BufferHeader))); + } + else + { + CompositeBuffer::Iterator It; + CompressedData.CopyTo(MakeMutableMemoryView(&Header, &Header + 1), It); + } Header.ByteSwap(); } return Header; @@ -837,6 +845,15 @@ BlockDecoder::DecompressToStream( { return false; } + // RawOffset+RawSize-1 below underflows when RawSize is 0, and the + // BlockCount-0 / BlockSize-0 arithmetic is only defined when the header + // has already been validated (see IsHeaderValid). Guard both here as + // defence in depth. + if (RawSize == 0 || Header.BlockCount == 0 || Header.BlockSizeExponent >= 32 || RawOffset > Header.TotalRawSize || + RawSize > Header.TotalRawSize - RawOffset) + { + return false; + } const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; @@ -1386,6 +1403,50 @@ GetDecoder(CompressionMethod Method) } } +// Sanity-check a header that was just read from an untrusted buffer before +// any of the decode arithmetic (1 << BlockSizeExponent, BlockCount*BlockSize, +// divides by BlockSize, etc.) is performed. Must be called after the magic, +// decoder and CRC checks pass. +static bool +IsHeaderValid(const BufferHeader& Header) +{ + // 1 << BlockSizeExponent is UB for Exponent >= 64 and wildly impractical + // below that. Real producers use <= 24 (16 MiB blocks); cap at 32 for + // headroom while staying well below the UB boundary. + if (Header.BlockSizeExponent >= 32) + { + return false; + } + + // Only the block-based methods use BlockCount / BlockSizeExponent. The + // None method keeps them zero. + if (Header.Method != CompressionMethod::None) + { + // A non-empty buffer needs at least one block. + if (Header.BlockCount == 0) + { + return Header.TotalRawSize == 0; + } + + const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; + + // BlockCount * BlockSize must not overflow and must fit TotalRawSize + // in the half-open range ((BlockCount - 1) * BlockSize, BlockCount * BlockSize]. + if (Header.BlockCount > (UINT64_MAX / BlockSize)) + { + return false; + } + const uint64_t MaxRawSize = uint64_t(Header.BlockCount) * BlockSize; + const uint64_t MinRawSize = MaxRawSize - BlockSize; + if (Header.TotalRawSize > MaxRawSize || Header.TotalRawSize <= MinRawSize) + { + return false; + } + } + + return true; +} + ////////////////////////////////////////////////////////////////////////// bool @@ -1426,6 +1487,10 @@ ReadHeader(const CompositeBuffer& CompressedData, BufferHeader& OutHeader, Uniqu { return false; } + if (!IsHeaderValid(OutHeader)) + { + return false; + } uint64_t FullHeaderSize = Decoder->GetHeaderSize(OutHeader); if (FullHeaderSize > CompressedDataSize) { @@ -1520,7 +1585,7 @@ TryReadHeader(DecoderContext& Context, Archive& Ar, FHeader& OutHeader, MemoryVi FHeader* const HeaderCopy = static_cast<FHeader*>(HeaderView.GetData()); HeaderCopy->ByteSwap(); - if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView)) + if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView) && IsHeaderValid(Header)) { Context.HeaderOffset = uint64_t(Offset); Context.HeaderSize = HeaderSize; @@ -1560,7 +1625,7 @@ TryReadHeader(DecoderContext& Context, const CompositeBuffer& Buffer, FHeader& O const MemoryView HeaderView = Buffer.ViewOrCopyRange(0, HeaderSize, Context.Header, [](uint64_t Size) { return UniqueBuffer::Alloc(zen::Max(NextPow2(Size), DefaultHeaderSize)); }); - if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView)) + if (Header.Crc32 == FHeader::CalculateCrc32(HeaderView) && IsHeaderValid(Header)) { Context.HeaderOffset = 0; Context.HeaderSize = HeaderSize; diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp index 9984f35ac..8a172de3a 100644 --- a/src/zencore/crypto.cpp +++ b/src/zencore/crypto.cpp @@ -432,6 +432,285 @@ namespace crypto { return true; } + ////////////////////////////////////////////////////////////////////////// + // + // AES-256-GCM backends + // + +#if ZEN_USE_MBEDTLS + + MemoryView GcmTransform(TransformMode Mode, + const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView In, + MutableMemoryView Out, + MutableMemoryView TagOut, // only used for Encrypt + MemoryView TagIn, // only used for Decrypt + std::optional<std::string>& Reason) + { + Reason = "AES-GCM is not implemented on the mbedTLS backend"sv; + (void)Mode; + (void)Key; + (void)Nonce; + (void)Aad; + (void)In; + (void)Out; + (void)TagOut; + (void)TagIn; + return MemoryView(); + } + +#elif ZEN_USE_OPENSSL + + MemoryView GcmTransform(TransformMode Mode, + const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView In, + MutableMemoryView Out, + MutableMemoryView TagOut, + MemoryView TagIn, + std::optional<std::string>& Reason) + { + EvpContext Ctx; + + const EVP_CIPHER* Cipher = EVP_aes_256_gcm(); + ZEN_ASSERT(Cipher != nullptr); + + const bool Encrypting = (Mode == TransformMode::Encrypt); + + if (EVP_CipherInit_ex(Ctx, Cipher, nullptr, nullptr, nullptr, Encrypting ? 1 : 0) != 1) + { + Reason = "EVP_CipherInit_ex (algo) failed"sv; + return {}; + } + + // Explicitly set IV length to 12; the default is 12 for GCM but pinning + // it prevents any surprise from an OpenSSL build that defaults + // differently. + if (EVP_CIPHER_CTX_ctrl(Ctx, EVP_CTRL_AEAD_SET_IVLEN, (int)AesGcm::NonceSize, nullptr) != 1) + { + Reason = "EVP_CTRL_AEAD_SET_IVLEN failed"sv; + return {}; + } + + if (EVP_CipherInit_ex(Ctx, + nullptr, + nullptr, + reinterpret_cast<const unsigned char*>(Key.GetView().GetData()), + reinterpret_cast<const unsigned char*>(Nonce.GetData()), + Encrypting ? 1 : 0) != 1) + { + Reason = "EVP_CipherInit_ex (key+nonce) failed"sv; + return {}; + } + + // Feed AAD (if any) before the ciphertext/plaintext. + if (!Aad.IsEmpty()) + { + int AadOutLen = 0; + if (EVP_CipherUpdate(Ctx, + nullptr, + &AadOutLen, + reinterpret_cast<const unsigned char*>(Aad.GetData()), + static_cast<int>(Aad.GetSize())) != 1) + { + Reason = "EVP_CipherUpdate (AAD) failed"sv; + return {}; + } + } + + // For decrypt, set the expected tag before calling Final so the tag + // check can fail cleanly. + if (!Encrypting) + { + if (EVP_CIPHER_CTX_ctrl(Ctx, EVP_CTRL_AEAD_SET_TAG, (int)TagIn.GetSize(), (void*)TagIn.GetData()) != 1) + { + Reason = "EVP_CTRL_AEAD_SET_TAG failed"sv; + return {}; + } + } + + int BodyLen = 0; + if (EVP_CipherUpdate(Ctx, + reinterpret_cast<unsigned char*>(Out.GetData()), + &BodyLen, + reinterpret_cast<const unsigned char*>(In.GetData()), + static_cast<int>(In.GetSize())) != 1) + { + Reason = "EVP_CipherUpdate (body) failed"sv; + return {}; + } + + int FinalLen = 0; + if (EVP_CipherFinal_ex(Ctx, reinterpret_cast<unsigned char*>(Out.GetData()) + BodyLen, &FinalLen) != 1) + { + // For decrypt, this is the authentication-tag mismatch path. + Reason = Encrypting ? std::string("EVP_CipherFinal_ex (encrypt) failed") : std::string("AES-GCM authentication tag mismatch"); + return {}; + } + + if (Encrypting) + { + if (EVP_CIPHER_CTX_ctrl(Ctx, EVP_CTRL_AEAD_GET_TAG, (int)TagOut.GetSize(), TagOut.GetData()) != 1) + { + Reason = "EVP_CTRL_AEAD_GET_TAG failed"sv; + return {}; + } + } + + return Out.Left(static_cast<size_t>(BodyLen + FinalLen)); + } + +#else // ZEN_USE_BCRYPT + + MemoryView GcmTransform(TransformMode Mode, + const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView In, + MutableMemoryView Out, + MutableMemoryView TagOut, + MemoryView TagIn, + std::optional<std::string>& Reason) + { + BCRYPT_ALG_HANDLE hAlg = nullptr; + NTSTATUS Status = STATUS_UNSUCCESSFUL; + + if (!NT_SUCCESS(Status = BCryptOpenAlgorithmProvider(&hAlg, BCRYPT_AES_ALGORITHM, nullptr, 0))) + { + Reason = fmt::format("BCryptOpenAlgorithmProvider failed, 0x{:08x}"sv, Status); + return {}; + } + auto CloseAlg = MakeGuard([hAlg] { BCryptCloseAlgorithmProvider(hAlg, 0); }); + + if (!NT_SUCCESS(Status = + BCryptSetProperty(hAlg, BCRYPT_CHAINING_MODE, (PUCHAR)BCRYPT_CHAIN_MODE_GCM, sizeof(BCRYPT_CHAIN_MODE_GCM), 0))) + { + Reason = fmt::format("BCryptSetProperty(CHAIN_MODE_GCM) failed, 0x{:08x}"sv, Status); + return {}; + } + + BCRYPT_KEY_HANDLE hKey = nullptr; + if (!NT_SUCCESS(Status = BCryptGenerateSymmetricKey(hAlg, + &hKey, + nullptr, + 0, + (PUCHAR)Key.GetView().GetData(), + (ULONG)Key.GetView().GetSize(), + 0))) + { + Reason = fmt::format("BCryptGenerateSymmetricKey failed, 0x{:08x}"sv, Status); + return {}; + } + auto CloseKey = MakeGuard([hKey] { BCryptDestroyKey(hKey); }); + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO AuthInfo; + BCRYPT_INIT_AUTH_MODE_INFO(AuthInfo); + AuthInfo.pbNonce = (PUCHAR)Nonce.GetData(); + AuthInfo.cbNonce = (ULONG)Nonce.GetSize(); + AuthInfo.pbAuthData = Aad.IsEmpty() ? nullptr : (PUCHAR)Aad.GetData(); + AuthInfo.cbAuthData = (ULONG)Aad.GetSize(); + + ULONG ResultLen = 0; + + if (Mode == TransformMode::Encrypt) + { + AuthInfo.pbTag = (PUCHAR)TagOut.GetData(); + AuthInfo.cbTag = (ULONG)TagOut.GetSize(); + + Status = BCryptEncrypt(hKey, + (PUCHAR)In.GetData(), + (ULONG)In.GetSize(), + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out.GetData(), + (ULONG)Out.GetSize(), + &ResultLen, + /*flags=*/0); + + if (!NT_SUCCESS(Status)) + { + Reason = fmt::format("BCryptEncrypt (GCM) failed, 0x{:08x}"sv, Status); + return {}; + } + } + else + { + AuthInfo.pbTag = (PUCHAR)TagIn.GetData(); + AuthInfo.cbTag = (ULONG)TagIn.GetSize(); + + Status = BCryptDecrypt(hKey, + (PUCHAR)In.GetData(), + (ULONG)In.GetSize(), + &AuthInfo, + nullptr, + 0, + (PUCHAR)Out.GetData(), + (ULONG)Out.GetSize(), + &ResultLen, + /*flags=*/0); + + if (!NT_SUCCESS(Status)) + { + // STATUS_AUTH_TAG_MISMATCH (0xC000A002) is the tag-failure path. + Reason = fmt::format("BCryptDecrypt (GCM) failed, 0x{:08x}"sv, Status); + return {}; + } + } + + return Out.Left(ResultLen); + } + +#endif // backend selection + + MemoryView Gcm(TransformMode Mode, + const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView In, + MutableMemoryView Out, + MutableMemoryView TagOut, + MemoryView TagIn, + std::optional<std::string>& Reason) + { + if (!Key.IsValid()) + { + Reason = "invalid key"sv; + return {}; + } + if (Nonce.GetSize() != AesGcm::NonceSize) + { + Reason = fmt::format("AES-GCM nonce must be exactly {} bytes"sv, AesGcm::NonceSize); + return {}; + } + if (Out.GetSize() < In.GetSize()) + { + Reason = "AES-GCM output buffer is too small"sv; + return {}; + } + if (Mode == TransformMode::Encrypt) + { + if (TagOut.GetSize() != AesGcm::TagSize) + { + Reason = fmt::format("AES-GCM tag output must be exactly {} bytes"sv, AesGcm::TagSize); + return {}; + } + } + else + { + if (TagIn.GetSize() != AesGcm::TagSize) + { + Reason = fmt::format("AES-GCM tag input must be exactly {} bytes"sv, AesGcm::TagSize); + return {}; + } + } + + return GcmTransform(Mode, Key, Nonce, Aad, In, Out, TagOut, TagIn, Reason); + } + } // namespace crypto bool @@ -741,6 +1020,78 @@ Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, Muta return crypto::Transform(crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason); } +MemoryView +AesGcm::Encrypt(const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView Plaintext, + MutableMemoryView Out, + MutableMemoryView OutTag, + std::optional<std::string>& Reason) +{ + return crypto::Gcm(crypto::TransformMode::Encrypt, Key, Nonce, Aad, Plaintext, Out, OutTag, /*TagIn*/ {}, Reason); +} + +MemoryView +AesGcm::Decrypt(const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView Ciphertext, + MemoryView Tag, + MutableMemoryView Out, + std::optional<std::string>& Reason) +{ + return crypto::Gcm(crypto::TransformMode::Decrypt, Key, Nonce, Aad, Ciphertext, Out, /*TagOut*/ {}, Tag, Reason); +} + +////////////////////////////////////////////////////////////////////////// +// +// CryptoRandom +// + +bool +CryptoRandom::Fill(MutableMemoryView Buffer, std::optional<std::string>* Reason) +{ + if (Buffer.GetSize() == 0) + { + return true; + } + + auto SetReason = [&](std::string Msg) { + if (Reason) + { + *Reason = std::move(Msg); + } + }; + +#if ZEN_USE_BCRYPT + // BCRYPT_USE_SYSTEM_PREFERRED_RNG draws from the OS CSPRNG without + // requiring the caller to manage an algorithm handle. + const NTSTATUS Status = BCryptGenRandom(nullptr, (PUCHAR)Buffer.GetData(), (ULONG)Buffer.GetSize(), BCRYPT_USE_SYSTEM_PREFERRED_RNG); + if (!NT_SUCCESS(Status)) + { + SetReason(fmt::format("BCryptGenRandom failed, 0x{:08x}", static_cast<uint32_t>(Status))); + return false; + } + return true; +#elif ZEN_USE_OPENSSL + // RAND_bytes returns 1 on success, 0 on failure, -1 if not supported. + const int Rc = RAND_bytes(reinterpret_cast<unsigned char*>(Buffer.GetData()), static_cast<int>(Buffer.GetSize())); + if (Rc != 1) + { + SetReason(fmt::format("RAND_bytes failed (rc={})", Rc)); + return false; + } + return true; +#else + // mbedTLS: no CSPRNG wired up here yet. Callers on this backend must + // provide their own random source until a proper wiring is added. + SetReason("CryptoRandom::Fill is not implemented on the mbedTLS backend"); + (void)Buffer; + return false; +#endif +} + #if ZEN_WITH_TESTS void @@ -801,6 +1152,205 @@ TEST_CASE("crypto.aes") } } +TEST_CASE("crypto.aesgcm") +{ + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const uint8_t NonceBytes[AesGcm::NonceSize] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + const MemoryView Nonce = MakeMemoryView(NonceBytes); + + SUBCASE("round trip without AAD") + { + std::string_view Plain = "The quick brown fox jumps over the lazy dog"sv; + + std::vector<uint8_t> Cipher(Plain.size()); + std::vector<uint8_t> Tag(AesGcm::TagSize); + std::optional<std::string> Reason; + + MemoryView CipherView = AesGcm::Encrypt(Key, + Nonce, + /*Aad*/ {}, + MakeMemoryView(Plain), + MakeMutableMemoryView(Cipher), + MakeMutableMemoryView(Tag), + Reason); + REQUIRE(!Reason.has_value()); + CHECK_EQ(CipherView.GetSize(), Plain.size()); + + std::vector<uint8_t> Decoded(Plain.size()); + MemoryView DecodedView = + AesGcm::Decrypt(Key, Nonce, /*Aad*/ {}, CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason); + REQUIRE(!Reason.has_value()); + CHECK_EQ(DecodedView.GetSize(), Plain.size()); + + std::string_view DecodedText(reinterpret_cast<const char*>(DecodedView.GetData()), DecodedView.GetSize()); + CHECK_EQ(DecodedText, Plain); + } + + SUBCASE("round trip with AAD") + { + std::string_view Plain = "payload"sv; + std::string_view Aad = "header bits that are authenticated but not encrypted"sv; + + std::vector<uint8_t> Cipher(Plain.size()); + std::vector<uint8_t> Tag(AesGcm::TagSize); + std::optional<std::string> Reason; + + MemoryView CipherView = AesGcm::Encrypt(Key, + Nonce, + MakeMemoryView(Aad), + MakeMemoryView(Plain), + MakeMutableMemoryView(Cipher), + MakeMutableMemoryView(Tag), + Reason); + REQUIRE(!Reason.has_value()); + + std::vector<uint8_t> Decoded(Plain.size()); + MemoryView DecodedView = + AesGcm::Decrypt(Key, Nonce, MakeMemoryView(Aad), CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason); + REQUIRE(!Reason.has_value()); + CHECK_EQ(DecodedView.GetSize(), Plain.size()); + } + + SUBCASE("tampered ciphertext fails authentication") + { + std::string_view Plain = "important"sv; + + std::vector<uint8_t> Cipher(Plain.size()); + std::vector<uint8_t> Tag(AesGcm::TagSize); + std::optional<std::string> Reason; + + MemoryView CipherView = AesGcm::Encrypt(Key, + Nonce, + /*Aad*/ {}, + MakeMemoryView(Plain), + MakeMutableMemoryView(Cipher), + MakeMutableMemoryView(Tag), + Reason); + REQUIRE(!Reason.has_value()); + + // Flip a bit in the ciphertext. + Cipher[0] ^= 0x01; + + std::vector<uint8_t> Decoded(Plain.size()); + MemoryView DecodedView = + AesGcm::Decrypt(Key, Nonce, /*Aad*/ {}, CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason); + CHECK(Reason.has_value()); + CHECK(DecodedView.IsEmpty()); + } + + SUBCASE("tampered tag fails authentication") + { + std::string_view Plain = "important"sv; + + std::vector<uint8_t> Cipher(Plain.size()); + std::vector<uint8_t> Tag(AesGcm::TagSize); + std::optional<std::string> Reason; + + MemoryView CipherView = AesGcm::Encrypt(Key, + Nonce, + /*Aad*/ {}, + MakeMemoryView(Plain), + MakeMutableMemoryView(Cipher), + MakeMutableMemoryView(Tag), + Reason); + REQUIRE(!Reason.has_value()); + + Tag[0] ^= 0x80; + + std::vector<uint8_t> Decoded(Plain.size()); + MemoryView DecodedView = + AesGcm::Decrypt(Key, Nonce, /*Aad*/ {}, CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason); + CHECK(Reason.has_value()); + CHECK(DecodedView.IsEmpty()); + } + + SUBCASE("AAD mismatch fails authentication") + { + std::string_view Plain = "payload"sv; + std::string_view AadOk = "expected header"sv; + std::string_view AadNo = "different header"sv; + + std::vector<uint8_t> Cipher(Plain.size()); + std::vector<uint8_t> Tag(AesGcm::TagSize); + std::optional<std::string> Reason; + + MemoryView CipherView = AesGcm::Encrypt(Key, + Nonce, + MakeMemoryView(AadOk), + MakeMemoryView(Plain), + MakeMutableMemoryView(Cipher), + MakeMutableMemoryView(Tag), + Reason); + REQUIRE(!Reason.has_value()); + + std::vector<uint8_t> Decoded(Plain.size()); + MemoryView DecodedView = + AesGcm::Decrypt(Key, Nonce, MakeMemoryView(AadNo), CipherView, MakeMemoryView(Tag), MakeMutableMemoryView(Decoded), Reason); + CHECK(Reason.has_value()); + CHECK(DecodedView.IsEmpty()); + } + + SUBCASE("wrong nonce size is rejected") + { + std::string_view Plain = "x"sv; + + const uint8_t TooShort[8] = {0}; + std::vector<uint8_t> Cipher(Plain.size()); + std::vector<uint8_t> Tag(AesGcm::TagSize); + std::optional<std::string> Reason; + + MemoryView CipherView = AesGcm::Encrypt(Key, + MakeMemoryView(TooShort), + /*Aad*/ {}, + MakeMemoryView(Plain), + MakeMutableMemoryView(Cipher), + MakeMutableMemoryView(Tag), + Reason); + CHECK(Reason.has_value()); + CHECK(CipherView.IsEmpty()); + } +} + +TEST_CASE("crypto.random") +{ + SUBCASE("fills buffer with non-zero bytes") + { + uint8_t Buffer[32] = {}; + MutableMemoryView View = MakeMutableMemoryView(Buffer); + const bool Ok = CryptoRandom::Fill(View); + REQUIRE(Ok); + + // Probability of 32 all-zero bytes from a CSPRNG is 2^-256 — we + // accept it as "effectively never". + bool AnyNonZero = false; + for (uint8_t B : Buffer) + { + if (B != 0) + { + AnyNonZero = true; + break; + } + } + CHECK(AnyNonZero); + } + + SUBCASE("two calls produce different output") + { + uint8_t A[32] = {}; + uint8_t B[32] = {}; + CHECK(CryptoRandom::Fill(MakeMutableMemoryView(A))); + CHECK(CryptoRandom::Fill(MakeMutableMemoryView(B))); + CHECK(memcmp(A, B, 32) != 0); + } + + SUBCASE("zero-size buffer is a no-op success") + { + uint8_t Dummy = 0xAB; + CHECK(CryptoRandom::Fill(MutableMemoryView(&Dummy, size_t{0}))); + CHECK_EQ(Dummy, 0xAB); + } +} + TEST_CASE("crypto.securerandom") { std::array<uint8_t, 64> A{}; diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp index e8ceac5c0..1e18ef8be 100644 --- a/src/zencore/filesystem.cpp +++ b/src/zencore/filesystem.cpp @@ -181,7 +181,7 @@ WipeDirectory(const wchar_t* DirPath, bool KeepDotFiles) bool Success = true; - if (hFind != nullptr) + if (hFind != INVALID_HANDLE_VALUE) { do { @@ -436,15 +436,18 @@ RemoveFileNative(const std::filesystem::path& Path, bool ForceRemoveReadOnlyFile if (!ForceRemoveReadOnlyFiles) { struct stat Stat; - int err = stat(Path.native().c_str(), &Stat); - if (err != 0) + if (stat(Path.native().c_str(), &Stat) != 0) { - int32_t err = errno; - if (err == ENOENT) + const int StatErrno = errno; + if (StatErrno == ENOENT) { Ec.clear(); - return false; } + else + { + Ec = MakeErrorCode(StatErrno); + } + return false; } const uint32_t Mode = (uint32_t)Stat.st_mode; if (IsFileModeReadOnly(Mode)) @@ -3437,6 +3440,37 @@ MakeSafeAbsolutePath(const std::filesystem::path& Path) return Tmp; } +std::optional<std::filesystem::path> +ResolveSafeRelativePath(const std::filesystem::path& TrustedRoot, std::string_view RelativePath) +{ + if (RelativePath.empty()) + { + return std::nullopt; + } + + std::filesystem::path Requested(RelativePath); + if (Requested.is_absolute() || Requested.has_root_name() || Requested.has_root_directory()) + { + return std::nullopt; + } + for (const std::filesystem::path& Component : Requested) + { + if (Component == "..") + { + return std::nullopt; + } + } + + const std::filesystem::path NormalizedRoot = TrustedRoot.lexically_normal(); + const std::filesystem::path Joined = (NormalizedRoot / Requested).lexically_normal(); + if (std::mismatch(NormalizedRoot.begin(), NormalizedRoot.end(), Joined.begin(), Joined.end()).first != NormalizedRoot.end()) + { + return std::nullopt; + } + + return Joined; +} + class SharedMemoryImpl : public SharedMemory { public: @@ -4238,6 +4272,41 @@ TEST_CASE("filesystem.MakeSafeAbsolutePath") # endif // ZEN_PLATFORM_WINDOWS } +TEST_CASE("filesystem.ResolveSafeRelativePath") +{ + const std::filesystem::path Root = std::filesystem::path("root") / "traces"; + + // Empty input is rejected. + CHECK_FALSE(ResolveSafeRelativePath(Root, "").has_value()); + + // A plain relative path resolves under the root. + { + auto Resolved = ResolveSafeRelativePath(Root, "session.utrace"); + REQUIRE(Resolved.has_value()); + CHECK_EQ(*Resolved, (Root.lexically_normal() / "session.utrace")); + } + + // Nested relative segments are allowed as long as they stay inside the root. + { + auto Resolved = ResolveSafeRelativePath(Root, "2026-04/session.utrace"); + REQUIRE(Resolved.has_value()); + CHECK_EQ(*Resolved, (Root.lexically_normal() / "2026-04" / "session.utrace")); + } + + // ".." components are rejected before normalisation can collapse them. + CHECK_FALSE(ResolveSafeRelativePath(Root, "..").has_value()); + CHECK_FALSE(ResolveSafeRelativePath(Root, "../etc/passwd").has_value()); + CHECK_FALSE(ResolveSafeRelativePath(Root, "foo/../../bar").has_value()); + + // Absolute paths are rejected on both platforms. + CHECK_FALSE(ResolveSafeRelativePath(Root, "/etc/passwd").has_value()); +# if ZEN_PLATFORM_WINDOWS + CHECK_FALSE(ResolveSafeRelativePath(Root, "C:/Windows/win.ini").has_value()); + CHECK_FALSE(ResolveSafeRelativePath(Root, "C:\\Windows\\win.ini").has_value()); + CHECK_FALSE(ResolveSafeRelativePath(Root, "\\\\server\\share\\evil").has_value()); +# endif +} + TEST_CASE("ExpandEnvironmentVariables") { // No variables - pass-through diff --git a/src/zencore/include/zencore/crypto.h b/src/zencore/include/zencore/crypto.h index a5e23135f..36bf80527 100644 --- a/src/zencore/include/zencore/crypto.h +++ b/src/zencore/include/zencore/crypto.h @@ -73,6 +73,81 @@ public: std::optional<std::string>& Reason); }; +/** + * AES-256-GCM authenticated encryption. + * + * GCM is an AEAD construction: the authentication tag proves both the + * ciphertext and any additional associated data (AAD) were not modified. + * Use this instead of plain Aes (CBC) whenever a persistent artefact + * could be tampered with by a local attacker — CBC has no integrity + * check and will happily decrypt bit-flipped ciphertext into corrupted + * plaintext. + * + * Nonce must be exactly 12 bytes (NIST SP 800-38D recommended size). + * Tag is always exactly 16 bytes. + * + * CRITICAL: never reuse a (key, nonce) pair. GCM is catastrophically + * broken under nonce reuse — both confidentiality and authenticity fail. + * For every encryption, the caller must supply a fresh nonce drawn from + * a CSPRNG (a 96-bit random nonce has negligible collision probability + * for any reasonable message volume) or from a strictly monotonic + * counter under a single key. + */ +class AesGcm +{ +public: + static constexpr size_t NonceSize = 12; // bytes + static constexpr size_t TagSize = 16; // bytes + + /** + * Encrypt Plaintext. Writes ciphertext of the same length as Plaintext + * into Out, and the 16-byte authentication tag into OutTag. Aad may be + * empty. Returns the ciphertext view on success, or an empty view with + * Reason set on failure. + */ + static MemoryView Encrypt(const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView Plaintext, + MutableMemoryView Out, + MutableMemoryView OutTag, + std::optional<std::string>& Reason); + + /** + * Decrypt Ciphertext, verifying against the supplied 16-byte Tag. + * Returns the plaintext view on success. On authentication failure or + * any other error, returns an empty view with Reason set; the Out + * buffer contents are undefined in that case. + */ + static MemoryView Decrypt(const AesKey256Bit& Key, + MemoryView Nonce, + MemoryView Aad, + MemoryView Ciphertext, + MemoryView Tag, + MutableMemoryView Out, + std::optional<std::string>& Reason); +}; + +/** + * Cryptographically secure random-byte source. + * + * Uses the OS CSPRNG where available (BCryptGenRandom on Windows, + * RAND_bytes on OpenSSL). Suitable for generating keys, nonces, and + * any other value whose predictability matters — do NOT use + * std::mt19937 or std::random_device alone for these purposes. + */ +class CryptoRandom +{ +public: + /** + * Fill Buffer with cryptographically random bytes. Returns true on + * success; on failure returns false and writes a human-readable cause + * to Reason if provided. On failure the buffer contents are + * unspecified. + */ + static bool Fill(MutableMemoryView Buffer, std::optional<std::string>* Reason = nullptr); +}; + // Fill Out with cryptographically secure random bytes from the platform RNG. // Returns false if the platform RNG failed; on success Out is filled entirely. bool SecureRandomBytes(MutableMemoryView Out); diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h index 6218c3421..8f6209f79 100644 --- a/src/zencore/include/zencore/filesystem.h +++ b/src/zencore/include/zencore/filesystem.h @@ -451,6 +451,19 @@ bool SetFileReadOnly(const std::filesystem::path& Filename, bool ReadOnly); void MakeSafeAbsolutePathInPlace(std::filesystem::path& Path); [[nodiscard]] std::filesystem::path MakeSafeAbsolutePath(const std::filesystem::path& Path); +/** Resolve an attacker-supplied relative path underneath a trusted root directory. + * + * Rejects absolute paths, drive-qualified paths, root-directory paths, and any + * component equal to "..". Lexically-normalises the joined path and verifies the + * result still lives under the normalised root. Intended for validating HTTP + * query-parameter paths before passing them to filesystem APIs. + * + * Returns nullopt if RelativePath is empty, contains rejected syntax, or escapes + * TrustedRoot after normalisation. + */ +[[nodiscard]] std::optional<std::filesystem::path> ResolveSafeRelativePath(const std::filesystem::path& TrustedRoot, + std::string_view RelativePath); + class SharedMemory { public: diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h index 3183c7c0c..0fed11458 100644 --- a/src/zencore/include/zencore/sharedbuffer.h +++ b/src/zencore/include/zencore/sharedbuffer.h @@ -145,6 +145,12 @@ public: [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer.Get()); } + /** Returns true if the segment is backed by an extended core (e.g. a file range) whose data + * is materialized on demand. For extended segments, sub-ranges should be taken via + * IoBuffer(outer, offset, size) to avoid materializing the whole segment. Plain memory + * segments (IsExtended() == false) can be sub-viewed directly with GetView().Mid(...). */ + [[nodiscard]] inline bool IsExtended() const { return m_Buffer && m_Buffer->ExtendedCore() != nullptr; } + SharedBuffer& operator=(UniqueBuffer&& Rhs) { m_Buffer = std::move(Rhs.m_Buffer); diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h index fded960f3..53bfd196d 100644 --- a/src/zencore/include/zencore/string.h +++ b/src/zencore/include/zencore/string.h @@ -1058,6 +1058,34 @@ StrCaseEndsWith(std::string_view Str, std::string_view Suffix) } /** + * Constant-time string equality. Unlike std::string_view::operator==, which + * short-circuits on the first byte mismatch (and on size mismatch before + * inspecting any bytes), this helper always iterates over max(Lhs.size(), + * Rhs.size()) bytes and ORs every byte difference into an accumulator. The + * length-compare is folded into the final check so that neither length + * mismatches nor content differences produce a timing signal an attacker + * can observe remotely. + * + * Use for comparisons of secret-bearing data (passwords, bearer tokens, + * HMAC outputs) against an attacker-supplied value. + */ +inline bool +ConstantTimeEquals(std::string_view Lhs, std::string_view Rhs) +{ + const size_t N = std::max(Lhs.size(), Rhs.size()); + uint32_t Accum = static_cast<uint32_t>(Lhs.size() ^ Rhs.size()); + const size_t LhsLen = Lhs.size(); + const size_t RhsLen = Rhs.size(); + for (size_t I = 0; I < N; ++I) + { + const uint8_t LByte = (I < LhsLen) ? static_cast<uint8_t>(Lhs[I]) : 0; + const uint8_t RByte = (I < RhsLen) ? static_cast<uint8_t>(Rhs[I]) : 0; + Accum |= uint32_t(LByte ^ RByte); + } + return Accum == 0; +} + +/** * @brief * Helper function to implement case sensitive spaceship operator for strings. * MacOS clang version we use does not implement <=> for std::string diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp index 529afe341..d1aa0bb9e 100644 --- a/src/zencore/iobuffer.cpp +++ b/src/zencore/iobuffer.cpp @@ -10,6 +10,7 @@ #include <zencore/memory/llm.h> #include <zencore/memory/memory.h> #include <zencore/memoryview.h> +#include <zencore/scopeguard.h> #include <zencore/testing.h> #include <zencore/thread.h> #include <zencore/trace.h> @@ -375,7 +376,14 @@ IoBufferExtendedCore::Materialize() const /* offset */ MapOffset); #endif // ZEN_PLATFORM_WINDOWS - if (MappedBase == nullptr) +#if ZEN_PLATFORM_WINDOWS + const bool MapFailed = (MappedBase == nullptr); +#else + // mmap returns MAP_FAILED (not nullptr) on failure + const bool MapFailed = (MappedBase == MAP_FAILED); +#endif + + if (MapFailed) { int32_t Error = zen::GetLastError(); #if ZEN_PLATFORM_WINDOWS @@ -540,32 +548,54 @@ IoBufferBuilder::ReadFromFileMaybe(const IoBuffer& InBuffer) const uint64_t NumberOfBytesToRead = FileRef.FileChunkSize; const uint64_t FileOffset = FileRef.FileChunkOffset; + uint8_t* const Dst = reinterpret_cast<uint8_t*>(OutBuffer.MutableData()); + #if ZEN_PLATFORM_WINDOWS - OVERLAPPED Ovl{}; + while (BytesRead < NumberOfBytesToRead) + { + const uint64_t Remaining = NumberOfBytesToRead - BytesRead; + const uint64_t ChunkStart = FileOffset + BytesRead; + const DWORD ChunkSize = DWORD(Min<uint64_t>(Remaining, 0xffff'ffffu)); - Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); - Ovl.OffsetHigh = DWORD(FileOffset >> 32); + OVERLAPPED Ovl{}; + Ovl.Offset = DWORD(ChunkStart & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(ChunkStart >> 32); - DWORD dwNumberOfBytesRead = 0; - BOOL Success = ::ReadFile(FileRef.FileHandle, OutBuffer.MutableData(), DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl); - if (Success) - { - BytesRead = size_t(dwNumberOfBytesRead); - } - else - { - Error = zen::GetLastError(); + DWORD dwBytesRead = 0; + if (!::ReadFile(FileRef.FileHandle, Dst + BytesRead, ChunkSize, &dwBytesRead, &Ovl)) + { + Error = zen::GetLastError(); + break; + } + if (dwBytesRead == 0) + { + // Hit EOF before we got everything we asked for + break; + } + BytesRead += size_t(dwBytesRead); } #else int Fd = int(intptr_t(FileRef.FileHandle)); - ssize_t ReadResult = pread(Fd, OutBuffer.MutableData(), size_t(NumberOfBytesToRead), off_t(FileOffset)); - if (ReadResult != -1) + while (BytesRead < NumberOfBytesToRead) { - BytesRead = size_t(ReadResult); - } - else - { - Error = zen::GetLastError(); + const size_t Remaining = size_t(NumberOfBytesToRead - BytesRead); + const off_t ChunkStart = off_t(FileOffset + BytesRead); + const ssize_t ReadResult = pread(Fd, Dst + BytesRead, Remaining, ChunkStart); + if (ReadResult < 0) + { + if (errno == EINTR) + { + continue; + } + Error = zen::GetLastError(); + break; + } + if (ReadResult == 0) + { + // Hit EOF before we got everything we asked for + break; + } + BytesRead += size_t(ReadResult); } #endif @@ -632,10 +662,19 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of { return {}; } + auto FdGuard = MakeGuard([&Fd] { + if (Fd >= 0) + { + close(Fd); + } + }); static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); struct stat Stat; - fstat(Fd, &Stat); + if (fstat(Fd, &Stat) != 0) + { + return {}; + } FileSize = Stat.st_size; #endif // ZEN_PLATFORM_WINDOWS @@ -655,18 +694,20 @@ IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Of if (Size) { #if ZEN_PLATFORM_WINDOWS - void* Fd = DataFile.Detach(); -#endif + void* Fd = DataFile.Detach(); + auto HandleGuard = MakeGuard([Fd] { CloseHandle((HANDLE)Fd); }); + IoBuffer NewBuffer(IoBuffer::File, Fd, Offset, Size, Offset == 0 && Size == FileSize); + HandleGuard.Dismiss(); +#else IoBuffer NewBuffer(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size, Offset == 0 && Size == FileSize); + FdGuard.Dismiss(); +#endif NewBuffer.SetContentType(ContentType); return NewBuffer; } -#if !ZEN_PLATFORM_WINDOWS - close(Fd); -#endif - // For an empty file, we may as well just return an empty memory IoBuffer + // (FdGuard on non-Windows closes Fd automatically; DataFile destructor closes the HANDLE on Windows) return IoBuffer(IoBuffer::Wrap, "", 0); } @@ -694,23 +735,33 @@ IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName, Ze DataFile.GetSize((ULONGLONG&)FileSize); - Handle = DataFile.Detach(); + Handle = DataFile.Detach(); + auto HandleGuard = MakeGuard([Handle] { CloseHandle((HANDLE)Handle); }); #else int Fd = open(FileName.native().c_str(), O_RDONLY); if (Fd < 0) { return {}; } + auto FdGuard = MakeGuard([Fd] { close(Fd); }); static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); struct stat Stat; - fstat(Fd, &Stat); + if (fstat(Fd, &Stat) != 0) + { + return {}; + } FileSize = Stat.st_size; Handle = (void*)uintptr_t(Fd); #endif // ZEN_PLATFORM_WINDOWS IoBuffer NewBuffer(IoBuffer::File, Handle, 0, FileSize, /*IsWholeFile*/ true); +#if ZEN_PLATFORM_WINDOWS + HandleGuard.Dismiss(); +#else + FdGuard.Dismiss(); +#endif NewBuffer.SetContentType(ContentType); return NewBuffer; } diff --git a/src/zenhttp/auth/authmgr.cpp b/src/zenhttp/auth/authmgr.cpp index 2fa22f2c2..9d10b1ba6 100644 --- a/src/zenhttp/auth/authmgr.cpp +++ b/src/zenhttp/auth/authmgr.cpp @@ -12,6 +12,11 @@ #include <zencore/trace.h> #include <zenhttp/auth/oidc.h> +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif + #include <condition_variable> #include <memory> #include <shared_mutex> @@ -25,13 +30,37 @@ namespace zen { using namespace std::literals; namespace details { + // On-disk format for the GCM-encrypted authstate: + // + // [ 4 bytes: magic "ZEN1" ] + // [12 bytes: GCM nonce ] — fresh random value per write + // [ N bytes: ciphertext ] — same length as the plaintext + // [16 bytes: GCM tag ] + // + // The magic is bound into the authentication by passing it as AAD to + // AesGcm; that ties the tag to the exact header bytes, so an attacker + // cannot strip the format indicator or swap to a different format without + // failing authentication. + // + // The magic also doubles as the format version. If the framing ever + // needs to change (e.g. adding a key-identifier for rotation), bump to + // "ZEN2" and teach the reader to accept both. + constexpr std::string_view kAuthStateMagic = "ZEN1"; + constexpr size_t kAuthStateHeaderSz = 4 + AesGcm::NonceSize; + IoBuffer ReadEncryptedFile(std::filesystem::path Path, const AesKey256Bit& Key, - const AesIV128Bit& IV, - std::optional<std::string>& Reason) + const AesIV128Bit& LegacyIV, + std::optional<std::string>& Reason, + bool* OutWasLegacy = nullptr) { ZEN_TRACE_CPU("AuthMgr::ReadEncryptedFile"); + if (OutWasLegacy) + { + *OutWasLegacy = false; + } + FileContents Result = ReadFile(Path); if (Result.ErrorCode) @@ -46,24 +75,52 @@ namespace details { return IoBuffer(); } + // Current format: magic-prefixed AES-256-GCM + if (EncryptedBuffer.GetSize() >= kAuthStateHeaderSz + AesGcm::TagSize && + memcmp(EncryptedBuffer.GetData(), kAuthStateMagic.data(), kAuthStateMagic.size()) == 0) + { + const uint8_t* Bytes = static_cast<const uint8_t*>(EncryptedBuffer.GetData()); + const size_t Total = EncryptedBuffer.GetSize(); + const size_t CipherSize = Total - kAuthStateHeaderSz - AesGcm::TagSize; + + MemoryView Magic(Bytes, kAuthStateMagic.size()); + MemoryView Nonce(Bytes + kAuthStateMagic.size(), AesGcm::NonceSize); + MemoryView Cipher(Bytes + kAuthStateHeaderSz, CipherSize); + MemoryView Tag(Bytes + kAuthStateHeaderSz + CipherSize, AesGcm::TagSize); + + std::vector<uint8_t> PlainBuffer(CipherSize); + MemoryView PlainView = AesGcm::Decrypt(Key, Nonce, /*Aad=*/Magic, Cipher, Tag, MakeMutableMemoryView(PlainBuffer), Reason); + + if (PlainView.IsEmpty()) + { + return IoBuffer(); + } + + return IoBufferBuilder::MakeCloneFromMemory(PlainView); + } + + // Legacy format: raw AES-256-CBC with the caller-configured IV. + // Decrypt it once so we don't lose the user's cached state across + // the upgrade; the next SaveState will rewrite in the GCM format. std::vector<uint8_t> DecryptionBuffer; DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize); - MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); + MemoryView DecryptedView = Aes::Decrypt(Key, LegacyIV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); if (DecryptedView.IsEmpty()) { return IoBuffer(); } + if (OutWasLegacy) + { + *OutWasLegacy = true; + } + return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); } - void WriteEncryptedFile(std::filesystem::path Path, - IoBuffer FileData, - const AesKey256Bit& Key, - const AesIV128Bit& IV, - std::optional<std::string>& Reason) + void WriteEncryptedFile(std::filesystem::path Path, IoBuffer FileData, const AesKey256Bit& Key, std::optional<std::string>& Reason) { ZEN_TRACE_CPU("AuthMgr::WriteEncryptedFile"); @@ -72,17 +129,44 @@ namespace details { return; } - std::vector<uint8_t> EncryptionBuffer; - EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize); + if (!Key.IsValid()) + { + Reason = "invalid key"; + return; + } + + // Fresh nonce per write. Never reuse a (Key, Nonce) pair — GCM is + // catastrophically broken under nonce reuse. + uint8_t Nonce[AesGcm::NonceSize]; + if (!CryptoRandom::Fill(MakeMutableMemoryView(Nonce), &Reason)) + { + return; + } + + std::vector<uint8_t> FileBuffer; + FileBuffer.resize(kAuthStateHeaderSz + FileData.GetSize() + AesGcm::TagSize); + + // [magic][nonce][cipher][tag] + memcpy(FileBuffer.data(), kAuthStateMagic.data(), kAuthStateMagic.size()); + memcpy(FileBuffer.data() + kAuthStateMagic.size(), Nonce, AesGcm::NonceSize); - MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason); + MutableMemoryView CipherOut(FileBuffer.data() + kAuthStateHeaderSz, FileData.GetSize()); + MutableMemoryView TagOut(FileBuffer.data() + kAuthStateHeaderSz + FileData.GetSize(), AesGcm::TagSize); - if (EncryptedView.IsEmpty()) + MemoryView CipherView = AesGcm::Encrypt(Key, + MakeMemoryView(Nonce), + /*Aad=*/MakeMemoryView(kAuthStateMagic), + FileData.GetView(), + CipherOut, + TagOut, + Reason); + + if (CipherView.IsEmpty()) { return; } - TemporaryFile::SafeWriteFile(Path, EncryptedView); + TemporaryFile::SafeWriteFile(Path, MakeMemoryView(FileBuffer)); } } // namespace details @@ -191,10 +275,9 @@ public: bool IsNew = false; { - auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, - .RefreshToken = RefreshResult.RefreshToken, - .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), - .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + auto Token = OpenIdToken{.RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; std::unique_lock _(m_TokenMutex); @@ -240,7 +323,6 @@ private: struct OpenIdToken { - std::string IdentityToken; std::string RefreshToken; std::string AccessToken; TimePoint ExpireTime{}; @@ -283,9 +365,18 @@ private: try { std::optional<std::string> Reason; + bool WasLegacy = false; + + IoBuffer Buffer = details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, + m_Config.EncryptionKey, + m_Config.EncryptionIV, + Reason, + &WasLegacy); - IoBuffer Buffer = - details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason); + if (Buffer && WasLegacy) + { + ZEN_INFO("authstate read via legacy AES-CBC fallback; next save will migrate to AES-GCM"); + } if (!Buffer) { @@ -399,7 +490,6 @@ private: details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, AuthState.Save().GetBuffer().AsIoBuffer(), m_Config.EncryptionKey, - m_Config.EncryptionIV, Reason); if (Reason) @@ -466,10 +556,9 @@ private: { ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first); - auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, - .RefreshToken = RefreshResult.RefreshToken, - .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), - .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + auto Token = OpenIdToken{.RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; { std::unique_lock _(m_TokenMutex); @@ -534,4 +623,107 @@ AuthMgr::Create(const AuthConfig& Config) return std::make_unique<AuthMgrImpl>(Config); } +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("http.authmgr"); + +TEST_CASE("authmgr.authstate_gcm_roundtrip") +{ + ScopedTemporaryDirectory TmpDir; + std::filesystem::path Path = TmpDir.Path() / "authstate"; + + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const AesIV128Bit UnusedIv; // ignored on write, unused on GCM read + + const std::string_view Plain = "sensitive compact-binary payload representing cached auth state"sv; + IoBuffer InBuf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Plain)); + + std::optional<std::string> WriteReason; + details::WriteEncryptedFile(Path, InBuf, Key, WriteReason); + REQUIRE_FALSE(WriteReason.has_value()); + + std::optional<std::string> ReadReason; + bool WasLegacy = true; + IoBuffer Out = details::ReadEncryptedFile(Path, Key, UnusedIv, ReadReason, &WasLegacy); + REQUIRE_FALSE(ReadReason.has_value()); + REQUIRE(Out.GetSize() == Plain.size()); + CHECK(memcmp(Out.GetData(), Plain.data(), Plain.size()) == 0); + CHECK_FALSE(WasLegacy); +} + +TEST_CASE("authmgr.authstate_legacy_cbc_fallback") +{ + // Hand-craft a legacy-format authstate file (raw AES-CBC, no magic) and + // verify the reader falls back, returns the plaintext, and flags the file + // as legacy so the caller can rewrite it in the new format. + ScopedTemporaryDirectory TmpDir; + std::filesystem::path Path = TmpDir.Path() / "authstate"; + + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const uint8_t IvBytes[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const AesIV128Bit LegacyIv = AesIV128Bit::FromMemoryView(MakeMemoryView(IvBytes)); + + const std::string_view Plain = "legacy authstate payload"sv; + + std::vector<uint8_t> CbcBuf(Plain.size() + Aes::BlockSize); + std::optional<std::string> CbcReason; + MemoryView CbcView = Aes::Encrypt(Key, LegacyIv, MakeMemoryView(Plain), MakeMutableMemoryView(CbcBuf), CbcReason); + REQUIRE_FALSE(CbcReason.has_value()); + + TemporaryFile::SafeWriteFile(Path, CbcView); + + std::optional<std::string> ReadReason; + bool WasLegacy = false; + IoBuffer Out = details::ReadEncryptedFile(Path, Key, LegacyIv, ReadReason, &WasLegacy); + REQUIRE_FALSE(ReadReason.has_value()); + REQUIRE(Out.GetSize() == Plain.size()); + CHECK(memcmp(Out.GetData(), Plain.data(), Plain.size()) == 0); + CHECK(WasLegacy); +} + +TEST_CASE("authmgr.authstate_gcm_tamper_detection") +{ + ScopedTemporaryDirectory TmpDir; + std::filesystem::path Path = TmpDir.Path() / "authstate"; + + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const AesIV128Bit UnusedIv; + + const std::string_view Plain = "payload that should not decrypt after tampering"sv; + IoBuffer InBuf = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView(Plain)); + + std::optional<std::string> WriteReason; + details::WriteEncryptedFile(Path, InBuf, Key, WriteReason); + REQUIRE_FALSE(WriteReason.has_value()); + + // Flip a single byte in the middle of the ciphertext region (past magic + + // nonce) and verify the GCM tag check rejects the tampered file. Copy + // the bytes out first and drop the FileContents / IoBuffer handles before + // rewriting so the underlying file isn't still open when we overwrite it. + std::vector<uint8_t> Mutated; + { + FileContents FC = ReadFile(Path); + IoBuffer Whole = FC.Flatten(); + REQUIRE(Whole.GetSize() > 4 + AesGcm::NonceSize); + Mutated.assign(static_cast<const uint8_t*>(Whole.GetData()), static_cast<const uint8_t*>(Whole.GetData()) + Whole.GetSize()); + } + Mutated[4 + AesGcm::NonceSize] ^= 0x40; + TemporaryFile::SafeWriteFile(Path, MakeMemoryView(Mutated)); + + std::optional<std::string> ReadReason; + bool WasLegacy = false; + IoBuffer Out = details::ReadEncryptedFile(Path, Key, UnusedIv, ReadReason, &WasLegacy); + CHECK(ReadReason.has_value()); + CHECK(Out.GetSize() == 0); +} + +TEST_SUITE_END(); + +void +authmgr_forcelink() +{ +} + +#endif // ZEN_WITH_TESTS + } // namespace zen diff --git a/src/zenhttp/auth/authservice.cpp b/src/zenhttp/auth/authservice.cpp index f89ca91da..d8a1588c6 100644 --- a/src/zenhttp/auth/authservice.cpp +++ b/src/zenhttp/auth/authservice.cpp @@ -3,6 +3,7 @@ #include "zenhttp/auth/authservice.h" #include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> #include <zencore/string.h> #include <zenhttp/auth/authmgr.h> @@ -21,6 +22,19 @@ HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) [this](HttpRouterRequest& RouterRequest) { HttpServerRequest& ServerRequest = RouterRequest.ServerRequest(); + // The refresh-token endpoint is how the local Unreal Editor / zen CLI + // hands off a refresh token obtained via an interactive flow. There + // is no legitimate reason for it to be reachable from the network: a + // remote caller that can reach it can evict the cached token (DoS) + // or — with a valid same-provider token — substitute credentials used + // by GetOpenIdAccessToken. Require the origin to be the local + // machine; any configured password filter still runs in front. + if (!ServerRequest.IsLocalMachineRequest()) + { + ZEN_WARN("rejecting non-local POST to /auth/oidc/refreshtoken"); + return ServerRequest.WriteResponse(HttpResponseCode::Forbidden); + } + const HttpContentType ContentType = ServerRequest.RequestContentType(); if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false) diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp index 23bbc17e8..63c8d9030 100644 --- a/src/zenhttp/auth/oidc.cpp +++ b/src/zenhttp/auth/oidc.cpp @@ -32,6 +32,68 @@ namespace details { using namespace std::literals; +// Return the "scheme://host[:port]" prefix of an absolute URL, or empty on +// malformed input. RFC 8414 §3 requires every endpoint advertised by an +// OP to live under the issuer's origin, so exact-prefix comparison on this +// value is sufficient to pin an endpoint to the configured IdP. +static std::string_view +OriginOf(std::string_view Url) +{ + // scheme://... + auto SchemeEnd = Url.find("://"); + if (SchemeEnd == std::string_view::npos) + { + return {}; + } + // First path char (or query / fragment) after the authority. + const size_t AuthorityStart = SchemeEnd + 3; + size_t OriginEnd = Url.size(); + for (size_t I = AuthorityStart; I < Url.size(); ++I) + { + const char C = Url[I]; + if (C == '/' || C == '?' || C == '#') + { + OriginEnd = I; + break; + } + } + // Require at least one character in the authority. + if (OriginEnd == AuthorityStart) + { + return {}; + } + return Url.substr(0, OriginEnd); +} + +// True if the URL uses a scheme / host that we trust for discovery even +// without HTTPS — narrowly, only loopback for dev / test setups. +static bool +IsLoopbackHttp(std::string_view Url) +{ + constexpr std::string_view Prefixes[] = { + "http://localhost"sv, + "http://127.0.0.1"sv, + "http://[::1]"sv, + }; + for (std::string_view P : Prefixes) + { + if (Url.size() >= P.size() && Url.substr(0, P.size()) == P) + { + // Ensure the next char (if any) ends the authority cleanly. + if (Url.size() == P.size()) + { + return true; + } + const char Next = Url[P.size()]; + if (Next == ':' || Next == '/' || Next == '?' || Next == '#') + { + return true; + } + } + } + return false; +} + static std::string FormUrlEncode(std::string_view Input) { @@ -60,6 +122,19 @@ OidcClient::OidcClient(const OidcClient::Options& Options) OidcClient::InitResult OidcClient::Initialize() { + // The OIDC discovery document determines where we send refresh tokens, so + // the transport to the discovery endpoint has to be trustworthy. Require + // HTTPS on the configured BaseUrl. Loopback is permitted over plain HTTP + // for developer setups that run a local IdP mock — no meaningful attack + // surface on the loopback interface. + if (m_BaseUrl.size() < 8 || m_BaseUrl.substr(0, 8) != "https://"sv) + { + if (!IsLoopbackHttp(m_BaseUrl)) + { + return {.Reason = "BaseUrl must use https:// (or a http://localhost / 127.0.0.1 / [::1] loopback)"}; + } + } + HttpClient Http{m_BaseUrl}; HttpClient::Response Response = Http.Get("/.well-known/openid-configuration"sv); @@ -81,14 +156,86 @@ OidcClient::Initialize() return {.Reason = std::move(JsonError)}; } - m_Config = {.Issuer = Json["issuer"].string_value(), + // RFC 8414 §3: the discovery document's `issuer` value MUST identify the + // OP and MUST be the origin used to fetch the document. Without this + // check, an attacker who can intercept discovery (or a misconfigured + // intermediate) can swap the issuer identity without detection. Accept + // a trailing '/' divergence since OPs vary. + const std::string Issuer = Json["issuer"].string_value(); + { + std::string_view ExpectedBase = m_BaseUrl; + while (!ExpectedBase.empty() && ExpectedBase.back() == '/') + { + ExpectedBase.remove_suffix(1); + } + std::string_view ActualIssuer = Issuer; + while (!ActualIssuer.empty() && ActualIssuer.back() == '/') + { + ActualIssuer.remove_suffix(1); + } + if (ActualIssuer.empty() || ActualIssuer != ExpectedBase) + { + return {.Reason = fmt::format("discovery issuer mismatch (expected '{}')", ExpectedBase)}; + } + } + + // Pin every endpoint we actually use to the same origin as BaseUrl. This + // is the last defense against a tampered discovery document redirecting + // token submissions to an attacker-controlled host. We check the + // endpoints we may call later; endpoints this client never dispatches to + // are left alone so a discovery document with unrelated auxiliary URLs + // isn't rejected for no reason. + const std::string_view BaseOrigin = OriginOf(m_BaseUrl); + if (BaseOrigin.empty()) + { + return {.Reason = "BaseUrl is malformed"}; + } + + const std::string TokenEndpoint = Json["token_endpoint"].string_value(); + const std::string UserInfoEndpoint = Json["userinfo_endpoint"].string_value(); + const std::string JwksUri = Json["jwks_uri"].string_value(); + + auto CheckOrigin = [&](std::string_view Name, std::string_view Url) -> std::optional<std::string> { + if (Url.empty()) + { + return std::nullopt; + } + const std::string_view Origin = OriginOf(Url); + if (Origin != BaseOrigin) + { + return fmt::format("discovery endpoint '{}' is off-origin (expected origin '{}')", Name, BaseOrigin); + } + return std::nullopt; + }; + + if (auto Err = CheckOrigin("token_endpoint"sv, TokenEndpoint); Err.has_value()) + { + return {.Reason = std::move(*Err)}; + } + if (auto Err = CheckOrigin("userinfo_endpoint"sv, UserInfoEndpoint); Err.has_value()) + { + return {.Reason = std::move(*Err)}; + } + if (auto Err = CheckOrigin("jwks_uri"sv, JwksUri); Err.has_value()) + { + return {.Reason = std::move(*Err)}; + } + + // token_endpoint is required for the refresh flow we implement; fail early + // rather than at RefreshToken time if the OP omitted it. + if (TokenEndpoint.empty()) + { + return {.Reason = "discovery document is missing token_endpoint"}; + } + + m_Config = {.Issuer = Issuer, .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(), - .TokenEndpoint = Json["token_endpoint"].string_value(), - .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), + .TokenEndpoint = TokenEndpoint, + .UserInfoEndpoint = UserInfoEndpoint, .RegistrationEndpoint = Json["registration_endpoint"].string_value(), .EndSessionEndpoint = Json["end_session_endpoint"].string_value(), .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(), - .JwksUri = Json["jwks_uri"].string_value(), + .JwksUri = JwksUri, .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]), @@ -118,7 +265,12 @@ OidcClient::RefreshToken(std::string_view RefreshToken) if (Response.StatusCode != HttpResponseCode::OK) { - return {.Reason = fmt::format("{} ({})", ToString(Response.StatusCode), Response.AsText())}; + // Do NOT include Response.AsText() in the reason string. Some IdPs + // echo the submitted refresh_token (or a prefix of it) in their error + // body — plumbing that into the Reason string causes AuthMgrImpl's + // ZEN_WARN in the refresh paths to write the token into the log. + // Only the status code is safe to surface up to the log sites. + return {.Reason = fmt::format("{} (provider returned {} bytes)", ToString(Response.StatusCode), Response.AsText().size())}; } std::string JsonError; @@ -129,10 +281,14 @@ OidcClient::RefreshToken(std::string_view RefreshToken) return {.Reason = std::move(JsonError)}; } + // Note: id_token is intentionally not parsed. It is a JWT whose contents + // are meaningful only after signature / issuer / audience / expiry + // verification against the provider's JWKS, and nothing downstream + // currently consumes it. Leaving it unparsed avoids planting an + // unauthenticated identity claim in the OpenIdToken cache. return {.TokenType = Json["token_type"].string_value(), .AccessToken = Json["access_token"].string_value(), .RefreshToken = Json["refresh_token"].string_value(), - .IdentityToken = Json["id_token"].string_value(), .Scope = Json["scope"].string_value(), .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()), .Ok = true}; diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index 4337fcb79..842bf9d49 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -23,6 +23,8 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { +using namespace std::literals; + ////////////////////////////////////////////////////////////////////////// struct HttpWsClient::Impl @@ -271,6 +273,43 @@ struct HttpWsClient::Impl }); } + // Trim ASCII LWS (space / tab) from both ends of a header value, along with + // a trailing CR if the caller didn't strip it. + static std::string_view TrimHeaderValue(std::string_view V) + { + while (!V.empty() && (V.front() == ' ' || V.front() == '\t')) + { + V.remove_prefix(1); + } + while (!V.empty() && (V.back() == ' ' || V.back() == '\t' || V.back() == '\r')) + { + V.remove_suffix(1); + } + return V; + } + + // Return true if a comma-separated header value contains the given token, + // case-insensitively. Used for Connection header parsing where the value + // may legitimately be "Upgrade, keep-alive" etc. + static bool HeaderContainsToken(std::string_view Value, std::string_view Token) + { + while (!Value.empty()) + { + auto CommaPos = Value.find(','); + std::string_view Part = TrimHeaderValue(Value.substr(0, CommaPos)); + if (Part.size() == Token.size() && StrCaseCompare(Part, Token) == 0) + { + return true; + } + if (CommaPos == std::string_view::npos) + { + break; + } + Value.remove_prefix(CommaPos + 1); + } + return false; + } + void DoReadHandshakeResponse() { WithSocket([this](auto& Socket) { @@ -284,30 +323,105 @@ struct HttpWsClient::Impl return; } - // Parse the response const auto& Data = m_ReadBuffer.data(); std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); // Consume the headers from the read buffer (any extra data stays for frame parsing) auto HeaderEnd = Response.find("\r\n\r\n"); - if (HeaderEnd != std::string::npos) + if (HeaderEnd == std::string::npos) { - m_ReadBuffer.consume(HeaderEnd + 4); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: incomplete headers"); + m_Handler.OnWsClose(1006, "handshake incomplete"); + return; } + m_ReadBuffer.consume(HeaderEnd + 4); + + // Parse the status line. Substring matching on "101" anywhere + // in the response is unsafe — a server returning + // "HTTP/1.1 404 Not Found\r\nX-Retry-After: 101\r\n" would have + // satisfied it. We require the first line to start with + // "HTTP/1.x 101" followed by end-of-line or space. + // + // ResponseView spans up through the first "\r\n" of the + // terminating "\r\n\r\n" so that every header line — including + // the last one — is terminated by "\r\n" in the view. + std::string_view ResponseView(Response.data(), HeaderEnd + 2); + auto StatusLineEnd = ResponseView.find("\r\n"); + if (StatusLineEnd == std::string_view::npos) + { + m_Handler.OnWsClose(1006, "handshake malformed"); + return; + } + std::string_view StatusLine = ResponseView.substr(0, StatusLineEnd); - // Validate 101 response - if (Response.find("101") == std::string::npos) + // Expect: "HTTP/1.x 101" (12 chars min), with 'x' being '0' or '1'. + bool StatusOk = StatusLine.size() >= 12 && StatusLine.substr(0, 7) == "HTTP/1." && + (StatusLine[7] == '0' || StatusLine[7] == '1') && StatusLine[8] == ' ' && + StatusLine.substr(9, 3) == "101" && (StatusLine.size() == 12 || StatusLine[12] == ' '); + if (!StatusOk) { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected; status line: {}", StatusLine.substr(0, 80)); m_Handler.OnWsClose(1006, "handshake rejected"); return; } - // Validate Sec-WebSocket-Accept + // Parse headers and extract the three fields RFC 6455 §4.1 + // requires a client to validate: Upgrade, Connection, and + // Sec-WebSocket-Accept. Case-insensitive on header names and + // on the Upgrade / Connection token values; exact-match on the + // Sec-WebSocket-Accept base64 value. + bool UpgradeOk = false; + bool ConnectionOk = false; + std::string_view AcceptValue; + + std::string_view HeaderBlock = ResponseView.substr(StatusLineEnd + 2); + while (!HeaderBlock.empty()) + { + auto NextLineEnd = HeaderBlock.find("\r\n"); + if (NextLineEnd == std::string_view::npos) + { + break; + } + std::string_view Line = HeaderBlock.substr(0, NextLineEnd); + HeaderBlock = HeaderBlock.substr(NextLineEnd + 2); + if (Line.empty()) + { + break; + } + + auto ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + continue; + } + std::string_view Name = Line.substr(0, ColonPos); + std::string_view Value = TrimHeaderValue(Line.substr(ColonPos + 1)); + + if (Name.size() == 7 && StrCaseCompare(Name, "Upgrade"sv) == 0) + { + UpgradeOk = (Value.size() == 9 && StrCaseCompare(Value, "websocket"sv) == 0); + } + else if (Name.size() == 10 && StrCaseCompare(Name, "Connection"sv) == 0) + { + ConnectionOk = HeaderContainsToken(Value, "upgrade"sv); + } + else if (Name.size() == 20 && StrCaseCompare(Name, "Sec-WebSocket-Accept"sv) == 0) + { + AcceptValue = Value; + } + } + + if (!UpgradeOk || !ConnectionOk) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake missing required Upgrade/Connection headers"); + m_Handler.OnWsClose(1006, "handshake missing headers"); + return; + } + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); - if (Response.find(ExpectedAccept) == std::string::npos) + if (AcceptValue.size() != ExpectedAccept.size() || AcceptValue != ExpectedAccept) { - ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid or missing Sec-WebSocket-Accept"); m_Handler.OnWsClose(1006, "invalid accept key"); return; } diff --git a/src/zenhttp/include/zenhttp/auth/authmgr.h b/src/zenhttp/include/zenhttp/auth/authmgr.h index 054588ab9..b61b221f6 100644 --- a/src/zenhttp/include/zenhttp/auth/authmgr.h +++ b/src/zenhttp/include/zenhttp/auth/authmgr.h @@ -17,7 +17,13 @@ struct AuthConfig std::filesystem::path RootDirectory; std::chrono::seconds UpdateInterval{30}; AesKey256Bit EncryptionKey; - AesIV128Bit EncryptionIV; + + // LEGACY: consulted only when reading a pre-AES-GCM authstate file written + // before the format migration. New writes use AES-GCM with a fresh random + // nonce per write and do not consult this field. Kept so existing + // deployments roll forward transparently; remove once the next format + // version bump lands. + AesIV128Bit EncryptionIV; }; class AuthMgr diff --git a/src/zenhttp/include/zenhttp/auth/oidc.h b/src/zenhttp/include/zenhttp/auth/oidc.h index 6f9c3198e..1008367df 100644 --- a/src/zenhttp/include/zenhttp/auth/oidc.h +++ b/src/zenhttp/include/zenhttp/auth/oidc.h @@ -39,7 +39,6 @@ public: std::string TokenType; std::string AccessToken; std::string RefreshToken; - std::string IdentityToken; std::string Scope; std::string Reason; int64_t ExpiresInSeconds{}; diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index 2d25515d3..22fb419c1 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -13,11 +13,12 @@ namespace zen { enum class WebSocketOpcode : uint8_t { - kText = 0x1, - kBinary = 0x2, - kClose = 0x8, - kPing = 0x9, - kPong = 0xA + kContinuation = 0x0, // RFC 6455 §5.4 - only seen inside the codec; handlers always receive the coalesced opcode + kText = 0x1, + kBinary = 0x2, + kClose = 0x8, + kPing = 0x9, + kPong = 0xA }; struct WebSocketMessage diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp index 0e3a743c3..70315fa03 100644 --- a/src/zenhttp/security/passwordsecurity.cpp +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -67,7 +67,10 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUr { return true; } - if (Password() == InPassword) + // Constant-time compare: a plain == short-circuits on the first byte mismatch + // (and on length mismatch before any bytes are inspected), exposing the + // configured password to a remote timing oracle across the network. + if (ConstantTimeEquals(Password(), InPassword)) { return true; } @@ -148,6 +151,19 @@ TEST_CASE("passwordsecurity.allowsomelocaluris") CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); } +TEST_CASE("passwordsecurity.constanttimeequals") +{ + // Documents the contract of the helper used to compare passwords. We + // cannot observe timing from here, but we can at least verify equality + // and inequality of the three interesting shapes: same length equal, + // same length unequal, and length mismatch. + CHECK(ConstantTimeEquals("abcdef"sv, "abcdef"sv)); + CHECK_FALSE(ConstantTimeEquals("abcdef"sv, "abcdeg"sv)); + CHECK_FALSE(ConstantTimeEquals("abcdef"sv, "abcdefg"sv)); + CHECK_FALSE(ConstantTimeEquals(""sv, "x"sv)); + CHECK(ConstantTimeEquals(""sv, ""sv)); +} + TEST_CASE("passwordsecurity.conflictingunprotecteduris") { try diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp index 87d8cc275..0428b5437 100644 --- a/src/zenhttp/security/passwordsecurityfilter.cpp +++ b/src/zenhttp/security/passwordsecurityfilter.cpp @@ -34,13 +34,41 @@ IHttpRequestFilter::Result PasswordHttpFilter::FilterRequest(HttpServerRequest& Request) { std::string_view Password; - std::string_view AuthorizationHeader = Request.GetAuthorizationHeader(); - size_t AuthorizationHeaderLength = AuthorizationHeader.length(); - if (AuthorizationHeaderLength > m_AuthenticationTypeString.length()) + std::string_view AuthorizationHeader = Request.GetAuthorizationHeader(); + + // Look for the configured scheme prefix (e.g. "Basic ") case-insensitively. + // Only extract a candidate credential when the scheme prefix matches and the + // remainder passes two defense-in-depth checks: + // * trim any trailing LWS (SP / HTAB) from the credential, so a + // well-behaved but "Authorization: Basic <b64> " style header still + // works instead of producing a silent 403; + // * reject the header outright if the credential contains any control + // byte (< 0x20) or DEL (0x7F) — valid Base64 never does, and passing + // such bytes downstream is a log-injection / smuggling vector. + if (AuthorizationHeader.length() > m_AuthenticationTypeString.length() && + StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0) { - if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0) + std::string_view Candidate = AuthorizationHeader.substr(m_AuthenticationTypeString.length()); + + while (!Candidate.empty() && (Candidate.back() == ' ' || Candidate.back() == '\t')) + { + Candidate.remove_suffix(1); + } + + bool HasControlByte = false; + for (char C : Candidate) + { + const uint8_t B = static_cast<uint8_t>(C); + if (B < 0x20 || B == 0x7F) + { + HasControlByte = true; + break; + } + } + + if (!HasControlByte) { - Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length()); + Password = Candidate; } } diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp index 078c21ea1..d223a50c0 100644 --- a/src/zenhttp/servers/wsasio.cpp +++ b/src/zenhttp/servers/wsasio.cpp @@ -107,11 +107,17 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData() const auto* Data = static_cast<const uint8_t*>(InputBuffer.data()); const auto Size = InputBuffer.size(); - WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); - if (!Frame.IsValid) + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size, /*RequireMask=*/true); + if (Frame.Status == WsFrameParseStatus::kNeedMoreData) { break; // not enough data yet } + if (Frame.Status == WsFrameParseStatus::kProtocolError) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket protocol violation; closing with 1002"); + DoClose(1002, "protocol error"); + return; + } m_ReadBuffer.consume(Frame.BytesConsumed); @@ -120,15 +126,70 @@ WsAsioConnectionT<SocketType>::ProcessReceivedData() m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); } + // Cap for a coalesced multi-frame message. Matches the single-frame cap + // in the codec; any legitimate zenhttp use case fits in either case. + static constexpr size_t kMaxAssembledSize = 4 * 1024 * 1024; + switch (Frame.Opcode) { case WebSocketOpcode::kText: case WebSocketOpcode::kBinary: { - WebSocketMessage Msg; - Msg.Opcode = Frame.Opcode; - Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); - m_Handler.OnWebSocketMessage(*this, Msg); + // Per RFC 6455 §5.4, a new data frame is a protocol violation when a + // fragmented message is already in progress (non-control frames must + // not be interleaved). + if (m_FragmentInProgress) + { + ZEN_LOG_DEBUG(WsLog(), "New data frame arrived mid-fragment; closing 1002"); + DoClose(1002, "interleaved data frame"); + return; + } + + if (Frame.Fin) + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + } + else + { + if (Frame.Payload.size() > kMaxAssembledSize) + { + DoClose(1009, "message too big"); + return; + } + m_FragmentOpcode = Frame.Opcode; + m_FragmentBuffer = std::move(Frame.Payload); + m_FragmentInProgress = true; + } + break; + } + + case WebSocketOpcode::kContinuation: + { + if (!m_FragmentInProgress) + { + ZEN_LOG_DEBUG(WsLog(), "Continuation frame with no message in progress; closing 1002"); + DoClose(1002, "unexpected continuation"); + return; + } + if (m_FragmentBuffer.size() + Frame.Payload.size() > kMaxAssembledSize) + { + DoClose(1009, "message too big"); + return; + } + m_FragmentBuffer.insert(m_FragmentBuffer.end(), Frame.Payload.begin(), Frame.Payload.end()); + if (Frame.Fin) + { + WebSocketMessage Msg; + Msg.Opcode = m_FragmentOpcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, m_FragmentBuffer.data(), m_FragmentBuffer.size()); + m_FragmentBuffer.clear(); + m_FragmentBuffer.shrink_to_fit(); + m_FragmentInProgress = false; + m_Handler.OnWebSocketMessage(*this, Msg); + } break; } diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h index 64602ee46..5db38db92 100644 --- a/src/zenhttp/servers/wsasio.h +++ b/src/zenhttp/servers/wsasio.h @@ -77,6 +77,12 @@ private: std::deque<std::vector<uint8_t>> m_WriteQueue; bool m_IsWriting = false; + // Fragmented-message reassembly (RFC 6455 §5.4). Only the reader touches + // these; no synchronization required. + std::vector<uint8_t> m_FragmentBuffer; + WebSocketOpcode m_FragmentOpcode{WebSocketOpcode::kContinuation}; + bool m_FragmentInProgress = false; + std::atomic<bool> m_IsOpen{true}; std::atomic<bool> m_CloseSent{false}; }; diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp index e452141fe..781f04c5e 100644 --- a/src/zenhttp/servers/wsframecodec.cpp +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -16,7 +16,7 @@ namespace zen { // WsFrameParseResult -WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size, bool RequireMask) { // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) if (Size < 2) @@ -24,10 +24,48 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) return {}; } - const bool Fin = (Data[0] & 0x80) != 0; - const uint8_t OpcodeRaw = Data[0] & 0x0F; - const bool Masked = (Data[1] & 0x80) != 0; - uint64_t PayloadLen = Data[1] & 0x7F; + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t RsvBits = Data[0] & 0x70; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + const uint8_t ShortLength = Data[1] & 0x7F; + uint64_t PayloadLen = ShortLength; + + const bool IsControlFrame = (OpcodeRaw & 0x08) != 0; + + // RFC 6455 section 5.2: RSV1/2/3 must be zero unless a negotiated extension + // defines them. We do not negotiate any extensions, so any non-zero RSV bit + // is a protocol violation. + if (RsvBits != 0) + { + WsFrameParseResult Error; + Error.Status = WsFrameParseStatus::kProtocolError; + return Error; + } + + // RFC 6455 section 5.5: control frames (Close / Ping / Pong and any opcode + // in 0x8..0xF) MUST NOT be fragmented and MUST have a payload of 125 bytes + // or less. Rejecting fragmented or oversized control frames prevents a + // peer from tying up unbounded memory inside an auto-pong, and closes off + // a class of smuggling tricks where handlers might observe partial control + // payloads. + if (IsControlFrame && (!Fin || ShortLength > 125)) + { + WsFrameParseResult Error; + Error.Status = WsFrameParseStatus::kProtocolError; + return Error; + } + + // RFC 6455 section 5.1: a server MUST close the connection upon receiving an + // unmasked client frame. Signal this distinctly from "need more data" so the + // server close path can trigger a 1002 close rather than stalling for bytes + // that will never satisfy the parse. + if (RequireMask && !Masked) + { + WsFrameParseResult Error; + Error.Status = WsFrameParseStatus::kProtocolError; + return Error; + } size_t HeaderSize = 2; @@ -51,11 +89,19 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) HeaderSize = 10; } - // Reject frames with unreasonable payload sizes to prevent OOM - static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB + // Reject frames with unreasonable payload sizes to bound per-connection + // memory. Parsers accumulate the whole frame before dispatch (see the + // read loops in wsasio.cpp / wshttpsys.cpp), so this cap also bounds the + // accumulator: a peer that advertises a large frame and streams bytes + // slowly cannot grow buffers past this limit. 4 MB is well above anything + // the monitoring / stats endpoints produce; raise it if a legitimate use + // case emerges. + static constexpr uint64_t kMaxPayloadSize = 4 * 1024 * 1024; // 4 MB if (PayloadLen > kMaxPayloadSize) { - return {}; + WsFrameParseResult Error; + Error.Status = WsFrameParseStatus::kProtocolError; + return Error; } const size_t MaskSize = Masked ? 4 : 0; @@ -70,6 +116,7 @@ WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) const uint8_t* PayloadData = Data + HeaderSize + MaskSize; WsFrameParseResult Result; + Result.Status = WsFrameParseStatus::kValid; Result.IsValid = true; Result.BytesConsumed = TotalFrame; Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw); diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h index 2d90b6fa1..0a62e6c80 100644 --- a/src/zenhttp/servers/wsframecodec.h +++ b/src/zenhttp/servers/wsframecodec.h @@ -15,11 +15,22 @@ namespace zen { /** + * Outcome of a frame parse attempt + */ +enum class WsFrameParseStatus : uint8_t +{ + kNeedMoreData, // the buffer does not yet contain a full frame + kValid, // a complete, well-formed frame was parsed + kProtocolError // the frame violates RFC 6455; caller must close the connection (typically with 1002) +}; + +/** * Result of attempting to parse a single WebSocket frame from a byte buffer */ struct WsFrameParseResult { - bool IsValid = false; // true if a complete frame was successfully parsed + WsFrameParseStatus Status = WsFrameParseStatus::kNeedMoreData; + bool IsValid = false; // true iff Status == kValid (convenience mirror) size_t BytesConsumed = 0; // number of bytes consumed from the input buffer WebSocketOpcode Opcode = WebSocketOpcode::kText; bool Fin = false; @@ -37,11 +48,17 @@ struct WsFrameCodec /** * Try to parse one complete frame from the front of the buffer. * - * Returns a result with IsValid == false and BytesConsumed == 0 when + * Returns a result with Status == kNeedMoreData and BytesConsumed == 0 when * there is not enough data yet. The caller should accumulate more data * and retry. + * + * When RequireMask is true (server parsing client-to-server frames), a frame + * that is not masked is rejected with Status == kProtocolError per RFC 6455 + * section 5.1 — the caller must close the connection with status 1002. + * When RequireMask is false (client parsing server-to-client frames), the + * mask bit is not validated. */ - static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size); + static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size, bool RequireMask = false); /** * Build a server-to-client frame (no masking) diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp index 8520e9f60..1f2b11bf1 100644 --- a/src/zenhttp/servers/wshttpsys.cpp +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -176,11 +176,17 @@ WsHttpSysConnection::ProcessReceivedData() { while (!m_Accumulated.empty()) { - WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); - if (!Frame.IsValid) + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size(), /*RequireMask=*/true); + if (Frame.Status == WsFrameParseStatus::kNeedMoreData) { break; // not enough data yet } + if (Frame.Status == WsFrameParseStatus::kProtocolError) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket protocol violation; closing with 1002"); + DoClose(1002, "protocol error"); + return; + } // Remove consumed bytes m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); @@ -190,15 +196,67 @@ WsHttpSysConnection::ProcessReceivedData() m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); } + // Cap for a coalesced multi-frame message. Matches the single-frame cap + // in the codec; any legitimate zenhttp use case fits in either case. + static constexpr size_t kMaxAssembledSize = 4 * 1024 * 1024; + switch (Frame.Opcode) { case WebSocketOpcode::kText: case WebSocketOpcode::kBinary: { - WebSocketMessage Msg; - Msg.Opcode = Frame.Opcode; - Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); - m_Handler.OnWebSocketMessage(*this, Msg); + if (m_FragmentInProgress) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "New data frame arrived mid-fragment; closing 1002"); + DoClose(1002, "interleaved data frame"); + return; + } + + if (Frame.Fin) + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + } + else + { + if (Frame.Payload.size() > kMaxAssembledSize) + { + DoClose(1009, "message too big"); + return; + } + m_FragmentOpcode = Frame.Opcode; + m_FragmentBuffer = std::move(Frame.Payload); + m_FragmentInProgress = true; + } + break; + } + + case WebSocketOpcode::kContinuation: + { + if (!m_FragmentInProgress) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "Continuation frame with no message in progress; closing 1002"); + DoClose(1002, "unexpected continuation"); + return; + } + if (m_FragmentBuffer.size() + Frame.Payload.size() > kMaxAssembledSize) + { + DoClose(1009, "message too big"); + return; + } + m_FragmentBuffer.insert(m_FragmentBuffer.end(), Frame.Payload.begin(), Frame.Payload.end()); + if (Frame.Fin) + { + WebSocketMessage Msg; + Msg.Opcode = m_FragmentOpcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, m_FragmentBuffer.data(), m_FragmentBuffer.size()); + m_FragmentBuffer.clear(); + m_FragmentBuffer.shrink_to_fit(); + m_FragmentInProgress = false; + m_Handler.OnWebSocketMessage(*this, Msg); + } break; } diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h index 6015e3873..49b0cdf67 100644 --- a/src/zenhttp/servers/wshttpsys.h +++ b/src/zenhttp/servers/wshttpsys.h @@ -87,6 +87,12 @@ private: std::vector<uint8_t> m_ReadBuffer; std::vector<uint8_t> m_Accumulated; + // Fragmented-message reassembly (RFC 6455 §5.4). Only the read-completion + // callback touches these; no synchronization required. + std::vector<uint8_t> m_FragmentBuffer; + WebSocketOpcode m_FragmentOpcode{WebSocketOpcode::kContinuation}; + bool m_FragmentInProgress = false; + // Write state RwLock m_WriteLock; std::deque<std::vector<uint8_t>> m_WriteQueue; diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp index a58037fec..d375aeb12 100644 --- a/src/zenhttp/servers/wstest.cpp +++ b/src/zenhttp/servers/wstest.cpp @@ -239,6 +239,97 @@ TEST_CASE("websocket.framecodec") CHECK_EQ(Result.Payload.size(), 70000u); } + SUBCASE("TryParseFrame - declared payload exceeding codec cap is rejected") + { + // Hand-build a masked Binary frame header that declares a 128 MB payload + // via the 64-bit extended length. The codec should reject this as a + // protocol error without trying to allocate or wait for the bytes. + std::vector<uint8_t> Frame; + Frame.push_back(0x82); // FIN + binary + Frame.push_back(0x80 | 127); // MASK + 64-bit extended length + uint64_t DeclaredLen = 128ull * 1024 * 1024; // 128 MB, well above the 4 MB cap + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((DeclaredLen >> (i * 8)) & 0xFF)); + } + uint8_t MaskKey[4] = {0x11, 0x22, 0x33, 0x44}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + // Do not include any payload bytes — the cap should trip before any are needed + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true); + CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError); + } + + SUBCASE("TryParseFrame - fragmented control frame rejected") + { + // Build a Close frame by hand with FIN=0 (fragmented — illegal for control frames) + std::vector<uint8_t> Frame; + Frame.push_back(static_cast<uint8_t>(WebSocketOpcode::kClose)); // no FIN bit + Frame.push_back(0x80 | 2); // MASK + 2-byte payload + uint8_t MaskKey[4] = {0x01, 0x02, 0x03, 0x04}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + Frame.push_back(0x03 ^ MaskKey[0]); // close code high byte (1000 = 0x03E8) + Frame.push_back(0xE8 ^ MaskKey[1]); // close code low byte + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true); + CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError); + } + + SUBCASE("TryParseFrame - oversized control payload rejected") + { + // Build a Ping with 126-byte payload (exceeds the 125-byte control-frame limit). + // Payload length = 126 forces the extended length encoding, which is itself + // illegal for a control frame — both conditions fail the spec. + std::vector<uint8_t> Payload(126, 0xAA); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPing, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true); + CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError); + } + + SUBCASE("TryParseFrame - non-zero RSV bits rejected") + { + // Start from a valid masked text frame and flip RSV1 on byte 0 + std::string_view Text = "rsv"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + Frame[0] |= 0x40; // set RSV1 + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true); + CHECK_EQ(Result.Status, WsFrameParseStatus::kProtocolError); + } + + SUBCASE("TryParseFrame - unmasked client frame rejected when RequireMask=true") + { + // Build a server-style (unmasked) text frame and feed it to the codec + // with RequireMask=true. RFC 6455 section 5.1 requires the server to + // reject unmasked client frames. + std::string_view Text = "unmasked"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + + WsFrameParseResult Strict = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true); + CHECK_EQ(Strict.Status, WsFrameParseStatus::kProtocolError); + CHECK_FALSE(Strict.IsValid); + + // Same bytes with RequireMask=false should still parse successfully + WsFrameParseResult Lenient = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/false); + CHECK_EQ(Lenient.Status, WsFrameParseStatus::kValid); + CHECK(Lenient.IsValid); + } + + SUBCASE("TryParseFrame - masked client frame accepted when RequireMask=true") + { + std::string_view Text = "masked"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size(), /*RequireMask=*/true); + CHECK_EQ(Result.Status, WsFrameParseStatus::kValid); + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), Text.size()); + } + SUBCASE("BuildMaskedCloseFrame roundtrip") { std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); @@ -309,6 +400,46 @@ namespace { return Frame; } + /** + * Build a masked frame with an explicit FIN bit (for fragmentation tests) + */ + std::vector<uint8_t> BuildMaskedFrameEx(WebSocketOpcode Opcode, std::span<const uint8_t> Payload, bool Fin) + { + std::vector<uint8_t> Frame; + + // FIN-or-not + opcode + Frame.push_back((Fin ? 0x80 : 0x00) | static_cast<uint8_t>(Opcode)); + + if (Payload.size() < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size())); + } + else if (Payload.size() <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF)); + } + } + + uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + for (size_t i = 0; i < Payload.size(); ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; + } + std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text) { std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); @@ -744,6 +875,191 @@ TEST_CASE("websocket.integration") Sock.close(); } + SUBCASE("fragmented message coalesced into single OnMessage") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Split "hello world" into three fragments: first is Text/FIN=0, + // two continuations (FIN=0, then FIN=1). + std::string_view Part1 = "hello "; + std::string_view Part2 = "wor"; + std::string_view Part3 = "ld"; + + std::vector<uint8_t> F1 = BuildMaskedFrameEx(WebSocketOpcode::kText, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part1.data()), Part1.size()), + /*Fin=*/false); + std::vector<uint8_t> F2 = BuildMaskedFrameEx(WebSocketOpcode::kContinuation, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part2.data()), Part2.size()), + /*Fin=*/false); + std::vector<uint8_t> F3 = BuildMaskedFrameEx(WebSocketOpcode::kContinuation, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part3.data()), Part3.size()), + /*Fin=*/true); + + asio::write(Sock, asio::buffer(F1)); + asio::write(Sock, asio::buffer(F2)); + asio::write(Sock, asio::buffer(F3)); + + // Server should dispatch a single coalesced message and echo it back. + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, "hello world"sv); + CHECK_EQ(TestService.m_MessageCount.load(), 1); + CHECK_EQ(TestService.m_LastMessage, "hello world"); + + Sock.close(); + } + + SUBCASE("unexpected continuation frame closes connection with 1002") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Continuation frame with no in-progress message → protocol error + std::string_view Part = "orphan"; + std::vector<uint8_t> F = BuildMaskedFrameEx(WebSocketOpcode::kContinuation, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part.data()), Part.size()), + /*Fin=*/true); + asio::write(Sock, asio::buffer(F)); + + // Server should send a close frame with 1002 + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + REQUIRE(Reply.Payload.size() >= 2); + uint16_t Code = (uint16_t(Reply.Payload[0]) << 8) | uint16_t(Reply.Payload[1]); + CHECK_EQ(Code, 1002); + + Sock.close(); + } + + SUBCASE("interleaved text frame mid-fragment closes connection with 1002") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Start a fragmented Text message, then send another Text frame without + // finishing the first — RFC 6455 §5.4 says this is illegal. + std::string_view Part1 = "open"; + std::vector<uint8_t> F1 = BuildMaskedFrameEx(WebSocketOpcode::kText, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part1.data()), Part1.size()), + /*Fin=*/false); + std::string_view Part2 = "interrupt"; + std::vector<uint8_t> F2 = BuildMaskedFrameEx(WebSocketOpcode::kText, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part2.data()), Part2.size()), + /*Fin=*/true); + + asio::write(Sock, asio::buffer(F1)); + asio::write(Sock, asio::buffer(F2)); + + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + REQUIRE(Reply.Payload.size() >= 2); + uint16_t Code = (uint16_t(Reply.Payload[0]) << 8) | uint16_t(Reply.Payload[1]); + CHECK_EQ(Code, 1002); + + Sock.close(); + } + + SUBCASE("ping interleaved with fragments is allowed") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send: Text frame FIN=0, then Ping (control), then Continuation FIN=1. + // The control frame must be dispatched immediately and not affect + // fragment state; the final continuation should coalesce and echo. + std::string_view Part1 = "ab"; + std::vector<uint8_t> F1 = BuildMaskedFrameEx(WebSocketOpcode::kText, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part1.data()), Part1.size()), + /*Fin=*/false); + std::string_view PingData = "p"; + std::vector<uint8_t> PingFrame = + BuildMaskedFrame(WebSocketOpcode::kPing, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(PingData.data()), PingData.size())); + std::string_view Part2 = "cd"; + std::vector<uint8_t> F2 = BuildMaskedFrameEx(WebSocketOpcode::kContinuation, + std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(Part2.data()), Part2.size()), + /*Fin=*/true); + + asio::write(Sock, asio::buffer(F1)); + asio::write(Sock, asio::buffer(PingFrame)); + asio::write(Sock, asio::buffer(F2)); + + // We expect two server-to-client frames in sequence: Pong (for the + // interleaved Ping) then Text (the coalesced echo). Read bytes into + // a growing buffer and parse both frames from the same buffer so + // ReadOneFrame's one-shot semantics don't lose the trailing bytes + // when both frames arrive in a single TCP segment. + std::vector<uint8_t> Buffer; + WsFrameParseResult Pong; + WsFrameParseResult Echo; + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (std::chrono::steady_clock::now() < Deadline) + { + uint8_t Tmp[4096]; + asio::error_code Ec; + size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); + if (Ec || BytesRead == 0) + { + break; + } + Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); + + if (!Pong.IsValid) + { + WsFrameParseResult F = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (F.IsValid) + { + Pong = std::move(F); + Buffer.erase(Buffer.begin(), Buffer.begin() + Pong.BytesConsumed); + } + } + if (Pong.IsValid) + { + WsFrameParseResult F = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (F.IsValid) + { + Echo = std::move(F); + break; + } + } + } + + REQUIRE(Pong.IsValid); + CHECK_EQ(Pong.Opcode, WebSocketOpcode::kPong); + + REQUIRE(Echo.IsValid); + CHECK_EQ(Echo.Opcode, WebSocketOpcode::kText); + std::string_view EchoText(reinterpret_cast<const char*>(Echo.Payload.data()), Echo.Payload.size()); + CHECK_EQ(EchoText, "abcd"sv); + + Sock.close(); + } + SUBCASE("multiple messages in sequence") { asio::io_context IoCtx; @@ -877,6 +1193,57 @@ TEST_CASE("websocket.client") CHECK_FALSE(Client.IsOpen()); } + SUBCASE("rejects non-101 response that contains '101' in a header value") + { + // Stand up a tiny TCP listener that replies to any connection with a + // 404 response whose headers happen to contain "101". An earlier, + // permissive client would accept this as a successful WebSocket + // upgrade; the hardened client must reject it. + asio::io_context FakeCtx; + asio::ip::tcp::acceptor Acceptor(FakeCtx, asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), 0)); + uint16_t FakePort = Acceptor.local_endpoint().port(); + + std::thread FakeServer([&] { + asio::ip::tcp::socket S(FakeCtx); + asio::error_code Ec; + Acceptor.accept(S, Ec); + if (Ec) + { + return; + } + asio::streambuf Req; + asio::read_until(S, Req, "\r\n\r\n", Ec); + std::string_view Resp = + "HTTP/1.1 404 Not Found\r\n" + "X-Retry-After: 101\r\n" + "Content-Length: 0\r\n" + "\r\n"; + asio::write(S, asio::buffer(Resp.data(), Resp.size()), Ec); + S.close(Ec); + }); + auto FakeGuard = MakeGuard([&] { + Acceptor.close(); + if (FakeServer.joinable()) + { + FakeServer.join(); + } + }); + + TestWsClientHandler Handler; + std::string FakeUrl = fmt::format("ws://127.0.0.1:{}/any", FakePort); + HttpWsClient FakeClient(FakeUrl, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); + FakeClient.Connect(); + + auto FakeDeadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < FakeDeadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_OpenCount.load(), 0); + CHECK_EQ(Handler.m_CloseCount.load(), 1); + } + SUBCASE("connect to bad port") { TestWsClientHandler Handler; diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 1317f0159..30a50de12 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -12,6 +12,7 @@ namespace zen { +void authmgr_forcelink(); void zipfs_test_forcelink(); void @@ -25,6 +26,7 @@ zenhttp_forcelinktests() forcelink_packageformat(); passwordsecurity_forcelink(); websocket_forcelink(); + authmgr_forcelink(); zipfs_test_forcelink(); } diff --git a/src/zenhttp/zipfs.cpp b/src/zenhttp/zipfs.cpp index c0ffa2052..3bd09091a 100644 --- a/src/zenhttp/zipfs.cpp +++ b/src/zenhttp/zipfs.cpp @@ -100,57 +100,105 @@ namespace { ////////////////////////////////////////////////////////////////////////// ZipFs::ZipFs(IoBuffer&& Buffer) { - MemoryView View = Buffer.GetView(); - - uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize(); - if (View.GetSize() < sizeof(EocdRecord)) + // Treat the input buffer as attacker-controlled. Every offset, size, and + // trailer length is validated against View.GetSize() before it is used to + // form a pointer. All additions are performed in uint64_t to prevent 32-bit + // wrap. + const MemoryView View = Buffer.GetView(); + const uint8_t* Base = static_cast<const uint8_t*>(View.GetData()); + const size_t Size = View.GetSize(); + + if (Size < sizeof(EocdRecord)) { return; } - const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord)); + const size_t EocdOffset = Size - sizeof(EocdRecord); + const EocdRecord* Eocd = reinterpret_cast<const EocdRecord*>(Base + EocdOffset); - // It is more correct to search backwards for EocdRecord::Magic as the - // comment can be of a variable length. But here we're not going to support - // zip files with comments. - if (EocdCursor->Signature != EocdRecord::Magic) + // We only support a zip whose EOCD sits at the very end of the buffer — no + // trailing comment, no Zip64. + if (Eocd->Signature != EocdRecord::Magic) + { + return; + } + if (Eocd->ThisDiskIndex == 0xffff) { return; } - // Zip64 isn't supported either - if (EocdCursor->ThisDiskIndex == 0xffff) + const uint32_t CdOffsetRel = Eocd->CdOffset; + const uint32_t CdSize = Eocd->CdSize; + const uint16_t CdRecordCount = Eocd->CdRecordCount; + + // Central directory must fit strictly before the EOCD. Derive the archive + // origin from the EOCD's declared layout so any pre-zip padding in the + // buffer is accounted for; LFH offsets are relative to this origin. + if (uint64_t(CdOffsetRel) + uint64_t(CdSize) > EocdOffset) { return; } - Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize); + const uint8_t* ArchiveStart = Base + (EocdOffset - CdOffsetRel - CdSize); + const uint8_t* CdCursor = ArchiveStart + CdOffsetRel; + const uint8_t* CdEnd = CdCursor + CdSize; - const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset); - for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i) + for (uint32_t Record = 0; Record < CdRecordCount; ++Record) { - const CentralDirectoryRecord& Cd = *CdCursor; + if (size_t(CdEnd - CdCursor) < sizeof(CentralDirectoryRecord)) + { + return; + } + const CentralDirectoryRecord& Cd = *reinterpret_cast<const CentralDirectoryRecord*>(CdCursor); + if (Cd.Signature != CentralDirectoryRecord::Magic) + { + return; + } + + const uint16_t NameLen = Cd.FileNameLength; + const uint16_t ExtraLen = Cd.ExtraFieldLength; + const uint16_t CommentLen = Cd.CommentLength; + const uint32_t Trailer = uint32_t(NameLen) + uint32_t(ExtraLen) + uint32_t(CommentLen); + if (size_t(CdEnd - CdCursor) - sizeof(CentralDirectoryRecord) < Trailer) + { + return; + } + + const uint16_t Compression = Cd.CompressionMethod; + const uint32_t Compressed = Cd.CompressedSize; + const uint32_t Original = Cd.OriginalSize; + const uint32_t LfhOffset = Cd.Offset; - bool Acceptable = true; - Acceptable &= (Cd.OriginalSize > 0); // has some content - Acceptable &= (Cd.CompressionMethod == 0 || Cd.CompressionMethod == 8); // stored or deflate - if (Acceptable) + const bool AcceptableCompression = (Compression == 0) || (Compression == 8); + const bool HasContent = Original > 0; + if (AcceptableCompression && HasContent) { - const uint8_t* Lfh = Cursor + Cd.Offset; - if (uintptr_t(Lfh - Cursor) < View.GetSize()) + // LFH header must fit inside the [ArchiveStart, CdOffsetRel) region + // (i.e. the pre-CD body). The LFH's own name + extra + compressed + // payload must also fit in that region. + const uint64_t LfhEndRel = uint64_t(LfhOffset) + sizeof(LocalFileHeader); + if (LfhEndRel <= CdOffsetRel) { - std::string_view FileName(Cd.FileName, Cd.FileNameLength); - FileItem Item; - Item.View = MemoryView{Lfh, size_t(0)}; - Item.CompressionMethod = Cd.CompressionMethod; - Item.CompressedSize = Cd.CompressedSize; - Item.UncompressedSize = Cd.OriginalSize; - m_Files.insert(std::make_pair(FileName, std::move(Item))); + const LocalFileHeader* Lfh = reinterpret_cast<const LocalFileHeader*>(ArchiveStart + LfhOffset); + const uint64_t DataStartRel = + LfhEndRel + uint64_t(uint16_t(Lfh->FileNameLength)) + uint64_t(uint16_t(Lfh->ExtraFieldLength)); + const uint64_t DataEndRel = DataStartRel + uint64_t(Compressed); + if (DataEndRel <= CdOffsetRel) + { + const uint8_t* FileData = ArchiveStart + DataStartRel; + std::string_view FileName(Cd.FileName, NameLen); + + FileItem Item; + Item.View = MemoryView{FileData, size_t(0)}; + Item.CompressionMethod = Compression; + Item.CompressedSize = Compressed; + Item.UncompressedSize = Original; + m_Files.insert({FileName, std::move(Item)}); + } } } - uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength; - CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes); + CdCursor += sizeof(CentralDirectoryRecord) + Trailer; } m_Buffer = std::move(Buffer); @@ -184,8 +232,7 @@ ZipFs::GetFile(const std::string_view& FileName) const return IoBuffer(IoBuffer::Wrap, Item.View.GetData(), Item.View.GetSize()); } - const auto* Lfh = (LocalFileHeader*)(Item.View.GetData()); - const uint8_t* FileData = (const uint8_t*)(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength); + const uint8_t* FileData = static_cast<const uint8_t*>(Item.View.GetData()); if (Item.CompressionMethod == 0) { diff --git a/src/zenhttp/zipfs_test.cpp b/src/zenhttp/zipfs_test.cpp index b3a45c408..0ccca7ce2 100644 --- a/src/zenhttp/zipfs_test.cpp +++ b/src/zenhttp/zipfs_test.cpp @@ -216,6 +216,150 @@ TEST_CASE("zipfs.not_found") CHECK(!Result); } +////////////////////////////////////////////////////////////////////////// +// Malformed / attacker-shaped inputs — the parser must refuse these rather +// than read out of bounds. Field offsets below mirror the in-memory layout: +// +// EocdRecord: CentralDirectoryRecord: +// +0 Signature (4) +0 Signature (4) +// +4 ThisDiskIndex (2) +20 CompressedSize (4) +// +6 CdStartDiskIndex (2) +28 FileNameLength (2) +// +8 CdRecordsThis (2) +42 Offset (4) +// +10 CdRecords (2) +// +12 CdSize (4) +// +16 CdOffset (4) +// +20 CommentSize (2) = 46 bytes header +// = 22 bytes + +namespace { +constexpr size_t kEocdSize = 22; +constexpr size_t kEocdOffCdSz = 12; +constexpr size_t kEocdOffCdOff = 16; +constexpr size_t kCdrOffCompSz = 20; +constexpr size_t kCdrOffNameLn = 28; +constexpr size_t kCdrOffLfhOff = 42; + +template<typename T> +void +WriteAt(zen::IoBuffer& Buf, size_t Offset, T Value) +{ + std::memcpy(static_cast<uint8_t*>(Buf.GetMutableView().GetData()) + Offset, &Value, sizeof(Value)); +} + +template<typename T> +T +ReadAt(const zen::IoBuffer& Buf, size_t Offset) +{ + T Out; + std::memcpy(&Out, static_cast<const uint8_t*>(Buf.GetView().GetData()) + Offset, sizeof(Out)); + return Out; +} + +size_t +EocdOffset(const zen::IoBuffer& Buf) +{ + return Buf.GetView().GetSize() - kEocdSize; +} + +size_t +CdOffset(const zen::IoBuffer& Buf) +{ + return ReadAt<uint32_t>(Buf, EocdOffset(Buf) + kEocdOffCdOff); +} +} // namespace + +TEST_CASE("zipfs.buffer_smaller_than_eocd") +{ + zen::IoBuffer Tiny(10); + std::memset(Tiny.GetMutableView().GetData(), 0, Tiny.GetView().GetSize()); + zen::ZipFs Fs(std::move(Tiny)); + CHECK(!Fs.GetFile("anything")); +} + +TEST_CASE("zipfs.bad_eocd_magic") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint32_t>(Buf, EocdOffset(Buf), 0xdeadbeef); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + +TEST_CASE("zipfs.cd_offset_past_eocd") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint32_t>(Buf, EocdOffset(Buf) + kEocdOffCdOff, 0xFFFF0000u); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + +TEST_CASE("zipfs.cd_size_past_eocd") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint32_t>(Buf, EocdOffset(Buf) + kEocdOffCdSz, 0xFFFF0000u); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + +TEST_CASE("zipfs.bad_cd_record_signature") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint32_t>(Buf, CdOffset(Buf), 0xdeadbeef); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + +TEST_CASE("zipfs.oversize_filename_length") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint16_t>(Buf, CdOffset(Buf) + kCdrOffNameLn, 0xFFFF); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + +TEST_CASE("zipfs.lfh_offset_past_buffer") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint32_t>(Buf, CdOffset(Buf) + kCdrOffLfhOff, 0xFFFFFFFF); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + +TEST_CASE("zipfs.compressed_size_past_buffer") +{ + ZipBuilder Zip; + Zip.AddFile("test.txt", "x", 1, false); + zen::IoBuffer Buf = Zip.Build(); + + WriteAt<uint32_t>(Buf, CdOffset(Buf) + kCdrOffCompSz, 0xFFFF0000u); + + zen::ZipFs Fs(std::move(Buf)); + CHECK(!Fs.GetFile("test.txt")); +} + TEST_SUITE_END(); #endif // ZEN_WITH_TESTS diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp index 812536074..c7c2b0023 100644 --- a/src/zenserver/frontend/frontend.cpp +++ b/src/zenserver/frontend/frontend.cpp @@ -143,12 +143,6 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) Uri = UriBuilder; } - // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads - if (Uri.find("..") != Uri.npos) - { - return Request.WriteResponse(HttpResponseCode::Forbidden); - } - // Map the file extension to a MIME type. To keep things constrained, only a // small subset of file extensions is allowed @@ -184,28 +178,32 @@ HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) constexpr std::string_view DataPrefix = "data/"; if (!m_DocsDirectory.empty() && InUri.starts_with(DataPrefix)) { - std::string_view DocsRelative = InUri.substr(DataPrefix.size()); - auto FullPath = m_DocsDirectory / std::filesystem::path(DocsRelative).make_preferred(); - FileContents File = ReadFile(FullPath); - - if (!File.ErrorCode) + const std::string_view DocsRelative = InUri.substr(DataPrefix.size()); + if (std::optional<std::filesystem::path> FullPath = ResolveSafeRelativePath(m_DocsDirectory, DocsRelative)) { - Request.WriteResponse(ResponseCode, ContentType, File.Data[0]); - return true; + FileContents File = ReadFile(*FullPath); + + if (!File.ErrorCode) + { + Request.WriteResponse(ResponseCode, ContentType, File.Data[0]); + return true; + } } } // The given content directory overrides any zip-fs discovered in the binary if (!m_Directory.empty()) { - auto FullPath = m_Directory / std::filesystem::path(InUri).make_preferred(); - FileContents File = ReadFile(FullPath); - - if (!File.ErrorCode) + if (std::optional<std::filesystem::path> FullPath = ResolveSafeRelativePath(m_Directory, InUri)) { - Request.WriteResponse(ResponseCode, ContentType, File.Data[0]); + FileContents File = ReadFile(*FullPath); - return true; + if (!File.ErrorCode) + { + Request.WriteResponse(ResponseCode, ContentType, File.Data[0]); + + return true; + } } } diff --git a/src/zenserver/sessions/httpsessions.cpp b/src/zenserver/sessions/httpsessions.cpp index 2276cb81a..88db36828 100644 --- a/src/zenserver/sessions/httpsessions.cpp +++ b/src/zenserver/sessions/httpsessions.cpp @@ -5,6 +5,7 @@ #include <zencore/compactbinarybuilder.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> +#include <zencore/string.h> #include <zencore/trace.h> #include "sessions.h" @@ -470,9 +471,14 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req) std::string_view CursorStr = Params.GetValue("cursor"sv); if (!CursorStr.empty()) { - uint64_t AfterCursor = std::strtoull(std::string(CursorStr).c_str(), nullptr, 10); + const std::optional<uint64_t> AfterCursor = ParseInt<uint64_t>(CursorStr); + if (!AfterCursor) + { + m_SessionsStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'cursor' parameter"sv); + } - SessionsService::Session::CursorResult Result = Session->GetLogEntriesAfter(AfterCursor); + SessionsService::Session::CursorResult Result = Session->GetLogEntriesAfter(*AfterCursor); CbObjectWriter Response; Response << "cursor" << Result.Cursor; @@ -495,11 +501,23 @@ HttpSessionsService::SessionLogRequest(HttpRouterRequest& Req) if (std::string_view LimitStr = Params.GetValue("limit"sv); !LimitStr.empty()) { - Limit = uint32_t(std::strtoul(std::string(LimitStr).c_str(), nullptr, 10)); + const std::optional<uint32_t> Parsed = ParseInt<uint32_t>(LimitStr); + if (!Parsed) + { + m_SessionsStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'limit' parameter"sv); + } + Limit = *Parsed; } if (std::string_view OffsetStr = Params.GetValue("offset"sv); !OffsetStr.empty()) { - Offset = uint32_t(std::strtoul(std::string(OffsetStr).c_str(), nullptr, 10)); + const std::optional<uint32_t> Parsed = ParseInt<uint32_t>(OffsetStr); + if (!Parsed) + { + m_SessionsStats.BadRequestCount++; + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'offset' parameter"sv); + } + Offset = *Parsed; } std::vector<SessionsService::LogEntry> Entries = Session->GetLogEntries(Limit, Offset); diff --git a/src/zenserver/storage/admin/admin.cpp b/src/zenserver/storage/admin/admin.cpp index 34d9e570e..1de5f74fe 100644 --- a/src/zenserver/storage/admin/admin.cpp +++ b/src/zenserver/storage/admin/admin.cpp @@ -26,6 +26,60 @@ namespace zen { +#if ZEN_WITH_TRACE +namespace { + // Accept only loopback destinations for admin-triggered trace streams. Handles + // "localhost", "127.0.0.1", "::1", and bracketed IPv6 ("[::1]"), each optionally + // followed by ":<digits>". Rejects any control characters so the value is also + // safe to log. + bool IsLoopbackTraceHost(std::string_view Host) + { + if (Host.empty()) + { + return false; + } + for (char C : Host) + { + if (static_cast<unsigned char>(C) < 0x20 || C == 0x7F) + { + return false; + } + } + + std::string_view HostOnly = Host; + if (HostOnly.front() == '[') + { + const size_t Close = HostOnly.find(']'); + if (Close == std::string_view::npos) + { + return false; + } + const std::string_view Tail = HostOnly.substr(Close + 1); + if (!Tail.empty()) + { + if (Tail.front() != ':' || Tail.size() < 2 || Tail.find_first_not_of("0123456789", 1) != std::string_view::npos) + { + return false; + } + } + HostOnly = HostOnly.substr(1, Close - 1); + } + else if (const size_t Colon = HostOnly.find(':'); + Colon != std::string_view::npos && HostOnly.find(':', Colon + 1) == std::string_view::npos) + { + const std::string_view Port = HostOnly.substr(Colon + 1); + if (Port.empty() || Port.find_first_not_of("0123456789") != std::string_view::npos) + { + return false; + } + HostOnly = HostOnly.substr(0, Colon); + } + + return HostOnly == "localhost" || HostOnly == "127.0.0.1" || HostOnly == "::1"; + } +} // namespace +#endif // ZEN_WITH_TRACE + struct DirStats { uint64_t FileCount = 0; @@ -149,17 +203,13 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, [&](HttpRouterRequest& Req) { const auto& JobIdString = Req.GetCapture(1); std::optional<uint64_t> JobIdArg = ParseInt<uint64_t>(JobIdString); - if (!JobIdArg) - { - Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); - } - JobId Id{.Id = JobIdArg.value_or(0)}; - if (Id.Id == 0) + if (!JobIdArg || JobIdArg.value() == 0) { return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest, ZenContentType::kText, - fmt::format("Invalid Job Id: {}", Id.Id)); + fmt::format("Invalid Job Id: '{}'", JobIdString)); } + const JobId Id{.Id = JobIdArg.value()}; std::optional<JobQueue::JobDetails> CurrentState = m_BackgroundJobQueue.Get(Id); if (!CurrentState) @@ -271,11 +321,13 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, [&](HttpRouterRequest& Req) { const auto& JobIdString = Req.GetCapture(1); std::optional<uint64_t> JobIdArg = ParseInt<uint64_t>(JobIdString); - if (!JobIdArg) + if (!JobIdArg || JobIdArg.value() == 0) { - Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest, + ZenContentType::kText, + fmt::format("Invalid Job Id: '{}'", JobIdString)); } - JobId Id{.Id = JobIdArg.value_or(0)}; + const JobId Id{.Id = JobIdArg.value()}; if (m_BackgroundJobQueue.CancelJob(Id)) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); @@ -610,11 +662,6 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, const HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); TraceOptions TraceOptions; - if (!IsTracing()) - { - TraceInit("zenserver"); - } - if (auto Channels = Params.GetValue("channels"); Channels.empty() == false) { TraceOptions.Channels = Channels; @@ -622,22 +669,41 @@ HttpAdminService::HttpAdminService(GcScheduler& Scheduler, if (auto File = Params.GetValue("file"); File.empty() == false) { - TraceOptions.File = File; + const std::filesystem::path TracesRoot = m_ServerOptions.DataDir / "traces"; + std::optional<std::filesystem::path> Resolved = ResolveSafeRelativePath(TracesRoot, File); + if (!Resolved) + { + ZEN_WARN("admin trace/start rejected unsafe 'file' parameter '{}'", File); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'file' parameter"sv); + } + TraceOptions.File = Resolved->string(); } else if (auto Host = Params.GetValue("host"); Host.empty() == false) { + if (!IsLoopbackTraceHost(Host)) + { + ZEN_WARN("admin trace/start rejected non-loopback 'host' parameter '{}'", Host); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid 'host' parameter (must be a loopback address)"sv); + } TraceOptions.Host = Host; } else { - return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest, - HttpContentType::kText, - "Invalid trace type, use `file` or `host`"sv); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid trace type, use `file` or `host`"sv); + } + + if (!IsTracing()) + { + TraceInit("zenserver"); } TraceConfigure(TraceOptions); - return Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "Tracing started"); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "Tracing started"); }, HttpVerb::kPost); diff --git a/src/zenserver/storage/buildstore/httpbuildstore.cpp b/src/zenserver/storage/buildstore/httpbuildstore.cpp index f935e2c6b..bdaaef327 100644 --- a/src/zenserver/storage/buildstore/httpbuildstore.cpp +++ b/src/zenserver/storage/buildstore/httpbuildstore.cpp @@ -138,8 +138,7 @@ HttpBuildStoreService::PutBlobRequest(HttpRouterRequest& Req) HttpContentType::kText, fmt::format("Payload blob {} content type {} is invalid", Hash, ToString(Payload.GetContentType()))); } - m_BuildStore.PutBlob(BlobHash, ServerRequest.ReadPayload()); - // ZEN_INFO("Stored blob {}. Size: {}", BlobHash, ServerRequest.ReadPayload().GetSize()); + m_BuildStore.PutBlob(BlobHash, std::move(Payload)); return ServerRequest.WriteResponse(HttpResponseCode::OK); } diff --git a/src/zenserver/storage/cache/httpstructuredcache.cpp b/src/zenserver/storage/cache/httpstructuredcache.cpp index 4d3673e70..7ca496dec 100644 --- a/src/zenserver/storage/cache/httpstructuredcache.cpp +++ b/src/zenserver/storage/cache/httpstructuredcache.cpp @@ -70,6 +70,7 @@ namespace { static constinit std::string_view HttpZCacheUtilStopRecording = "exec$/stop-recording"sv; static constinit std::string_view HttpZCacheUtilReplayRecording = "exec$/replay-recording"sv; static constinit std::string_view HttpZCacheDetailsPrefix = "details$"sv; + static constinit std::string_view HttpZCacheRecordingsDirName = "recordings"sv; } // namespace ////////////////////////////////////////////////////////////////////////// @@ -396,17 +397,26 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) { HttpServerRequest::QueryParams Params = Request.GetQueryParams(); - std::string RecordPath = UrlDecode(Params.GetValue("path")); + const std::string RecordPath = UrlDecode(Params.GetValue("path")); + + const std::filesystem::path RecordingsRoot = m_CacheStore.GetBasePath() / HttpZCacheRecordingsDirName; + std::optional<std::filesystem::path> ResolvedPath = ResolveSafeRelativePath(RecordingsRoot, RecordPath); + if (!ResolvedPath) + { + m_CacheStats.BadRequestCount++; + ZEN_WARN("cache RPC start-recording rejected unsafe path '{}'", RecordPath); + return Request.WriteResponse(HttpResponseCode::BadRequest); + } { RwLock::ExclusiveLockScope _(m_RequestRecordingLock); m_RequestRecordingEnabled.store(false); m_RequestRecorder.reset(); - m_RequestRecorder = cache::MakeDiskRequestRecorder(RecordPath); + m_RequestRecorder = cache::MakeDiskRequestRecorder(*ResolvedPath); m_RequestRecordingEnabled.store(true); } - ZEN_INFO("cache RPC recording STARTED -> '{}'", RecordPath); + ZEN_INFO("cache RPC recording STARTED -> '{}'", *ResolvedPath); Request.WriteResponse(HttpResponseCode::OK); return; } @@ -435,7 +445,16 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) HttpServerRequest::QueryParams Params = Request.GetQueryParams(); - std::string RecordPath = UrlDecode(Params.GetValue("path")); + const std::string RecordPath = UrlDecode(Params.GetValue("path")); + + const std::filesystem::path RecordingsRoot = m_CacheStore.GetBasePath() / HttpZCacheRecordingsDirName; + std::optional<std::filesystem::path> ResolvedPath = ResolveSafeRelativePath(RecordingsRoot, RecordPath); + if (!ResolvedPath) + { + m_CacheStats.BadRequestCount++; + ZEN_WARN("cache RPC replay-recording rejected unsafe path '{}'", RecordPath); + return Request.WriteResponse(HttpResponseCode::BadRequest); + } const uint32_t HardwareConcurrency = GetHardwareConcurrency(); const uint32_t MaxThreadCount = std::max<uint32_t>(HardwareConcurrency, 16u); @@ -449,9 +468,9 @@ HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) } ThreadCount = std::clamp<uint32_t>(ThreadCount, 1u, MaxThreadCount); - ZEN_INFO("initiating cache RPC replay using {} threads, from '{}'", ThreadCount, RecordPath); + ZEN_INFO("initiating cache RPC replay using {} threads, from '{}'", ThreadCount, *ResolvedPath); - std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(RecordPath, false)); + std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(*ResolvedPath, false)); ReplayRequestRecorder(RequestContext, *Replayer, ThreadCount); ZEN_INFO("cache RPC replay COMPLETED"); diff --git a/src/zenserver/storage/objectstore/objectstore.cpp b/src/zenserver/storage/objectstore/objectstore.cpp index 1115c1cd6..252a381ae 100644 --- a/src/zenserver/storage/objectstore/objectstore.cpp +++ b/src/zenserver/storage/objectstore/objectstore.cpp @@ -29,6 +29,15 @@ using namespace std::literals; ZEN_DEFINE_LOG_CATEGORY_STATIC(LogObj, "obj"sv); +namespace { + // Permitted bucket-name characters. Must stay in sync with the "bucket" URL matcher + // registered in HttpObjectStoreService::Initialize so POST / DELETE / PUT with an + // explicit bucketname payload uses the same rule as the route matcher. + constexpr AsciiSet ValidBucketCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + + bool IsValidBucketName(std::string_view Name) { return !Name.empty() && AsciiSet::HasOnly(Name, ValidBucketCharactersSet); } +} // namespace + class CbXmlWriter { public: @@ -302,12 +311,10 @@ HttpObjectStoreService::Initialize() } static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]() ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; - static constexpr AsciiSet ValidBucketCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; m_Router.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); }); - m_Router.AddMatcher("bucket", - [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidBucketCharactersSet); }); + m_Router.AddMatcher("bucket", [](std::string_view Str) -> bool { return IsValidBucketName(Str); }); m_Router.RegisterRoute( "", @@ -471,8 +478,9 @@ HttpObjectStoreService::CreateBucket(HttpRouterRequest& Request) const CbObject Params = Request.ServerRequest().ReadPayloadObject(); const std::string_view BucketName = Params["bucketname"].AsString(); - if (BucketName.empty()) + if (!IsValidBucketName(BucketName)) { + ZEN_LOG_WARN(LogObj, "CREATE - rejected invalid bucket name '{}'", BucketName); return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); } @@ -515,9 +523,21 @@ HttpObjectStoreService::ListBucket(HttpRouterRequest& Request, const std::string BucketPrefix.erase(0, BucketPrefix.find_first_not_of('/')); BucketPrefix.erase(0, BucketPrefix.find_first_not_of('\\')); - const fs::path BucketRoot = GetBucketDirectory(BucketName); - const fs::path RelativeBucketPath = fs::path(BucketPrefix).make_preferred(); - const fs::path FullPath = BucketRoot / RelativeBucketPath; + const fs::path BucketRoot = GetBucketDirectory(BucketName); + + fs::path RelativeBucketPath; + fs::path FullPath = BucketRoot; + if (!BucketPrefix.empty()) + { + std::optional<fs::path> Resolved = ResolveSafeRelativePath(BucketRoot, BucketPrefix); + if (!Resolved) + { + ZEN_LOG_WARN(LogObj, "LIST - bucket '{}' rejected unsafe prefix '{}'", BucketName, BucketPrefix); + return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); + } + FullPath = std::move(*Resolved); + RelativeBucketPath = fs::relative(FullPath, BucketRoot); + } struct Visitor : FileSystemTraversal::TreeVisitor { @@ -589,8 +609,9 @@ HttpObjectStoreService::DeleteBucket(HttpRouterRequest& Request) const CbObject Params = Request.ServerRequest().ReadPayloadObject(); const std::string_view BucketName = Params["bucketname"].AsString(); - if (BucketName.empty()) + if (!IsValidBucketName(BucketName)) { + ZEN_LOG_WARN(LogObj, "DELETE - rejected invalid bucket name '{}'", BucketName); return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); } @@ -621,15 +642,14 @@ HttpObjectStoreService::GetObject(HttpRouterRequest& Request, const std::string_ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); } - const fs::path RelativeBucketPath = fs::path(BucketPrefix).make_preferred(); - - if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with("..")) + std::optional<fs::path> ResolvedFilePath = ResolveSafeRelativePath(BucketDir, BucketPrefix); + if (!ResolvedFilePath) { - ZEN_LOG_DEBUG(LogObj, "GET - from bucket '{}' [FAILED], invalid file path", BucketName); + ZEN_LOG_WARN(LogObj, "GET - from bucket '{}' rejected unsafe path '{}'", BucketName, BucketPrefix); return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); } - - const fs::path FilePath = BucketDir / RelativeBucketPath; + const fs::path FilePath = std::move(*ResolvedFilePath); + const fs::path RelativeBucketPath = fs::relative(FilePath, BucketDir); if (!IsFile(FilePath)) { ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' [FAILED], doesn't exist", BucketName, FilePath); @@ -720,16 +740,17 @@ HttpObjectStoreService::PutObject(HttpRouterRequest& Request) return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); } - const fs::path RelativeBucketPath = fs::path(Request.GetCapture(2)).make_preferred(); + const std::string_view PathCapture = Request.GetCapture(2); - if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with("..")) + std::optional<fs::path> ResolvedFilePath = ResolveSafeRelativePath(BucketDir, PathCapture); + if (!ResolvedFilePath) { - ZEN_LOG_DEBUG(LogObj, "PUT - bucket '{}' [FAILED], invalid file path", BucketName); + ZEN_LOG_WARN(LogObj, "PUT - bucket '{}' rejected unsafe path '{}'", BucketName, PathCapture); return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); } - - const fs::path FilePath = BucketDir / RelativeBucketPath; - const fs::path FileDirectory = FilePath.parent_path(); + const fs::path FilePath = std::move(*ResolvedFilePath); + const fs::path RelativeBucketPath = fs::relative(FilePath, BucketDir); + const fs::path FileDirectory = FilePath.parent_path(); { std::lock_guard _(m_BucketsMutex); diff --git a/src/zenserver/storage/projectstore/httpprojectstore.cpp b/src/zenserver/storage/projectstore/httpprojectstore.cpp index c40690d5f..9c94a0381 100644 --- a/src/zenserver/storage/projectstore/httpprojectstore.cpp +++ b/src/zenserver/storage/projectstore/httpprojectstore.cpp @@ -1000,10 +1000,10 @@ HttpProjectService::HandleChunkBatchRequest(HttpRouterRequest& Req) uint64_t RequestBytes; }; - if (Payload.Size() <= sizeof(RequestHeader)) + if (Payload.Size() < sizeof(RequestHeader)) { m_ProjectStats.BadRequestCount++; - HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } RequestHeader RequestHdr; @@ -1012,7 +1012,21 @@ HttpProjectService::HandleChunkBatchRequest(HttpRouterRequest& Req) if (RequestHdr.Magic != RequestHeader::kMagic) { m_ProjectStats.BadRequestCount++; - HttpReq.WriteResponse(HttpResponseCode::BadRequest); + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + constexpr uint32_t kMaxChunkCount = 1u << 16; + if (RequestHdr.ChunkCount > kMaxChunkCount) + { + m_ProjectStats.BadRequestCount++; + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + const uint64_t ExpectedChunkBytes = uint64_t(RequestHdr.ChunkCount) * sizeof(RequestChunkEntry); + if (ExpectedChunkBytes > Reader.Remaining()) + { + m_ProjectStats.BadRequestCount++; + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); } std::vector<RequestChunkEntry> RequestedChunks; @@ -1610,9 +1624,14 @@ HttpProjectService::HandleOplogOpNewRequest(HttpRouterRequest& Req) if (std::string_view SaltParam = Params.GetValue("salt"sv); SaltParam.empty() == false) { - const uint32_t Salt = std::stoi(std::string(SaltParam)); - SaltHash = IoHash::HashBuffer(&Salt, sizeof Salt); - IsUsingSalt = true; + const std::optional<uint32_t> Salt = ParseInt<uint32_t>(SaltParam); + if (!Salt) + { + m_ProjectStats.BadRequestCount++; + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'salt' parameter"sv); + } + SaltHash = IoHash::HashBuffer(&Salt.value(), sizeof(uint32_t)); + IsUsingSalt = true; } Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); diff --git a/src/zenserver/storage/workspaces/httpworkspaces.cpp b/src/zenserver/storage/workspaces/httpworkspaces.cpp index 12e7bae73..ba3bc00dd 100644 --- a/src/zenserver/storage/workspaces/httpworkspaces.cpp +++ b/src/zenserver/storage/workspaces/httpworkspaces.cpp @@ -4,6 +4,7 @@ #include <zencore/basicfile.h> #include <zencore/compactbinarybuilder.h> +#include <zencore/filesystem.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/trace.h> @@ -29,6 +30,48 @@ namespace { return {}; } + // Validate a workspace root_path supplied via HTTP. Rejects empty / non-absolute + // paths, Windows UNC (\\server\share) and device-namespace prefixes (\\?\, \\.\), + // and strings containing control characters. Canonicalises the result so any later + // joins and stored config anchor at a resolved, existing directory — a follow-up + // symlink swap on disk can no longer redirect the workspace root. + std::optional<std::filesystem::path> ValidateWorkspaceRootPath(std::string_view RawInput) + { + if (RawInput.empty()) + { + return std::nullopt; + } + for (char C : RawInput) + { + if (static_cast<unsigned char>(C) < 0x20 || C == 0x7F) + { + return std::nullopt; + } + } + if (RawInput.starts_with("\\\\") || RawInput.starts_with("//")) + { + return std::nullopt; + } + + std::filesystem::path Requested(RawInput); + if (!Requested.is_absolute()) + { + return std::nullopt; + } + + std::error_code Ec; + std::filesystem::path Canonical = std::filesystem::canonical(Requested, Ec); + if (Ec) + { + return std::nullopt; + } + if (!std::filesystem::is_directory(Canonical, Ec) || Ec) + { + return std::nullopt; + } + return Canonical; + } + void WriteWorkspaceConfig(CbWriter& Writer, const Workspaces::WorkspaceConfiguration& Config) { Writer << "id" << Config.Id; @@ -505,14 +548,17 @@ HttpWorkspacesService::WorkspaceRequest(HttpRouterRequest& Req) { case HttpVerb::kPut: { - std::filesystem::path WorkspacePath = GetPathParameter(ServerRequest, "root_path"sv); - if (WorkspacePath.empty()) + const std::string RawRootPath = HttpServerRequest::Decode(ServerRequest.GetQueryParams().GetValue("root_path"sv)); + std::optional<std::filesystem::path> ValidatedRootPath = ValidateWorkspaceRootPath(RawRootPath); + if (!ValidatedRootPath) { m_WorkspacesStats.BadRequestCount++; + ZEN_WARN("workspace PUT rejected unsafe 'root_path' parameter '{}'", RawRootPath); return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid 'root_path' parameter"); } + std::filesystem::path WorkspacePath = std::move(*ValidatedRootPath); if (Req.GetCapture(1) == Oid::Zero.ToString()) { @@ -1096,6 +1142,16 @@ HttpWorkspacesService::ShareRequest(HttpRouterRequest& Req, const Oid& Workspace fmt::format("Workspace '{}' does not exist", WorkspaceId)); } + std::optional<std::filesystem::path> ResolvedSharePath = ResolveSafeRelativePath(Workspace.RootPath, SharePath.string()); + if (!ResolvedSharePath) + { + m_WorkspacesStats.BadRequestCount++; + ZEN_WARN("share PUT in workspace '{}' rejected unsafe 'share_path' parameter '{}'", WorkspaceId, SharePath); + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Invalid 'share_path' parameter"); + } + if (!Workspace.AllowShareCreationFromHttp) { if (!MayChangeConfiguration(ServerRequest)) @@ -1143,7 +1199,7 @@ HttpWorkspacesService::ShareRequest(HttpRouterRequest& Req, const Oid& Workspace } } - if (!IsDir(Workspace.RootPath / NewConfig.SharePath)) + if (!IsDir(*ResolvedSharePath)) { return ServerRequest.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, diff --git a/src/zenstore/include/zenstore/cache/structuredcachestore.h b/src/zenstore/include/zenstore/cache/structuredcachestore.h index 9e0561364..425dd4b79 100644 --- a/src/zenstore/include/zenstore/cache/structuredcachestore.h +++ b/src/zenstore/include/zenstore/cache/structuredcachestore.h @@ -274,6 +274,7 @@ public: Configuration GetConfiguration() const { return m_Configuration; } void SetLoggingConfig(const Configuration::LogConfig& Loggingconfig); + const std::filesystem::path& GetBasePath() const { return m_BasePath; } Info GetInfo() const; std::optional<ZenCacheNamespace::Info> GetNamespaceInfo(std::string_view Namespace); std::optional<ZenCacheNamespace::BucketInfo> GetBucketInfo(std::string_view Namespace, std::string_view Bucket); diff --git a/src/zenutil/include/zenutil/process/subprocessmanager.h b/src/zenutil/include/zenutil/process/subprocessmanager.h index 95d7fa43d..e8b4a700c 100644 --- a/src/zenutil/include/zenutil/process/subprocessmanager.h +++ b/src/zenutil/include/zenutil/process/subprocessmanager.h @@ -156,7 +156,7 @@ private: friend class ProcessGroup; struct Impl; - std::unique_ptr<Impl> m_Impl; + std::shared_ptr<Impl> m_Impl; }; /// A process managed by SubprocessManager. diff --git a/src/zenutil/process/asyncpipereader.cpp b/src/zenutil/process/asyncpipereader.cpp index 8eac350c6..7d603aa3f 100644 --- a/src/zenutil/process/asyncpipereader.cpp +++ b/src/zenutil/process/asyncpipereader.cpp @@ -31,7 +31,7 @@ static constexpr size_t kReadBufferSize = 4096; #if !ZEN_PLATFORM_WINDOWS -struct AsyncPipeReader::Impl +struct AsyncPipeReader::Impl : public TRefCounted<Impl> { asio::io_context& m_IoContext; std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor; @@ -82,22 +82,25 @@ struct AsyncPipeReader::Impl return; } - m_Descriptor->async_read_some(asio::buffer(m_Buffer), [this](const asio::error_code& Ec, size_t BytesRead) { + // Capture Self so the Impl outlives any completion the io_context has + // already picked up but not yet dispatched, even if the owning + // AsyncPipeReader is destroyed in the meantime. + m_Descriptor->async_read_some(asio::buffer(m_Buffer), [Self = Ref<Impl>(this)](const asio::error_code& Ec, size_t BytesRead) { if (Ec) { - if (Ec != asio::error::operation_aborted && m_EofCallback) + if (Ec != asio::error::operation_aborted && Self->m_EofCallback) { - m_EofCallback(); + Self->m_EofCallback(); } return; } - if (BytesRead > 0 && m_DataCallback) + if (BytesRead > 0 && Self->m_DataCallback) { - m_DataCallback(std::string_view(m_Buffer.data(), BytesRead)); + Self->m_DataCallback(std::string_view(Self->m_Buffer.data(), BytesRead)); } - EnqueueRead(); + Self->EnqueueRead(); }); } }; @@ -183,7 +186,7 @@ CreateOverlappedStdoutPipe(StdoutPipeHandles& OutPipe) return true; } -struct AsyncPipeReader::Impl +struct AsyncPipeReader::Impl : public TRefCounted<Impl> { asio::io_context& m_IoContext; std::unique_ptr<asio::windows::stream_handle> m_StreamHandle; @@ -229,22 +232,25 @@ struct AsyncPipeReader::Impl return; } - m_StreamHandle->async_read_some(asio::buffer(m_Buffer), [this](const asio::error_code& Ec, size_t BytesRead) { + // Capture Self so the Impl outlives any completion the io_context has + // already picked up but not yet dispatched, even if the owning + // AsyncPipeReader is destroyed in the meantime. + m_StreamHandle->async_read_some(asio::buffer(m_Buffer), [Self = Ref<Impl>(this)](const asio::error_code& Ec, size_t BytesRead) { if (Ec) { - if (Ec != asio::error::operation_aborted && m_EofCallback) + if (Ec != asio::error::operation_aborted && Self->m_EofCallback) { - m_EofCallback(); + Self->m_EofCallback(); } return; } - if (BytesRead > 0 && m_DataCallback) + if (BytesRead > 0 && Self->m_DataCallback) { - m_DataCallback(std::string_view(m_Buffer.data(), BytesRead)); + Self->m_DataCallback(std::string_view(Self->m_Buffer.data(), BytesRead)); } - EnqueueRead(); + Self->EnqueueRead(); }); } }; @@ -255,11 +261,21 @@ struct AsyncPipeReader::Impl // Common wrapper // ============================================================================ -AsyncPipeReader::AsyncPipeReader(asio::io_context& IoContext) : m_Impl(std::make_unique<Impl>(IoContext)) +AsyncPipeReader::AsyncPipeReader(asio::io_context& IoContext) : m_Impl(new Impl(IoContext)) { } -AsyncPipeReader::~AsyncPipeReader() = default; +AsyncPipeReader::~AsyncPipeReader() +{ + // Explicitly stop pending async reads. The Impl may outlive this call if a + // completion is still in the io_context queue (the handler holds a strong + // Ref back to Impl to keep it alive). Stop() here guarantees reads stop + // even if nobody has called Stop() on the wrapper. + if (m_Impl) + { + m_Impl->Stop(); + } +} void AsyncPipeReader::Start(StdoutPipeHandles&& Pipe, std::function<void(std::string_view)> DataCallback, std::function<void()> EofCallback) diff --git a/src/zenutil/process/asyncpipereader.h b/src/zenutil/process/asyncpipereader.h index ad2ff8455..3e4be6906 100644 --- a/src/zenutil/process/asyncpipereader.h +++ b/src/zenutil/process/asyncpipereader.h @@ -2,11 +2,11 @@ #pragma once +#include <zenbase/refcount.h> #include <zencore/process.h> #include <zencore/zencore.h> #include <functional> -#include <memory> #include <string_view> namespace asio { @@ -56,7 +56,7 @@ public: private: struct Impl; - std::unique_ptr<Impl> m_Impl; + Ref<Impl> m_Impl; }; } // namespace zen diff --git a/src/zenutil/process/exitwatcher.cpp b/src/zenutil/process/exitwatcher.cpp index cef31ebca..2e88dfdeb 100644 --- a/src/zenutil/process/exitwatcher.cpp +++ b/src/zenutil/process/exitwatcher.cpp @@ -38,7 +38,7 @@ namespace zen { #if ZEN_PLATFORM_LINUX -struct ProcessExitWatcher::Impl +struct ProcessExitWatcher::Impl : public TRefCounted<Impl> { asio::io_context& m_IoContext; std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor; @@ -64,8 +64,11 @@ struct ProcessExitWatcher::Impl m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, m_PidFd); + // Capture a strong Ref so the Impl outlives the pending async_wait even + // if the owning ProcessExitWatcher is destroyed while a completion is + // in flight on the io_context. m_Descriptor->async_wait(asio::posix::stream_descriptor::wait_read, - [this, Callback = std::move(OnExit)](const asio::error_code& Ec) { + [Self = Ref<Impl>(this), Callback = std::move(OnExit)](const asio::error_code& Ec) { if (Ec) { return; // Cancelled or error @@ -74,7 +77,7 @@ struct ProcessExitWatcher::Impl int ExitCode = -1; int Status = 0; // The pidfd told us the process exited. Reap it with waitpid. - if (waitpid(m_Pid, &Status, WNOHANG) > 0) + if (waitpid(Self->m_Pid, &Status, WNOHANG) > 0) { if (WIFEXITED(Status)) { @@ -115,7 +118,7 @@ struct ProcessExitWatcher::Impl #elif ZEN_PLATFORM_WINDOWS -struct ProcessExitWatcher::Impl +struct ProcessExitWatcher::Impl : public TRefCounted<Impl> { asio::io_context& m_IoContext; std::unique_ptr<asio::windows::object_handle> m_ObjectHandle; @@ -147,16 +150,21 @@ struct ProcessExitWatcher::Impl // object_handle takes ownership of the handle m_ObjectHandle = std::make_unique<asio::windows::object_handle>(m_IoContext, m_DuplicatedHandle); - m_ObjectHandle->async_wait([this, DupHandle = m_DuplicatedHandle, Callback = std::move(OnExit)](const asio::error_code& Ec) { - if (Ec) - { - return; - } - - DWORD ExitCode = 0; - GetExitCodeProcess(static_cast<HANDLE>(DupHandle), &ExitCode); - Callback(static_cast<int>(ExitCode)); - }); + // Capture a strong Ref so the duplicated handle (owned by m_ObjectHandle, + // which is owned by *this) cannot be closed out from under a completion + // that the io_context is still about to dispatch. + m_ObjectHandle->async_wait( + [Self = Ref<Impl>(this), DupHandle = m_DuplicatedHandle, Callback = std::move(OnExit)](const asio::error_code& Ec) { + (void)Self; + if (Ec) + { + return; + } + + DWORD ExitCode = 0; + GetExitCodeProcess(static_cast<HANDLE>(DupHandle), &ExitCode); + Callback(static_cast<int>(ExitCode)); + }); } void Cancel() @@ -182,7 +190,7 @@ struct ProcessExitWatcher::Impl #elif ZEN_PLATFORM_MAC -struct ProcessExitWatcher::Impl +struct ProcessExitWatcher::Impl : public TRefCounted<Impl> { asio::io_context& m_IoContext; std::unique_ptr<asio::posix::stream_descriptor> m_Descriptor; @@ -218,8 +226,11 @@ struct ProcessExitWatcher::Impl m_Descriptor = std::make_unique<asio::posix::stream_descriptor>(m_IoContext, m_KqueueFd); + // Capture a strong Ref so the Impl outlives the pending async_wait even + // if the owning ProcessExitWatcher is destroyed while a completion is + // in flight on the io_context. m_Descriptor->async_wait(asio::posix::stream_descriptor::wait_read, - [this, Callback = std::move(OnExit)](const asio::error_code& Ec) { + [Self = Ref<Impl>(this), Callback = std::move(OnExit)](const asio::error_code& Ec) { if (Ec) { return; @@ -228,11 +239,11 @@ struct ProcessExitWatcher::Impl // Drain the kqueue event struct kevent Event; struct timespec Timeout = {0, 0}; - kevent(m_KqueueFd, nullptr, 0, &Event, 1, &Timeout); + kevent(Self->m_KqueueFd, nullptr, 0, &Event, 1, &Timeout); int ExitCode = -1; int Status = 0; - if (waitpid(m_Pid, &Status, WNOHANG) > 0) + if (waitpid(Self->m_Pid, &Status, WNOHANG) > 0) { if (WIFEXITED(Status)) { @@ -273,11 +284,21 @@ struct ProcessExitWatcher::Impl // Common wrapper (delegates to Impl) // ============================================================================ -ProcessExitWatcher::ProcessExitWatcher(asio::io_context& IoContext) : m_Impl(std::make_unique<Impl>(IoContext)) +ProcessExitWatcher::ProcessExitWatcher(asio::io_context& IoContext) : m_Impl(new Impl(IoContext)) { } -ProcessExitWatcher::~ProcessExitWatcher() = default; +ProcessExitWatcher::~ProcessExitWatcher() +{ + // Explicitly cancel pending async ops. The Impl may outlive this call if a + // completion is still in the io_context queue (the handler holds a strong + // Ref back to Impl to keep it alive). Cancel() here guarantees the watch + // stops even if nobody has called Cancel() on the wrapper. + if (m_Impl) + { + m_Impl->Cancel(); + } +} void ProcessExitWatcher::Watch(const ProcessHandle& Handle, std::function<void(int ExitCode)> OnExit) diff --git a/src/zenutil/process/exitwatcher.h b/src/zenutil/process/exitwatcher.h index 24906d7d0..c9b23368a 100644 --- a/src/zenutil/process/exitwatcher.h +++ b/src/zenutil/process/exitwatcher.h @@ -2,11 +2,11 @@ #pragma once +#include <zenbase/refcount.h> #include <zencore/process.h> #include <zencore/zencore.h> #include <functional> -#include <memory> namespace asio { class io_context; @@ -42,7 +42,7 @@ public: private: struct Impl; - std::unique_ptr<Impl> m_Impl; + Ref<Impl> m_Impl; }; } // namespace zen diff --git a/src/zenutil/process/subprocessmanager.cpp b/src/zenutil/process/subprocessmanager.cpp index d0b912a0d..acb518808 100644 --- a/src/zenutil/process/subprocessmanager.cpp +++ b/src/zenutil/process/subprocessmanager.cpp @@ -236,7 +236,7 @@ ManagedProcess::GetTag() const // SubprocessManager::Impl // ============================================================================ -struct SubprocessManager::Impl +struct SubprocessManager::Impl : public std::enable_shared_from_this<Impl> { asio::io_context& m_IoContext; SubprocessManagerConfig m_Config; @@ -308,7 +308,10 @@ SubprocessManager::Impl::Impl(asio::io_context& IoContext, SubprocessManagerConf if (m_Config.MetricsSampleIntervalMs > 0) { m_MetricsTimer = std::make_unique<asio::steady_timer>(IoContext); - EnqueueMetricsTimer(); + // Don't start the timer here: EnqueueMetricsTimer captures + // weak_from_this(), which requires the enclosing shared_ptr to + // already own this. The caller (SubprocessManager ctor) invokes + // EnqueueMetricsTimer() after the shared_ptr is established. } } @@ -381,8 +384,17 @@ SubprocessManager::Impl::SetupExitWatcher(ManagedProcess* Proc, ProcessExitCallb { int Pid = Proc->Pid(); - Proc->m_Impl->m_ExitWatcher.Watch(Proc->m_Impl->m_Handle, [this, Pid, Callback = std::move(OnExit)](int ExitCode) { - ManagedProcess* Found = FindProcess(Pid); + // Capture a weak_ptr so the handler safely no-ops if the manager is + // destroyed (or the process has been Remove()'d) before the exit + // completion is dispatched on the io_context. + Proc->m_Impl->m_ExitWatcher.Watch(Proc->m_Impl->m_Handle, [Self = weak_from_this(), Pid, Callback = std::move(OnExit)](int ExitCode) { + auto Locked = Self.lock(); + if (!Locked) + { + return; + } + + ManagedProcess* Found = Locked->FindProcess(Pid); if (Found) { @@ -399,17 +411,22 @@ SubprocessManager::Impl::SetupStdoutReader(ManagedProcess* Proc, StdoutPipeHandl Proc->m_Impl->m_StdoutReader = std::make_unique<AsyncPipeReader>(m_IoContext); Proc->m_Impl->m_StdoutReader->Start( std::move(Pipe), - [this, Pid](std::string_view Data) { - ManagedProcess* Found = FindProcess(Pid); + [Self = weak_from_this(), Pid](std::string_view Data) { + auto Locked = Self.lock(); + if (!Locked) + { + return; + } + ManagedProcess* Found = Locked->FindProcess(Pid); if (Found) { if (Found->m_Impl->m_StdoutCallback) { Found->m_Impl->m_StdoutCallback(*Found, Data); } - else if (m_DefaultStdoutCallback) + else if (Locked->m_DefaultStdoutCallback) { - m_DefaultStdoutCallback(*Found, Data); + Locked->m_DefaultStdoutCallback(*Found, Data); } else { @@ -427,17 +444,22 @@ SubprocessManager::Impl::SetupStderrReader(ManagedProcess* Proc, StdoutPipeHandl Proc->m_Impl->m_StderrReader = std::make_unique<AsyncPipeReader>(m_IoContext); Proc->m_Impl->m_StderrReader->Start( std::move(Pipe), - [this, Pid](std::string_view Data) { - ManagedProcess* Found = FindProcess(Pid); + [Self = weak_from_this(), Pid](std::string_view Data) { + auto Locked = Self.lock(); + if (!Locked) + { + return; + } + ManagedProcess* Found = Locked->FindProcess(Pid); if (Found) { if (Found->m_Impl->m_StderrCallback) { Found->m_Impl->m_StderrCallback(*Found, Data); } - else if (m_DefaultStderrCallback) + else if (Locked->m_DefaultStderrCallback) { - m_DefaultStderrCallback(*Found, Data); + Locked->m_DefaultStderrCallback(*Found, Data); } else { @@ -557,14 +579,19 @@ SubprocessManager::Impl::EnqueueMetricsTimer() } m_MetricsTimer->expires_after(std::chrono::milliseconds(m_Config.MetricsSampleIntervalMs)); - m_MetricsTimer->async_wait([this](const asio::error_code& Ec) { - if (Ec || !m_Running.load()) + m_MetricsTimer->async_wait([Self = weak_from_this()](const asio::error_code& Ec) { + auto Locked = Self.lock(); + if (!Locked) + { + return; + } + if (Ec || !Locked->m_Running.load()) { return; } - SampleBatch(); - EnqueueMetricsTimer(); + Locked->SampleBatch(); + Locked->EnqueueMetricsTimer(); }); } @@ -711,8 +738,11 @@ SubprocessManager::Impl::EnumerateGroups(std::function<void(const ProcessGroup&) // ============================================================================ SubprocessManager::SubprocessManager(asio::io_context& IoContext, SubprocessManagerConfig Config) -: m_Impl(std::make_unique<Impl>(IoContext, Config)) +: m_Impl(std::make_shared<Impl>(IoContext, Config)) { + // Start the metrics timer now that the shared_ptr owns the Impl - only + // then does weak_from_this() produce a valid weak_ptr for the handler. + m_Impl->EnqueueMetricsTimer(); } SubprocessManager::~SubprocessManager() = default; |