aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordecomputebuffer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordecomputebuffer.cpp')
-rw-r--r--src/zenhorde/hordecomputebuffer.cpp454
1 files changed, 454 insertions, 0 deletions
diff --git a/src/zenhorde/hordecomputebuffer.cpp b/src/zenhorde/hordecomputebuffer.cpp
new file mode 100644
index 000000000..0d032b5d5
--- /dev/null
+++ b/src/zenhorde/hordecomputebuffer.cpp
@@ -0,0 +1,454 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordecomputebuffer.h"
+
+#include <algorithm>
+#include <cassert>
+#include <chrono>
+#include <condition_variable>
+#include <cstring>
+
+namespace zen::horde {
+
+// Simplified ring buffer implementation for in-process use only.
+// Uses a single contiguous buffer with write/read cursors and
+// mutex+condvar for synchronization. This is simpler than the UE version
+// which uses lock-free atomics and shared memory, but sufficient for our
+// use case where we're the initiator side of the compute protocol.
+
+struct ComputeBuffer::Detail : TRefCounted<Detail>
+{
+ std::vector<uint8_t> Data;
+ size_t NumChunks = 0;
+ size_t ChunkLength = 0;
+
+ // Current write state
+ size_t WriteChunkIdx = 0;
+ size_t WriteOffset = 0;
+ bool WriteComplete = false;
+
+ // Current read state
+ size_t ReadChunkIdx = 0;
+ size_t ReadOffset = 0;
+ bool Detached = false;
+
+ // Per-chunk written length
+ std::vector<size_t> ChunkWrittenLength;
+ std::vector<bool> ChunkFinished; // Writer moved to next chunk
+
+ std::mutex Mutex;
+ std::condition_variable ReadCV; ///< Signaled when new data is written or stream completes
+ std::condition_variable WriteCV; ///< Signaled when reader advances past a chunk, freeing space
+
+ bool HasWriter = false;
+ bool HasReader = false;
+
+ uint8_t* ChunkPtr(size_t ChunkIdx) { return Data.data() + ChunkIdx * ChunkLength; }
+ const uint8_t* ChunkPtr(size_t ChunkIdx) const { return Data.data() + ChunkIdx * ChunkLength; }
+};
+
+// ComputeBuffer
+
+ComputeBuffer::ComputeBuffer()
+{
+}
+ComputeBuffer::~ComputeBuffer()
+{
+}
+
+bool
+ComputeBuffer::CreateNew(const Params& InParams)
+{
+ auto* NewDetail = new Detail();
+ NewDetail->NumChunks = InParams.NumChunks;
+ NewDetail->ChunkLength = InParams.ChunkLength;
+ NewDetail->Data.resize(InParams.NumChunks * InParams.ChunkLength, 0);
+ NewDetail->ChunkWrittenLength.resize(InParams.NumChunks, 0);
+ NewDetail->ChunkFinished.resize(InParams.NumChunks, false);
+
+ m_Detail = NewDetail;
+ return true;
+}
+
+void
+ComputeBuffer::Close()
+{
+ m_Detail = nullptr;
+}
+
+bool
+ComputeBuffer::IsValid() const
+{
+ return static_cast<bool>(m_Detail);
+}
+
+ComputeBufferReader
+ComputeBuffer::CreateReader()
+{
+ assert(m_Detail);
+ m_Detail->HasReader = true;
+ return ComputeBufferReader(m_Detail);
+}
+
+ComputeBufferWriter
+ComputeBuffer::CreateWriter()
+{
+ assert(m_Detail);
+ m_Detail->HasWriter = true;
+ return ComputeBufferWriter(m_Detail);
+}
+
+// ComputeBufferReader
+
+ComputeBufferReader::ComputeBufferReader()
+{
+}
+ComputeBufferReader::~ComputeBufferReader()
+{
+}
+
+ComputeBufferReader::ComputeBufferReader(const ComputeBufferReader& Other) = default;
+ComputeBufferReader::ComputeBufferReader(ComputeBufferReader&& Other) noexcept = default;
+ComputeBufferReader& ComputeBufferReader::operator=(const ComputeBufferReader& Other) = default;
+ComputeBufferReader& ComputeBufferReader::operator=(ComputeBufferReader&& Other) noexcept = default;
+
+ComputeBufferReader::ComputeBufferReader(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
+{
+}
+
+void
+ComputeBufferReader::Close()
+{
+ m_Detail = nullptr;
+}
+
+void
+ComputeBufferReader::Detach()
+{
+ if (m_Detail)
+ {
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ m_Detail->Detached = true;
+ m_Detail->ReadCV.notify_all();
+ }
+}
+
+bool
+ComputeBufferReader::IsValid() const
+{
+ return static_cast<bool>(m_Detail);
+}
+
+bool
+ComputeBufferReader::IsComplete() const
+{
+ if (!m_Detail)
+ {
+ return true;
+ }
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ if (m_Detail->Detached)
+ {
+ return true;
+ }
+ return m_Detail->WriteComplete && m_Detail->ReadChunkIdx == m_Detail->WriteChunkIdx &&
+ m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[m_Detail->ReadChunkIdx];
+}
+
+void
+ComputeBufferReader::AdvanceReadPosition(size_t Size)
+{
+ if (!m_Detail)
+ {
+ return;
+ }
+
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+
+ m_Detail->ReadOffset += Size;
+
+ // Check if we need to move to next chunk
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
+ {
+ const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
+ m_Detail->ReadChunkIdx = NextChunk;
+ m_Detail->ReadOffset = 0;
+ m_Detail->WriteCV.notify_all();
+ }
+
+ m_Detail->ReadCV.notify_all();
+}
+
+size_t
+ComputeBufferReader::GetMaxReadSize() const
+{
+ if (!m_Detail)
+ {
+ return 0;
+ }
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ return m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
+}
+
+const uint8_t*
+ComputeBufferReader::WaitToRead(size_t MinSize, int TimeoutMs, bool* OutTimedOut)
+{
+ if (!m_Detail)
+ {
+ return nullptr;
+ }
+
+ std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
+
+ auto Predicate = [&]() -> bool {
+ if (m_Detail->Detached)
+ {
+ return true;
+ }
+
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
+
+ if (Available >= MinSize)
+ {
+ return true;
+ }
+
+ // If chunk is finished and we've read everything, try to move to next
+ if (m_Detail->ChunkFinished[ReadChunk] && m_Detail->ReadOffset >= m_Detail->ChunkWrittenLength[ReadChunk])
+ {
+ if (m_Detail->WriteComplete)
+ {
+ return true; // End of stream
+ }
+ // Move to next chunk
+ const size_t NextChunk = (ReadChunk + 1) % m_Detail->NumChunks;
+ m_Detail->ReadChunkIdx = NextChunk;
+ m_Detail->ReadOffset = 0;
+ m_Detail->WriteCV.notify_all();
+ return false; // Re-check with new chunk
+ }
+
+ if (m_Detail->WriteComplete)
+ {
+ return true; // End of stream
+ }
+
+ return false;
+ };
+
+ if (TimeoutMs < 0)
+ {
+ m_Detail->ReadCV.wait(Lock, Predicate);
+ }
+ else
+ {
+ if (!m_Detail->ReadCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
+ {
+ if (OutTimedOut)
+ {
+ *OutTimedOut = true;
+ }
+ return nullptr;
+ }
+ }
+
+ if (m_Detail->Detached)
+ {
+ return nullptr;
+ }
+
+ const size_t ReadChunk = m_Detail->ReadChunkIdx;
+ const size_t Available = m_Detail->ChunkWrittenLength[ReadChunk] - m_Detail->ReadOffset;
+
+ if (Available < MinSize)
+ {
+ return nullptr; // End of stream
+ }
+
+ return m_Detail->ChunkPtr(ReadChunk) + m_Detail->ReadOffset;
+}
+
+size_t
+ComputeBufferReader::Read(void* Buffer, size_t MaxSize, int TimeoutMs, bool* OutTimedOut)
+{
+ const uint8_t* Data = WaitToRead(1, TimeoutMs, OutTimedOut);
+ if (!Data)
+ {
+ return 0;
+ }
+
+ const size_t Available = GetMaxReadSize();
+ const size_t ToCopy = std::min(Available, MaxSize);
+ memcpy(Buffer, Data, ToCopy);
+ AdvanceReadPosition(ToCopy);
+ return ToCopy;
+}
+
+// ComputeBufferWriter
+
+ComputeBufferWriter::ComputeBufferWriter() = default;
+ComputeBufferWriter::ComputeBufferWriter(const ComputeBufferWriter& Other) = default;
+ComputeBufferWriter::ComputeBufferWriter(ComputeBufferWriter&& Other) noexcept = default;
+ComputeBufferWriter::~ComputeBufferWriter() = default;
+ComputeBufferWriter& ComputeBufferWriter::operator=(const ComputeBufferWriter& Other) = default;
+ComputeBufferWriter& ComputeBufferWriter::operator=(ComputeBufferWriter&& Other) noexcept = default;
+
+ComputeBufferWriter::ComputeBufferWriter(Ref<ComputeBuffer::Detail> InDetail) : m_Detail(std::move(InDetail))
+{
+}
+
+void
+ComputeBufferWriter::Close()
+{
+ if (m_Detail)
+ {
+ {
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ if (!m_Detail->WriteComplete)
+ {
+ m_Detail->WriteComplete = true;
+ m_Detail->ReadCV.notify_all();
+ }
+ }
+ m_Detail = nullptr;
+ }
+}
+
+bool
+ComputeBufferWriter::IsValid() const
+{
+ return static_cast<bool>(m_Detail);
+}
+
+void
+ComputeBufferWriter::MarkComplete()
+{
+ if (m_Detail)
+ {
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ m_Detail->WriteComplete = true;
+ m_Detail->ReadCV.notify_all();
+ }
+}
+
+void
+ComputeBufferWriter::AdvanceWritePosition(size_t Size)
+{
+ if (!m_Detail || Size == 0)
+ {
+ return;
+ }
+
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ const size_t WriteChunk = m_Detail->WriteChunkIdx;
+ m_Detail->ChunkWrittenLength[WriteChunk] += Size;
+ m_Detail->WriteOffset += Size;
+ m_Detail->ReadCV.notify_all();
+}
+
+size_t
+ComputeBufferWriter::GetMaxWriteSize() const
+{
+ if (!m_Detail)
+ {
+ return 0;
+ }
+ std::lock_guard<std::mutex> Lock(m_Detail->Mutex);
+ const size_t WriteChunk = m_Detail->WriteChunkIdx;
+ return m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
+}
+
+size_t
+ComputeBufferWriter::GetChunkMaxLength() const
+{
+ if (!m_Detail)
+ {
+ return 0;
+ }
+ return m_Detail->ChunkLength;
+}
+
+size_t
+ComputeBufferWriter::Write(const void* Buffer, size_t MaxSize, int TimeoutMs)
+{
+ uint8_t* Dest = WaitToWrite(1, TimeoutMs);
+ if (!Dest)
+ {
+ return 0;
+ }
+
+ const size_t Available = GetMaxWriteSize();
+ const size_t ToCopy = std::min(Available, MaxSize);
+ memcpy(Dest, Buffer, ToCopy);
+ AdvanceWritePosition(ToCopy);
+ return ToCopy;
+}
+
+uint8_t*
+ComputeBufferWriter::WaitToWrite(size_t MinSize, int TimeoutMs)
+{
+ if (!m_Detail)
+ {
+ return nullptr;
+ }
+
+ std::unique_lock<std::mutex> Lock(m_Detail->Mutex);
+
+ if (m_Detail->WriteComplete)
+ {
+ return nullptr;
+ }
+
+ const size_t WriteChunk = m_Detail->WriteChunkIdx;
+ const size_t Available = m_Detail->ChunkLength - m_Detail->ChunkWrittenLength[WriteChunk];
+
+ // If current chunk has enough space, return pointer
+ if (Available >= MinSize)
+ {
+ return m_Detail->ChunkPtr(WriteChunk) + m_Detail->ChunkWrittenLength[WriteChunk];
+ }
+
+ // Current chunk is full - mark it as finished and move to next.
+ // The writer cannot advance until the reader has fully consumed the next chunk,
+ // preventing the writer from overwriting data the reader hasn't processed yet.
+ m_Detail->ChunkFinished[WriteChunk] = true;
+ m_Detail->ReadCV.notify_all();
+
+ const size_t NextChunk = (WriteChunk + 1) % m_Detail->NumChunks;
+
+ // Wait until reader has consumed the next chunk
+ auto Predicate = [&]() -> bool {
+ // Check if read has moved past this chunk
+ return m_Detail->ReadChunkIdx != NextChunk || m_Detail->Detached;
+ };
+
+ if (TimeoutMs < 0)
+ {
+ m_Detail->WriteCV.wait(Lock, Predicate);
+ }
+ else
+ {
+ if (!m_Detail->WriteCV.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), Predicate))
+ {
+ return nullptr;
+ }
+ }
+
+ if (m_Detail->Detached)
+ {
+ return nullptr;
+ }
+
+ // Reset next chunk
+ m_Detail->ChunkWrittenLength[NextChunk] = 0;
+ m_Detail->ChunkFinished[NextChunk] = false;
+ m_Detail->WriteChunkIdx = NextChunk;
+ m_Detail->WriteOffset = 0;
+
+ return m_Detail->ChunkPtr(NextChunk);
+}
+
+} // namespace zen::horde