diff options
| author | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
|---|---|---|
| committer | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
| commit | d1abc50ee9d4fb72efc646e17decafea741caa34 (patch) | |
| tree | e4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zenhorde/hordecomputebuffer.cpp | |
| parent | Allow requests with invalid content-types unless specified in command line or... (diff) | |
| parent | updated chunk–block analyser (#818) (diff) | |
| download | zen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip | |
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zenhorde/hordecomputebuffer.cpp')
| -rw-r--r-- | src/zenhorde/hordecomputebuffer.cpp | 454 |
1 files changed, 454 insertions, 0 deletions
diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp new file mode 100644 index 000000000..0d032b5d5 --- /dev/null +++ b/src/zenhorde/hordecomputebuffer.cpp @@ -0,0 +1,454 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hordecomputebuffer.h" + +#include <algorithm> +#include <cassert> +#include <chrono> +#include <condition_variable> +#include <cstring> + +namespace zen::horde { + +// Simplified ring buffer implementation for in-process use only. +// Uses a single contiguous buffer with write/read cursors and +// mutex+condvar for synchronization. This is simpler than the UE version +// which uses lock-free atomics and shared memory, but sufficient for our +// use case where we're the initiator side of the compute protocol. + +struct ComputeBuffer::Detail : TRefCounted<Detail> +{ + std::vector<uint8_t> Data; + size_t NumChunks = 0; + size_t ChunkLength = 0; + + // Current write state + size_t WriteChunkIdx = 0; + size_t WriteOffset = 0; + bool WriteComplete = false; + + // Current read state + size_t ReadChunkIdx = 0; + size_t ReadOffset = 0; + bool Detached = false; + + // Per-chunk written length + std::vector<size_t> ChunkWrittenLength; + std::vector<bool> ChunkFinished; // Writer moved to next chunk + + std::mutex Mutex; + std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes + std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space + + bool HasWriter = false; + bool HasReader = false; + + uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; } + const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; } +}; + +// ComputeBuffer + +ComputeBuffer::ComputeBuffer() +{ +} +ComputeBuffer::~ComputeBuffer() +{ +} + +bool +ComputeBuffer::CreateNew(const Params& InParams) +{ + auto* NewDetail = new Detail(); + NewDetail->NumChunks = InParams.NumChunks; + NewDetail->ChunkLength = InParams.ChunkLength; + NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0); + NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0); + NewDetail->ChunkFinished.resize(InParams.NumChunks, false); + + m_Detail = NewDetail; + return true; +} + +void +ComputeBuffer::Close() +{ + m_Detail = nullptr; +} + +bool +ComputeBuffer::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +ComputeBufferReader +ComputeBuffer::CreateReader() +{ + assert(m_Detail); + m_Detail->HasReader = true; + return ComputeBufferReader(m_Detail); +} + +ComputeBufferWriter +ComputeBuffer::CreateWriter() +{ + assert(m_Detail); + m_Detail->HasWriter = true; + return ComputeBufferWriter(m_Detail); +} + +// ComputeBufferReader + +ComputeBufferReader::ComputeBufferReader() +{ +} +ComputeBufferReader::~ComputeBufferReader() +{ +} + +ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default; +ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default; +ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default; +ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default; + +ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferReader::Close() +{ + m_Detail = nullptr; +} + +void +ComputeBufferReader::Detach() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->Detached = true; + m_Detail->ReadCV.notify_all(); + } +} + +bool +ComputeBufferReader::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +bool +ComputeBufferReader::IsComplete() const +{ + if (!m_Detail) + { + return true; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (m_Detail->Detached) + { + return true; + } + return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx && + m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx]; +} + +void +ComputeBufferReader::AdvanceReadPosition(size_t Size) +{ + if (!m_Detail) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + + m_Detail->ReadOffset += Size; + + // Check if we need to move to next chunk + const size_t ReadChunk = m_Detail->ReadChunkIdx; + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + } + + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferReader::GetMaxReadSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t ReadChunk = m_Detail->ReadChunkIdx; + return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; +} + +const uint8_t* +ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + auto Predicate = [&]() -> bool { + if (m_Detail->Detached) + { + return true; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available >= MinSize) + { + return true; + } + + // If chunk is finished and we've read everything, try to move to next + if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk]) + { + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + // Move to next chunk + const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks; + m_Detail->ReadChunkIdx = NextChunk; + m_Detail->ReadOffset = 0; + m_Detail->WriteCV.notify_all(); + return false; // Re-check with new chunk + } + + if (m_Detail->WriteComplete) + { + return true; // End of stream + } + + return false; + }; + + if (TimeoutMs < 0) + { + m_Detail->ReadCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + if (OutTimedOut) + { + *OutTimedOut = true; + } + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + const size_t ReadChunk = m_Detail->ReadChunkIdx; + const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset; + + if (Available < MinSize) + { + return nullptr; // End of stream + } + + return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset; +} + +size_t +ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut) +{ + const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut); + if (!Data) + { + return 0; + } + + const size_t Available = GetMaxReadSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Buffer, Data, ToCopy); + AdvanceReadPosition(ToCopy); + return ToCopy; +} + +// ComputeBufferWriter + +ComputeBufferWriter::ComputeBufferWriter() = default; +ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default; +ComputeBufferWriter::~ComputeBufferWriter() = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default; +ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default; + +ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail)) +{ +} + +void +ComputeBufferWriter::Close() +{ + if (m_Detail) + { + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + if (!m_Detail->WriteComplete) + { + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } + } + m_Detail = nullptr; + } +} + +bool +ComputeBufferWriter::IsValid() const +{ + return static_cast<bool>(m_Detail); +} + +void +ComputeBufferWriter::MarkComplete() +{ + if (m_Detail) + { + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + m_Detail->WriteComplete = true; + m_Detail->ReadCV.notify_all(); + } +} + +void +ComputeBufferWriter::AdvanceWritePosition(size_t Size) +{ + if (!m_Detail || Size == 0) + { + return; + } + + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + m_Detail->ChunkWrittenLength[WriteChunk] += Size; + m_Detail->WriteOffset += Size; + m_Detail->ReadCV.notify_all(); +} + +size_t +ComputeBufferWriter::GetMaxWriteSize() const +{ + if (!m_Detail) + { + return 0; + } + std::lock_guard<std::mutex> Lock(m_Detail->Mutex); + const size_t WriteChunk = m_Detail->WriteChunkIdx; + return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; +} + +size_t +ComputeBufferWriter::GetChunkMaxLength() const +{ + if (!m_Detail) + { + return 0; + } + return m_Detail->ChunkLength; +} + +size_t +ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs) +{ + uint8_t* Dest = WaitToWrite(1, TimeoutMs); + if (!Dest) + { + return 0; + } + + const size_t Available = GetMaxWriteSize(); + const size_t ToCopy = std::min(Available, MaxSize); + memcpy(Dest, Buffer, ToCopy); + AdvanceWritePosition(ToCopy); + return ToCopy; +} + +uint8_t* +ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs) +{ + if (!m_Detail) + { + return nullptr; + } + + std::unique_lock<std::mutex> Lock(m_Detail->Mutex); + + if (m_Detail->WriteComplete) + { + return nullptr; + } + + const size_t WriteChunk = m_Detail->WriteChunkIdx; + const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk]; + + // If current chunk has enough space, return pointer + if (Available >= MinSize) + { + return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk]; + } + + // Current chunk is full - mark it as finished and move to next. + // The writer cannot advance until the reader has fully consumed the next chunk, + // preventing the writer from overwriting data the reader hasn't processed yet. + m_Detail->ChunkFinished[WriteChunk] = true; + m_Detail->ReadCV.notify_all(); + + const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks; + + // Wait until reader has consumed the next chunk + auto Predicate = [&]() -> bool { + // Check if read has moved past this chunk + return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached; + }; + + if (TimeoutMs < 0) + { + m_Detail->WriteCV.wait(Lock, Predicate); + } + else + { + if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate)) + { + return nullptr; + } + } + + if (m_Detail->Detached) + { + return nullptr; + } + + // Reset next chunk + m_Detail->ChunkWrittenLength[NextChunk] = 0; + m_Detail->ChunkFinished[NextChunk] = false; + m_Detail->WriteChunkIdx = NextChunk; + m_Detail->WriteOffset = 0; + + return m_Detail->ChunkPtr(NextChunk); +} + +} // namespace zen::horde |