aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordeagentmessage.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordeagentmessage.cpp')
-rw-r--r--src/zenhorde/hordeagentmessage.cpp340
1 files changed, 340 insertions, 0 deletions
diff --git a/src/zenhorde/hordeagentmessage.cpp b/src/zenhorde/hordeagentmessage.cpp
new file mode 100644
index 000000000..998134a96
--- /dev/null
+++ b/src/zenhorde/hordeagentmessage.cpp
@@ -0,0 +1,340 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordeagentmessage.h"
+
+#include <zencore/intmath.h>
+
+#include <cassert>
+#include <cstring>
+
+namespace zen::horde {
+
+AgentMessageChannel::AgentMessageChannel(Ref<ComputeChannel> Channel) : m_Channel(std::move(Channel))
+{
+}
+
+AgentMessageChannel::~AgentMessageChannel() = default;
+
+void
+AgentMessageChannel::Close()
+{
+ CreateMessage(AgentMessageType::None, 0);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Ping()
+{
+ CreateMessage(AgentMessageType::Ping, 0);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Fork(int ChannelId, int BufferSize)
+{
+ CreateMessage(AgentMessageType::Fork, sizeof(int) + sizeof(int));
+ WriteInt32(ChannelId);
+ WriteInt32(BufferSize);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Attach()
+{
+ CreateMessage(AgentMessageType::Attach, 0);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::UploadFiles(const char* Path, const char* Locator)
+{
+ CreateMessage(AgentMessageType::WriteFiles, strlen(Path) + strlen(Locator) + 20);
+ WriteString(Path);
+ WriteString(Locator);
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Execute(const char* Exe,
+ const char* const* Args,
+ size_t NumArgs,
+ const char* WorkingDir,
+ const char* const* EnvVars,
+ size_t NumEnvVars,
+ ExecuteProcessFlags Flags)
+{
+ size_t RequiredSize = 50 + strlen(Exe);
+ for (size_t i = 0; i < NumArgs; ++i)
+ {
+ RequiredSize += strlen(Args[i]) + 10;
+ }
+ if (WorkingDir)
+ {
+ RequiredSize += strlen(WorkingDir) + 10;
+ }
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ RequiredSize += strlen(EnvVars[i]) + 20;
+ }
+
+ CreateMessage(AgentMessageType::ExecuteV2, RequiredSize);
+ WriteString(Exe);
+
+ WriteUnsignedVarInt(NumArgs);
+ for (size_t i = 0; i < NumArgs; ++i)
+ {
+ WriteString(Args[i]);
+ }
+
+ WriteOptionalString(WorkingDir);
+
+ // ExecuteV2 protocol requires env vars as separate key/value pairs.
+ // Callers pass "KEY=VALUE" strings; we split on the first '=' here.
+ WriteUnsignedVarInt(NumEnvVars);
+ for (size_t i = 0; i < NumEnvVars; ++i)
+ {
+ const char* Eq = strchr(EnvVars[i], '=');
+ assert(Eq != nullptr);
+
+ WriteString(std::string_view(EnvVars[i], Eq - EnvVars[i]));
+ if (*(Eq + 1) == '\0')
+ {
+ WriteOptionalString(nullptr);
+ }
+ else
+ {
+ WriteOptionalString(Eq + 1);
+ }
+ }
+
+ WriteInt32(static_cast<int>(Flags));
+ FlushMessage();
+}
+
+void
+AgentMessageChannel::Blob(const uint8_t* Data, size_t Length)
+{
+ // Blob responses are chunked to fit within the compute buffer's chunk size.
+ // The 128-byte margin accounts for the ReadBlobResponse header (offset + total length fields).
+ const size_t MaxChunkSize = m_Channel->Writer.GetChunkMaxLength() - 128 - MessageHeaderLength;
+ for (size_t ChunkOffset = 0; ChunkOffset < Length;)
+ {
+ const size_t ChunkLength = std::min(Length - ChunkOffset, MaxChunkSize);
+
+ CreateMessage(AgentMessageType::ReadBlobResponse, ChunkLength + 128);
+ WriteInt32(static_cast<int>(ChunkOffset));
+ WriteInt32(static_cast<int>(Length));
+ WriteFixedLengthBytes(Data + ChunkOffset, ChunkLength);
+ FlushMessage();
+
+ ChunkOffset += ChunkLength;
+ }
+}
+
+AgentMessageType
+AgentMessageChannel::ReadResponse(int32_t TimeoutMs, bool* OutTimedOut)
+{
+ // Deferred advance: the previous response's buffer is only released when the next
+ // ReadResponse is called. This allows callers to read response data between calls
+ // without copying, since the pointer comes directly from the ring buffer.
+ if (m_ResponseData)
+ {
+ m_Channel->Reader.AdvanceReadPosition(m_ResponseLength + MessageHeaderLength);
+ m_ResponseData = nullptr;
+ m_ResponseLength = 0;
+ }
+
+ const uint8_t* Header = m_Channel->Reader.WaitToRead(MessageHeaderLength, TimeoutMs, OutTimedOut);
+ if (!Header)
+ {
+ return AgentMessageType::None;
+ }
+
+ uint32_t Length;
+ memcpy(&Length, Header + 1, sizeof(uint32_t));
+
+ Header = m_Channel->Reader.WaitToRead(MessageHeaderLength + Length, TimeoutMs, OutTimedOut);
+ if (!Header)
+ {
+ return AgentMessageType::None;
+ }
+
+ m_ResponseType = static_cast<AgentMessageType>(Header[0]);
+ m_ResponseData = Header + MessageHeaderLength;
+ m_ResponseLength = Length;
+
+ return m_ResponseType;
+}
+
+void
+AgentMessageChannel::ReadException(ExceptionInfo& Ex)
+{
+ assert(m_ResponseType == AgentMessageType::Exception);
+ const uint8_t* Pos = m_ResponseData;
+ Ex.Message = ReadString(&Pos);
+ Ex.Description = ReadString(&Pos);
+}
+
+int
+AgentMessageChannel::ReadExecuteResult()
+{
+ assert(m_ResponseType == AgentMessageType::ExecuteResult);
+ const uint8_t* Pos = m_ResponseData;
+ return ReadInt32(&Pos);
+}
+
+void
+AgentMessageChannel::ReadBlobRequest(BlobRequest& Req)
+{
+ assert(m_ResponseType == AgentMessageType::ReadBlob);
+ const uint8_t* Pos = m_ResponseData;
+ Req.Locator = ReadString(&Pos);
+ Req.Offset = ReadUnsignedVarInt(&Pos);
+ Req.Length = ReadUnsignedVarInt(&Pos);
+}
+
+void
+AgentMessageChannel::CreateMessage(AgentMessageType Type, size_t MaxLength)
+{
+ m_RequestData = m_Channel->Writer.WaitToWrite(MessageHeaderLength + MaxLength);
+ m_RequestData[0] = static_cast<uint8_t>(Type);
+ m_MaxRequestSize = MaxLength;
+ m_RequestSize = 0;
+}
+
+void
+AgentMessageChannel::FlushMessage()
+{
+ const uint32_t Size = static_cast<uint32_t>(m_RequestSize);
+ memcpy(&m_RequestData[1], &Size, sizeof(uint32_t));
+ m_Channel->Writer.AdvanceWritePosition(MessageHeaderLength + m_RequestSize);
+ m_RequestSize = 0;
+ m_MaxRequestSize = 0;
+ m_RequestData = nullptr;
+}
+
+void
+AgentMessageChannel::WriteInt32(int Value)
+{
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(&Value), sizeof(int));
+}
+
+int
+AgentMessageChannel::ReadInt32(const uint8_t** Pos)
+{
+ int Value;
+ memcpy(&Value, *Pos, sizeof(int));
+ *Pos += sizeof(int);
+ return Value;
+}
+
+void
+AgentMessageChannel::WriteFixedLengthBytes(const uint8_t* Data, size_t Length)
+{
+ assert(m_RequestSize + Length <= m_MaxRequestSize);
+ memcpy(&m_RequestData[MessageHeaderLength + m_RequestSize], Data, Length);
+ m_RequestSize += Length;
+}
+
+const uint8_t*
+AgentMessageChannel::ReadFixedLengthBytes(const uint8_t** Pos, size_t Length)
+{
+ const uint8_t* Data = *Pos;
+ *Pos += Length;
+ return Data;
+}
+
+size_t
+AgentMessageChannel::MeasureUnsignedVarInt(size_t Value)
+{
+ if (Value == 0)
+ {
+ return 1;
+ }
+ return (FloorLog2_64(static_cast<uint64_t>(Value)) / 7) + 1;
+}
+
+void
+AgentMessageChannel::WriteUnsignedVarInt(size_t Value)
+{
+ const size_t ByteCount = MeasureUnsignedVarInt(Value);
+ assert(m_RequestSize + ByteCount <= m_MaxRequestSize);
+
+ uint8_t* Output = m_RequestData + MessageHeaderLength + m_RequestSize;
+ for (size_t i = 1; i < ByteCount; ++i)
+ {
+ Output[ByteCount - i] = static_cast<uint8_t>(Value);
+ Value >>= 8;
+ }
+ Output[0] = static_cast<uint8_t>((0xFF << (9 - static_cast<int>(ByteCount))) | static_cast<uint8_t>(Value));
+
+ m_RequestSize += ByteCount;
+}
+
+size_t
+AgentMessageChannel::ReadUnsignedVarInt(const uint8_t** Pos)
+{
+ const uint8_t* Data = *Pos;
+ const uint8_t FirstByte = Data[0];
+ const size_t NumBytes = CountLeadingZeros(0xFF & (~static_cast<unsigned int>(FirstByte))) + 1 - 24;
+
+ size_t Value = static_cast<size_t>(FirstByte & (0xFF >> NumBytes));
+ for (size_t i = 1; i < NumBytes; ++i)
+ {
+ Value <<= 8;
+ Value |= Data[i];
+ }
+
+ *Pos += NumBytes;
+ return Value;
+}
+
+size_t
+AgentMessageChannel::MeasureString(const char* Text) const
+{
+ const size_t Length = strlen(Text);
+ return MeasureUnsignedVarInt(Length) + Length;
+}
+
+void
+AgentMessageChannel::WriteString(const char* Text)
+{
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Length);
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+}
+
+void
+AgentMessageChannel::WriteString(std::string_view Text)
+{
+ WriteUnsignedVarInt(Text.size());
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+}
+
+std::string_view
+AgentMessageChannel::ReadString(const uint8_t** Pos)
+{
+ const size_t Length = ReadUnsignedVarInt(Pos);
+ const char* Start = reinterpret_cast<const char*>(ReadFixedLengthBytes(Pos, Length));
+ return std::string_view(Start, Length);
+}
+
+void
+AgentMessageChannel::WriteOptionalString(const char* Text)
+{
+ // Optional strings use length+1 encoding: 0 means null/absent,
+ // N>0 means a string of length N-1 follows. This matches the UE
+ // FAgentMessageChannel serialization convention.
+ if (!Text)
+ {
+ WriteUnsignedVarInt(0);
+ }
+ else
+ {
+ const size_t Length = strlen(Text);
+ WriteUnsignedVarInt(Length + 1);
+ WriteFixedLengthBytes(reinterpret_cast<const uint8_t*>(Text), Length);
+ }
+}
+
+} // namespace zen::horde