diff options
Diffstat (limited to 'src')
84 files changed, 8795 insertions, 2703 deletions
diff --git a/src/zen/cmds/exec_cmd.cpp b/src/zen/cmds/exec_cmd.cpp index cbc153e07..30e860a3f 100644 --- a/src/zen/cmds/exec_cmd.cpp +++ b/src/zen/cmds/exec_cmd.cpp @@ -18,6 +18,7 @@ #include <zencore/stream.h> #include <zencore/string.h> #include <zencore/system.h> +#include <zencore/thread.h> #include <zencore/timer.h> #include <zenhttp/httpclient.h> #include <zenhttp/packageformat.h> @@ -41,255 +42,122 @@ struct hash<zen::IoHash> : public zen::IoHash::Hasher namespace zen { -ExecCommand::ExecCommand() -{ - m_Options.add_options()("h,help", "Print help"); - m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>"); - m_Options.add_option("", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>"); - m_Options.add_option("", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>"); - m_Options.add_option("", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>"); - m_Options.add_option("", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>"); - m_Options.add_option("", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>"); - m_Options.add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>"); - m_Options.add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>"); - m_Options.add_option("", - "", - "mode", - "Select execution mode (http,inproc,dump,direct,beacon,buildlog)", - cxxopts::value(m_Mode)->default_value("http"), - "<string>"); - m_Options - .add_option("", "", "dump-actions", "Dump each action to console as it is dispatched", cxxopts::value(m_DumpActions), "<bool>"); - m_Options.add_option("", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>"); - m_Options.add_option("", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>"); - m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>"); - m_Options.parse_positional("mode"); -} - -ExecCommand::~ExecCommand() -{ -} - -void -ExecCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) -{ - // Configure - - if (!ParseOptions(argc, argv)) - { - return; - } - - m_HostName = ResolveTargetHostSpec(m_HostName); - - if (m_RecordingPath.empty()) - { - throw OptionParseException("replay path is required!", m_Options.help()); - } - - m_VerboseLogging = GlobalOptions.IsVerbose; - m_QuietLogging = m_Quiet && !m_VerboseLogging; - - enum ExecMode - { - kHttp, - kDirect, - kInproc, - kDump, - kBeacon, - kBuildLog - } Mode; - - if (m_Mode == "http"sv) - { - Mode = kHttp; - } - else if (m_Mode == "direct"sv) - { - Mode = kDirect; - } - else if (m_Mode == "inproc"sv) - { - Mode = kInproc; - } - else if (m_Mode == "dump"sv) - { - Mode = kDump; - } - else if (m_Mode == "beacon"sv) - { - Mode = kBeacon; - } - else if (m_Mode == "buildlog"sv) - { - Mode = kBuildLog; - } - else - { - throw OptionParseException("invalid mode specified!", m_Options.help()); - } +namespace { - // Gather information from recording path - - std::unique_ptr<zen::compute::RecordingReader> Reader; - std::unique_ptr<zen::compute::UeRecordingReader> UeReader; - - std::filesystem::path RecordingPath{m_RecordingPath}; - - if (!std::filesystem::is_directory(RecordingPath)) - { - throw OptionParseException("replay path should be a directory path!", m_Options.help()); - } - else + static std::string EscapeHtml(std::string_view Input) { - if (std::filesystem::is_directory(RecordingPath / "cid")) + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) { - Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath); - m_WorkerMap = Reader->ReadWorkers(); - m_ChunkResolver = Reader.get(); - m_RecordingReader = Reader.get(); - } - else - { - UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath); - m_WorkerMap = UeReader->ReadWorkers(); - m_ChunkResolver = UeReader.get(); - m_RecordingReader = UeReader.get(); + switch (C) + { + case '&': + Out += "&"; + break; + case '<': + Out += "<"; + break; + case '>': + Out += ">"; + break; + case '"': + Out += """; + break; + case '\'': + Out += "'"; + break; + default: + Out += C; + } } + return Out; } - ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount()); - - for (auto& Kv : m_WorkerMap) + static std::string EscapeJson(std::string_view Input) { - CbObject WorkerDesc = Kv.second.GetObject(); - const IoHash& WorkerId = Kv.first; - - RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId); - - if (m_VerboseLogging) + std::string Out; + Out.reserve(Input.size()); + for (char C : Input) { - zen::ExtendableStringBuilder<1024> ObjStr; -# if 0 - zen::CompactBinaryToJson(WorkerDesc, ObjStr); - ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr); -# else - zen::CompactBinaryToYaml(WorkerDesc, ObjStr); - ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr); -# endif + switch (C) + { + case '"': + Out += "\\\""; + break; + case '\\': + Out += "\\\\"; + break; + case '\n': + Out += "\\n"; + break; + case '\r': + Out += "\\r"; + break; + case '\t': + Out += "\\t"; + break; + default: + if (static_cast<unsigned char>(C) < 0x20) + { + Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C))); + } + else + { + Out += C; + } + } } + return Out; } - if (m_VerboseLogging) - { - EmitFunctionList(m_FunctionList); - } - - // Iterate over work items and dispatch or log them - - int ReturnValue = 0; - - Stopwatch ExecTimer; - - switch (Mode) - { - case kHttp: - // Forward requests to HTTP function service - ReturnValue = HttpExecute(); - break; - - case kDirect: - // Not currently supported - ReturnValue = LocalMessagingExecute(); - break; - - case kInproc: - // Handle execution in-core (by spawning child processes) - ReturnValue = InProcessExecute(); - break; - - case kDump: - // Dump high level information about actions to console - ReturnValue = DumpWorkItems(); - break; - - case kBeacon: - ReturnValue = BeaconExecute(); - break; - - case kBuildLog: - ReturnValue = BuildActionsLog(); - break; - - default: - ZEN_ERROR("Unknown operating mode! No work submitted"); - - ReturnValue = 1; - } +} // namespace - ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); - - if (!ReturnValue) - { - ZEN_CONSOLE("all work items completed successfully"); - } - else - { - ZEN_CONSOLE("some work items failed (code {})", ReturnValue); - } -} +////////////////////////////////////////////////////////////////////////// +// ExecSessionConfig — read-only configuration for a session run -int -ExecCommand::InProcessExecute() +struct ExecSessionConfig { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; + zen::ChunkResolver& Resolver; + zen::compute::RecordingReaderBase& RecordingReader; + const std::unordered_map<zen::IoHash, zen::CbPackage>& WorkerMap; + std::vector<ExecFunctionDefinition>& FunctionList; // mutable for EmitFunctionListOnce + std::string_view OrchestratorUrl; + const std::filesystem::path& OutputPath; + int Offset = 0; + int Stride = 1; + int Limit = 0; + bool Verbose = false; + bool Quiet = false; + bool DumpActions = false; + bool Binary = false; +}; - zen::compute::ComputeServiceSession ComputeSession(Resolver); +////////////////////////////////////////////////////////////////////////// +// ExecSessionRunner — owns per-run state, drives the session lifecycle - std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); - ComputeSession.AddLocalRunner(Resolver, TempPath); +class ExecSessionRunner +{ +public: + ExecSessionRunner(zen::compute::ComputeServiceSession& Session, const ExecSessionConfig& Config); + int Run(); - return ExecUsingSession(ComputeSession); -} +private: + // Types -int -ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession) -{ struct JobTracker { public: - inline void Insert(int LsnField) - { - RwLock::ExclusiveLockScope _(Lock); - PendingJobs.insert(LsnField); - } - - inline bool IsEmpty() const - { - RwLock::SharedLockScope _(Lock); - return PendingJobs.empty(); - } - - inline void Remove(int CompleteLsn) - { - RwLock::ExclusiveLockScope _(Lock); - PendingJobs.erase(CompleteLsn); - } - - inline size_t GetSize() const - { - RwLock::SharedLockScope _(Lock); - return PendingJobs.size(); - } + void Insert(int LsnField); + bool IsEmpty() const; + void Remove(int CompleteLsn); + size_t GetSize() const; private: mutable RwLock Lock; std::unordered_set<int> PendingJobs; }; - JobTracker PendingJobs; - struct ActionSummaryEntry { int32_t Lsn = 0; @@ -307,664 +175,471 @@ ExecCommand::ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSessio std::string ExecutionLocation; }; - std::mutex SummaryLock; - std::unordered_map<int32_t, ActionSummaryEntry> SummaryEntries; + // Methods + + std::string RegisterOrchestratorClient(); + void SendOrchestratorHeartbeat(); + void CompleteOrchestratorClient(); + void DrainCompletedJobs(); + void SaveResultOutput(int32_t CompleteLsn, CbPackage& ResultPackage); + void SaveActionOutput(int32_t Lsn, int RecordingIndex, const IoHash& ActionId, const CbObject& ActionObject); + void WriteSummaryFiles(); + void EmitFunctionListOnce(); + + // State + + zen::compute::ComputeServiceSession& m_Session; + ExecSessionConfig m_Config; + JobTracker m_PendingJobs; + std::mutex m_SummaryLock; + std::unordered_map<int32_t, ActionSummaryEntry> m_SummaryEntries; + int m_QueueId = 0; + std::string m_OrchestratorClientId; + Stopwatch m_OrchestratorHeartbeatTimer; + bool m_FunctionListEmittedOnce = false; + std::atomic<int> m_IsDraining{0}; +}; - ComputeSession.WaitUntilReady(); +////////////////////////////////////////////////////////////////////////// +// ExecSessionRunner::JobTracker - // Register as a client with the orchestrator (best-effort) +void +ExecSessionRunner::JobTracker::Insert(int LsnField) +{ + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.insert(LsnField); +} - std::string OrchestratorClientId; +bool +ExecSessionRunner::JobTracker::IsEmpty() const +{ + RwLock::SharedLockScope _(Lock); + return PendingJobs.empty(); +} - if (!m_OrchestratorUrl.empty()) +void +ExecSessionRunner::JobTracker::Remove(int CompleteLsn) +{ + RwLock::ExclusiveLockScope _(Lock); + PendingJobs.erase(CompleteLsn); +} + +size_t +ExecSessionRunner::JobTracker::GetSize() const +{ + RwLock::SharedLockScope _(Lock); + return PendingJobs.size(); +} + +////////////////////////////////////////////////////////////////////////// +// ExecSessionRunner implementation + +ExecSessionRunner::ExecSessionRunner(zen::compute::ComputeServiceSession& Session, const ExecSessionConfig& Config) +: m_Session(Session) +, m_Config(Config) +{ +} + +std::string +ExecSessionRunner::RegisterOrchestratorClient() +{ + if (m_Config.OrchestratorUrl.empty()) { - try - { - HttpClient OrchestratorClient(m_OrchestratorUrl); + return {}; + } - CbObjectWriter Ann; - Ann << "session_id"sv << GetSessionId(); - Ann << "hostname"sv << std::string_view(GetMachineName()); + try + { + HttpClient OrchestratorClient(m_Config.OrchestratorUrl); - CbObjectWriter Meta; - Meta << "source"sv - << "zen-exec"sv; - Ann << "metadata"sv << Meta.Save(); + CbObjectWriter Ann; + Ann << "session_id"sv << GetSessionId(); + Ann << "hostname"sv << std::string_view(GetMachineName()); - auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); - if (Resp.IsSuccess()) - { - OrchestratorClientId = std::string(Resp.AsObject()["id"].AsString()); - ZEN_CONSOLE_INFO("registered with orchestrator as {}", OrchestratorClientId); - } - else - { - ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode)); - } + CbObjectWriter Meta; + Meta << "source"sv + << "zen-exec"sv; + Ann << "metadata"sv << Meta.Save(); + + auto Resp = OrchestratorClient.Post("/orch/clients", Ann.Save()); + if (Resp.IsSuccess()) + { + std::string ClientId = std::string(Resp.AsObject()["id"].AsString()); + ZEN_CONSOLE_INFO("registered with orchestrator as {}", ClientId); + return ClientId; } - catch (const std::exception& Ex) + else { - ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + ZEN_WARN("failed to register with orchestrator (status {})", static_cast<int>(Resp.StatusCode)); } } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to register with orchestrator: {}", Ex.what()); + } - Stopwatch OrchestratorHeartbeatTimer; + return {}; +} - auto SendOrchestratorHeartbeat = [&] { - if (OrchestratorClientId.empty() || OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) - { - return; - } - OrchestratorHeartbeatTimer.Reset(); +void +ExecSessionRunner::SendOrchestratorHeartbeat() +{ + if (m_OrchestratorClientId.empty() || m_OrchestratorHeartbeatTimer.GetElapsedTimeMs() < 30'000) + { + return; + } + m_OrchestratorHeartbeatTimer.Reset(); + try + { + HttpClient OrchestratorClient(m_Config.OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", m_OrchestratorClientId)); + } + catch (...) + { + } +} + +void +ExecSessionRunner::CompleteOrchestratorClient() +{ + if (!m_OrchestratorClientId.empty()) + { try { - HttpClient OrchestratorClient(m_OrchestratorUrl); - std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/update", OrchestratorClientId)); + HttpClient OrchestratorClient(m_Config.OrchestratorUrl); + std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", m_OrchestratorClientId)); } catch (...) { } - }; - - auto ClientCleanup = MakeGuard([&] { - if (!OrchestratorClientId.empty()) - { - try - { - HttpClient OrchestratorClient(m_OrchestratorUrl); - std::ignore = OrchestratorClient.Post(fmt::format("/orch/clients/{}/complete", OrchestratorClientId)); - } - catch (...) - { - } - } - }); - - // Create a queue to group all actions from this exec session - - CbObjectWriter Metadata; - Metadata << "source"sv - << "zen-exec"sv; - - auto QueueResult = ComputeSession.CreateQueue("zen-exec", Metadata.Save()); - const int QueueId = QueueResult.QueueId; - if (!QueueId) - { - ZEN_ERROR("failed to create compute queue"); - return 1; } +} - auto QueueCleanup = MakeGuard([&] { ComputeSession.DeleteQueue(QueueId); }); - - if (!m_OutputPath.empty()) +void +ExecSessionRunner::DrainCompletedJobs() +{ + if (m_IsDraining.exchange(1)) { - zen::CreateDirectories(m_OutputPath); + return; } - std::atomic<int> IsDraining{0}; + auto _ = MakeGuard([&] { m_IsDraining.store(0, std::memory_order_release); }); - auto DrainCompletedJobs = [&] { - if (IsDraining.exchange(1)) - { - return; - } - - auto _ = MakeGuard([&] { IsDraining.store(0, std::memory_order_release); }); + CbObjectWriter Cbo; + m_Session.GetQueueCompleted(m_QueueId, Cbo); - CbObjectWriter Cbo; - ComputeSession.GetQueueCompleted(QueueId, Cbo); - - if (CbObject Completed = Cbo.Save()) + if (CbObject Completed = Cbo.Save()) + { + for (auto& It : Completed["completed"sv]) { - for (auto& It : Completed["completed"sv]) - { - int32_t CompleteLsn = It.AsInt32(); + int32_t CompleteLsn = It.AsInt32(); - CbPackage ResultPackage; - HttpResponseCode Response = ComputeSession.GetActionResult(CompleteLsn, /* out */ ResultPackage); + CbPackage ResultPackage; + HttpResponseCode Response = m_Session.GetActionResult(CompleteLsn, /* out */ ResultPackage); - if (Response == HttpResponseCode::OK) + if (Response == HttpResponseCode::OK) + { + if (!m_Config.OutputPath.empty() && ResultPackage) { - if (!m_OutputPath.empty() && ResultPackage) - { - int OutputAttachments = 0; - uint64_t OutputBytes = 0; - - if (!m_Binary) - { - // Write the root object as YAML - ExtendableStringBuilder<4096> YamlStr; - CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); - - std::string_view Yaml = YamlStr; - zen::WriteFile(m_OutputPath / fmt::format("{}.result.yaml", CompleteLsn), - IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - - // Write decompressed attachments - auto Attachments = ResultPackage.GetAttachments(); - - if (!Attachments.empty()) - { - std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.result.attachments", CompleteLsn); - zen::CreateDirectories(AttDir); - - for (const CbAttachment& Att : Attachments) - { - ++OutputAttachments; - - IoHash AttHash = Att.GetHash(); - - if (Att.IsCompressedBinary()) - { - SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); - OutputBytes += Decompressed.GetSize(); - zen::WriteFile(AttDir / AttHash.ToHexString(), - IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); - } - else - { - SharedBuffer Binary = Att.AsBinary(); - OutputBytes += Binary.GetSize(); - zen::WriteFile(AttDir / AttHash.ToHexString(), - IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); - } - } - } - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", - m_OutputPath.string(), - CompleteLsn, - OutputAttachments); - } - } - else - { - CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); - zen::WriteFile(m_OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); - - for (const CbAttachment& Att : ResultPackage.GetAttachments()) - { - ++OutputAttachments; - OutputBytes += Att.AsBinary().GetSize(); - } - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_OutputPath.string(), CompleteLsn); - } - } - - std::lock_guard Lock(SummaryLock); - if (auto It2 = SummaryEntries.find(CompleteLsn); It2 != SummaryEntries.end()) - { - It2->second.OutputAttachments = OutputAttachments; - It2->second.OutputBytes = OutputBytes; - } - } + SaveResultOutput(CompleteLsn, ResultPackage); + } - PendingJobs.Remove(CompleteLsn); + m_PendingJobs.Remove(CompleteLsn); - ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, PendingJobs.GetSize()); - } + ZEN_CONSOLE("completed: LSN {} ({} still pending)", CompleteLsn, m_PendingJobs.GetSize()); } } - }; - - // Describe workers + } +} - ZEN_CONSOLE("describing {} workers", m_WorkerMap.size()); +void +ExecSessionRunner::SaveResultOutput(int32_t CompleteLsn, CbPackage& ResultPackage) +{ + int OutputAttachments = 0; + uint64_t OutputBytes = 0; - for (auto Kv : m_WorkerMap) + if (!m_Config.Binary) { - CbPackage WorkerDesc = Kv.second; + // Write the root object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ResultPackage.GetObject(), YamlStr); - ComputeSession.RegisterWorker(WorkerDesc); - } + std::string_view Yaml = YamlStr; + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.result.yaml", CompleteLsn), + IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - // Then submit work items + // Write decompressed attachments + auto Attachments = ResultPackage.GetAttachments(); - int FailedWorkCounter = 0; - size_t RemainingWorkItems = m_RecordingReader->GetActionCount(); - int SubmittedWorkItems = 0; - - ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + if (!Attachments.empty()) + { + std::filesystem::path AttDir = m_Config.OutputPath / fmt::format("{}.result.attachments", CompleteLsn); + zen::CreateDirectories(AttDir); - int OffsetCounter = m_Offset; - int StrideCounter = m_Stride; + for (const CbAttachment& Att : Attachments) + { + ++OutputAttachments; - auto ShouldSchedule = [&]() -> bool { - if (m_Limit && SubmittedWorkItems >= m_Limit) - { - // Limit reached, ignore + IoHash AttHash = Att.GetHash(); - return false; + if (Att.IsCompressedBinary()) + { + SharedBuffer Decompressed = Att.AsCompressedBinary().Decompress(); + OutputBytes += Decompressed.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), + IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + else + { + SharedBuffer Binary = Att.AsBinary(); + OutputBytes += Binary.GetSize(); + zen::WriteFile(AttDir / AttHash.ToHexString(), IoBuffer(IoBuffer::Clone, Binary.GetData(), Binary.GetSize())); + } + } } - if (OffsetCounter && OffsetCounter--) + if (!m_Config.Quiet) { - // Still in offset, ignore - - return false; + ZEN_CONSOLE("saved result: {}/{}.result.yaml ({} attachments)", m_Config.OutputPath.string(), CompleteLsn, OutputAttachments); } + } + else + { + CompositeBuffer Serialized = FormatPackageMessageBuffer(ResultPackage); + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.result.pkg", CompleteLsn), std::move(Serialized)); - if (--StrideCounter == 0) + for (const CbAttachment& Att : ResultPackage.GetAttachments()) { - StrideCounter = m_Stride; - - return true; + ++OutputAttachments; + OutputBytes += Att.AsBinary().GetSize(); } - return false; - }; - - int TargetParallelism = 8; + if (!m_Config.Quiet) + { + ZEN_CONSOLE("saved result: {}/{}.result.pkg", m_Config.OutputPath.string(), CompleteLsn); + } + } - if (OffsetCounter || StrideCounter || m_Limit) + std::lock_guard Lock(m_SummaryLock); + if (auto It2 = m_SummaryEntries.find(CompleteLsn); It2 != m_SummaryEntries.end()) { - TargetParallelism = 1; + It2->second.OutputAttachments = OutputAttachments; + It2->second.OutputBytes = OutputBytes; } +} - std::atomic<int> RecordingIndex{0}; +void +ExecSessionRunner::SaveActionOutput(int32_t Lsn, int RecordingIndex, const IoHash& ActionId, const CbObject& ActionObject) +{ + ActionSummaryEntry Entry; + Entry.Lsn = Lsn; + Entry.RecordingIndex = RecordingIndex; + Entry.ActionId = ActionId; + Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); - m_RecordingReader->IterateActions( - [&](CbObject ActionObject, const IoHash& ActionId) { - // Enqueue job + if (!m_Config.Binary) + { + // Write action object as YAML + ExtendableStringBuilder<4096> YamlStr; + CompactBinaryToYaml(ActionObject, YamlStr); - const int CurrentRecordingIndex = RecordingIndex++; + std::string_view Yaml = YamlStr; + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.action.yaml", Lsn), IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - Stopwatch SubmitTimer; + // Write decompressed input attachments + std::filesystem::path AttDir = m_Config.OutputPath / fmt::format("{}.action.attachments", Lsn); + bool AttDirCreated = false; - const int Priority = 0; + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; - if (ShouldSchedule()) + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachCid)) { - if (m_VerboseLogging) - { - int AttachmentCount = 0; - uint64_t AttachmentBytes = 0; - eastl::hash_set<IoHash> ReferencedChunks; + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); + SharedBuffer Decompressed = Compressed.Decompress(); - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachData = Field.AsAttachment(); + Entry.InputBytes += Decompressed.GetSize(); - ReferencedChunks.insert(AttachData); - ++AttachmentCount; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) - { - AttachmentBytes += ChunkData.GetSize(); - } - }); - - zen::ExtendableStringBuilder<1024> ObjStr; - zen::CompactBinaryToJson(ActionObject, ObjStr); - ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}", - ActionId, - AttachmentCount, - NiceBytes(AttachmentBytes), - ObjStr); - } - - if (m_DumpActions) + if (!AttDirCreated) { - int AttachmentCount = 0; - uint64_t AttachmentBytes = 0; - - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachData = Field.AsAttachment(); - - ++AttachmentCount; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachData)) - { - AttachmentBytes += ChunkData.GetSize(); - } - }); - - zen::ExtendableStringBuilder<1024> ObjStr; - zen::CompactBinaryToYaml(ActionObject, ObjStr); - ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + zen::CreateDirectories(AttDir); + AttDirCreated = true; } - if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = - ComputeSession.EnqueueActionToQueue(QueueId, ActionObject, Priority)) - { - const int32_t LsnField = EnqueueResult.Lsn; - - --RemainingWorkItems; - ++SubmittedWorkItems; - - if (!m_QuietLogging) - { - ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", - SubmittedWorkItems, - LsnField, - NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), - RemainingWorkItems); - } - - if (!m_OutputPath.empty()) - { - ActionSummaryEntry Entry; - Entry.Lsn = LsnField; - Entry.RecordingIndex = CurrentRecordingIndex; - Entry.ActionId = ActionId; - Entry.FunctionName = std::string(ActionObject["Function"sv].AsString()); - - if (!m_Binary) - { - // Write action object as YAML - ExtendableStringBuilder<4096> YamlStr; - CompactBinaryToYaml(ActionObject, YamlStr); - - std::string_view Yaml = YamlStr; - zen::WriteFile(m_OutputPath / fmt::format("{}.action.yaml", LsnField), - IoBuffer(IoBuffer::Clone, Yaml.data(), Yaml.size())); - - // Write decompressed input attachments - std::filesystem::path AttDir = m_OutputPath / fmt::format("{}.action.attachments", LsnField); - bool AttDirCreated = false; - - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachCid = Field.AsAttachment(); - ++Entry.InputAttachments; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) - { - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); - SharedBuffer Decompressed = Compressed.Decompress(); - - Entry.InputBytes += Decompressed.GetSize(); - - if (!AttDirCreated) - { - zen::CreateDirectories(AttDir); - AttDirCreated = true; - } - - zen::WriteFile(AttDir / AttachCid.ToHexString(), - IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); - } - }); - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", - m_OutputPath.string(), - LsnField, - Entry.InputAttachments); - } - } - else - { - // Build a CbPackage from the action and write as .pkg - CbPackage ActionPackage; - ActionPackage.SetObject(ActionObject); - - ActionObject.IterateAttachments([&](CbFieldView Field) { - IoHash AttachCid = Field.AsAttachment(); - ++Entry.InputAttachments; - - if (IoBuffer ChunkData = m_ChunkResolver->FindChunkByCid(AttachCid)) - { - IoHash RawHash; - uint64_t RawSize = 0; - CompressedBuffer Compressed = - CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); - - Entry.InputBytes += ChunkData.GetSize(); - ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); - } - }); - - CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); - zen::WriteFile(m_OutputPath / fmt::format("{}.action.pkg", LsnField), std::move(Serialized)); - - if (!m_QuietLogging) - { - ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_OutputPath.string(), LsnField); - } - } - - std::lock_guard Lock(SummaryLock); - SummaryEntries.emplace(LsnField, std::move(Entry)); - } + zen::WriteFile(AttDir / AttachCid.ToHexString(), IoBuffer(IoBuffer::Clone, Decompressed.GetData(), Decompressed.GetSize())); + } + }); - PendingJobs.Insert(LsnField); - } - else - { - if (!m_QuietLogging) - { - std::string_view FunctionName = ActionObject["Function"sv].AsString(); - const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); - const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + if (!m_Config.Quiet) + { + ZEN_CONSOLE("saved action: {}/{}.action.yaml ({} attachments)", m_Config.OutputPath.string(), Lsn, Entry.InputAttachments); + } + } + else + { + // Build a CbPackage from the action and write as .pkg + CbPackage ActionPackage; + ActionPackage.SetObject(ActionObject); - ZEN_ERROR( - "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work " - "descriptor " - "at: 'file://{}'", - std::string(FunctionName), - FunctionVersion, - BuildSystemVersion, - "<null>"); + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachCid = Field.AsAttachment(); + ++Entry.InputAttachments; - EmitFunctionListOnce(m_FunctionList); - } + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachCid)) + { + IoHash RawHash; + uint64_t RawSize = 0; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkData), RawHash, RawSize); - ++FailedWorkCounter; - } + Entry.InputBytes += ChunkData.GetSize(); + ActionPackage.AddAttachment(CbAttachment(std::move(Compressed), RawHash)); } + }); - // Check for completed work + CompositeBuffer Serialized = FormatPackageMessageBuffer(ActionPackage); + zen::WriteFile(m_Config.OutputPath / fmt::format("{}.action.pkg", Lsn), std::move(Serialized)); - DrainCompletedJobs(); - SendOrchestratorHeartbeat(); - }, - TargetParallelism); + if (!m_Config.Quiet) + { + ZEN_CONSOLE("saved action: {}/{}.action.pkg", m_Config.OutputPath.string(), Lsn); + } + } - // Wait until all pending work is complete + std::lock_guard Lock(m_SummaryLock); + m_SummaryEntries.emplace(Lsn, std::move(Entry)); +} - while (!PendingJobs.IsEmpty()) +void +ExecSessionRunner::WriteSummaryFiles() +{ + if (m_Config.OutputPath.empty() || m_SummaryEntries.empty()) { - // TODO: improve this logic - zen::Sleep(500); - - DrainCompletedJobs(); - SendOrchestratorHeartbeat(); + return; } // Merge timing data from queue history into summary entries - if (!SummaryEntries.empty()) + // RunnerAction::State indices (can't include functionrunner.h from here) + constexpr int kStateNew = 0; + constexpr int kStatePending = 1; + constexpr int kStateRunning = 3; + constexpr int kStateCompleted = 4; // first terminal state + constexpr int kStateCount = 8; + + for (const auto& HistEntry : m_Session.GetQueueHistory(m_QueueId, 0)) { - // RunnerAction::State indices (can't include functionrunner.h from here) - constexpr int kStateNew = 0; - constexpr int kStatePending = 1; - constexpr int kStateRunning = 3; - constexpr int kStateCompleted = 4; // first terminal state - constexpr int kStateCount = 8; - - for (const auto& HistEntry : ComputeSession.GetQueueHistory(QueueId, 0)) + std::lock_guard Lock(m_SummaryLock); + if (auto It = m_SummaryEntries.find(HistEntry.Lsn); It != m_SummaryEntries.end()) { - std::lock_guard Lock(SummaryLock); - if (auto It = SummaryEntries.find(HistEntry.Lsn); It != SummaryEntries.end()) + // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) + uint64_t EndTick = 0; + for (int S = kStateCompleted; S < kStateCount; ++S) { - // Find terminal state timestamp (Completed, Failed, Abandoned, or Cancelled) - uint64_t EndTick = 0; - for (int S = kStateCompleted; S < kStateCount; ++S) + if (HistEntry.Timestamps[S] != 0) { - if (HistEntry.Timestamps[S] != 0) - { - EndTick = HistEntry.Timestamps[S]; - break; - } + EndTick = HistEntry.Timestamps[S]; + break; } - uint64_t StartTick = HistEntry.Timestamps[kStateNew]; - if (EndTick > StartTick) - { - It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); - } - It->second.CpuSeconds = HistEntry.CpuSeconds; - It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; - It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; - It->second.ExecutionLocation = HistEntry.ExecutionLocation; } + uint64_t StartTick = HistEntry.Timestamps[kStateNew]; + if (EndTick > StartTick) + { + It->second.WallSeconds = float(double(EndTick - StartTick) / double(TimeSpan::TicksPerSecond)); + } + It->second.CpuSeconds = HistEntry.CpuSeconds; + It->second.SubmittedTicks = HistEntry.Timestamps[kStatePending]; + It->second.StartedTicks = HistEntry.Timestamps[kStateRunning]; + It->second.ExecutionLocation = HistEntry.ExecutionLocation; } } - // Write summary file if output path is set + // Sort entries by recording index - if (!m_OutputPath.empty() && !SummaryEntries.empty()) + std::vector<ActionSummaryEntry> Sorted; + Sorted.reserve(m_SummaryEntries.size()); + for (auto& [_, Entry] : m_SummaryEntries) { - std::vector<ActionSummaryEntry> Sorted; - Sorted.reserve(SummaryEntries.size()); - for (auto& [_, Entry] : SummaryEntries) - { - Sorted.push_back(std::move(Entry)); - } - - std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { - return A.RecordingIndex < B.RecordingIndex; - }); + Sorted.push_back(std::move(Entry)); + } - auto FormatTimestamp = [](uint64_t Ticks) -> std::string { - if (Ticks == 0) - { - return "-"; - } - return DateTime(Ticks).ToString("%H:%M:%S.%s"); - }; - - ExtendableStringBuilder<4096> Summary; - Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", - "LSN", - "Index", - "ActionId", - "Function", - "InAtt", - "InBytes", - "OutAtt", - "OutBytes", - "Wall(s)", - "CPU(s)", - "Submitted", - "Started", - "Location")); - Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "")); + std::sort(Sorted.begin(), Sorted.end(), [](const ActionSummaryEntry& A, const ActionSummaryEntry& B) { + return A.RecordingIndex < B.RecordingIndex; + }); - for (const ActionSummaryEntry& Entry : Sorted) + auto FormatTimestamp = [](uint64_t Ticks) -> std::string { + if (Ticks == 0) { - Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", - Entry.Lsn, - Entry.RecordingIndex, - Entry.ActionId, - Entry.FunctionName, - Entry.InputAttachments, - NiceBytes(Entry.InputBytes), - Entry.OutputAttachments, - NiceBytes(Entry.OutputBytes), - Entry.WallSeconds, - Entry.CpuSeconds, - FormatTimestamp(Entry.SubmittedTicks), - FormatTimestamp(Entry.StartedTicks), - Entry.ExecutionLocation)); + return "-"; } + return DateTime(Ticks).ToString("%H:%M:%S.%s"); + }; - std::filesystem::path SummaryPath = m_OutputPath / "summary.txt"; - std::string_view SummaryStr = Summary; - zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); + // Write summary.txt + + ExtendableStringBuilder<4096> Summary; + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8} {:>8} {:>12} {:>12} {:<24}\n", + "LSN", + "Index", + "ActionId", + "Function", + "InAtt", + "InBytes", + "OutAtt", + "OutBytes", + "Wall(s)", + "CPU(s)", + "Submitted", + "Started", + "Location")); + Summary.Append(fmt::format("{:-<8} {:-<8} {:-<40} {:-<40} {:-<8} {:-<12} {:-<8} {:-<12} {:-<8} {:-<8} {:-<12} {:-<12} {:-<24}\n", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "")); + + for (const ActionSummaryEntry& Entry : Sorted) + { + Summary.Append(fmt::format("{:<8} {:<8} {:<40} {:<40} {:>8} {:>12} {:>8} {:>12} {:>8.2f} {:>8.2f} {:>12} {:>12} {:<24}\n", + Entry.Lsn, + Entry.RecordingIndex, + Entry.ActionId, + Entry.FunctionName, + Entry.InputAttachments, + NiceBytes(Entry.InputBytes), + Entry.OutputAttachments, + NiceBytes(Entry.OutputBytes), + Entry.WallSeconds, + Entry.CpuSeconds, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + Entry.ExecutionLocation)); + } - ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); + std::filesystem::path SummaryPath = m_Config.OutputPath / "summary.txt"; + std::string_view SummaryStr = Summary; + zen::WriteFile(SummaryPath, IoBuffer(IoBuffer::Clone, SummaryStr.data(), SummaryStr.size())); - if (!m_Binary) - { - auto EscapeHtml = [](std::string_view Input) -> std::string { - std::string Out; - Out.reserve(Input.size()); - for (char C : Input) - { - switch (C) - { - case '&': - Out += "&"; - break; - case '<': - Out += "<"; - break; - case '>': - Out += ">"; - break; - case '"': - Out += """; - break; - case '\'': - Out += "'"; - break; - default: - Out += C; - } - } - return Out; - }; + ZEN_CONSOLE("wrote summary to {}", SummaryPath.string()); - auto EscapeJson = [](std::string_view Input) -> std::string { - std::string Out; - Out.reserve(Input.size()); - for (char C : Input) - { - switch (C) - { - case '"': - Out += "\\\""; - break; - case '\\': - Out += "\\\\"; - break; - case '\n': - Out += "\\n"; - break; - case '\r': - Out += "\\r"; - break; - case '\t': - Out += "\\t"; - break; - default: - if (static_cast<unsigned char>(C) < 0x20) - { - Out += fmt::format("\\u{:04x}", static_cast<unsigned>(static_cast<unsigned char>(C))); - } - else - { - Out += C; - } - } - } - return Out; - }; + // Write summary.html - ExtendableStringBuilder<8192> Html; + if (!m_Config.Binary) + { + ExtendableStringBuilder<8192> Html; - Html.Append(std::string_view(R"(<!DOCTYPE html> + Html.Append(std::string_view(R"(<!DOCTYPE html> <html><head><meta charset="utf-8"><title>Exec Summary</title> <style> body{font-family:system-ui,sans-serif;margin:20px;background:#fafafa} @@ -1007,51 +682,50 @@ a:hover{text-decoration:underline} const DATA=[ )")); - std::string_view ResultExt = ".result.yaml"; - std::string_view ActionExt = ".action.yaml"; + std::string_view ResultExt = ".result.yaml"; + std::string_view ActionExt = ".action.yaml"; - for (const ActionSummaryEntry& Entry : Sorted) + for (const ActionSummaryEntry& Entry : Sorted) + { + std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName)); + std::string ActionIdStr = Entry.ActionId.ToHexString(); + std::string ActionLink; + if (!ActionExt.empty()) { - std::string SafeName = EscapeJson(EscapeHtml(Entry.FunctionName)); - std::string ActionIdStr = Entry.ActionId.ToHexString(); - std::string ActionLink; - if (!ActionExt.empty()) - { - ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt)); - } - - // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes, - // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink, - // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay, - // 17=location - Html.Append( - fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n", - Entry.Lsn, - Entry.RecordingIndex, - ActionIdStr, - SafeName, - Entry.InputAttachments, - Entry.InputBytes, - Entry.OutputAttachments, - Entry.OutputBytes, - Entry.WallSeconds, - Entry.CpuSeconds, - EscapeJson(NiceBytes(Entry.InputBytes)), - EscapeJson(NiceBytes(Entry.OutputBytes)), - ActionLink, - Entry.SubmittedTicks, - Entry.StartedTicks, - FormatTimestamp(Entry.SubmittedTicks), - FormatTimestamp(Entry.StartedTicks), - EscapeJson(EscapeHtml(Entry.ExecutionLocation)))); + ActionLink = EscapeJson(fmt::format(" <a href=\"{}{}\">[action]</a>", Entry.Lsn, ActionExt)); } - Html.Append(fmt::format(R"(]; + // Indices: 0=lsn, 1=idx, 2=actionId, 3=fn, 4=inAtt, 5=inBytes, 6=outAtt, 7=outBytes, + // 8=wall, 9=cpu, 10=niceBytesIn, 11=niceBytesOut, 12=actionLink, + // 13=submittedTicks, 14=startedTicks, 15=submittedDisplay, 16=startedDisplay, + // 17=location + Html.Append(fmt::format("[{},{},\"{}\",\"{}\",{},{},{},{},{:.6f},{:.6f},\"{}\",\"{}\",\"{}\",{},{},\"{}\",\"{}\",\"{}\"],\n", + Entry.Lsn, + Entry.RecordingIndex, + ActionIdStr, + SafeName, + Entry.InputAttachments, + Entry.InputBytes, + Entry.OutputAttachments, + Entry.OutputBytes, + Entry.WallSeconds, + Entry.CpuSeconds, + EscapeJson(NiceBytes(Entry.InputBytes)), + EscapeJson(NiceBytes(Entry.OutputBytes)), + ActionLink, + Entry.SubmittedTicks, + Entry.StartedTicks, + FormatTimestamp(Entry.SubmittedTicks), + FormatTimestamp(Entry.StartedTicks), + EscapeJson(EscapeHtml(Entry.ExecutionLocation)))); + } + + Html.Append(fmt::format(R"(]; const RESULT_EXT="{}"; )", - ResultExt)); + ResultExt)); - Html.Append(std::string_view(R"JS((function(){ + Html.Append(std::string_view(R"JS((function(){ const ROW_H=33,BUF=20; const container=document.getElementById("container"); const tbody=container.querySelector("tbody"); @@ -1158,14 +832,244 @@ document.getElementById("csvBtn").addEventListener("click",()=>{ </script></body></html> )JS")); - std::filesystem::path HtmlPath = m_OutputPath / "summary.html"; - std::string_view HtmlStr = Html; - zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + std::filesystem::path HtmlPath = m_Config.OutputPath / "summary.html"; + std::string_view HtmlStr = Html; + zen::WriteFile(HtmlPath, IoBuffer(IoBuffer::Clone, HtmlStr.data(), HtmlStr.size())); + + ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + } +} + +void +ExecSessionRunner::EmitFunctionListOnce() +{ + if (!m_FunctionListEmittedOnce) + { + ExecCommand::EmitFunctionList(m_Config.FunctionList); + m_FunctionListEmittedOnce = true; + } +} + +int +ExecSessionRunner::Run() +{ + m_Session.WaitUntilReady(); + + // Register as a client with the orchestrator (best-effort) + + m_OrchestratorClientId = RegisterOrchestratorClient(); + + auto ClientCleanup = MakeGuard([&] { CompleteOrchestratorClient(); }); + + // Create a queue to group all actions from this exec session + + CbObjectWriter Metadata; + Metadata << "source"sv + << "zen-exec"sv; + + auto QueueResult = m_Session.CreateQueue("zen-exec", Metadata.Save()); + const int QueueId = QueueResult.QueueId; + if (!QueueId) + { + ZEN_ERROR("failed to create compute queue"); + return 1; + } + + m_QueueId = QueueId; + + auto QueueCleanup = MakeGuard([&] { m_Session.DeleteQueue(QueueId); }); + + if (!m_Config.OutputPath.empty()) + { + zen::CreateDirectories(m_Config.OutputPath); + } + + // Describe workers + + ZEN_CONSOLE("describing {} workers", m_Config.WorkerMap.size()); + + for (auto Kv : m_Config.WorkerMap) + { + CbPackage WorkerDesc = Kv.second; + + m_Session.RegisterWorker(WorkerDesc); + } + + // Then submit work items + + int FailedWorkCounter = 0; + size_t RemainingWorkItems = m_Config.RecordingReader.GetActionCount(); + int SubmittedWorkItems = 0; + + ZEN_CONSOLE("submitting {} work items", RemainingWorkItems); + + int OffsetCounter = m_Config.Offset; + int StrideCounter = m_Config.Stride; - ZEN_CONSOLE("wrote HTML summary to {}", HtmlPath.string()); + auto ShouldSchedule = [&]() -> bool { + if (m_Config.Limit && SubmittedWorkItems >= m_Config.Limit) + { + // Limit reached, ignore + + return false; + } + + if (OffsetCounter && OffsetCounter--) + { + // Still in offset, ignore + + return false; } + + if (--StrideCounter == 0) + { + StrideCounter = m_Config.Stride; + + return true; + } + + return false; + }; + + int TargetParallelism = 8; + + if (OffsetCounter || StrideCounter || m_Config.Limit) + { + TargetParallelism = 1; + } + + std::atomic<int> RecordingIndex{0}; + + m_Config.RecordingReader.IterateActions( + [&](CbObject ActionObject, const IoHash& ActionId) { + // Enqueue job + + const int CurrentRecordingIndex = RecordingIndex++; + + Stopwatch SubmitTimer; + + const int Priority = 0; + + if (ShouldSchedule()) + { + if (m_Config.Verbose) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + eastl::hash_set<IoHash> ReferencedChunks; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ReferencedChunks.insert(AttachData); + ++AttachmentCount; + + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(ActionObject, ObjStr); + ZEN_CONSOLE("work item {} ({} attachments, {} bytes): {}", + ActionId, + AttachmentCount, + NiceBytes(AttachmentBytes), + ObjStr); + } + + if (m_Config.DumpActions) + { + int AttachmentCount = 0; + uint64_t AttachmentBytes = 0; + + ActionObject.IterateAttachments([&](CbFieldView Field) { + IoHash AttachData = Field.AsAttachment(); + + ++AttachmentCount; + + if (IoBuffer ChunkData = m_Config.Resolver.FindChunkByCid(AttachData)) + { + AttachmentBytes += ChunkData.GetSize(); + } + }); + + zen::ExtendableStringBuilder<1024> ObjStr; + zen::CompactBinaryToYaml(ActionObject, ObjStr); + ZEN_CONSOLE("action {} ({} attachments, {}):\n{}", ActionId, AttachmentCount, NiceBytes(AttachmentBytes), ObjStr); + } + + if (zen::compute::ComputeServiceSession::EnqueueResult EnqueueResult = + m_Session.EnqueueActionToQueue(QueueId, ActionObject, Priority)) + { + const int32_t LsnField = EnqueueResult.Lsn; + + --RemainingWorkItems; + ++SubmittedWorkItems; + + if (!m_Config.Quiet) + { + ZEN_CONSOLE("submitted work item #{} - LSN {} - {}. {} remaining", + SubmittedWorkItems, + LsnField, + NiceTimeSpanMs(SubmitTimer.GetElapsedTimeMs()), + RemainingWorkItems); + } + + if (!m_Config.OutputPath.empty()) + { + SaveActionOutput(LsnField, CurrentRecordingIndex, ActionId, ActionObject); + } + + m_PendingJobs.Insert(LsnField); + } + else + { + if (!m_Config.Quiet) + { + std::string_view FunctionName = ActionObject["Function"sv].AsString(); + const Guid FunctionVersion = ActionObject["FunctionVersion"sv].AsUuid(); + const Guid BuildSystemVersion = ActionObject["BuildSystemVersion"sv].AsUuid(); + + ZEN_ERROR( + "failed to resolve function for work with (Function:{},FunctionVersion:{},BuildSystemVersion:{}). Work " + "descriptor " + "at: 'file://{}'", + std::string(FunctionName), + FunctionVersion, + BuildSystemVersion, + "<null>"); + + EmitFunctionListOnce(); + } + + ++FailedWorkCounter; + } + } + + // Check for completed work + + DrainCompletedJobs(); + SendOrchestratorHeartbeat(); + }, + TargetParallelism); + + // Wait until all pending work is complete + + while (!m_PendingJobs.IsEmpty()) + { + // TODO: improve this logic + zen::Sleep(500); + + DrainCompletedJobs(); + SendOrchestratorHeartbeat(); } + // Write summary files + + WriteSummaryFiles(); + if (FailedWorkCounter) { return 1; @@ -1174,37 +1078,91 @@ document.getElementById("csvBtn").addEventListener("click",()=>{ return 0; } -int -ExecCommand::LocalMessagingExecute() +////////////////////////////////////////////////////////////////////////// +// ExecHttpSubCmd + +ExecHttpSubCmd::ExecHttpSubCmd(ExecCommand& Parent) : ZenSubCmdBase("http", "Forward requests to HTTP compute service"), m_Parent(Parent) { - // Non-HTTP work submission path + SubOptions().add_option("", "u", "hosturl", ZenCmdBase::kHostUrlHelp, cxxopts::value(m_HostName), "<hosturl>"); +} + +void +ExecHttpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + m_HostName = ZenCmdBase::ResolveTargetHostSpec(m_HostName); - // To be reimplemented using final transport + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; - return 0; + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); + + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession); + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } } ////////////////////////////////////////////////////////////////////////// +// ExecInprocSubCmd -int -ExecCommand::HttpExecute() +ExecInprocSubCmd::ExecInprocSubCmd(ExecCommand& Parent) : ZenSubCmdBase("inproc", "Handle execution in-process"), m_Parent(Parent) { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; +} - std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); +void +ExecInprocSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; zen::compute::ComputeServiceSession ComputeSession(Resolver); - ComputeSession.AddRemoteRunner(Resolver, TempPath, m_HostName); - return ExecUsingSession(ComputeSession); + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + ComputeSession.AddLocalRunner(Resolver, TempPath); + + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession); + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } } -int -ExecCommand::BeaconExecute() +////////////////////////////////////////////////////////////////////////// +// ExecBeaconSubCmd + +ExecBeaconSubCmd::ExecBeaconSubCmd(ExecCommand& Parent) +: ZenSubCmdBase("beacon", "Execute via beacon/orchestrator discovery") +, m_Parent(Parent) { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; + SubOptions().add_option("", "", "orch", "Orchestrator URL for worker discovery", cxxopts::value(m_OrchestratorUrl), "<url>"); + SubOptions().add_option("", "", "beacon", "Beacon path", cxxopts::value(m_BeaconPath), "<path>"); +} + +void +ExecBeaconSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); zen::compute::ComputeServiceSession ComputeSession(Resolver); @@ -1221,49 +1179,36 @@ ExecCommand::BeaconExecute() ComputeSession.AddRemoteRunner(Resolver, TempPath, "http://localhost:8558"); } - return ExecUsingSession(ComputeSession); -} - -////////////////////////////////////////////////////////////////////////// + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession, m_OrchestratorUrl); -void -ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId) -{ - const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid(); + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); - for (auto& Item : WorkerDesc["functions"sv]) + if (!ReturnValue) { - CbObjectView Function = Item.AsObjectView(); - - std::string_view FunctionName = Function["name"sv].AsString(); - const Guid FunctionVersion = Function["version"sv].AsUuid(); - - m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, - .FunctionVersion = FunctionVersion, - .BuildSystemVersion = WorkerBuildSystemVersion, - .WorkerId = WorkerId}); + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); } } -void -ExecCommand::EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList) -{ - if (m_FunctionListEmittedOnce == false) - { - EmitFunctionList(FunctionList); +////////////////////////////////////////////////////////////////////////// +// ExecDumpSubCmd - m_FunctionListEmittedOnce = true; - } +ExecDumpSubCmd::ExecDumpSubCmd(ExecCommand& Parent) : ZenSubCmdBase("dump", "Dump high level information about actions"), m_Parent(Parent) +{ } -int -ExecCommand::DumpWorkItems() +void +ExecDumpSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) { std::atomic<int> EmittedCount{0}; eastl::hash_map<IoHash, uint64_t> SeenAttachments; // Attachment CID -> count of references - m_RecordingReader->IterateActions( + m_Parent.m_RecordingReader->IterateActions( [&](CbObject ActionObject, const IoHash& ActionId) { eastl::hash_map<IoHash, CompressedBuffer> Attachments; @@ -1272,7 +1217,7 @@ ExecCommand::DumpWorkItems() ActionObject.IterateAttachments([&](const zen::CbFieldView AttachmentField) { const IoHash AttachmentCid = AttachmentField.GetValue().AsHash(); - IoBuffer AttachmentData = m_ChunkResolver->FindChunkByCid(AttachmentCid); + IoBuffer AttachmentData = m_Parent.m_ChunkResolver->FindChunkByCid(AttachmentCid); IoHash RawHash; uint64_t RawSize = 0; CompressedBuffer CompressedData = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentData), RawHash, RawSize); @@ -1322,36 +1267,191 @@ ExecCommand::DumpWorkItems() { ZEN_CONSOLE("{} attachments with {} references", Cids.size(), RefCount); } +} - return 0; +////////////////////////////////////////////////////////////////////////// +// ExecBuildlogSubCmd + +ExecBuildlogSubCmd::ExecBuildlogSubCmd(ExecCommand& Parent) : ZenSubCmdBase("buildlog", "Generate build actions log"), m_Parent(Parent) +{ +} + +void +ExecBuildlogSubCmd::Run(const ZenCliOptions& /*GlobalOptions*/) +{ + ZEN_ASSERT(m_Parent.m_ChunkResolver); + ChunkResolver& Resolver = *m_Parent.m_ChunkResolver; + + if (std::filesystem::exists(m_Parent.m_RecordingLogPath)) + { + m_Parent.ThrowOptionError(fmt::format("recording log directory '{}' already exists!", m_Parent.m_RecordingLogPath)); + } + + ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!"); + + std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + + zen::compute::ComputeServiceSession ComputeSession(Resolver); + ComputeSession.StartRecording(Resolver, m_Parent.m_RecordingLogPath); + + Stopwatch ExecTimer; + int ReturnValue = m_Parent.RunSession(ComputeSession); + + ZEN_CONSOLE("complete - took {}", NiceTimeSpanMs(ExecTimer.GetElapsedTimeMs())); + + if (!ReturnValue) + { + ZEN_CONSOLE("all work items completed successfully"); + } + else + { + ZEN_CONSOLE("some work items failed (code {})", ReturnValue); + } } ////////////////////////////////////////////////////////////////////////// +// ExecCommand -int -ExecCommand::BuildActionsLog() +ExecCommand::ExecCommand() { - ZEN_ASSERT(m_ChunkResolver); - ChunkResolver& Resolver = *m_ChunkResolver; + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("replay", "p", "path", "Recording path (directory or .actionlog file)", cxxopts::value(m_RecordingPath), "<path>"); + m_Options.add_option("replay", "", "log", "Action log directory", cxxopts::value(m_RecordingLogPath), "<path>"); + m_Options.add_option("replay", "", "offset", "Recording replay start offset", cxxopts::value(m_Offset), "<offset>"); + m_Options.add_option("replay", "", "stride", "Recording replay stride", cxxopts::value(m_Stride), "<stride>"); + m_Options.add_option("replay", "", "limit", "Recording replay limit", cxxopts::value(m_Limit), "<limit>"); + m_Options.add_option("", "", "quiet", "Quiet mode (less logging)", cxxopts::value(m_Quiet), "<bool>"); + m_Options.add_option("output", + "", + "dump-actions", + "Dump each action to console as it is dispatched", + cxxopts::value(m_DumpActions), + "<bool>"); + m_Options.add_option("output", "o", "output", "Save action results to directory", cxxopts::value(m_OutputPath), "<path>"); + m_Options.add_option("output", "", "binary", "Write output as binary packages instead of YAML", cxxopts::value(m_Binary), "<bool>"); + m_Options.add_option("__hidden__", "", "subcommand", "", cxxopts::value<std::string>(m_SubCommand)->default_value(""), ""); + m_Options.parse_positional({"subcommand"}); + + AddSubCommand(m_HttpSubCmd); + AddSubCommand(m_InprocSubCmd); + AddSubCommand(m_BeaconSubCmd); + AddSubCommand(m_DumpSubCmd); + AddSubCommand(m_BuildlogSubCmd); +} +ExecCommand::~ExecCommand() +{ +} + +bool +ExecCommand::OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) +{ if (m_RecordingPath.empty()) { - throw OptionParseException("need to specify recording path", m_Options.help()); + ThrowOptionError("replay path is required!"); } - if (std::filesystem::exists(m_RecordingLogPath)) + m_VerboseLogging = GlobalOptions.IsVerbose; + m_QuietLogging = m_Quiet && !m_VerboseLogging; + + // Gather information from recording path + + std::filesystem::path RecordingPath{m_RecordingPath}; + + if (!std::filesystem::is_directory(RecordingPath)) { - throw OptionParseException(fmt::format("recording log directory '{}' already exists!", m_RecordingLogPath), m_Options.help()); + ThrowOptionError("replay path should be a directory path!"); + } + else + { + if (std::filesystem::is_directory(RecordingPath / "cid")) + { + m_Reader = std::make_unique<zen::compute::RecordingReader>(RecordingPath); + m_WorkerMap = m_Reader->ReadWorkers(); + m_ChunkResolver = m_Reader.get(); + m_RecordingReader = m_Reader.get(); + } + else + { + m_UeReader = std::make_unique<zen::compute::UeRecordingReader>(RecordingPath); + m_WorkerMap = m_UeReader->ReadWorkers(); + m_ChunkResolver = m_UeReader.get(); + m_RecordingReader = m_UeReader.get(); + } } - ZEN_NOT_IMPLEMENTED("build log generation not implemented yet!"); + ZEN_CONSOLE("found {} workers, {} action items", m_WorkerMap.size(), m_RecordingReader->GetActionCount()); - std::filesystem::path TempPath = std::filesystem::absolute(".zen_temp"); + for (auto& Kv : m_WorkerMap) + { + CbObject WorkerDesc = Kv.second.GetObject(); + const IoHash& WorkerId = Kv.first; - zen::compute::ComputeServiceSession ComputeSession(Resolver); - ComputeSession.StartRecording(Resolver, m_RecordingLogPath); + RegisterWorkerFunctionsFromDescription(WorkerDesc, WorkerId); + + if (m_VerboseLogging) + { + zen::ExtendableStringBuilder<1024> ObjStr; +# if 0 + zen::CompactBinaryToJson(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}: {}", WorkerId, ObjStr); +# else + zen::CompactBinaryToYaml(WorkerDesc, ObjStr); + ZEN_CONSOLE("worker {}:\n{}", WorkerId, ObjStr); +# endif + } + } + + if (m_VerboseLogging) + { + EmitFunctionList(m_FunctionList); + } + + return true; +} - return ExecUsingSession(ComputeSession); +int +ExecCommand::RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl) +{ + ExecSessionConfig Config{ + .Resolver = *m_ChunkResolver, + .RecordingReader = *m_RecordingReader, + .WorkerMap = m_WorkerMap, + .FunctionList = m_FunctionList, + .OrchestratorUrl = OrchestratorUrl, + .OutputPath = m_OutputPath, + .Offset = m_Offset, + .Stride = m_Stride, + .Limit = m_Limit, + .Verbose = m_VerboseLogging, + .Quiet = m_QuietLogging, + .DumpActions = m_DumpActions, + .Binary = m_Binary, + }; + + ExecSessionRunner Runner(ComputeSession, Config); + return Runner.Run(); +} + +////////////////////////////////////////////////////////////////////////// + +void +ExecCommand::RegisterWorkerFunctionsFromDescription(const CbObject& WorkerDesc, const IoHash& WorkerId) +{ + const Guid WorkerBuildSystemVersion = WorkerDesc["buildsystem_version"sv].AsUuid(); + + for (auto& Item : WorkerDesc["functions"sv]) + { + CbObjectView Function = Item.AsObjectView(); + + std::string_view FunctionName = Function["name"sv].AsString(); + const Guid FunctionVersion = Function["version"sv].AsUuid(); + + m_FunctionList.emplace_back(FunctionDefinition{.FunctionName = std::string{FunctionName}, + .FunctionVersion = FunctionVersion, + .BuildSystemVersion = WorkerBuildSystemVersion, + .WorkerId = WorkerId}); + } } void diff --git a/src/zen/cmds/exec_cmd.h b/src/zen/cmds/exec_cmd.h index 6311354c0..c55412780 100644 --- a/src/zen/cmds/exec_cmd.h +++ b/src/zen/cmds/exec_cmd.h @@ -11,6 +11,7 @@ #include <filesystem> #include <functional> +#include <memory> #include <unordered_map> namespace zen { @@ -28,13 +29,79 @@ class ComputeServiceSession; namespace zen { +class ExecCommand; + +struct ExecFunctionDefinition +{ + std::string FunctionName; + zen::Guid FunctionVersion; + zen::Guid BuildSystemVersion; + zen::IoHash WorkerId; +}; + +////////////////////////////////////////////////////////////////////////// +// Subcommands + +class ExecHttpSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecHttpSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; + std::string m_HostName; +}; + +class ExecInprocSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecInprocSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; +}; + +class ExecBeaconSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecBeaconSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; + std::string m_OrchestratorUrl; + std::filesystem::path m_BeaconPath; +}; + +class ExecDumpSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecDumpSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; +}; + +class ExecBuildlogSubCmd : public ZenSubCmdBase +{ +public: + explicit ExecBuildlogSubCmd(ExecCommand& Parent); + void Run(const ZenCliOptions& GlobalOptions) override; + +private: + ExecCommand& m_Parent; +}; + /** * Zen CLI command for executing functions from a recording * * Mostly for testing and debugging purposes */ -class ExecCommand : public ZenCmdBase +class ExecCommand : public ZenCmdWithSubCommands { public: ExecCommand(); @@ -43,57 +110,47 @@ public: static constexpr char Name[] = "exec"; static constexpr char Description[] = "Execute functions from a recording"; - virtual void Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; - virtual cxxopts::Options& Options() override { return m_Options; } + cxxopts::Options& Options() override { return m_Options; } + + // Shared state & helpers (public for subcommand access) + + using FunctionDefinition = ExecFunctionDefinition; + + static void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList); + void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); + + int RunSession(zen::compute::ComputeServiceSession& ComputeSession, std::string_view OrchestratorUrl = {}); + + std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap; + std::vector<FunctionDefinition> m_FunctionList; + zen::ChunkResolver* m_ChunkResolver = nullptr; + zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; + bool m_VerboseLogging = false; + bool m_QuietLogging = false; + bool m_DumpActions = false; + std::filesystem::path m_OutputPath; + bool m_Binary = false; + std::filesystem::path m_RecordingLogPath; private: + bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions) override; + cxxopts::Options m_Options{Name, Description}; - std::string m_HostName; - std::string m_OrchestratorUrl; - std::filesystem::path m_BeaconPath; + std::string m_SubCommand; std::filesystem::path m_RecordingPath; - std::filesystem::path m_RecordingLogPath; int m_Offset = 0; int m_Stride = 1; int m_Limit = 0; bool m_Quiet = false; - std::string m_Mode{"http"}; - std::filesystem::path m_OutputPath; - bool m_Binary = false; - - struct FunctionDefinition - { - std::string FunctionName; - zen::Guid FunctionVersion; - zen::Guid BuildSystemVersion; - zen::IoHash WorkerId; - }; - - bool m_FunctionListEmittedOnce = false; - void EmitFunctionListOnce(const std::vector<FunctionDefinition>& FunctionList); - void EmitFunctionList(const std::vector<FunctionDefinition>& FunctionList); - - std::unordered_map<zen::IoHash, zen::CbPackage> m_WorkerMap; - std::vector<FunctionDefinition> m_FunctionList; - bool m_VerboseLogging = false; - bool m_QuietLogging = false; - bool m_DumpActions = false; - - zen::ChunkResolver* m_ChunkResolver = nullptr; - zen::compute::RecordingReaderBase* m_RecordingReader = nullptr; - - void RegisterWorkerFunctionsFromDescription(const zen::CbObject& WorkerDesc, const zen::IoHash& WorkerId); - - int ExecUsingSession(zen::compute::ComputeServiceSession& ComputeSession); - // Execution modes + std::unique_ptr<zen::compute::RecordingReader> m_Reader; + std::unique_ptr<zen::compute::UeRecordingReader> m_UeReader; - int DumpWorkItems(); - int HttpExecute(); - int InProcessExecute(); - int LocalMessagingExecute(); - int BeaconExecute(); - int BuildActionsLog(); + ExecHttpSubCmd m_HttpSubCmd{*this}; + ExecInprocSubCmd m_InprocSubCmd{*this}; + ExecBeaconSubCmd m_BeaconSubCmd{*this}; + ExecDumpSubCmd m_DumpSubCmd{*this}; + ExecBuildlogSubCmd m_BuildlogSubCmd{*this}; }; } // namespace zen diff --git a/src/zen/cmds/service_cmd.cpp b/src/zen/cmds/service_cmd.cpp index 3347f1afe..37baf5483 100644 --- a/src/zen/cmds/service_cmd.cpp +++ b/src/zen/cmds/service_cmd.cpp @@ -310,17 +310,34 @@ ServiceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::vector<char*> SubCommandArguments; cxxopts::Options* SubOption = nullptr; int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments); - if (!ParseOptions(ParentCommandArgCount, argv)) + + if (SubOption == nullptr) + { + if (!ParseOptions(ParentCommandArgCount, argv)) + { + return; + } + throw OptionParseException("'verb' option is required", m_Options.help()); + } + + // Parse subcommand permissively — forward unrecognised options to the parent parser. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { return; } - if (SubOption == nullptr) + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : SubUnmatched) { - throw OptionParseException("'verb' option is required", m_Options.help()); + ParentArgs.push_back(Arg.data()); } - if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) + if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data())) { return; } diff --git a/src/zen/cmds/workspaces_cmd.cpp b/src/zen/cmds/workspaces_cmd.cpp index 220ef6a9e..9e49b464e 100644 --- a/src/zen/cmds/workspaces_cmd.cpp +++ b/src/zen/cmds/workspaces_cmd.cpp @@ -127,14 +127,36 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::vector<char*> SubCommandArguments; cxxopts::Options* SubOption = nullptr; int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments); - if (!ParseOptions(ParentCommandArgCount, argv)) + + if (SubOption == nullptr) + { + if (!ParseOptions(ParentCommandArgCount, argv)) + { + return; + } + throw OptionParseException("'verb' option is required", m_Options.help()); + } + + // Parse subcommand permissively — forward unrecognised options to the parent parser. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { return; } - if (SubOption == nullptr) + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : SubUnmatched) { - throw OptionParseException("'verb' option is required", m_Options.help()); + ParentArgs.push_back(Arg.data()); + } + + if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data())) + { + return; } m_HostName = ResolveTargetHostSpec(m_HostName); @@ -150,11 +172,6 @@ WorkspaceCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) std::filesystem::path StatePath = m_SystemRootDir / "workspaces"; - if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) - { - return; - } - if (SubOption == &m_CreateOptions) { if (m_Path.empty()) @@ -376,14 +393,36 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** std::vector<char*> SubCommandArguments; cxxopts::Options* SubOption = nullptr; int ParentCommandArgCount = GetSubCommand(m_Options, argc, argv, m_SubCommands, SubOption, SubCommandArguments); - if (!ParseOptions(ParentCommandArgCount, argv)) + + if (SubOption == nullptr) + { + if (!ParseOptions(ParentCommandArgCount, argv)) + { + return; + } + throw OptionParseException("'verb' option is required", m_Options.help()); + } + + // Parse subcommand permissively — forward unrecognised options to the parent parser. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) { return; } - if (SubOption == nullptr) + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentCommandArgCount - 1) + SubUnmatched.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentCommandArgCount - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : SubUnmatched) { - throw OptionParseException("'verb' option is required", m_Options.help()); + ParentArgs.push_back(Arg.data()); + } + + if (!ParseOptions(static_cast<int>(ParentArgs.size()), ParentArgs.data())) + { + return; } m_HostName = ResolveTargetHostSpec(m_HostName); @@ -403,11 +442,6 @@ WorkspaceShareCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** std::filesystem::path StatePath = m_SystemRootDir / "workspaces"; - if (!ParseOptions(*SubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) - { - return; - } - if (SubOption == &m_CreateOptions) { if (m_WorkspaceRoot.empty()) diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp index 86154c291..849d7a075 100644 --- a/src/zen/zen.cpp +++ b/src/zen/zen.cpp @@ -115,6 +115,21 @@ ZenCmdBase::CommandCategory() const return DefaultCategory; } +std::string +ZenCmdBase::HelpText() +{ + std::vector<std::string> Groups = Options().groups(); + Groups.erase(std::remove(Groups.begin(), Groups.end(), std::string("__hidden__")), Groups.end()); + Options().set_width(TuiConsoleColumns(80)); + return Options().help(Groups); +} + +void +ZenCmdBase::ThrowOptionError(std::string_view Message) +{ + throw OptionParseException(std::string(Message), HelpText()); +} + bool ZenCmdBase::ParseOptions(int argc, char** argv) { @@ -166,6 +181,35 @@ ZenCmdBase::ParseOptions(cxxopts::Options& CmdOptions, int argc, char** argv) return true; } +bool +ZenCmdBase::ParseOptionsPermissive(cxxopts::Options& CmdOptions, int argc, char** argv, std::vector<std::string>& OutUnmatched) +{ + CmdOptions.set_width(TuiConsoleColumns(80)); + CmdOptions.allow_unrecognised_options(); + + cxxopts::ParseResult Result; + + try + { + Result = CmdOptions.parse(argc, argv); + } + catch (const std::exception& Ex) + { + throw zen::OptionParseException(Ex.what(), CmdOptions.help()); + } + + CmdOptions.show_positional_help(); + + if (Result.count("help")) + { + printf("%s\n", CmdOptions.help().c_str()); + return false; + } + + OutUnmatched = Result.unmatched(); + return true; +} + // Get the number of args including the sub command // Build an array for sub command to parse int @@ -243,6 +287,23 @@ ZenCmdWithSubCommands::PrintHelp() } void +ZenCmdWithSubCommands::PrintSubCommandHelp(cxxopts::Options& SubCmdOptions) +{ + // Show the subcommand's own options. + SubCmdOptions.set_width(TuiConsoleColumns(80)); + printf("%s\n", SubCmdOptions.help().c_str()); + + // Show the parent command's options (excluding the hidden positional group). + std::vector<std::string> ParentGroups = Options().groups(); + ParentGroups.erase(std::remove(ParentGroups.begin(), ParentGroups.end(), std::string("__hidden__")), ParentGroups.end()); + + Options().set_width(TuiConsoleColumns(80)); + printf("%s options:\n%s\n", Options().program().c_str(), Options().help(ParentGroups).c_str()); + + printf("For global options run: zen --help\n"); +} + +void ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { std::vector<cxxopts::Options*> SubOptionPtrs; @@ -270,42 +331,15 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** } } - // Parse parent options. When a subcommand was matched we strip its name from - // the arg list so the parent parser does not see it as an unmatched positional. - if (MatchedSubOption != nullptr) - { - std::vector<char*> ParentArgs; - ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1)); - ParentArgs.push_back(argv[0]); - std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs)); - if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data())) - { - return; - } - } - else + if (MatchedSubOption == nullptr) { if (!ParseOptions(Options(), ParentArgc, argv)) { return; } - } - - if (MatchedSubOption == nullptr) - { PrintHelp(); - - ExtendableStringBuilder<128> VerbList; - for (bool First = true; ZenSubCmdBase * SubCmd : m_SubCommands) - { - if (!First) - { - VerbList.Append(", "); - } - VerbList.Append(SubCmd->SubOptions().program()); - First = false; - } - throw OptionParseException(fmt::format("No subcommand specified. Available subcommands: {}", VerbList.ToView()), {}); + fflush(stdout); + throw OptionParseException("No subcommand specified", {}); } ZenSubCmdBase* MatchedSubCmd = nullptr; @@ -319,9 +353,40 @@ ZenCmdWithSubCommands::Run(const ZenCliOptions& GlobalOptions, int argc, char** } ZEN_ASSERT(MatchedSubCmd != nullptr); - // Parse subcommand args before OnParentOptionsParsed so --help on the subcommand - // works without requiring parent options like --hosturl to be populated. - if (!ParseOptions(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data())) + // Intercept --help/-h in the subcommand args so we can show combined help + // (subcommand options + parent options) without requiring parent options to + // be populated. + for (size_t i = 1; i < SubCommandArguments.size(); ++i) + { + std::string_view Arg(SubCommandArguments[i]); + if (Arg == "--help" || Arg == "-h") + { + PrintSubCommandHelp(*MatchedSubOption); + return; + } + } + + // Parse subcommand args permissively — unrecognised options are collected + // and forwarded to the parent parser so that parent options (e.g. --path) + // can appear after the subcommand name on the command line. + std::vector<std::string> SubUnmatched; + if (!ParseOptionsPermissive(*MatchedSubOption, gsl::narrow<int>(SubCommandArguments.size()), SubCommandArguments.data(), SubUnmatched)) + { + return; + } + + // Build parent arg list: original parent args (without subcommand name) + forwarded unmatched. + std::vector<std::string> UnmatchedStorage = std::move(SubUnmatched); + std::vector<char*> ParentArgs; + ParentArgs.reserve(static_cast<size_t>(ParentArgc - 1) + UnmatchedStorage.size()); + ParentArgs.push_back(argv[0]); + std::copy(argv + 1, argv + ParentArgc - 1, std::back_inserter(ParentArgs)); + for (std::string& Arg : UnmatchedStorage) + { + ParentArgs.push_back(Arg.data()); + } + + if (!ParseOptions(Options(), static_cast<int>(ParentArgs.size()), ParentArgs.data())) { return; } diff --git a/src/zen/zen.h b/src/zen/zen.h index 05ce32d0a..97cc9af6f 100644 --- a/src/zen/zen.h +++ b/src/zen/zen.h @@ -60,6 +60,7 @@ public: bool ParseOptions(int argc, char** argv); static bool ParseOptions(cxxopts::Options& Options, int argc, char** argv); + static bool ParseOptionsPermissive(cxxopts::Options& Options, int argc, char** argv, std::vector<std::string>& OutUnmatched); static int GetSubCommand(cxxopts::Options& Options, int argc, char** argv, @@ -74,7 +75,9 @@ public: static constexpr const char* kHostUrlHelp = "Host URL or unix:///path/to/socket"; - static void LogExecutableVersionAndPid(); + std::string HelpText(); + [[noreturn]] void ThrowOptionError(std::string_view Message); + static void LogExecutableVersionAndPid(); }; class StorageCommand : public ZenCmdBase @@ -114,6 +117,7 @@ protected: void AddSubCommand(ZenSubCmdBase& SubCmd); virtual bool OnParentOptionsParsed(const ZenCliOptions& GlobalOptions); void PrintHelp(); + void PrintSubCommandHelp(cxxopts::Options& SubCmdOptions); private: std::vector<ZenSubCmdBase*> m_SubCommands; diff --git a/src/zencompute/CLAUDE.md b/src/zencompute/CLAUDE.md index f5188123f..a1a39fc3c 100644 --- a/src/zencompute/CLAUDE.md +++ b/src/zencompute/CLAUDE.md @@ -46,9 +46,12 @@ Public API entry point. Uses PIMPL (`struct Impl` in computeservice.cpp). Owns: - Action maps: `m_PendingActions`, `m_RunningMap`, `m_ResultsMap` - Queue map: `m_Queues` (QueueEntry objects) - Action history ring: `m_ActionHistory` (bounded deque, default 1000) +- WebSocket client (`m_OrchestratorWsClient`) subscribed to the orchestrator's `/orch/ws` push for instant worker discovery **Session states:** Created → Ready → Draining → Paused → Abandoned → Sunset. Both Abandoned and Sunset can be jumped to from any earlier state. Abandoned is used for spot instance termination grace periods — on entry, all pending and running actions are immediately marked as `RunnerAction::State::Abandoned` and running processes are best-effort cancelled. Auto-retry is suppressed while the session is Abandoned. `IsHealthy()` returns false for Abandoned and Sunset. +**Convenience helpers:** `Ready()`, `Abandon()`, `SetOrchestrator(Endpoint, BasePath)` are inline wrappers for common state transitions and orchestrator configuration. + ### `RunnerAction` (runners/functionrunner.h) Shared ref-counted struct representing one action through its lifecycle. @@ -67,8 +70,11 @@ New → Pending → Submitting → Running → Completed → Failed → Abandoned → Cancelled + → Retracted ``` -`SetActionState()` rejects non-forward transitions. The one exception is `ResetActionStateToPending()`, which uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling. It clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. +`SetActionState()` rejects non-forward transitions (Retracted has the highest ordinal so runner-side transitions cannot override it). `ResetActionStateToPending()` uses CAS to atomically transition from Failed/Abandoned back to Pending for rescheduling — it clears timestamps from Submitting onward, resets execution fields, increments `RetryCount`, and calls `PostUpdate()` to re-enter the scheduler pipeline. + +**Retracted state:** An explicit, instigator-initiated request to pull an action back and reschedule it on a different runner (e.g. capacity opened up elsewhere). Unlike Failed/Abandoned auto-retry, rescheduling from Retracted does not increment `RetryCount` since nothing went wrong. Retraction is idempotent and can target Pending, Submitting, or Running actions. ### `LocalProcessRunner` (runners/localrunner.h) Base for all local execution. Platform runners subclass this and override: @@ -90,10 +96,29 @@ Base for all local execution. Platform runners subclass this and override: - macOS: `proc_pidinfo(PROC_PIDTASKINFO)` pti_total_user+system nanoseconds ÷ 1,000,000,000 ### `FunctionRunner` / `RunnerGroup` (runners/functionrunner.h) -Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. +Abstract base for runners. `RunnerGroup<T>` holds a vector of runners and load-balances across them using a round-robin atomic index. `BaseRunnerGroup::SubmitActions()` distributes a batch proportionally based on per-runner capacity. `SubmitActions()` supports batch submission — actions are grouped and forwarded in chunks. + +### `RemoteHttpRunner` (runners/remotehttprunner.h) +Submits actions to remote zenserver instances over HTTP. Key features: +- **WebSocket completion notifications**: connects a WS client to `/compute/ws` on the remote. When a message arrives (action completed), the monitor thread wakes immediately instead of polling. Falls back to adaptive polling (200ms→50ms) when WS is unavailable. +- **Batch submission**: groups actions by queue and submits in configurable chunks (`m_MaxBatchSize`, default 50), falling back to individual submission on failure. +- **Queue cancellation**: `CancelRemoteQueue()` sends cancel requests to the remote. +- **Graceful shutdown**: `Shutdown()` closes the WS client, cancels all remote queues, stops the monitor thread, and marks remaining actions as Failed. ### `HttpComputeService` (include/zencompute/httpcomputeservice.h) -Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. +Wraps `ComputeServiceSession` as an HTTP service. All routes are registered in the constructor. Handles CbPackage attachment ingestion from `CidStore` before forwarding to the service. Supports both single-action and batch (actions array) payloads via a shared `HandleSubmitAction` helper. + +## Orchestrator Discovery + +`ComputeServiceSession` discovers remote workers via the orchestrator endpoint (`SetOrchestratorEndpoint()`). Two complementary mechanisms: + +1. **Polling** (`UpdateCoordinatorState`): `GET /orch/agents` on the scheduler thread, throttled to every 5s (500ms when no workers are known yet). Discovers new workers and removes stale/unreachable ones. + +2. **WebSocket push** (`OrchestratorWsHandler`): connects to `/orch/ws` on the orchestrator at setup time. When the orchestrator broadcasts a state change, the handler sets `m_OrchestratorQueryForced` and signals the scheduler event, bypassing the polling throttle. Falls back silently to polling if the WS connection fails. + +`NotifyOrchestratorChanged()` is the public API to trigger an immediate re-query — useful in tests and for external notification sources. + +Use `HttpToWsUrl(Endpoint, Path)` from `zenhttp/httpwsclient.h` to convert HTTP(S) endpoints to WebSocket URLs. This helper is shared across all WS client setup sites in the codebase. ## Action Lifecycle (End to End) @@ -118,6 +143,8 @@ Actions that fail or are abandoned can be automatically retried or manually resc **Retry limit:** Default of 3, overridable per-queue via the `max_retries` integer field in the queue's `Config` CbObject (set at `CreateQueue` time). Both automatic and manual paths respect this limit. +**Retraction (API path):** `RetractAction(Lsn)` pulls a Pending/Submitting/Running action back for rescheduling on a different runner. The action transitions to Retracted, then `ResetActionStateToPending()` is called *without* incrementing `RetryCount`. Retraction is idempotent. + **Cancelled actions are never retried** — cancellation is an intentional user action, not a transient failure. ## Queue System @@ -161,8 +188,9 @@ All routes registered in `HttpComputeService` constructor. Prefix is configured | GET | `jobs/running` | In-flight actions with CPU metrics | | GET | `jobs/completed` | Actions with results available | | GET/POST/DELETE | `jobs/{lsn}` | GET: result; POST: reschedule failed action; DELETE: retire | +| POST | `jobs/{lsn}/retract` | Retract a pending/running action for rescheduling (idempotent) | | POST | `jobs/{worker}` | Submit action for specific worker | -| POST | `jobs` | Submit action (worker resolved from descriptor) | +| POST | `jobs` | Submit action (or batch via `actions` array) | | GET | `workers` | List worker IDs | | GET | `workers/all` | All workers with full descriptors | | GET/POST | `workers/{worker}` | Get/register worker | @@ -179,8 +207,9 @@ Queue ref is capture(1) in all `queues/{queueref}/...` routes. | GET | `queues/{queueref}/completed` | Queue's completed results | | GET | `queues/{queueref}/history` | Queue's action history | | GET | `queues/{queueref}/running` | Queue's running actions | -| POST | `queues/{queueref}/jobs` | Submit to queue | +| POST | `queues/{queueref}/jobs` | Submit to queue (or batch via `actions` array) | | GET/POST | `queues/{queueref}/jobs/{lsn}` | GET: result; POST: reschedule | +| POST | `queues/{queueref}/jobs/{lsn}/retract` | Retract action for rescheduling | | GET/POST | `queues/{queueref}/workers/...` | Worker endpoints (same as global) | Worker handler logic is extracted into private helpers (`HandleWorkersGet`, `HandleWorkersAllGet`, `HandleWorkerRequest`) shared by top-level and queue-scoped routes. diff --git a/src/zencompute/cloudmetadata.cpp b/src/zencompute/cloudmetadata.cpp index 65bac895f..eb4c05f9f 100644 --- a/src/zencompute/cloudmetadata.cpp +++ b/src/zencompute/cloudmetadata.cpp @@ -23,22 +23,6 @@ static constexpr std::string_view kImdsEndpoint = "http://169.254.169.254"; // is a local service on the hypervisor so 200ms is generous for actual cloud VMs. static constexpr auto kImdsTimeout = std::chrono::milliseconds{200}; -std::string_view -ToString(CloudProvider Provider) -{ - switch (Provider) - { - case CloudProvider::AWS: - return "AWS"; - case CloudProvider::Azure: - return "Azure"; - case CloudProvider::GCP: - return "GCP"; - default: - return "None"; - } -} - CloudMetadata::CloudMetadata(std::filesystem::path DataDir) : CloudMetadata(std::move(DataDir), std::string(kImdsEndpoint)) { } @@ -610,7 +594,7 @@ CloudMetadata::PollGCPTermination() #if ZEN_WITH_TESTS -# include <zencompute/mockimds.h> +# include <zenutil/cloud/mockimds.h> # include <zencore/filesystem.h> # include <zencore/testing.h> diff --git a/src/zencompute/computeservice.cpp b/src/zencompute/computeservice.cpp index 838d741b6..92901de64 100644 --- a/src/zencompute/computeservice.cpp +++ b/src/zencompute/computeservice.cpp @@ -33,6 +33,7 @@ # include <zenutil/workerpools.h> # include <zentelemetry/stats.h> # include <zenhttp/httpclient.h> +# include <zenhttp/httpwsclient.h> # include <set> # include <deque> @@ -42,6 +43,7 @@ # include <unordered_set> ZEN_THIRD_PARTY_INCLUDES_START +# include <EASTL/fixed_vector.h> # include <EASTL/hash_set.h> ZEN_THIRD_PARTY_INCLUDES_END @@ -95,6 +97,14 @@ using SessionState = ComputeServiceSession::SessionState; static_assert(ZEN_ARRAY_COUNT(ComputeServiceSession::ActionHistoryEntry::Timestamps) == static_cast<size_t>(RunnerAction::State::_Count)); +static ComputeServiceSession::EnqueueResult +MakeErrorResult(std::string_view Error) +{ + CbObjectWriter Writer; + Writer << "error"sv << Error; + return {0, Writer.Save()}; +} + ////////////////////////////////////////////////////////////////////////// struct ComputeServiceSession::Impl @@ -130,14 +140,40 @@ struct ComputeServiceSession::Impl void SetOrchestratorEndpoint(std::string_view Endpoint); void SetOrchestratorBasePath(std::filesystem::path BasePath); + void NotifyOrchestratorChanged(); std::string m_OrchestratorEndpoint; std::filesystem::path m_OrchestratorBasePath; Stopwatch m_OrchestratorQueryTimer; + std::atomic<bool> m_OrchestratorQueryForced{false}; std::unordered_set<std::string> m_KnownWorkerUris; void UpdateCoordinatorState(); + // WebSocket subscription to orchestrator push notifications + struct OrchestratorWsHandler : public IWsClientHandler + { + Impl& Owner; + + explicit OrchestratorWsHandler(Impl& InOwner) : Owner(InOwner) {} + + void OnWsOpen() override + { + ZEN_LOG_INFO(Owner.m_Log, "orchestrator WebSocket connected"); + Owner.NotifyOrchestratorChanged(); + } + + void OnWsMessage(const WebSocketMessage&) override { Owner.NotifyOrchestratorChanged(); } + + void OnWsClose(uint16_t Code, std::string_view Reason) override + { + ZEN_LOG_WARN(Owner.m_Log, "orchestrator WebSocket closed (code {}: {})", Code, Reason); + } + }; + + std::unique_ptr<OrchestratorWsHandler> m_OrchestratorWsHandler; + std::unique_ptr<HttpWsClient> m_OrchestratorWsClient; + // Worker registration and discovery struct FunctionDefinition @@ -157,6 +193,8 @@ struct ComputeServiceSession::Impl std::atomic<int32_t> m_ActionsCounter = 0; // sequence number metrics::Meter m_ArrivalRate; + std::atomic<IComputeCompletionObserver*> m_CompletionObserver{nullptr}; + RwLock m_PendingLock; std::map<int, Ref<RunnerAction>> m_PendingActions; @@ -267,6 +305,8 @@ struct ComputeServiceSession::Impl void DrainQueue(int QueueId); ComputeServiceSession::EnqueueResult EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority); ComputeServiceSession::EnqueueResult EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority); + ComputeServiceSession::EnqueueResult ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue); + void ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn); void GetQueueCompleted(int QueueId, CbWriter& Cbo); void NotifyQueueActionComplete(int QueueId, int Lsn, RunnerAction::State ActionState); void ExpireCompletedQueues(); @@ -292,11 +332,13 @@ struct ComputeServiceSession::Impl void HandleActionUpdates(); void PostUpdate(RunnerAction* Action); + void RemoveActionFromActiveMaps(int ActionLsn); static constexpr int kDefaultMaxRetries = 3; int GetMaxRetriesForQueue(int QueueId); ComputeServiceSession::RescheduleResult RescheduleAction(int ActionLsn); + ComputeServiceSession::RescheduleResult RetractAction(int ActionLsn); ActionCounts GetActionCounts() { @@ -449,6 +491,28 @@ void ComputeServiceSession::Impl::SetOrchestratorEndpoint(std::string_view Endpoint) { m_OrchestratorEndpoint = Endpoint; + + // Subscribe to orchestrator WebSocket push so we discover worker changes + // immediately instead of waiting for the next polling cycle. + try + { + std::string WsUrl = HttpToWsUrl(Endpoint, "/orch/ws"); + + m_OrchestratorWsHandler = std::make_unique<OrchestratorWsHandler>(*this); + + HttpWsClientSettings WsSettings; + WsSettings.LogCategory = "orch_disc_ws"; + WsSettings.ConnectTimeout = std::chrono::milliseconds{3000}; + + m_OrchestratorWsClient = std::make_unique<HttpWsClient>(WsUrl, *m_OrchestratorWsHandler, WsSettings); + m_OrchestratorWsClient->Connect(); + } + catch (const std::exception& Ex) + { + ZEN_WARN("failed to connect orchestrator WebSocket, falling back to polling: {}", Ex.what()); + m_OrchestratorWsClient.reset(); + m_OrchestratorWsHandler.reset(); + } } void @@ -458,6 +522,13 @@ ComputeServiceSession::Impl::SetOrchestratorBasePath(std::filesystem::path BaseP } void +ComputeServiceSession::Impl::NotifyOrchestratorChanged() +{ + m_OrchestratorQueryForced.store(true, std::memory_order_relaxed); + m_SchedulingThreadEvent.Set(); +} + +void ComputeServiceSession::Impl::UpdateCoordinatorState() { ZEN_TRACE_CPU("ComputeServiceSession::UpdateCoordinatorState"); @@ -467,10 +538,14 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() } // Poll faster when we have no discovered workers yet so remote runners come online quickly - const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; - if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + const bool Forced = m_OrchestratorQueryForced.exchange(false, std::memory_order_relaxed); + if (!Forced) { - return; + const uint64_t PollIntervalMs = m_KnownWorkerUris.empty() ? 500 : 5000; + if (m_OrchestratorQueryTimer.GetElapsedTimeMs() < PollIntervalMs) + { + return; + } } m_OrchestratorQueryTimer.Reset(); @@ -520,7 +595,24 @@ ComputeServiceSession::Impl::UpdateCoordinatorState() continue; } - ZEN_INFO("discovered new worker at {}", UriStr); + std::string_view Hostname = Worker["hostname"sv].AsString(); + std::string_view Platform = Worker["platform"sv].AsString(); + int Cpus = Worker["cpus"sv].AsInt32(); + uint64_t MemTotal = Worker["memory_total"sv].AsUInt64(); + + if (!Hostname.empty()) + { + ZEN_INFO("discovered new worker at {} ({}, {}, {} cpus, {:.1f} GB)", + UriStr, + Hostname, + Platform, + Cpus, + static_cast<double>(MemTotal) / (1024.0 * 1024.0 * 1024.0)); + } + else + { + ZEN_INFO("discovered new worker at {}", UriStr); + } m_KnownWorkerUris.insert(UriStr); @@ -598,6 +690,15 @@ ComputeServiceSession::Impl::Shutdown() { RequestStateTransition(SessionState::Sunset); + // Close orchestrator WebSocket before stopping the scheduler thread + // to prevent callbacks into a shutting-down scheduler. + if (m_OrchestratorWsClient) + { + m_OrchestratorWsClient->Close(); + m_OrchestratorWsClient.reset(); + } + m_OrchestratorWsHandler.reset(); + m_SchedulingThreadEnabled = false; m_SchedulingThreadEvent.Set(); if (m_SchedulingThread.joinable()) @@ -720,8 +821,14 @@ ComputeServiceSession::Impl::RegisterWorker(CbPackage Worker) // different descriptor. Thus we only need to call this the first time, when the // worker is added - m_LocalRunnerGroup.RegisterWorker(Worker); - m_RemoteRunnerGroup.RegisterWorker(Worker); + if (!m_LocalRunnerGroup.RegisterWorker(Worker)) + { + ZEN_WARN("failed to register worker {} on one or more local runners", WorkerId); + } + if (!m_RemoteRunnerGroup.RegisterWorker(Worker)) + { + ZEN_WARN("failed to register worker {} on one or more remote runners", WorkerId); + } if (m_Recorder) { @@ -767,7 +874,10 @@ ComputeServiceSession::Impl::SyncWorkersToRunner(FunctionRunner& Runner) for (const CbPackage& Worker : Workers) { - Runner.RegisterWorker(Worker); + if (!Runner.RegisterWorker(Worker)) + { + ZEN_WARN("failed to sync worker {} to runner", Worker.GetObjectHash()); + } } } @@ -868,9 +978,7 @@ ComputeServiceSession::Impl::EnqueueResolvedAction(int QueueId, WorkerDesc Worke if (m_SessionState.load(std::memory_order_relaxed) != SessionState::Ready) { - CbObjectWriter Writer; - Writer << "error"sv << fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load())); - return {0, Writer.Save()}; + return MakeErrorResult(fmt::format("session is not accepting actions (state: {})", ToString(m_SessionState.load()))); } const int ActionLsn = ++m_ActionsCounter; @@ -1258,42 +1366,51 @@ ComputeServiceSession::Impl::DrainQueue(int QueueId) } ComputeServiceSession::EnqueueResult -ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +ComputeServiceSession::Impl::ValidateQueueForEnqueue(int QueueId, Ref<QueueEntry>& OutQueue) { - Ref<QueueEntry> Queue = FindQueue(QueueId); + OutQueue = FindQueue(QueueId); - if (!Queue) + if (!OutQueue) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue not found"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue not found"sv); } - QueueState QState = Queue->State.load(); + QueueState QState = OutQueue->State.load(); if (QState == QueueState::Cancelled) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is cancelled"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue is cancelled"sv); } if (QState == QueueState::Draining) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is draining"sv; - return {0, Writer.Save()}; + return MakeErrorResult("queue is draining"sv); + } + + return {}; +} + +void +ComputeServiceSession::Impl::ActivateActionInQueue(const Ref<QueueEntry>& Queue, int Lsn) +{ + Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Lsn); }); + Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); + Queue->IdleSince.store(0, std::memory_order_relaxed); +} + +ComputeServiceSession::EnqueueResult +ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionObject, int Priority) +{ + Ref<QueueEntry> Queue; + if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage) + { + return Error; } EnqueueResult Result = EnqueueAction(QueueId, ActionObject, Priority); if (Result.Lsn != 0) { - Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); - Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); - Queue->IdleSince.store(0, std::memory_order_relaxed); + ActivateActionInQueue(Queue, Result.Lsn); } return Result; @@ -1302,40 +1419,17 @@ ComputeServiceSession::Impl::EnqueueActionToQueue(int QueueId, CbObject ActionOb ComputeServiceSession::EnqueueResult ComputeServiceSession::Impl::EnqueueResolvedActionToQueue(int QueueId, WorkerDesc Worker, CbObject ActionObj, int Priority) { - Ref<QueueEntry> Queue = FindQueue(QueueId); - - if (!Queue) - { - CbObjectWriter Writer; - Writer << "error"sv - << "queue not found"sv; - return {0, Writer.Save()}; - } - - QueueState QState = Queue->State.load(); - if (QState == QueueState::Cancelled) + Ref<QueueEntry> Queue; + if (EnqueueResult Error = ValidateQueueForEnqueue(QueueId, Queue); Error.ResponseMessage) { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is cancelled"sv; - return {0, Writer.Save()}; - } - - if (QState == QueueState::Draining) - { - CbObjectWriter Writer; - Writer << "error"sv - << "queue is draining"sv; - return {0, Writer.Save()}; + return Error; } EnqueueResult Result = EnqueueResolvedAction(QueueId, Worker, ActionObj, Priority); if (Result.Lsn != 0) { - Queue->m_Lock.WithExclusiveLock([&] { Queue->ActiveLsns.insert(Result.Lsn); }); - Queue->ActiveCount.fetch_add(1, std::memory_order_relaxed); - Queue->IdleSince.store(0, std::memory_order_relaxed); + ActivateActionInQueue(Queue, Result.Lsn); } return Result; @@ -1770,6 +1864,68 @@ ComputeServiceSession::Impl::RescheduleAction(int ActionLsn) return {.Success = true, .RetryCount = NewRetryCount}; } +ComputeServiceSession::RescheduleResult +ComputeServiceSession::Impl::RetractAction(int ActionLsn) +{ + Ref<RunnerAction> Action; + bool WasRunning = false; + + // Look for the action in pending or running maps + m_RunningLock.WithSharedLock([&] { + if (auto It = m_RunningMap.find(ActionLsn); It != m_RunningMap.end()) + { + Action = It->second; + WasRunning = true; + } + }); + + if (!Action) + { + m_PendingLock.WithSharedLock([&] { + if (auto It = m_PendingActions.find(ActionLsn); It != m_PendingActions.end()) + { + Action = It->second; + } + }); + } + + if (!Action) + { + return {.Success = false, .Error = "Action not found in pending or running maps"}; + } + + if (!Action->RetractAction()) + { + return {.Success = false, .Error = "Action cannot be retracted from its current state"}; + } + + // If the action was running, send a cancellation signal to the runner + if (WasRunning) + { + m_LocalRunnerGroup.CancelAction(ActionLsn); + } + + ZEN_INFO("action {} ({}) retract requested", Action->ActionId, ActionLsn); + return {.Success = true, .RetryCount = Action->RetryCount.load(std::memory_order_relaxed)}; +} + +void +ComputeServiceSession::Impl::RemoveActionFromActiveMaps(int ActionLsn) +{ + m_RunningLock.WithExclusiveLock([&] { + m_PendingLock.WithExclusiveLock([&] { + if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) + { + m_PendingActions.erase(ActionLsn); + } + else + { + m_RunningMap.erase(FindIt); + } + }); + }); +} + void ComputeServiceSession::Impl::HandleActionUpdates() { @@ -1781,6 +1937,10 @@ ComputeServiceSession::Impl::HandleActionUpdates() std::unordered_set<int> SeenLsn; + // Collect terminal action notifications for the completion observer. + // Inline capacity of 64 avoids heap allocation in the common case. + eastl::fixed_vector<IComputeCompletionObserver::CompletedActionNotification, 64> TerminalBatch; + // Process each action's latest state, deduplicating by LSN. // // This is safe because state transitions are monotonically increasing by enum @@ -1798,7 +1958,23 @@ ComputeServiceSession::Impl::HandleActionUpdates() { // Newly enqueued — add to pending map for scheduling case RunnerAction::State::Pending: - m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + // Guard against a race where the session is abandoned between + // EnqueueAction (which calls PostUpdate) and this scheduler + // tick. AbandonAllActions() only scans m_PendingActions, so it + // misses actions still in m_UpdatedActions at the time the + // session transitions. Detect that here and immediately abandon + // rather than inserting into the pending map, where they would + // otherwise be stuck indefinitely. + if (m_SessionState.load(std::memory_order_relaxed) >= SessionState::Abandoned) + { + Action->SetActionState(RunnerAction::State::Abandoned); + // SetActionState calls PostUpdate; the Abandoned action + // will be processed as a terminal on the next scheduler pass. + } + else + { + m_PendingLock.WithExclusiveLock([&] { m_PendingActions.insert({ActionLsn, Action}); }); + } break; // Async submission in progress — remains in pending map @@ -1816,6 +1992,15 @@ ComputeServiceSession::Impl::HandleActionUpdates() ZEN_DEBUG("action {} ({}) RUNNING", Action->ActionId, ActionLsn); break; + // Retracted — pull back for rescheduling without counting against retry limit + case RunnerAction::State::Retracted: + { + RemoveActionFromActiveMaps(ActionLsn); + Action->ResetActionStateToPending(); + ZEN_INFO("action {} ({}) retracted for rescheduling", Action->ActionId, ActionLsn); + break; + } + // Terminal states — move to results, record history, notify queue case RunnerAction::State::Completed: case RunnerAction::State::Failed: @@ -1834,19 +2019,7 @@ ComputeServiceSession::Impl::HandleActionUpdates() if (Action->RetryCount.load(std::memory_order_relaxed) < MaxRetries) { - // Remove from whichever active map the action is in before resetting - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + RemoveActionFromActiveMaps(ActionLsn); // Reset triggers PostUpdate() which re-enters the action as Pending Action->ResetActionStateToPending(); @@ -1861,19 +2034,14 @@ ComputeServiceSession::Impl::HandleActionUpdates() } } - // Remove from whichever active map the action is in - m_RunningLock.WithExclusiveLock([&] { - m_PendingLock.WithExclusiveLock([&] { - if (auto FindIt = m_RunningMap.find(ActionLsn); FindIt == m_RunningMap.end()) - { - m_PendingActions.erase(ActionLsn); - } - else - { - m_RunningMap.erase(FindIt); - } - }); - }); + RemoveActionFromActiveMaps(ActionLsn); + + // Update queue counters BEFORE publishing the result into + // m_ResultsMap. GetActionResult erases from m_ResultsMap + // under m_ResultsLock, so if we updated counters after + // releasing that lock, a caller could observe ActiveCount + // still at 1 immediately after GetActionResult returned OK. + NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); m_ResultsLock.WithExclusiveLock([&] { m_ResultsMap[ActionLsn] = Action; @@ -1902,16 +2070,46 @@ ComputeServiceSession::Impl::HandleActionUpdates() }); m_RetiredCount.fetch_add(1); m_ResultRate.Mark(1); + { + using ObserverState = IComputeCompletionObserver::ActionState; + ObserverState NotifyState{}; + switch (TerminalState) + { + case RunnerAction::State::Completed: + NotifyState = ObserverState::Completed; + break; + case RunnerAction::State::Failed: + NotifyState = ObserverState::Failed; + break; + case RunnerAction::State::Abandoned: + NotifyState = ObserverState::Abandoned; + break; + case RunnerAction::State::Cancelled: + NotifyState = ObserverState::Cancelled; + break; + default: + break; + } + TerminalBatch.push_back({ActionLsn, NotifyState}); + } ZEN_DEBUG("action {} ({}) RUNNING -> COMPLETED with {}", Action->ActionId, ActionLsn, TerminalState == RunnerAction::State::Completed ? "SUCCESS" : "FAILURE"); - NotifyQueueActionComplete(Action->QueueId, ActionLsn, TerminalState); break; } } } } + + // Notify the completion observer, if any, about all terminal actions in this batch. + if (!TerminalBatch.empty()) + { + if (IComputeCompletionObserver* Observer = m_CompletionObserver.load(std::memory_order_acquire)) + { + Observer->OnActionsCompleted({TerminalBatch.data(), TerminalBatch.size()}); + } + } } size_t @@ -2014,6 +2212,12 @@ ComputeServiceSession::SetOrchestratorBasePath(std::filesystem::path BasePath) } void +ComputeServiceSession::NotifyOrchestratorChanged() +{ + m_Impl->NotifyOrchestratorChanged(); +} + +void ComputeServiceSession::StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath) { m_Impl->StartRecording(InResolver, RecordingPath); @@ -2194,6 +2398,12 @@ ComputeServiceSession::RescheduleAction(int ActionLsn) return m_Impl->RescheduleAction(ActionLsn); } +ComputeServiceSession::RescheduleResult +ComputeServiceSession::RetractAction(int ActionLsn) +{ + return m_Impl->RetractAction(ActionLsn); +} + std::vector<ComputeServiceSession::RunningActionInfo> ComputeServiceSession::GetRunningActions() { @@ -2219,6 +2429,12 @@ ComputeServiceSession::GetCompleted(CbWriter& Cbo) } void +ComputeServiceSession::SetCompletionObserver(IComputeCompletionObserver* Observer) +{ + m_Impl->m_CompletionObserver.store(Observer, std::memory_order_release); +} + +void ComputeServiceSession::PostUpdate(RunnerAction* Action) { m_Impl->PostUpdate(Action); diff --git a/src/zencompute/httpcomputeservice.cpp b/src/zencompute/httpcomputeservice.cpp index e82a40781..bdfd9d197 100644 --- a/src/zencompute/httpcomputeservice.cpp +++ b/src/zencompute/httpcomputeservice.cpp @@ -16,6 +16,7 @@ # include <zencore/iobuffer.h> # include <zencore/iohash.h> # include <zencore/logging.h> +# include <zencore/string.h> # include <zencore/system.h> # include <zencore/thread.h> # include <zencore/trace.h> @@ -23,8 +24,10 @@ # include <zenstore/cidstore.h> # include <zentelemetry/stats.h> +# include <algorithm> # include <span> # include <unordered_map> +# include <vector> using namespace std::literals; @@ -50,6 +53,11 @@ struct HttpComputeService::Impl ComputeServiceSession m_ComputeService; SystemMetricsTracker m_MetricsTracker; + // WebSocket connections (completion push) + + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + // Metrics metrics::OperationTiming m_HttpRequests; @@ -91,6 +99,12 @@ struct HttpComputeService::Impl void HandleWorkersAllGet(HttpServerRequest& HttpReq); void WriteQueueDescription(CbWriter& Cbo, int QueueId, const ComputeServiceSession::QueueStatus& Status); void HandleWorkerRequest(HttpServerRequest& HttpReq, const IoHash& WorkerId); + void HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker); + + // WebSocket / observer + void OnWebSocketOpen(Ref<WebSocketConnection> Connection); + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code); + void OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions); void RegisterRoutes(); @@ -110,6 +124,7 @@ struct HttpComputeService::Impl m_ComputeService.WaitUntilReady(); m_StatsService.RegisterHandler("compute", *m_Self); RegisterRoutes(); + m_ComputeService.SetCompletionObserver(m_Self); } }; @@ -149,7 +164,7 @@ HttpComputeService::Impl::RegisterRoutes() return HttpReq.WriteResponse(HttpResponseCode::Forbidden); } - bool Success = m_ComputeService.RequestStateTransition(ComputeServiceSession::SessionState::Abandoned); + bool Success = m_ComputeService.Abandon(); if (Success) { @@ -325,6 +340,29 @@ HttpComputeService::Impl::RegisterRoutes() HttpVerb::kGet | HttpVerb::kPost); m_Router.RegisterRoute( + "jobs/{lsn}/retract", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int ActionLsn = ParseInt<int>(Req.GetCapture(1)).value_or(0); + + auto Result = m_ComputeService.RetractAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "success"sv << true; + Cbo << "lsn"sv << ActionLsn; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( "jobs/{worker}/{action}", // This route is inefficient, and is only here for backwards compatibility. The preferred path is the // one which uses the scheduled action lsn for lookups [this](HttpRouterRequest& Req) { @@ -373,127 +411,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestVerb()) - { - case HttpVerb::kGet: - // TODO: return status of all pending or executing jobs - break; - - case HttpVerb::kPost: - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("action {} accepted (lsn {})", ActionObj.GetHash(), Result.Lsn); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - break; - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedAction(Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - - return; - } - break; - - default: - break; - } - break; - - default: - break; - } + HandleSubmitAction(HttpReq, 0, RequestPriority, &Worker); }, HttpVerb::kPost); @@ -511,118 +429,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - // Resolve worker - - // - - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - // This operation takes the proposed job spec and identifies which - // chunks are not present on this server. This list is then returned in - // the "need" list in the response - - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - ActionObj.IterateAttachments([&](CbFieldView Field) { - const IoHash FileHash = Field.AsHash(); - - if (!m_CidStore.ContainsChunk(FileHash)) - { - NeedList.push_back(FileHash); - } - }); - - if (NeedList.empty()) - { - // We already have everything, enqueue the action for execution - - if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("action accepted (lsn {})", Result.Lsn); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - CbObject Response = Cbo.Save(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - std::span<const CbAttachment> Attachments = Action.GetAttachments(); - - int AttachmentCount = 0; - int NewAttachmentCount = 0; - uint64_t TotalAttachmentBytes = 0; - uint64_t TotalNewBytes = 0; - - for (const CbAttachment& Attachment : Attachments) - { - ZEN_ASSERT(Attachment.IsCompressedBinary()); - - const IoHash DataHash = Attachment.GetHash(); - CompressedBuffer DataView = Attachment.AsCompressedBinary(); - - ZEN_UNUSED(DataHash); - - const uint64_t CompressedSize = DataView.GetCompressedSize(); - - TotalAttachmentBytes += CompressedSize; - ++AttachmentCount; - - const CidStore::InsertResult InsertResult = - m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); - - if (InsertResult.New) - { - TotalNewBytes += CompressedSize; - ++NewAttachmentCount; - } - } - - if (ComputeServiceSession::EnqueueResult Result = m_ComputeService.EnqueueAction(ActionObj, RequestPriority)) - { - ZEN_DEBUG("accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - Result.Lsn, - zen::NiceBytes(TotalAttachmentBytes), - AttachmentCount, - zen::NiceBytes(TotalNewBytes), - NewAttachmentCount); - - HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - // Could not resolve? - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - return; - } + HandleSubmitAction(HttpReq, 0, RequestPriority, nullptr); }, HttpVerb::kPost); @@ -1090,72 +897,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - if (!CheckAttachments(ActionObj, NeedList)) - { - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, ActionObj.GetHash(), Result.Lsn); - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - IngestStats Stats = IngestPackageAttachments(Action); - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueResolvedActionToQueue(QueueId, Worker, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", - QueueId, - ActionObj.GetHash(), - Result.Lsn, - zen::NiceBytes(Stats.Bytes), - Stats.Count, - zen::NiceBytes(Stats.NewBytes), - Stats.NewCount); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - default: - break; - } + HandleSubmitAction(HttpReq, QueueId, RequestPriority, &Worker); }, HttpVerb::kPost); @@ -1178,71 +920,7 @@ HttpComputeService::Impl::RegisterRoutes() RequestPriority = ParseInt<int>(PriorityParam).value_or(-1); } - switch (HttpReq.RequestContentType()) - { - case HttpContentType::kCbObject: - { - IoBuffer Payload = HttpReq.ReadPayload(); - CbObject ActionObj = LoadCompactBinaryObject(Payload); - - std::vector<IoHash> NeedList; - - if (!CheckAttachments(ActionObj, NeedList)) - { - CbObjectWriter Cbo; - Cbo.BeginArray("need"); - - for (const IoHash& Hash : NeedList) - { - Cbo << Hash; - } - - Cbo.EndArray(); - - return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); - } - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: action accepted (lsn {})", QueueId, Result.Lsn); - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - case HttpContentType::kCbPackage: - { - CbPackage Action = HttpReq.ReadPayloadPackage(); - CbObject ActionObj = Action.GetObject(); - - IngestStats Stats = IngestPackageAttachments(Action); - - if (ComputeServiceSession::EnqueueResult Result = - m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, RequestPriority)) - { - ZEN_DEBUG("queue {}: accepted action (lsn {}): {} in {} attachments. {} new ({} attachments)", - QueueId, - Result.Lsn, - zen::NiceBytes(Stats.Bytes), - Stats.Count, - zen::NiceBytes(Stats.NewBytes), - Stats.NewCount); - - return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); - } - else - { - return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); - } - } - - default: - break; - } + HandleSubmitAction(HttpReq, QueueId, RequestPriority, nullptr); }, HttpVerb::kPost); @@ -1306,6 +984,45 @@ HttpComputeService::Impl::RegisterRoutes() } }, HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "queues/{queueref}/jobs/{lsn}/retract", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const int QueueId = ResolveQueueRef(HttpReq, Req.GetCapture(1)); + const int ActionLsn = ParseInt<int>(Req.GetCapture(2)).value_or(0); + + if (QueueId == 0) + { + return; + } + + ZEN_UNUSED(QueueId); + + auto Result = m_ComputeService.RetractAction(ActionLsn); + + CbObjectWriter Cbo; + if (Result.Success) + { + Cbo << "success"sv << true; + Cbo << "lsn"sv << ActionLsn; + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + else + { + Cbo << "error"sv << Result.Error; + HttpReq.WriteResponse(HttpResponseCode::Conflict, Cbo.Save()); + } + }, + HttpVerb::kPost); + + // WebSocket upgrade endpoint — the handler logic lives in + // HttpComputeService::OnWebSocket* methods; this route merely + // satisfies the router so the upgrade request isn't rejected. + m_Router.RegisterRoute( + "ws", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); } ////////////////////////////////////////////////////////////////////////// @@ -1320,12 +1037,17 @@ HttpComputeService::HttpComputeService(CidStore& InCidStore, HttpComputeService::~HttpComputeService() { + m_Impl->m_ComputeService.SetCompletionObserver(nullptr); m_Impl->m_StatsService.UnregisterHandler("compute", *this); } void HttpComputeService::Shutdown() { + // Null out observer before shutting down the compute session to prevent + // callbacks into a partially-torn-down service. + m_Impl->m_ComputeService.SetCompletionObserver(nullptr); + m_Impl->m_WsConnectionsLock.WithExclusiveLock([&] { m_Impl->m_WsConnections.clear(); }); m_Impl->m_ComputeService.Shutdown(); } @@ -1492,6 +1214,184 @@ HttpComputeService::Impl::CheckAttachments(const CbObject& ActionObj, std::vecto } void +HttpComputeService::Impl::HandleSubmitAction(HttpServerRequest& HttpReq, int QueueId, int Priority, const WorkerDesc* Worker) +{ + // QueueId > 0: queue-scoped enqueue; QueueId == 0: implicit queue (global routes) + auto Enqueue = [&](CbObject ActionObj) -> ComputeServiceSession::EnqueueResult { + if (QueueId > 0) + { + if (Worker) + { + return m_ComputeService.EnqueueResolvedActionToQueue(QueueId, *Worker, ActionObj, Priority); + } + return m_ComputeService.EnqueueActionToQueue(QueueId, ActionObj, Priority); + } + else + { + if (Worker) + { + return m_ComputeService.EnqueueResolvedAction(*Worker, ActionObj, Priority); + } + return m_ComputeService.EnqueueAction(ActionObj, Priority); + } + }; + + // Read payload upfront and handle attachments based on content type + CbObject Body; + IngestStats Stats = {}; + + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + IoBuffer Payload = HttpReq.ReadPayload(); + Body = LoadCompactBinaryObject(Payload); + break; + } + + case HttpContentType::kCbPackage: + { + CbPackage Package = HttpReq.ReadPayloadPackage(); + Body = Package.GetObject(); + Stats = IngestPackageAttachments(Package); + break; + } + + default: + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + // Check for "actions" array to determine batch vs single-action path + CbArray Actions = Body.Find("actions"sv).AsArray(); + + if (Actions.Num() > 0) + { + // --- Batch path --- + + // For CbObject payloads, check all attachments upfront before enqueuing anything + if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + { + std::vector<IoHash> NeedList; + + for (CbField ActionField : Actions) + { + CbObject ActionObj = ActionField.AsObject(); + CheckAttachments(ActionObj, NeedList); + } + + if (!NeedList.empty()) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + } + + // Enqueue all actions and collect results + CbObjectWriter Cbo; + int Accepted = 0; + + Cbo.BeginArray("results"); + + for (CbField ActionField : Actions) + { + CbObject ActionObj = ActionField.AsObject(); + + ComputeServiceSession::EnqueueResult Result = Enqueue(ActionObj); + + Cbo.BeginObject(); + + if (Result) + { + Cbo << "lsn"sv << Result.Lsn; + ++Accepted; + } + else + { + Cbo << "error"sv << Result.ResponseMessage; + } + + Cbo.EndObject(); + } + + Cbo.EndArray(); + + if (Stats.Count > 0) + { + ZEN_DEBUG("queue {}: batch accepted {}/{} actions: {} in {} attachments. {} new ({} attachments)", + QueueId, + Accepted, + Actions.Num(), + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + } + else + { + ZEN_DEBUG("queue {}: batch accepted {}/{} actions", QueueId, Accepted, Actions.Num()); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + + // --- Single-action path: Body is the action itself --- + + if (HttpReq.RequestContentType() == HttpContentType::kCbObject) + { + std::vector<IoHash> NeedList; + + if (!CheckAttachments(Body, NeedList)) + { + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Cbo.Save()); + } + } + + if (ComputeServiceSession::EnqueueResult Result = Enqueue(Body)) + { + if (Stats.Count > 0) + { + ZEN_DEBUG("queue {}: accepted action {} (lsn {}): {} in {} attachments. {} new ({} attachments)", + QueueId, + Body.GetHash(), + Result.Lsn, + zen::NiceBytes(Stats.Bytes), + Stats.Count, + zen::NiceBytes(Stats.NewBytes), + Stats.NewCount); + } + else + { + ZEN_DEBUG("queue {}: action {} accepted (lsn {})", QueueId, Body.GetHash(), Result.Lsn); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Result.ResponseMessage); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::FailedDependency, Result.ResponseMessage); + } +} + +void HttpComputeService::Impl::HandleWorkersGet(HttpServerRequest& HttpReq) { CbObjectWriter Cbo; @@ -1632,6 +1532,136 @@ HttpComputeService::Impl::HandleWorkerRequest(HttpServerRequest& HttpReq, const } ////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpComputeService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + m_Impl->OnWebSocketOpen(std::move(Connection)); +} + +void +HttpComputeService::OnWebSocketMessage([[maybe_unused]] WebSocketConnection& Conn, [[maybe_unused]] const WebSocketMessage& Msg) +{ + // Clients are receive-only; ignore any inbound messages. +} + +void +HttpComputeService::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + m_Impl->OnWebSocketClose(Conn, Code); +} + +void +HttpComputeService::OnActionsCompleted(std::span<const CompletedActionNotification> Actions) +{ + m_Impl->OnActionsCompleted(Actions); +} + +////////////////////////////////////////////////////////////////////////// +// +// Impl — WebSocket / observer +// + +void +HttpComputeService::Impl::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + ZEN_INFO("compute WebSocket client connected"); + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); +} + +void +HttpComputeService::Impl::OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code) +{ + ZEN_INFO("compute WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +void +HttpComputeService::Impl::OnActionsCompleted(std::span<const IComputeCompletionObserver::CompletedActionNotification> Actions) +{ + using ActionState = IComputeCompletionObserver::ActionState; + using CompletedActionNotification = IComputeCompletionObserver::CompletedActionNotification; + + // Snapshot connections under shared lock + eastl::fixed_vector<Ref<WebSocketConnection>, 16> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = {begin(m_WsConnections), end(m_WsConnections)}; }); + + if (Connections.empty()) + { + return; + } + + // Build CompactBinary notification grouped by state: + // {"Completed": [lsn, ...], "Failed": [lsn, ...], ...} + // Each state name becomes an array key containing the LSNs in that state. + CbObjectWriter Cbo; + + // Sort by state so we can emit one array per state in a single pass. + // Copy into a local vector since the span is const. + eastl::fixed_vector<CompletedActionNotification, 16> Sorted(Actions.begin(), Actions.end()); + std::sort(Sorted.begin(), Sorted.end(), [](const auto& A, const auto& B) { return A.State < B.State; }); + + ActionState CurrentState{}; + bool ArrayOpen = false; + + for (const CompletedActionNotification& Action : Sorted) + { + if (!ArrayOpen || Action.State != CurrentState) + { + if (ArrayOpen) + { + Cbo.EndArray(); + } + CurrentState = Action.State; + Cbo.BeginArray(IComputeCompletionObserver::ActionStateToString(CurrentState)); + ArrayOpen = true; + } + Cbo.AddInteger(Action.Lsn); + } + + if (ArrayOpen) + { + Cbo.EndArray(); + } + + CbObject Msg = Cbo.Save(); + MemoryView MsgView = Msg.GetView(); + + // Broadcast to all connected clients, prune closed ones + bool HadClosedConnections = false; + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendBinary(MsgView); + } + else + { + HadClosedConnections = true; + } + } + + if (HadClosedConnections) + { + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [](const Ref<WebSocketConnection>& C) { + return !C->IsOpen(); + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); + } +} + +////////////////////////////////////////////////////////////////////////// void httpcomputeservice_forcelink() diff --git a/src/zencompute/include/zencompute/cloudmetadata.h b/src/zencompute/include/zencompute/cloudmetadata.h index a5bc5a34d..3b9642ac3 100644 --- a/src/zencompute/include/zencompute/cloudmetadata.h +++ b/src/zencompute/include/zencompute/cloudmetadata.h @@ -2,6 +2,8 @@ #pragma once +#include <zenutil/cloud/cloudprovider.h> + #include <zencore/compactbinarybuilder.h> #include <zencore/logging.h> #include <zencore/thread.h> @@ -13,16 +15,6 @@ namespace zen::compute { -enum class CloudProvider -{ - None, - AWS, - Azure, - GCP -}; - -std::string_view ToString(CloudProvider Provider); - /** Snapshot of detected cloud instance properties. */ struct CloudInstanceInfo { diff --git a/src/zencompute/include/zencompute/computeservice.h b/src/zencompute/include/zencompute/computeservice.h index 65ec5f9ee..1ca78738a 100644 --- a/src/zencompute/include/zencompute/computeservice.h +++ b/src/zencompute/include/zencompute/computeservice.h @@ -13,6 +13,7 @@ # include <zenhttp/httpcommon.h> # include <filesystem> +# include <span> namespace zen { class ChunkResolver; @@ -29,6 +30,53 @@ class RemoteHttpRunner; struct RunnerAction; struct SubmitResult; +/** + * Observer interface for action completion notifications. + * + * Implementors receive a batch of notifications whenever actions reach a + * terminal state (Completed, Failed, Abandoned, Cancelled). The callback + * fires on the scheduler thread *after* the action result has been placed + * in m_ResultsMap, so GET /jobs/{lsn} will succeed by the time the client + * reacts to the notification. + */ +class IComputeCompletionObserver +{ +public: + virtual ~IComputeCompletionObserver() = default; + + enum class ActionState + { + Completed, + Failed, + Abandoned, + Cancelled, + }; + + struct CompletedActionNotification + { + int Lsn; + ActionState State; + }; + + static constexpr std::string_view ActionStateToString(ActionState State) + { + switch (State) + { + case ActionState::Completed: + return "Completed"; + case ActionState::Failed: + return "Failed"; + case ActionState::Abandoned: + return "Abandoned"; + case ActionState::Cancelled: + return "Cancelled"; + } + return "Unknown"; + } + + virtual void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) = 0; +}; + struct WorkerDesc { CbPackage Descriptor; @@ -91,11 +139,25 @@ public: // Sunset can be reached from any non-Sunset state. bool RequestStateTransition(SessionState NewState); + // Convenience helpers for common state transitions. + bool Ready() { return RequestStateTransition(SessionState::Ready); } + bool Abandon() { return RequestStateTransition(SessionState::Abandoned); } + // Orchestration void SetOrchestratorEndpoint(std::string_view Endpoint); void SetOrchestratorBasePath(std::filesystem::path BasePath); + void SetOrchestrator(std::string_view Endpoint, std::filesystem::path BasePath) + { + SetOrchestratorEndpoint(Endpoint); + SetOrchestratorBasePath(std::move(BasePath)); + } + + /// Immediately wake the scheduler to re-poll the orchestrator for worker changes. + /// Resets the polling throttle so the next scheduler tick calls UpdateCoordinatorState(). + void NotifyOrchestratorChanged(); + // Worker registration and discovery void RegisterWorker(CbPackage Worker); @@ -182,6 +244,7 @@ public: }; [[nodiscard]] RescheduleResult RescheduleAction(int ActionLsn); + [[nodiscard]] RescheduleResult RetractAction(int ActionLsn); void GetCompleted(CbWriter&); @@ -215,7 +278,7 @@ public: // sized to match RunnerAction::State::_Count but we can't use the enum here // for dependency reasons, so just use a fixed size array and static assert in // the implementation file - uint64_t Timestamps[8] = {}; + uint64_t Timestamps[9] = {}; }; [[nodiscard]] std::vector<ActionHistoryEntry> GetActionHistory(int Limit = 100); @@ -235,6 +298,10 @@ public: void EmitStats(CbObjectWriter& Cbo); + // Completion observer (used by HttpComputeService for WebSocket push) + + void SetCompletionObserver(IComputeCompletionObserver* Observer); + // Recording void StartRecording(ChunkResolver& InResolver, const std::filesystem::path& RecordingPath); diff --git a/src/zencompute/include/zencompute/httpcomputeservice.h b/src/zencompute/include/zencompute/httpcomputeservice.h index ee1cd2614..b58e73a0d 100644 --- a/src/zencompute/include/zencompute/httpcomputeservice.h +++ b/src/zencompute/include/zencompute/httpcomputeservice.h @@ -9,6 +9,7 @@ # include "zencompute/computeservice.h" # include <zenhttp/httpserver.h> +# include <zenhttp/websocket.h> # include <filesystem> # include <memory> @@ -22,7 +23,7 @@ namespace zen::compute { /** * HTTP interface for compute service */ -class HttpComputeService : public HttpService, public IHttpStatsProvider +class HttpComputeService : public HttpService, public IHttpStatsProvider, public IWebSocketHandler, public IComputeCompletionObserver { public: HttpComputeService(CidStore& InCidStore, @@ -42,6 +43,16 @@ public: void HandleStatsRequest(HttpServerRequest& Request) override; + // IWebSocketHandler + + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + + // IComputeCompletionObserver + + void OnActionsCompleted(std::span<const CompletedActionNotification> Actions) override; + private: struct Impl; std::unique_ptr<Impl> m_Impl; diff --git a/src/zencompute/include/zencompute/mockimds.h b/src/zencompute/include/zencompute/mockimds.h index 521722e63..704306913 100644 --- a/src/zencompute/include/zencompute/mockimds.h +++ b/src/zencompute/include/zencompute/mockimds.h @@ -1,102 +1,6 @@ // Copyright Epic Games, Inc. All Rights Reserved. +// Moved to zenutil — this header is kept for backward compatibility. #pragma once -#include <zencompute/cloudmetadata.h> -#include <zenhttp/httpserver.h> - -#include <string> - -#if ZEN_WITH_TESTS - -namespace zen::compute { - -/** - * Mock IMDS (Instance Metadata Service) for testing CloudMetadata. - * - * Implements an HttpService that responds to the same URL paths as the real - * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). - * Tests configure which provider is "active" and set the desired response - * values, then pass the mock server's address as the ImdsEndpoint to the - * CloudMetadata constructor. - * - * When a request arrives for a provider that is not the ActiveProvider, the - * mock returns 404, causing CloudMetadata to write a sentinel file and move - * on to the next provider — exactly like a failed probe on bare metal. - * - * All config fields are public and can be mutated between poll cycles to - * simulate state changes (e.g. a spot interruption appearing mid-run). - * - * Usage: - * MockImdsService Mock; - * Mock.ActiveProvider = CloudProvider::AWS; - * Mock.Aws.InstanceId = "i-test"; - * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint - */ -class MockImdsService : public HttpService -{ -public: - /** AWS IMDSv2 response configuration. */ - struct AwsConfig - { - std::string Token = "mock-aws-token-v2"; - std::string InstanceId = "i-0123456789abcdef0"; - std::string AvailabilityZone = "us-east-1a"; - std::string LifeCycle = "on-demand"; // "spot" or "on-demand" - - // Empty string → endpoint returns 404 (instance not in an ASG). - // Non-empty → returned as the response body. "InService" means healthy; - // anything else (e.g. "Terminated:Wait") triggers termination detection. - std::string AutoscalingState; - - // Empty string → endpoint returns 404 (no spot interruption). - // Non-empty → returned as the response body, signalling a spot reclaim. - std::string SpotAction; - }; - - /** Azure IMDS response configuration. */ - struct AzureConfig - { - std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; - std::string Location = "eastus"; - std::string Priority = "Regular"; // "Spot" or "Regular" - - // Empty → instance is not in a VM Scale Set (no autoscaling). - std::string VmScaleSetName; - - // Empty → no scheduled events. Set to "Preempt", "Terminate", or - // "Reboot" to simulate a termination-class event. - std::string ScheduledEventType; - std::string ScheduledEventStatus = "Scheduled"; - }; - - /** GCP metadata response configuration. */ - struct GcpConfig - { - std::string InstanceId = "1234567890123456789"; - std::string Zone = "projects/123456/zones/us-central1-a"; - std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" - std::string MaintenanceEvent = "NONE"; // "NONE" or event description - }; - - /** Which provider's endpoints respond successfully. - * Requests targeting other providers receive 404. - */ - CloudProvider ActiveProvider = CloudProvider::None; - - AwsConfig Aws; - AzureConfig Azure; - GcpConfig Gcp; - - const char* BaseUri() const override; - void HandleRequest(HttpServerRequest& Request) override; - -private: - void HandleAwsRequest(HttpServerRequest& Request); - void HandleAzureRequest(HttpServerRequest& Request); - void HandleGcpRequest(HttpServerRequest& Request); -}; - -} // namespace zen::compute - -#endif // ZEN_WITH_TESTS +#include <zenutil/cloud/mockimds.h> diff --git a/src/zencompute/runners/functionrunner.cpp b/src/zencompute/runners/functionrunner.cpp index 768cdf1e1..4f116e7d8 100644 --- a/src/zencompute/runners/functionrunner.cpp +++ b/src/zencompute/runners/functionrunner.cpp @@ -215,15 +215,22 @@ BaseRunnerGroup::GetSubmittedActionCount() return TotalCount; } -void +bool BaseRunnerGroup::RegisterWorker(CbPackage Worker) { RwLock::SharedLockScope _(m_RunnersLock); + bool AllSucceeded = true; + for (auto& Runner : m_Runners) { - Runner->RegisterWorker(Worker); + if (!Runner->RegisterWorker(Worker)) + { + AllSucceeded = false; + } } + + return AllSucceeded; } void @@ -276,12 +283,34 @@ RunnerAction::~RunnerAction() } bool +RunnerAction::RetractAction() +{ + State CurrentState = m_ActionState.load(); + + do + { + // Only allow retraction from pre-terminal states (idempotent if already retracted) + if (CurrentState > State::Running) + { + return CurrentState == State::Retracted; + } + + if (m_ActionState.compare_exchange_strong(CurrentState, State::Retracted)) + { + this->Timestamps[static_cast<int>(State::Retracted)] = DateTime::Now().GetTicks(); + m_OwnerSession->PostUpdate(this); + return true; + } + } while (true); +} + +bool RunnerAction::ResetActionStateToPending() { - // Only allow reset from Failed or Abandoned states + // Only allow reset from Failed, Abandoned, or Retracted states State CurrentState = m_ActionState.load(); - if (CurrentState != State::Failed && CurrentState != State::Abandoned) + if (CurrentState != State::Failed && CurrentState != State::Abandoned && CurrentState != State::Retracted) { return false; } @@ -305,8 +334,11 @@ RunnerAction::ResetActionStateToPending() CpuUsagePercent.store(-1.0f, std::memory_order_relaxed); CpuSeconds.store(0.0f, std::memory_order_relaxed); - // Increment retry count - RetryCount.fetch_add(1, std::memory_order_relaxed); + // Increment retry count (skip for Retracted — nothing failed) + if (CurrentState != State::Retracted) + { + RetryCount.fetch_add(1, std::memory_order_relaxed); + } // Re-enter the scheduler pipeline m_OwnerSession->PostUpdate(this); diff --git a/src/zencompute/runners/functionrunner.h b/src/zencompute/runners/functionrunner.h index f67414dbb..56c3f3af0 100644 --- a/src/zencompute/runners/functionrunner.h +++ b/src/zencompute/runners/functionrunner.h @@ -29,8 +29,8 @@ public: FunctionRunner(std::filesystem::path BasePath); virtual ~FunctionRunner() = 0; - virtual void Shutdown() = 0; - virtual void RegisterWorker(const CbPackage& WorkerPackage) = 0; + virtual void Shutdown() = 0; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) = 0; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) = 0; [[nodiscard]] virtual size_t GetSubmittedActionCount() = 0; @@ -63,7 +63,7 @@ public: SubmitResult SubmitAction(Ref<RunnerAction> Action); std::vector<SubmitResult> SubmitActions(const std::vector<Ref<RunnerAction>>& Actions); size_t GetSubmittedActionCount(); - void RegisterWorker(CbPackage Worker); + [[nodiscard]] bool RegisterWorker(CbPackage Worker); void Shutdown(); bool CancelAction(int ActionLsn); void CancelRemoteQueue(int QueueId); @@ -114,6 +114,30 @@ struct RunnerGroup : public BaseRunnerGroup /** * This represents an action going through different stages of scheduling and execution. + * + * State machine + * ============= + * + * Normal forward flow (enforced by SetActionState rejecting backward transitions): + * + * New -> Pending -> Submitting -> Running -> Completed + * -> Failed + * -> Abandoned + * -> Cancelled + * + * Rescheduling (via ResetActionStateToPending): + * + * Failed ---> Pending (increments RetryCount, subject to retry limit) + * Abandoned ---> Pending (increments RetryCount, subject to retry limit) + * Retracted ---> Pending (does NOT increment RetryCount) + * + * Retraction (via RetractAction, idempotent): + * + * Pending/Submitting/Running -> Retracted -> Pending (rescheduled) + * + * Retracted is placed after Cancelled in enum order so that once set, + * no runner-side transition (Completed/Failed) can override it via + * SetActionState's forward-only rule. */ struct RunnerAction : public RefCounted { @@ -137,16 +161,20 @@ struct RunnerAction : public RefCounted enum class State { - New, - Pending, - Submitting, - Running, - Completed, - Failed, - Abandoned, - Cancelled, + New, // Initial state at construction, before entering the scheduler + Pending, // Queued and waiting for a runner slot + Submitting, // Being handed off to a runner (async submission in progress) + Running, // Executing on a runner process + Completed, // Finished successfully with results available + Failed, // Execution failed (transient error, eligible for retry) + Abandoned, // Infrastructure termination (e.g. spot eviction, session abandon) + Cancelled, // Intentional user cancellation (never retried) + Retracted, // Pulled back for rescheduling on a different runner (no retry cost) _Count }; + static_assert(State::Retracted > State::Completed && State::Retracted > State::Failed && State::Retracted > State::Abandoned && + State::Retracted > State::Cancelled, + "Retracted must be the highest terminal ordinal so runner-side transitions cannot override it"); static const char* ToString(State _) { @@ -168,6 +196,8 @@ struct RunnerAction : public RefCounted return "Abandoned"; case State::Cancelled: return "Cancelled"; + case State::Retracted: + return "Retracted"; default: return "Unknown"; } @@ -191,6 +221,7 @@ struct RunnerAction : public RefCounted void SetActionState(State NewState); bool IsSuccess() const { return ActionState() == State::Completed; } + bool RetractAction(); bool ResetActionStateToPending(); bool IsCompleted() const { diff --git a/src/zencompute/runners/localrunner.cpp b/src/zencompute/runners/localrunner.cpp index 7aaefb06e..b61e0a46f 100644 --- a/src/zencompute/runners/localrunner.cpp +++ b/src/zencompute/runners/localrunner.cpp @@ -7,14 +7,16 @@ # include <zencore/compactbinary.h> # include <zencore/compactbinarybuilder.h> # include <zencore/compactbinarypackage.h> +# include <zencore/compactbinaryfile.h> # include <zencore/compress.h> # include <zencore/except_fmt.h> # include <zencore/filesystem.h> # include <zencore/fmtutils.h> # include <zencore/iobuffer.h> # include <zencore/iohash.h> -# include <zencore/system.h> # include <zencore/scopeguard.h> +# include <zencore/stream.h> +# include <zencore/system.h> # include <zencore/timer.h> # include <zencore/trace.h> # include <zenstore/cidstore.h> @@ -152,7 +154,7 @@ LocalProcessRunner::CreateNewSandbox() return Path; } -void +bool LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) { ZEN_TRACE_CPU("LocalProcessRunner::RegisterWorker"); @@ -173,6 +175,8 @@ LocalProcessRunner::RegisterWorker(const CbPackage& WorkerPackage) ZEN_INFO("dumped worker '{}' to 'file://{}'", WorkerId, Path); } + + return true; } size_t @@ -301,7 +305,7 @@ LocalProcessRunner::PrepareActionSubmission(Ref<RunnerAction> Action) // Write out action - zen::WriteFile(SandboxPath / "build.action", ActionObj.GetBuffer().AsIoBuffer()); + WriteCompactBinaryObject(SandboxPath / "build.action", ActionObj); // Manifest inputs in sandbox diff --git a/src/zencompute/runners/localrunner.h b/src/zencompute/runners/localrunner.h index 7493e980b..b8cff6826 100644 --- a/src/zencompute/runners/localrunner.h +++ b/src/zencompute/runners/localrunner.h @@ -51,7 +51,7 @@ public: ~LocalProcessRunner(); virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; [[nodiscard]] virtual bool IsHealthy() override { return true; } [[nodiscard]] virtual size_t GetSubmittedActionCount() override; diff --git a/src/zencompute/runners/remotehttprunner.cpp b/src/zencompute/runners/remotehttprunner.cpp index 672636d06..ce6a81173 100644 --- a/src/zencompute/runners/remotehttprunner.cpp +++ b/src/zencompute/runners/remotehttprunner.cpp @@ -42,6 +42,20 @@ RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, , m_Http(m_BaseUrl) , m_InstanceId(Oid::NewOid()) { + // Attempt to connect a WebSocket for push-based completion notifications. + // If the remote doesn't support WS, OnWsClose fires and we fall back to polling. + { + std::string WsUrl = HttpToWsUrl(HostName, "/compute/ws"); + + HttpWsClientSettings WsSettings; + WsSettings.LogCategory = "http_exec_ws"; + WsSettings.ConnectTimeout = std::chrono::milliseconds{3000}; + + IWsClientHandler& Handler = *this; + m_WsClient = std::make_unique<HttpWsClient>(WsUrl, Handler, WsSettings); + m_WsClient->Connect(); + } + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; } @@ -53,7 +67,29 @@ RemoteHttpRunner::~RemoteHttpRunner() void RemoteHttpRunner::Shutdown() { - // TODO: should cleanly drain/cancel pending work + m_AcceptNewActions = false; + + // Close the WebSocket client first, so no more wakeup signals arrive. + if (m_WsClient) + { + m_WsClient->Close(); + } + + // Cancel all known remote queues so the remote side stops scheduling new + // work and cancels in-flight actions belonging to those queues. + + { + std::vector<std::pair<int, Oid>> Queues; + + m_QueueTokenLock.WithSharedLock([&] { Queues.assign(m_RemoteQueueTokens.begin(), m_RemoteQueueTokens.end()); }); + + for (const auto& [QueueId, Token] : Queues) + { + CancelRemoteQueue(QueueId); + } + } + + // Stop the monitor thread so it no longer polls the remote. m_MonitorThreadEnabled = false; m_MonitorThreadEvent.Set(); @@ -61,9 +97,22 @@ RemoteHttpRunner::Shutdown() { m_MonitorThread.join(); } + + // Drain the running map and mark all remaining actions as Failed so the + // scheduler can reschedule or finalize them. + + std::unordered_map<int, HttpRunningAction> Remaining; + + m_RunningLock.WithExclusiveLock([&] { Remaining.swap(m_RemoteRunningMap); }); + + for (auto& [RemoteLsn, HttpAction] : Remaining) + { + ZEN_DEBUG("shutdown: marking remote action LSN {} (local LSN {}) as Failed", RemoteLsn, HttpAction.Action->ActionLsn); + HttpAction.Action->SetActionState(RunnerAction::State::Failed); + } } -void +bool RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) { ZEN_TRACE_CPU("RemoteHttpRunner::RegisterWorker"); @@ -125,15 +174,13 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) { ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error + return false; } } else if (!IsHttpSuccessCode(DescResponse.StatusCode)) { ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); - - // TODO: propagate error + return false; } else { @@ -152,14 +199,20 @@ RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) WorkerUrl, (int)WorkerResponse.StatusCode, ToString(WorkerResponse.StatusCode)); - - // TODO: propagate error + return false; } + + return true; } size_t RemoteHttpRunner::QueryCapacity() { + if (!m_AcceptNewActions) + { + return 0; + } + // Estimate how much more work we're ready to accept RwLock::SharedLockScope _{m_RunningLock}; @@ -191,24 +244,68 @@ RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) return Results; } - // For larger batches, submit HTTP requests in parallel via the shared worker pool + // Collect distinct QueueIds and ensure remote queues exist once per queue - std::vector<std::future<SubmitResult>> Futures; - Futures.reserve(Actions.size()); + std::unordered_map<int, Oid> QueueTokens; // QueueId → remote token (0 stays as Zero) for (const Ref<RunnerAction>& Action : Actions) { - std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + const int QueueId = Action->QueueId; + if (QueueId != 0 && QueueTokens.find(QueueId) == QueueTokens.end()) + { + CbObject QueueMeta = Action->GetOwnerSession()->GetQueueMetadata(QueueId); + CbObject QueueConfig = Action->GetOwnerSession()->GetQueueConfig(QueueId); + QueueTokens[QueueId] = EnsureRemoteQueue(QueueId, QueueMeta, QueueConfig); + } + } - Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + // Group actions by QueueId + + struct QueueGroup + { + std::vector<Ref<RunnerAction>> Actions; + std::vector<size_t> OriginalIndices; + }; + + std::unordered_map<int, QueueGroup> Groups; + + for (size_t i = 0; i < Actions.size(); ++i) + { + auto& Group = Groups[Actions[i]->QueueId]; + Group.Actions.push_back(Actions[i]); + Group.OriginalIndices.push_back(i); } - std::vector<SubmitResult> Results; - Results.reserve(Futures.size()); + // Submit each group as a batch and map results back to original indices - for (auto& Future : Futures) + std::vector<SubmitResult> Results(Actions.size()); + + for (auto& [QueueId, Group] : Groups) { - Results.push_back(Future.get()); + std::string SubmitUrl = "/jobs"; + if (QueueId != 0) + { + if (Oid Token = QueueTokens[QueueId]; Token != Oid::Zero) + { + SubmitUrl = fmt::format("/queues/{}/jobs", Token); + } + } + + const size_t BatchLimit = size_t(m_MaxBatchSize); + + for (size_t Offset = 0; Offset < Group.Actions.size(); Offset += BatchLimit) + { + size_t End = zen::Min(Offset + BatchLimit, Group.Actions.size()); + + std::vector<Ref<RunnerAction>> Chunk(Group.Actions.begin() + Offset, Group.Actions.begin() + End); + + std::vector<SubmitResult> ChunkResults = SubmitActionBatch(SubmitUrl, Chunk); + + for (size_t j = 0; j < ChunkResults.size(); ++j) + { + Results[Group.OriginalIndices[Offset + j]] = std::move(ChunkResults[j]); + } + } } return Results; @@ -221,6 +318,11 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) // Verify whether we can accept more work + if (!m_AcceptNewActions) + { + return SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"}; + } + { RwLock::SharedLockScope _{m_RunningLock}; if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) @@ -275,7 +377,7 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) m_Http.GetBaseUri(), ActionId); - RegisterWorker(Action->Worker.Descriptor); + (void)RegisterWorker(Action->Worker.Descriptor); } else { @@ -384,6 +486,194 @@ RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) return {}; } +std::vector<SubmitResult> +RemoteHttpRunner::SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions) +{ + ZEN_TRACE_CPU("RemoteHttpRunner::SubmitActionBatch"); + + if (!m_AcceptNewActions) + { + return std::vector<SubmitResult>(Actions.size(), SubmitResult{.IsAccepted = false, .Reason = "runner is shutting down"}); + } + + // Capacity check + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + std::vector<SubmitResult> Results(Actions.size(), SubmitResult{.IsAccepted = false}); + return Results; + } + } + + // Per-action setup and build batch body + + CbObjectWriter Body; + Body.BeginArray("actions"sv); + + for (const Ref<RunnerAction>& Action : Actions) + { + Action->ExecutionLocation = m_HostName; + MaybeDumpAction(Action->ActionLsn, Action->ActionObj); + Body.AddObject(Action->ActionObj); + } + + Body.EndArray(); + + // POST the batch + + HttpClient::Response Response = m_Http.Post(SubmitUrl, Body.Save()); + + if (Response.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(Response, Actions); + } + + if (Response.StatusCode == HttpResponseCode::NotFound) + { + // Server needs attachments — resolve them and retry with a CbPackage + + CbObject NeedObj = Response.AsObject(); + + CbPackage Pkg; + Pkg.SetObject(Body.Save()); + + for (auto& Item : NeedObj["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + ZEN_WARN("batch submit: missing attachment {} — falling back to individual submit", NeedHash); + return FallbackToIndividualSubmit(Actions); + } + } + + HttpClient::Response RetryResponse = m_Http.Post(SubmitUrl, Pkg); + + if (RetryResponse.StatusCode == HttpResponseCode::OK) + { + return ParseBatchResponse(RetryResponse, Actions); + } + + ZEN_WARN("batch submit retry failed with {} {} — falling back to individual submit", + (int)RetryResponse.StatusCode, + ToString(RetryResponse.StatusCode)); + return FallbackToIndividualSubmit(Actions); + } + + // Unexpected status or connection error — fall back to individual submission + + if (Response) + { + ZEN_WARN("batch submit to {}{} returned {} {} — falling back to individual submit", + m_Http.GetBaseUri(), + SubmitUrl, + (int)Response.StatusCode, + ToString(Response.StatusCode)); + } + else + { + ZEN_WARN("batch submit to {}{} failed — falling back to individual submit", m_Http.GetBaseUri(), SubmitUrl); + } + + return FallbackToIndividualSubmit(Actions); +} + +std::vector<SubmitResult> +RemoteHttpRunner::ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + Results.reserve(Actions.size()); + + CbObject ResponseObj = Response.AsObject(); + CbArrayView ResultArray = ResponseObj["results"sv].AsArrayView(); + + size_t Index = 0; + for (CbFieldView Field : ResultArray) + { + if (Index >= Actions.size()) + { + break; + } + + CbObjectView Entry = Field.AsObjectView(); + const int32_t LsnField = Entry["lsn"sv].AsInt32(0); + + if (LsnField > 0) + { + HttpRunningAction NewAction; + NewAction.Action = Actions[Index]; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("batch: scheduled action {} with remote LSN {} (local LSN {})", + Actions[Index]->ActionObj.GetHash(), + LsnField, + Actions[Index]->ActionLsn); + + Actions[Index]->SetActionState(RunnerAction::State::Running); + + Results.push_back(SubmitResult{.IsAccepted = true}); + } + else + { + std::string_view ErrorMsg = Entry["error"sv].AsString(); + Results.push_back(SubmitResult{.IsAccepted = false, .Reason = std::string(ErrorMsg)}); + } + + ++Index; + } + + // If the server returned fewer results than actions, mark the rest as not accepted + while (Results.size() < Actions.size()) + { + Results.push_back(SubmitResult{.IsAccepted = false, .Reason = "no result from server"}); + } + + return Results; +} + +std::vector<SubmitResult> +RemoteHttpRunner::FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<std::future<SubmitResult>> Futures; + Futures.reserve(Actions.size()); + + for (const Ref<RunnerAction>& Action : Actions) + { + std::packaged_task<SubmitResult()> Task([this, Action]() { return SubmitAction(Action); }); + + Futures.push_back(m_WorkerPool.EnqueueTask(std::move(Task), WorkerThreadPool::EMode::EnableBacklog)); + } + + std::vector<SubmitResult> Results; + Results.reserve(Futures.size()); + + for (auto& Future : Futures) + { + Results.push_back(Future.get()); + } + + return Results; +} + Oid RemoteHttpRunner::EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config) { @@ -481,6 +771,35 @@ RemoteHttpRunner::GetSubmittedActionCount() return m_RemoteRunningMap.size(); } +////////////////////////////////////////////////////////////////////////// +// +// IWsClientHandler +// + +void +RemoteHttpRunner::OnWsOpen() +{ + ZEN_INFO("WebSocket connected to {}", m_HostName); + m_WsConnected.store(true, std::memory_order_release); +} + +void +RemoteHttpRunner::OnWsMessage([[maybe_unused]] const WebSocketMessage& Msg) +{ + // The message content is a wakeup signal; no parsing needed. + // Signal the monitor thread to sweep completed actions immediately. + m_MonitorThreadEvent.Set(); +} + +void +RemoteHttpRunner::OnWsClose([[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_WARN("WebSocket disconnected from {} (code {})", m_HostName, Code); + m_WsConnected.store(false, std::memory_order_release); +} + +////////////////////////////////////////////////////////////////////////// + void RemoteHttpRunner::MonitorThreadFunction() { @@ -489,28 +808,40 @@ RemoteHttpRunner::MonitorThreadFunction() do { const int NormalWaitingTime = 200; - int WaitTimeMs = NormalWaitingTime; - auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; - auto SweepOnce = [&] { + const int WsWaitingTime = 2000; // Safety-net interval when WS is connected + + int WaitTimeMs = m_WsConnected.load(std::memory_order_relaxed) ? WsWaitingTime : NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { const size_t RetiredCount = SweepRunningActions(); - m_RunningLock.WithSharedLock([&] { - if (m_RemoteRunningMap.size() > 16) - { - WaitTimeMs = NormalWaitingTime / 4; - } - else - { - if (RetiredCount) + if (m_WsConnected.load(std::memory_order_relaxed)) + { + // WS connected: use long safety-net interval; the WS message + // will wake us immediately for the real work. + WaitTimeMs = WsWaitingTime; + } + else + { + // No WS: adaptive polling as before + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) { - WaitTimeMs = NormalWaitingTime / 2; + WaitTimeMs = NormalWaitingTime / 4; } else { - WaitTimeMs = NormalWaitingTime; + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } } - } - }); + }); + } }; while (!WaitOnce()) @@ -518,7 +849,7 @@ RemoteHttpRunner::MonitorThreadFunction() SweepOnce(); } - // Signal received - this may mean we should quit + // Signal received — may be a WS wakeup or a quit signal SweepOnce(); } while (m_MonitorThreadEnabled); diff --git a/src/zencompute/runners/remotehttprunner.h b/src/zencompute/runners/remotehttprunner.h index 9119992a9..c17d0cf2a 100644 --- a/src/zencompute/runners/remotehttprunner.h +++ b/src/zencompute/runners/remotehttprunner.h @@ -14,9 +14,11 @@ # include <zencore/workthreadpool.h> # include <zencore/zencore.h> # include <zenhttp/httpclient.h> +# include <zenhttp/httpwsclient.h> # include <atomic> # include <filesystem> +# include <memory> # include <thread> # include <unordered_map> @@ -32,7 +34,7 @@ namespace zen::compute { */ -class RemoteHttpRunner : public FunctionRunner +class RemoteHttpRunner : public FunctionRunner, private IWsClientHandler { RemoteHttpRunner(RemoteHttpRunner&&) = delete; RemoteHttpRunner& operator=(RemoteHttpRunner&&) = delete; @@ -45,7 +47,7 @@ public: ~RemoteHttpRunner(); virtual void Shutdown() override; - virtual void RegisterWorker(const CbPackage& WorkerPackage) override; + [[nodiscard]] virtual bool RegisterWorker(const CbPackage& WorkerPackage) override; [[nodiscard]] virtual SubmitResult SubmitAction(Ref<RunnerAction> Action) override; [[nodiscard]] virtual bool IsHealthy() override; [[nodiscard]] virtual size_t GetSubmittedActionCount() override; @@ -66,7 +68,9 @@ private: std::string m_BaseUrl; HttpClient m_Http; - int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + std::atomic<bool> m_AcceptNewActions{true}; + int32_t m_MaxRunningActions = 256; // arbitrary limit for testing + int32_t m_MaxBatchSize = 50; struct HttpRunningAction { @@ -92,7 +96,20 @@ private: // creating remote queues. Generated once at construction and never changes. Oid m_InstanceId; + // WebSocket completion notification client + std::unique_ptr<HttpWsClient> m_WsClient; + std::atomic<bool> m_WsConnected{false}; + + // IWsClientHandler + void OnWsOpen() override; + void OnWsMessage(const WebSocketMessage& Msg) override; + void OnWsClose(uint16_t Code, std::string_view Reason) override; + Oid EnsureRemoteQueue(int QueueId, const CbObject& Metadata, const CbObject& Config); + + std::vector<SubmitResult> SubmitActionBatch(const std::string& SubmitUrl, const std::vector<Ref<RunnerAction>>& Actions); + std::vector<SubmitResult> ParseBatchResponse(const HttpClient::Response& Response, const std::vector<Ref<RunnerAction>>& Actions); + std::vector<SubmitResult> FallbackToIndividualSubmit(const std::vector<Ref<RunnerAction>>& Actions); }; } // namespace zen::compute diff --git a/src/zencompute/testing/mockimds.cpp b/src/zencompute/testing/mockimds.cpp index dd09312df..5415f48f3 100644 --- a/src/zencompute/testing/mockimds.cpp +++ b/src/zencompute/testing/mockimds.cpp @@ -1,205 +1,2 @@ // Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencompute/mockimds.h> - -#include <zencore/fmtutils.h> - -#if ZEN_WITH_TESTS - -namespace zen::compute { - -const char* -MockImdsService::BaseUri() const -{ - return "/"; -} - -void -MockImdsService::HandleRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - // AWS endpoints live under /latest/ - if (Uri.starts_with("latest/")) - { - if (ActiveProvider == CloudProvider::AWS) - { - HandleAwsRequest(Request); - return; - } - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - - // Azure endpoints live under /metadata/ - if (Uri.starts_with("metadata/")) - { - if (ActiveProvider == CloudProvider::Azure) - { - HandleAzureRequest(Request); - return; - } - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - - // GCP endpoints live under /computeMetadata/ - if (Uri.starts_with("computeMetadata/")) - { - if (ActiveProvider == CloudProvider::GCP) - { - HandleGcpRequest(Request); - return; - } - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -// --------------------------------------------------------------------------- -// AWS -// --------------------------------------------------------------------------- - -void -MockImdsService::HandleAwsRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - // IMDSv2 token acquisition (PUT only) - if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); - return; - } - - // Instance identity - if (Uri == "latest/meta-data/instance-id") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); - return; - } - - if (Uri == "latest/meta-data/placement/availability-zone") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); - return; - } - - if (Uri == "latest/meta-data/instance-life-cycle") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); - return; - } - - // Autoscaling lifecycle state — 404 when not in an ASG - if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") - { - if (Aws.AutoscalingState.empty()) - { - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); - return; - } - - // Spot interruption notice — 404 when no interruption pending - if (Uri == "latest/meta-data/spot/instance-action") - { - if (Aws.SpotAction.empty()) - { - Request.WriteResponse(HttpResponseCode::NotFound); - return; - } - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -// --------------------------------------------------------------------------- -// Azure -// --------------------------------------------------------------------------- - -void -MockImdsService::HandleAzureRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - // Instance metadata (single JSON document) - if (Uri == "metadata/instance") - { - std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", - Azure.VmId, - Azure.Location, - Azure.Priority, - Azure.VmScaleSetName); - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); - return; - } - - // Scheduled events for termination monitoring - if (Uri == "metadata/scheduledevents") - { - std::string Json; - if (Azure.ScheduledEventType.empty()) - { - Json = R"({"Events":[]})"; - } - else - { - Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", - Azure.ScheduledEventType, - Azure.ScheduledEventStatus); - } - - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -// --------------------------------------------------------------------------- -// GCP -// --------------------------------------------------------------------------- - -void -MockImdsService::HandleGcpRequest(HttpServerRequest& Request) -{ - std::string_view Uri = Request.RelativeUri(); - - if (Uri == "computeMetadata/v1/instance/id") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); - return; - } - - if (Uri == "computeMetadata/v1/instance/zone") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); - return; - } - - if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); - return; - } - - if (Uri == "computeMetadata/v1/instance/maintenance-event") - { - Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); - return; - } - - Request.WriteResponse(HttpResponseCode::NotFound); -} - -} // namespace zen::compute - -#endif // ZEN_WITH_TESTS +// Moved to zenutil/cloud/mockimds.cpp diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp index 9c81305d0..f2c46c2bc 100644 --- a/src/zencore/compactbinary.cpp +++ b/src/zencore/compactbinary.cpp @@ -1752,4 +1752,33 @@ TEST_SUITE_END(); #endif +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace CompactBinaryPrivate { + + ZEN_NOINLINE void ReferenceDebugTypes(DebugCb*, + DebugCbObject*, + DebugCbUniformObject*, + DebugCbArray*, + DebugCbUniformArray*, + DebugCbBinary*, + DebugCbString*, + DebugCbIntegerPositive*, + DebugCbIntegerNegative*, + DebugCbFloat32*, + DebugCbFloat64*, + DebugCbObjectAttachment*, + DebugCbBinaryAttachment*, + DebugCbHash*, + DebugCbUuid*, + DebugCbDateTime*, + DebugCbTimeSpan*, + DebugCbObjectId*, + DebugCbCustomById*, + DebugCbCustomByName*) + { + } + +} // namespace CompactBinaryPrivate + } // namespace zen diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp index a9ba30750..a82ff5594 100644 --- a/src/zencore/compactbinarybuilder.cpp +++ b/src/zencore/compactbinarybuilder.cpp @@ -1449,6 +1449,8 @@ TEST_CASE("usonbuilder.complex") } Writer.EndArray(); + Writer.AddObjectId("Oid"sv, Oid::FromHexString("0102030405060708090a0b0c")); + Writer.EndObject(); Object = Writer.Save().AsObject(); } diff --git a/src/zencore/compactbinaryfile.cpp b/src/zencore/compactbinaryfile.cpp index ec2fc3cd5..9ddafbe15 100644 --- a/src/zencore/compactbinaryfile.cpp +++ b/src/zencore/compactbinaryfile.cpp @@ -30,4 +30,15 @@ LoadCompactBinaryObject(const std::filesystem::path& FilePath) return {.Hash = IoHash::Zero}; } +void +WriteCompactBinaryObject(const std::filesystem::path& Path, const CbObject& Object) +{ + // We cannot use CbObject::GetView() here because it may not return a complete + // view since the type byte can be omitted in arrays. + CbWriter Writer; + Writer.AddObject(Object); + CbFieldIterator Fields = Writer.Save(); + zen::WriteFile(Path, IoBufferBuilder::MakeFromMemory(Fields.GetRangeView())); +} + } // namespace zen diff --git a/src/zencore/include/zencore/compactbinary.h b/src/zencore/include/zencore/compactbinary.h index b128e4205..74f4cdf8d 100644 --- a/src/zencore/include/zencore/compactbinary.h +++ b/src/zencore/include/zencore/compactbinary.h @@ -1530,4 +1530,98 @@ void uson_forcelink(); // internal void cbjson_forcelink(); // internal void cbyaml_forcelink(); // internal +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Types to support debug visualizers (natvis). + +// A field/array/object in memory can be visualized in a debugger supporting natvis by casting to this type. +// (DebugCb*)Pointer works for any field/array/object that starts with its field type. +struct DebugCb +{ +}; + +// Types that visualize a pointer to a value, *not* a pointer to a field. +struct DebugCbObject +{ +}; +struct DebugCbUniformObject +{ +}; +struct DebugCbArray +{ +}; +struct DebugCbUniformArray +{ +}; +struct DebugCbBinary +{ +}; +struct DebugCbString +{ +}; +struct DebugCbIntegerPositive +{ +}; +struct DebugCbIntegerNegative +{ +}; +struct DebugCbFloat32 +{ +}; +struct DebugCbFloat64 +{ +}; +struct DebugCbObjectAttachment +{ +}; +struct DebugCbBinaryAttachment +{ +}; +struct DebugCbHash +{ +}; +struct DebugCbUuid +{ +}; +struct DebugCbDateTime +{ +}; +struct DebugCbTimeSpan +{ +}; +struct DebugCbObjectId +{ +}; +struct DebugCbCustomById +{ +}; +struct DebugCbCustomByName +{ +}; + +namespace CompactBinaryPrivate { + + void ReferenceDebugTypes(DebugCb*, + DebugCbObject*, + DebugCbUniformObject*, + DebugCbArray*, + DebugCbUniformArray*, + DebugCbBinary*, + DebugCbString*, + DebugCbIntegerPositive*, + DebugCbIntegerNegative*, + DebugCbFloat32*, + DebugCbFloat64*, + DebugCbObjectAttachment*, + DebugCbBinaryAttachment*, + DebugCbHash*, + DebugCbUuid*, + DebugCbDateTime*, + DebugCbTimeSpan*, + DebugCbObjectId*, + DebugCbCustomById*, + DebugCbCustomByName*); + +} // namespace CompactBinaryPrivate + } // namespace zen diff --git a/src/zencore/include/zencore/compactbinaryfile.h b/src/zencore/include/zencore/compactbinaryfile.h index 33f3e7bea..a06524549 100644 --- a/src/zencore/include/zencore/compactbinaryfile.h +++ b/src/zencore/include/zencore/compactbinaryfile.h @@ -15,5 +15,6 @@ struct CbObjectFromFile }; CbObjectFromFile LoadCompactBinaryObject(const std::filesystem::path& FilePath); +void WriteCompactBinaryObject(const std::filesystem::path& Path, const CbObject& Object); } // namespace zen diff --git a/src/zencore/include/zencore/logging/ansicolorsink.h b/src/zencore/include/zencore/logging/ansicolorsink.h index 5060a8393..939c70d12 100644 --- a/src/zencore/include/zencore/logging/ansicolorsink.h +++ b/src/zencore/include/zencore/logging/ansicolorsink.h @@ -15,6 +15,9 @@ enum class ColorMode Auto }; +bool IsColorTerminal(); +bool ResolveColorMode(ColorMode Mode); + class AnsiColorStdoutSink : public Sink { public: diff --git a/src/zencore/include/zencore/logging/formatter.h b/src/zencore/include/zencore/logging/formatter.h index 11904d71d..e605b22b8 100644 --- a/src/zencore/include/zencore/logging/formatter.h +++ b/src/zencore/include/zencore/logging/formatter.h @@ -15,6 +15,12 @@ public: virtual ~Formatter() = default; virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) = 0; virtual std::unique_ptr<Formatter> Clone() const = 0; + + void SetColorEnabled(bool Enabled) { m_UseColor = Enabled; } + bool IsColorEnabled() const { return m_UseColor; } + +private: + bool m_UseColor = false; }; } // namespace zen::logging diff --git a/src/zencore/include/zencore/logging/helpers.h b/src/zencore/include/zencore/logging/helpers.h index ce021e1a5..765aa59e3 100644 --- a/src/zencore/include/zencore/logging/helpers.h +++ b/src/zencore/include/zencore/logging/helpers.h @@ -119,4 +119,81 @@ LevelToShortString(LogLevel Level) return ToStringView(Level); } +inline std::string_view +AnsiColorForLevel(LogLevel Level) +{ + using namespace std::string_view_literals; + switch (Level) + { + case Trace: + return "\033[37m"sv; // white + case Debug: + return "\033[36m"sv; // cyan + case Info: + return "\033[32m"sv; // green + case Warn: + return "\033[33m\033[1m"sv; // bold yellow + case Err: + return "\033[31m\033[1m"sv; // bold red + case Critical: + return "\033[1m\033[41m"sv; // bold on red background + default: + return "\033[m"sv; + } +} + +inline constexpr std::string_view kAnsiReset = "\033[m"; + +inline void +AppendAnsiColor(LogLevel Level, MemoryBuffer& Dest) +{ + std::string_view Color = AnsiColorForLevel(Level); + Dest.append(Color.data(), Color.data() + Color.size()); +} + +inline void +AppendAnsiReset(MemoryBuffer& Dest) +{ + Dest.append(kAnsiReset.data(), kAnsiReset.data() + kAnsiReset.size()); +} + +// Strip ANSI SGR escape sequences (\033[...m) from the buffer in-place. +// Only sequences terminated by 'm' are removed (colors, bold, underline, etc.). +// Other CSI sequences (cursor movement, erase, etc.) are left intact. +inline void +StripAnsiSgrSequences(MemoryBuffer& Buf) +{ + const char* Src = Buf.data(); + const char* End = Src + Buf.size(); + char* Dst = Buf.data(); + + while (Src < End) + { + if (Src[0] == '\033' && (Src + 1) < End && Src[1] == '[') + { + const char* Seq = Src + 2; + while (Seq < End && *Seq != 'm') + { + ++Seq; + } + if (Seq < End) + { + ++Seq; // skip 'm' + } + Src = Seq; + } + else + { + if (Dst != Src) + { + *Dst = *Src; + } + ++Dst; + ++Src; + } + } + + Buf.resize(static_cast<size_t>(Dst - Buf.data())); +} + } // namespace zen::logging::helpers diff --git a/src/zencore/include/zencore/logging/logmsg.h b/src/zencore/include/zencore/logging/logmsg.h index 1d8b6b1b7..a1acb503b 100644 --- a/src/zencore/include/zencore/logging/logmsg.h +++ b/src/zencore/include/zencore/logging/logmsg.h @@ -40,9 +40,6 @@ struct LogMessage void SetTime(LogClock::time_point InTime) { m_Time = InTime; } void SetSource(const SourceLocation& InSource) { m_Source = InSource; } - mutable size_t ColorRangeStart = 0; - mutable size_t ColorRangeEnd = 0; - private: static constexpr LogPoint s_DefaultPoints[LogLevelCount] = { {{}, Trace, {}}, diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h index 8410216c4..01356fa00 100644 --- a/src/zencore/include/zencore/testing.h +++ b/src/zencore/include/zencore/testing.h @@ -43,9 +43,8 @@ public: TestRunner(); ~TestRunner(); - void SetDefaultSuiteFilter(const char* Pattern); - int ApplyCommandLine(int Argc, char const* const* Argv); - int Run(); + int ApplyCommandLine(int Argc, char const* const* Argv, const char* DefaultSuiteFilter = nullptr); + int Run(); private: struct Impl; diff --git a/src/zencore/logging/ansicolorsink.cpp b/src/zencore/logging/ansicolorsink.cpp index 540d22359..03aae068a 100644 --- a/src/zencore/logging/ansicolorsink.cpp +++ b/src/zencore/logging/ansicolorsink.cpp @@ -4,12 +4,14 @@ #include <zencore/logging/helpers.h> #include <zencore/logging/messageonlyformatter.h> +#include <zencore/thread.h> + #include <cstdio> #include <cstdlib> -#include <mutex> #if defined(_WIN32) # include <io.h> +# include <zencore/windows.h> # define ZEN_ISATTY _isatty # define ZEN_FILENO _fileno #else @@ -62,188 +64,225 @@ public: Dest.push_back(' '); } - // level (colored range) + // level Dest.push_back('['); - Msg.ColorRangeStart = Dest.size(); + if (IsColorEnabled()) + { + helpers::AppendAnsiColor(Msg.GetLevel(), Dest); + } helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); - Msg.ColorRangeEnd = Dest.size(); + if (IsColorEnabled()) + { + helpers::AppendAnsiReset(Dest); + } Dest.push_back(']'); Dest.push_back(' '); - // message - helpers::AppendStringView(Msg.GetPayload(), Dest); + // message (align continuation lines with the first line) + size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0; + size_t LinePrefixCount = Dest.size() - AnsiBytes; + + auto MsgPayload = Msg.GetPayload(); + auto ItLineBegin = MsgPayload.begin(); + auto ItMessageEnd = MsgPayload.end(); + bool IsFirstLine = true; + + auto ItLineEnd = ItLineBegin; + + auto EmitLine = [&] { + if (IsFirstLine) + { + IsFirstLine = false; + } + else + { + for (size_t i = 0; i < LinePrefixCount; ++i) + { + Dest.push_back(' '); + } + } + helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), Dest); + }; + + while (ItLineEnd != ItMessageEnd) + { + if (*ItLineEnd++ == '\n') + { + EmitLine(); + ItLineBegin = ItLineEnd; + } + } + + if (ItLineBegin != ItMessageEnd) + { + EmitLine(); + } Dest.push_back('\n'); } - std::unique_ptr<Formatter> Clone() const override { return std::make_unique<DefaultConsoleFormatter>(); } + std::unique_ptr<Formatter> Clone() const override + { + auto Copy = std::make_unique<DefaultConsoleFormatter>(); + Copy->SetColorEnabled(IsColorEnabled()); + return Copy; + } private: std::chrono::seconds m_LastLogSecs{0}; std::tm m_CachedLocalTm{}; }; -static constexpr std::string_view s_Reset = "\033[m"; - -static std::string_view -GetColorForLevel(LogLevel InLevel) +bool +IsColorTerminal() { - using namespace std::string_view_literals; - switch (InLevel) + // If stdout is not a TTY, no color + if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0) { - case Trace: - return "\033[37m"sv; // white - case Debug: - return "\033[36m"sv; // cyan - case Info: - return "\033[32m"sv; // green - case Warn: - return "\033[33m\033[1m"sv; // bold yellow - case Err: - return "\033[31m\033[1m"sv; // bold red - case Critical: - return "\033[1m\033[41m"sv; // bold on red background - default: - return s_Reset; + return false; } -} -struct AnsiColorStdoutSink::Impl -{ - explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) {} + // NO_COLOR convention (https://no-color.org/) + if (std::getenv("NO_COLOR") != nullptr) + { + return false; + } - static bool IsColorTerminal() + // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit") + if (std::getenv("COLORTERM") != nullptr) { - // If stdout is not a TTY, no color - if (ZEN_ISATTY(ZEN_FILENO(stdout)) == 0) - { - return false; - } + return true; + } - // NO_COLOR convention (https://no-color.org/) - if (std::getenv("NO_COLOR") != nullptr) + // Check TERM for known color-capable values + const char* Term = std::getenv("TERM"); + if (Term != nullptr) + { + std::string_view TermView(Term); + // "dumb" terminals do not support color + if (TermView == "dumb") { return false; } - - // COLORTERM is set by terminals that support color (e.g. "truecolor", "24bit") - if (std::getenv("COLORTERM") != nullptr) + // Match against known color-capable terminal types. + // TERM often includes suffixes like "-256color", so we use substring matching. + constexpr std::string_view ColorTerms[] = { + "alacritty", + "ansi", + "color", + "console", + "cygwin", + "gnome", + "konsole", + "kterm", + "linux", + "msys", + "putty", + "rxvt", + "screen", + "tmux", + "vt100", + "vt102", + "xterm", + }; + for (std::string_view Candidate : ColorTerms) { - return true; - } - - // Check TERM for known color-capable values - const char* Term = std::getenv("TERM"); - if (Term != nullptr) - { - std::string_view TermView(Term); - // "dumb" terminals do not support color - if (TermView == "dumb") + if (TermView.find(Candidate) != std::string_view::npos) { - return false; - } - // Match against known color-capable terminal types. - // TERM often includes suffixes like "-256color", so we use substring matching. - constexpr std::string_view ColorTerms[] = { - "alacritty", - "ansi", - "color", - "console", - "cygwin", - "gnome", - "konsole", - "kterm", - "linux", - "msys", - "putty", - "rxvt", - "screen", - "tmux", - "vt100", - "vt102", - "xterm", - }; - for (std::string_view Candidate : ColorTerms) - { - if (TermView.find(Candidate) != std::string_view::npos) - { - return true; - } + return true; } } + } #if defined(_WIN32) - // Windows console supports ANSI color by default in modern versions - return true; + // Windows console supports ANSI color by default in modern versions + return true; #else - // Unknown terminal — be conservative - return false; + // Unknown terminal — be conservative + return false; #endif - } +} - static bool ResolveColorMode(ColorMode Mode) +bool +ResolveColorMode(ColorMode Mode) +{ + switch (Mode) { - switch (Mode) - { - case ColorMode::On: - return true; - case ColorMode::Off: - return false; - case ColorMode::Auto: - default: - return IsColorTerminal(); - } + case ColorMode::On: + return true; + case ColorMode::Off: + return false; + case ColorMode::Auto: + default: + return IsColorTerminal(); } +} - void Log(const LogMessage& Msg) +struct AnsiColorStdoutSink::Impl +{ + explicit Impl(ColorMode Mode) : m_Formatter(std::make_unique<DefaultConsoleFormatter>()), m_UseColor(ResolveColorMode(Mode)) { - std::lock_guard<std::mutex> Lock(m_Mutex); - - MemoryBuffer Formatted; - m_Formatter->Format(Msg, Formatted); + m_Formatter->SetColorEnabled(m_UseColor); + } - if (m_UseColor && Msg.ColorRangeEnd > Msg.ColorRangeStart) - { - // Print pre-color range - fwrite(Formatted.data(), 1, Msg.ColorRangeStart, m_File); + void WriteOutput(const MemoryBuffer& Buf) + { + RwLock::ExclusiveLockScope Lock(m_Lock); - // Print color - std::string_view Color = GetColorForLevel(Msg.GetLevel()); - fwrite(Color.data(), 1, Color.size(), m_File); +#if defined(_WIN32) + DWORD Written; + WriteFile(m_Handle, Buf.data(), static_cast<DWORD>(Buf.size()), &Written, nullptr); +#else + fwrite(Buf.data(), 1, Buf.size(), m_File); +#endif - // Print colored range - fwrite(Formatted.data() + Msg.ColorRangeStart, 1, Msg.ColorRangeEnd - Msg.ColorRangeStart, m_File); + m_Dirty.store(false, std::memory_order_relaxed); + } - // Reset color - fwrite(s_Reset.data(), 1, s_Reset.size(), m_File); + void Log(const LogMessage& Msg) + { + MemoryBuffer Formatted; + m_Formatter->Format(Msg, Formatted); - // Print remainder - fwrite(Formatted.data() + Msg.ColorRangeEnd, 1, Formatted.size() - Msg.ColorRangeEnd, m_File); - } - else + if (!m_UseColor) { - fwrite(Formatted.data(), 1, Formatted.size(), m_File); + helpers::StripAnsiSgrSequences(Formatted); } - fflush(m_File); + WriteOutput(Formatted); } void Flush() { - std::lock_guard<std::mutex> Lock(m_Mutex); + if (!m_Dirty.load(std::memory_order_relaxed)) + { + return; + } + RwLock::ExclusiveLockScope Lock(m_Lock); + m_Dirty.store(false, std::memory_order_relaxed); +#if defined(_WIN32) + FlushFileBuffers(m_Handle); +#else fflush(m_File); +#endif } void SetFormatter(std::unique_ptr<Formatter> InFormatter) { - std::lock_guard<std::mutex> Lock(m_Mutex); + RwLock::ExclusiveLockScope Lock(m_Lock); + InFormatter->SetColorEnabled(m_UseColor); m_Formatter = std::move(InFormatter); } private: - std::mutex m_Mutex; + RwLock m_Lock; std::unique_ptr<Formatter> m_Formatter; - FILE* m_File = stdout; - bool m_UseColor = true; +#if defined(_WIN32) + HANDLE m_Handle = GetStdHandle(STD_OUTPUT_HANDLE); +#else + FILE* m_File = stdout; +#endif + bool m_UseColor = true; + std::atomic<bool> m_Dirty = false; }; AnsiColorStdoutSink::AnsiColorStdoutSink(ColorMode Mode) : m_Impl(std::make_unique<Impl>(Mode)) diff --git a/src/zencore/process.cpp b/src/zencore/process.cpp index 080607f13..0c55e6c7e 100644 --- a/src/zencore/process.cpp +++ b/src/zencore/process.cpp @@ -492,8 +492,10 @@ CreateProcNormal(const std::filesystem::path& Executable, std::string_view Comma LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; // Build environment block when custom environment variables are specified + ExtendableWideStringBuilder<512> EnvironmentBlock; void* Environment = nullptr; + if (!Options.Environment.empty()) { // Capture current environment into a map diff --git a/src/zencore/sentryintegration.cpp b/src/zencore/sentryintegration.cpp index b7d01003b..634045cfb 100644 --- a/src/zencore/sentryintegration.cpp +++ b/src/zencore/sentryintegration.cpp @@ -248,6 +248,14 @@ SentryIntegration::Initialize(const Config& Conf, const std::string& CommandLine } sentry_options_t* SentryOptions = sentry_options_new(); + if (SentryOptions == nullptr) + { + // OOM — skip sentry entirely rather than crashing on the subsequent set calls + m_SentryErrorCode = -1; + m_IsInitialized = true; + return; + } + sentry_options_set_dsn(SentryOptions, Conf.Dsn.empty() ? sentry::DefaultDsn.c_str() : Conf.Dsn.c_str()); sentry_options_set_database_path(SentryOptions, SentryDatabasePath.c_str()); sentry_options_set_logger(SentryOptions, SentryLogFunction, this); diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp index f5bc723b1..c6ee5ee6b 100644 --- a/src/zencore/testing.cpp +++ b/src/zencore/testing.cpp @@ -181,6 +181,15 @@ struct TestListener : public doctest::IReporter void test_case_start(const doctest::TestCaseData& in) override { Current = ∈ + + if (in.m_test_suite && in.m_test_suite != CurrentSuite) + { + CurrentSuite = in.m_test_suite; + ZEN_CONSOLE("{}==============================================================================={}", ColorYellow, ColorNone); + ZEN_CONSOLE("{} TEST_SUITE: {}{}", ColorYellow, CurrentSuite, ColorNone); + ZEN_CONSOLE("{}==============================================================================={}", ColorYellow, ColorNone); + } + ZEN_CONSOLE("{}======== TEST_CASE: {:<50} ========{}", ColorYellow, Current->m_name, ColorNone); } @@ -217,8 +226,9 @@ struct TestListener : public doctest::IReporter void test_case_skipped(const doctest::TestCaseData& /*in*/) override {} - const doctest::TestCaseData* Current = nullptr; - std::chrono::steady_clock::time_point RunStart = {}; + const doctest::TestCaseData* Current = nullptr; + std::string_view CurrentSuite = {}; + std::chrono::steady_clock::time_point RunStart = {}; struct FailedTestInfo { @@ -244,15 +254,29 @@ TestRunner::~TestRunner() { } -void -TestRunner::SetDefaultSuiteFilter(const char* Pattern) -{ - m_Impl->Session.setOption("test-suite", Pattern); -} - int -TestRunner::ApplyCommandLine(int Argc, char const* const* Argv) +TestRunner::ApplyCommandLine(int Argc, char const* const* Argv, const char* DefaultSuiteFilter) { + // Apply the default suite filter only when the command line doesn't provide + // an explicit --test-suite / --ts override. + if (DefaultSuiteFilter) + { + bool HasExplicitSuiteFilter = false; + for (int i = 1; i < Argc; ++i) + { + std::string_view Arg = Argv[i]; + if (Arg.starts_with("--test-suite=") || Arg.starts_with("--ts=") || Arg.starts_with("-test-suite=") || Arg.starts_with("-ts=")) + { + HasExplicitSuiteFilter = true; + break; + } + } + if (!HasExplicitSuiteFilter) + { + m_Impl->Session.setOption("test-suite", DefaultSuiteFilter); + } + } + m_Impl->Session.applyCommandLine(Argc, Argv); for (int i = 1; i < Argc; ++i) @@ -316,6 +340,7 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink TestRunner Runner; // Derive default suite filter from ExecutableName: "zencore-test" -> "core.*" + std::string DefaultSuiteFilter; if (ExecutableName) { std::string_view Name = ExecutableName; @@ -329,13 +354,12 @@ RunTestMain(int Argc, char* Argv[], const char* ExecutableName, void (*ForceLink } if (!Name.empty()) { - std::string Filter(Name); - Filter += ".*"; - Runner.SetDefaultSuiteFilter(Filter.c_str()); + DefaultSuiteFilter = Name; + DefaultSuiteFilter += ".*"; } } - Runner.ApplyCommandLine(Argc, Argv); + Runner.ApplyCommandLine(Argc, Argv, DefaultSuiteFilter.empty() ? nullptr : DefaultSuiteFilter.c_str()); return Runner.Run(); } diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua index b08975df1..fe12c14e8 100644 --- a/src/zencore/xmake.lua +++ b/src/zencore/xmake.lua @@ -16,6 +16,7 @@ target('zencore') add_files("**.cpp") add_files("trace.cpp", {unity_ignored = true }) add_files("testing.cpp", {unity_ignored = true }) + add_extrafiles("zencore.natvis") if has_config("zenrpmalloc") then add_deps("rpmalloc") diff --git a/src/zencore/zencore.natvis b/src/zencore/zencore.natvis new file mode 100644 index 000000000..e2da28351 --- /dev/null +++ b/src/zencore/zencore.natvis @@ -0,0 +1,874 @@ +<?xml version="1.0" encoding="utf-8"?> +<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010"> + + <!-- Measure the length in bytes (1-9) of an encoded variable-length integer. See varint.h. --> + <Intrinsic Name="DebugMeasureVarUInt" Expression="(uint32_t)( + (*Data < 0x80) ? 1 : (*Data < 0xc0) ? 2 : (*Data < 0xe0) ? 3 : (*Data < 0xf0) ? 4 : + (*Data < 0xf8) ? 5 : (*Data < 0xfc) ? 6 : (*Data < 0xfe) ? 7 : (*Data == 0xfe) ? 8 : 9)"> + <Parameter Name="Data" Type="const uint8_t*"/> + </Intrinsic> + + <!-- Measure the number of bytes (1-9) required to encode an integer. See varint.h. --> + <Intrinsic Name="DebugMeasureVarUInt" Expression="(uint32_t)( + (Value < 0x000000000000007f) ? 1 : (Value < 0x0000000000003fff) ? 2 : + (Value < 0x00000000001fffff) ? 3 : (Value < 0x000000000fffffff) ? 4 : + (Value < 0x00000007ffffffff) ? 5 : (Value < 0x000003ffffffffff) ? 6 : + (Value < 0x0001ffffffffffff) ? 7 : (Value < 0x00ffffffffffffff) ? 8 : 9)"> + <Parameter Name="Value" Type="uint64_t"/> + </Intrinsic> + + <!-- Read a variable-length unsigned integer. See varint.h. --> + <Intrinsic Name="DebugReadVarUInt" Expression="(uint64_t)( + (*Data < 0x80) ? ((((uint64_t)Data[0]))) : + (*Data < 0xc0) ? ((((uint64_t)Data[0] & 0x3f) << 8) | ((uint64_t)Data[1])) : + (*Data < 0xe0) ? ((((uint64_t)Data[0] & 0x1f) << 16) | ((uint64_t)Data[1] << 8) | ((uint64_t)Data[2])) : + (*Data < 0xf0) ? ((((uint64_t)Data[0] & 0x0f) << 24) | ((uint64_t)Data[1] << 16) | ((uint64_t)Data[2] << 8) | ((uint64_t)Data[3])) : + (*Data < 0xf8) ? ((((uint64_t)Data[0] & 0x07) << 32) | ((uint64_t)Data[1] << 24) | ((uint64_t)Data[2] << 16) | ((uint64_t)Data[3] << 8) | ((uint64_t)Data[4])) : + (*Data < 0xfc) ? ((((uint64_t)Data[0] & 0x03) << 40) | ((uint64_t)Data[1] << 32) | ((uint64_t)Data[2] << 24) | ((uint64_t)Data[3] << 16) | ((uint64_t)Data[4] << 8) | ((uint64_t)Data[5])) : + (*Data < 0xfe) ? ((((uint64_t)Data[0] & 0x01) << 48) | ((uint64_t)Data[1] << 40) | ((uint64_t)Data[2] << 32) | ((uint64_t)Data[3] << 24) | ((uint64_t)Data[4] << 16) | ((uint64_t)Data[5] << 8) | ((uint64_t)Data[6])) : + (*Data < 0xff) ? ( ((uint64_t)Data[1] << 48) | ((uint64_t)Data[2] << 40) | ((uint64_t)Data[3] << 32) | ((uint64_t)Data[4] << 24) | ((uint64_t)Data[5] << 16) | ((uint64_t)Data[6] << 8) | ((uint64_t)Data[7])) : + ( ((uint64_t)Data[1] << 56) | ((uint64_t)Data[2] << 48) | ((uint64_t)Data[3] << 40) | ((uint64_t)Data[4] << 32) | ((uint64_t)Data[5] << 24) | ((uint64_t)Data[6] << 16) | ((uint64_t)Data[7] << 8) | ((uint64_t)Data[8])))"> + <Parameter Name="Data" Type="const uint8_t*"/> + </Intrinsic> + + <!-- Exact 2^Exp for Exp in the range [-1023, 1023], covering both float32 and float64. --> + <Intrinsic Name="Pow2i" Expression="(double)( + (Exp == 0) ? 1.0 : + (Exp > 0) + ? (((Exp & 1) ? 2.0 : 1.0) * + ((Exp & 2) ? 4.0 : 1.0) * + ((Exp & 4) ? 16.0 : 1.0) * + ((Exp & 8) ? 256.0 : 1.0) * + ((Exp & 16) ? 65536.0 : 1.0) * + ((Exp & 32) ? 4294967296.0 : 1.0) * + ((Exp & 64) ? 18446744073709551616.0 : 1.0) * + ((Exp & 128) ? 3.4028236692093846346337460743177e+38 : 1.0) * + ((Exp & 256) ? 1.1579208923731619542357098500869e+77 : 1.0) * + ((Exp & 512) ? 1.3407807929942597099574024998206e+154 : 1.0)) + : ((((-Exp) & 1) ? 0.5 : 1.0) * + (((-Exp) & 2) ? 0.25 : 1.0) * + (((-Exp) & 4) ? 0.0625 : 1.0) * + (((-Exp) & 8) ? 0.00390625 : 1.0) * + (((-Exp) & 16) ? 0.0000152587890625 : 1.0) * + (((-Exp) & 32) ? 2.3283064365386963e-10 : 1.0) * + (((-Exp) & 64) ? 5.421010862427522e-20 : 1.0) * + (((-Exp) & 128) ? 2.9387358770557187699218413430556e-39 : 1.0) * + (((-Exp) & 256) ? 8.6361685550944446253863518628004e-78 : 1.0) * + (((-Exp) & 512) ? 7.4583407312002067432909653154629e-155 : 1.0)))"> + <Parameter Name="Exp" Type="int32_t"/> + </Intrinsic> + + <Intrinsic Name="ByteSwap32" Expression="(uint32_t)(((Value & 0x000000ff) << 24) | ((Value & 0x0000ff00) << 8) | ((Value & 0x00ff0000) >> 8) | ((Value & 0xff000000) >> 24))"> + <Parameter Name="Value" Type="uint32_t"/> + </Intrinsic> + <Intrinsic Name="ByteSwap64" Expression="(uint64_t)(((uint64_t)ByteSwap32((uint32_t)Value) << 32) | (uint64_t)ByteSwap32((uint32_t)(Value >> 32)))"> + <Parameter Name="Value" Type="uint64_t"/> + </Intrinsic> + + <Intrinsic Name="Float32ExpBits" Expression="(uint32_t)((Value >> 23) & 0xff)"> + <Parameter Name="Value" Type="uint32_t"/> + </Intrinsic> + <Intrinsic Name="Float32FracBits" Expression="(uint32_t)(Value & 0x007fffff)"> + <Parameter Name="Value" Type="uint32_t"/> + </Intrinsic> + <Intrinsic Name="Float32Sign" Expression="(double)((Value & 0x80000000) ? -1.0 : 1.0)"> + <Parameter Name="Value" Type="uint32_t"/> + </Intrinsic> + + <Intrinsic Name="Float32FromBits" Expression="(float)( + ((Value & 0x7fffffff) == 0) ? ((Value & 0x80000000) ? -0.0 : 0.0) : + (Float32ExpBits(Value) == 0xff) + ? ((Float32FracBits(Value) != 0) ? (0.0 / 0.0) : ((Value & 0x80000000U) ? (-1.0 / 0.0) : (1.0 / 0.0))) + : (Float32Sign(Value) * + ((Float32ExpBits(Value) == 0) + ? ((float)Float32FracBits(Value) * 1.1920928955078125e-7 * Pow2i(-126)) + : ((1.0 + (float)Float32FracBits(Value) * 1.1920928955078125e-7) * Pow2i((int)Float32ExpBits(Value) - 127)) + )))"> + <Parameter Name="Value" Type="uint32_t"/> + </Intrinsic> + + <Intrinsic Name="Float64ExpBits" Expression="(uint32_t)((Value >> 52) & 0x7ff)"> + <Parameter Name="Value" Type="uint64_t"/> + </Intrinsic> + <Intrinsic Name="Float64FracBits" Expression="(uint64_t)(Value & 0x000fffffffffffff)"> + <Parameter Name="Value" Type="uint64_t"/> + </Intrinsic> + <Intrinsic Name="Float64Sign" Expression="(double)((Value & 0x8000000000000000) ? -1.0 : 1.0)"> + <Parameter Name="Value" Type="uint64_t"/> + </Intrinsic> + + <Intrinsic Name="Float64FromBits" Expression="(double)( + ((Value & 0x7fffffffffffffff) == 0) ? ((Value & 0x8000000000000000) ? -0.0 : 0.0) : + (Float64ExpBits(Value) == 0x7ff) + ? ((Float64FracBits(Value) != 0) ? (0.0 / 0.0) : ((Value & 0x8000000000000000) ? (-1.0 / 0.0) : (1.0 / 0.0))) + : (Float64Sign(Value) * + ((Float64ExpBits(Value) == 0) + ? ((double)Float64FracBits(Value) * 2.220446049250313e-16 * Pow2i(-1022)) + : ((1.0 + (double)Float64FracBits(Value) * 2.220446049250313e-16) * Pow2i((int)Float64ExpBits(Value) - 1023)) + )))"> + <Parameter Name="Value" Type="uint64_t"/> + </Intrinsic> + + <!-- Measure size in bytes of a compact binary field value. See CbFieldView::GetPayloadSize() in compactbinary.cpp. --> + <Intrinsic Name="MeasureCompactBinaryValue" Expression="(uint64_t)( + (Type == 0x00 || Type == 0x01) ? 0 : + (Type >= 0x02 && Type <= 0x07) ? (DebugMeasureVarUInt(Value) + DebugReadVarUInt(Value)) : + (Type == 0x08 || Type == 0x09) ? DebugMeasureVarUInt(Value) : + (Type == 0x0a) ? 4 : + (Type == 0x0b) ? 8 : + (Type == 0x0c || Type == 0x0d) ? 0 : + (Type == 0x0e || Type == 0x0f || Type == 0x10) ? 20 : + (Type == 0x11) ? 16 : + (Type == 0x12 || Type == 0x13) ? 8 : + (Type == 0x14) ? 12 : + (Type == 0x1e || Type == 0x1f) ? (DebugMeasureVarUInt(Value) + DebugReadVarUInt(Value)) : 0)"> + <Parameter Name="Value" Type="const uint8_t*"/> + <Parameter Name="Type" Type="uint8_t"/> + </Intrinsic> + + <Type Name="zen::CbFieldView"> + <Intrinsic Name="TypeValue" Expression="(uint8_t)((uint8_t)Type & 0x1f)"/> + <Expand> + <Item Name="[Value]" Condition="TypeValue() == 0x01">nullptr</Item> + <ExpandedItem Condition="TypeValue() == 0x02">*(zen::DebugCbObject*)Payload</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x03">*(zen::DebugCbUniformObject*)Payload</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x04">*(zen::DebugCbArray*)Payload</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x05">*(zen::DebugCbUniformArray*)Payload</ExpandedItem> + <Item Name="[Value]" Condition="TypeValue() == 0x06">*(zen::DebugCbBinary*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x07">*(zen::DebugCbString*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x08">*(zen::DebugCbIntegerPositive*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x09">*(zen::DebugCbIntegerNegative*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0a">*(zen::DebugCbFloat32*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0b">*(zen::DebugCbFloat64*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0c">false</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0d">true</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0e">*(zen::DebugCbObjectAttachment*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0f">*(zen::DebugCbBinaryAttachment*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x10">*(zen::DebugCbHash*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x11">*(zen::DebugCbUuid*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x12">*(zen::DebugCbDateTime*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x13">*(zen::DebugCbTimeSpan*)Payload</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x14">*(zen::DebugCbObjectId*)Payload</Item> + <ExpandedItem Condition="TypeValue() == 0x1e">*(zen::DebugCbCustomById*)Payload</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x1f">*(zen::DebugCbCustomByName*)Payload</ExpandedItem> + + <Item Name="[Name]" Condition="((uint8_t)Type & 0x80) != 0">*(zen::DebugCbString*)((const uint8_t*)Payload - NameLen - DebugMeasureVarUInt((uint64_t)NameLen))</Item> + <Item Name="[Type]">(zen::CbFieldType)TypeValue()</Item> + <Item Name="[Size]" Condition="TypeValue() != 0">1 + (((uint8_t)Type & 0x80) ? NameLen + DebugMeasureVarUInt((uint64_t)NameLen) : 0) + MeasureCompactBinaryValue((const uint8_t*)Payload, TypeValue())</Item> + <Item Name="[Error]" Condition="Error != zen::CbFieldError::None">Error</Item> + </Expand> + </Type> + + <Type Name="zen::DebugCb"> + <Intrinsic Name="TypeWithFlags" Expression="*(const uint8_t*)this"/> + <Intrinsic Name="TypeValue" Expression="(uint8_t)(TypeWithFlags() & 0x1f)"/> + <Intrinsic Name="HasName" Expression="(TypeWithFlags() & 0x80) != 0"/> + <Intrinsic Name="NameLenByteCount" Expression="HasName() ? DebugMeasureVarUInt((const uint8_t*)this + 1) : 0"/> + <Intrinsic Name="NameLen" Expression="HasName() ? DebugReadVarUInt((const uint8_t*)this + 1) : 0"/> + <Intrinsic Name="Value" Expression="(const uint8_t*)this + 1 + NameLenByteCount() + NameLen()"/> + <Expand HideRawView="true"> + <Item Name="[Value]" Condition="TypeValue() == 0x01">nullptr</Item> + <ExpandedItem Condition="TypeValue() == 0x02">*(zen::DebugCbObject*)Value()</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x03">*(zen::DebugCbUniformObject*)Value()</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x04">*(zen::DebugCbArray*)Value()</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x05">*(zen::DebugCbUniformArray*)Value()</ExpandedItem> + <Item Name="[Value]" Condition="TypeValue() == 0x06">*(zen::DebugCbBinary*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x07">*(zen::DebugCbString*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x08">*(zen::DebugCbIntegerPositive*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x09">*(zen::DebugCbIntegerNegative*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0a">*(zen::DebugCbFloat32*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0b">*(zen::DebugCbFloat64*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0c">false</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0d">true</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0e">*(zen::DebugCbObjectAttachment*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x0f">*(zen::DebugCbBinaryAttachment*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x10">*(zen::DebugCbHash*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x11">*(zen::DebugCbUuid*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x12">*(zen::DebugCbDateTime*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x13">*(zen::DebugCbTimeSpan*)Value()</Item> + <Item Name="[Value]" Condition="TypeValue() == 0x14">*(zen::DebugCbObjectId*)Value()</Item> + <ExpandedItem Condition="TypeValue() == 0x1e">*(zen::DebugCbCustomById*)Value()</ExpandedItem> + <ExpandedItem Condition="TypeValue() == 0x1f">*(zen::DebugCbCustomByName*)Value()</ExpandedItem> + + <Item Name="[Name]" Condition="HasName()">*(zen::DebugCbString*)((const uint8_t*)this + 1)</Item> + <Item Name="[Type]">(zen::CbFieldType)TypeValue()</Item> + <Item Name="[Size]">1 + NameLenByteCount() + NameLen() + MeasureCompactBinaryValue(Value(), TypeValue())</Item> + </Expand> + </Type> + + <Type Name="zen::DebugCbObject"> + <Intrinsic Name="ValueSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="ValueSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Expand HideRawView="true"> + <CustomListItems MaxItemsPerView="100"> + <Variable Name="Field" InitialValue="(const uint8_t*)this + ValueSizeByteCount()"/> + <Variable Name="FieldsEnd" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + ValueSize()"/> + <Variable Name="FieldType" InitialValue="(uint8_t)0"/> + <Variable Name="NameLenByteCount" InitialValue="(uint32_t)0"/> + <Variable Name="NameLen" InitialValue="(uint32_t)0"/> + <Variable Name="Name" InitialValue="(zen::DebugCbString*)nullptr"/> + <Variable Name="Value" InitialValue="(const uint8_t*)nullptr"/> + <Variable Name="ValueSize" InitialValue="(uint64_t)0"/> + <Loop Condition="Field < FieldsEnd"> + <!-- Decode --> + <Exec>FieldType = *Field & 0x1f</Exec> + <Exec>NameLenByteCount = DebugMeasureVarUInt(Field + 1)</Exec> + <Exec>NameLen = NameLenByteCount ? (uint32_t)DebugReadVarUInt(Field + 1) : 0</Exec> + <Exec>Name = (zen::DebugCbString*)(Field + 1)</Exec> + <Exec>Value = Field + 1 + NameLenByteCount + NameLen</Exec> + <Exec>ValueSize = MeasureCompactBinaryValue(Value, FieldType)</Exec> + <!-- Display --> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x01">nullptr</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x02">*(zen::DebugCbObject*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x03">*(zen::DebugCbUniformObject*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x04">*(zen::DebugCbArray*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x05">*(zen::DebugCbUniformArray*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x06">*(zen::DebugCbBinary*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x07">*(zen::DebugCbString*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x08">*(zen::DebugCbIntegerPositive*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x09">*(zen::DebugCbIntegerNegative*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x0a">*(zen::DebugCbFloat32*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x0b">*(zen::DebugCbFloat64*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x0c">false</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x0d">true</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x0e">*(zen::DebugCbObjectAttachment*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x0f">*(zen::DebugCbBinaryAttachment*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x10">*(zen::DebugCbHash*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x11">*(zen::DebugCbUuid*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x12">*(zen::DebugCbDateTime*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x13">*(zen::DebugCbTimeSpan*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x14">*(zen::DebugCbObjectId*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x1e">*(zen::DebugCbCustomById*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType == 0x1f">*(zen::DebugCbCustomByName*)Value</Item> + <!-- Advance --> + <Exec>Field = Value + ValueSize</Exec> + </Loop> + </CustomListItems> + </Expand> + </Type> + + <Type Name="zen::DebugCbUniformObject"> + <Intrinsic Name="ValueSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="ValueSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="FieldType" Expression="(uint8_t)*((const uint8_t*)this + ValueSizeByteCount()) & 0x1f"/> + <Expand HideRawView="true"> + <CustomListItems MaxItemsPerView="100"> + <Variable Name="Field" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + 1"/> + <Variable Name="FieldsEnd" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + ValueSize()"/> + <Variable Name="NameLenByteCount" InitialValue="(uint32_t)0"/> + <Variable Name="NameLen" InitialValue="(uint32_t)0"/> + <Variable Name="Name" InitialValue="(zen::DebugCbString*)nullptr"/> + <Variable Name="Value" InitialValue="(const uint8_t*)nullptr"/> + <Variable Name="ValueSize" InitialValue="(uint64_t)0"/> + <Loop Condition="Field < FieldsEnd"> + <!-- Decode --> + <Exec>NameLenByteCount = DebugMeasureVarUInt(Field)</Exec> + <Exec>NameLen = NameLenByteCount ? (uint32_t)DebugReadVarUInt(Field) : 0</Exec> + <Exec>Name = (zen::DebugCbString*)Field</Exec> + <Exec>Value = Field + NameLenByteCount + NameLen</Exec> + <Exec>ValueSize = MeasureCompactBinaryValue(Value, FieldType())</Exec> + <!-- Display --> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x01">nullptr</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x02">*(zen::DebugCbObject*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x03">*(zen::DebugCbUniformObject*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x04">*(zen::DebugCbArray*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x05">*(zen::DebugCbUniformArray*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x06">*(zen::DebugCbBinary*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x07">*(zen::DebugCbString*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x08">*(zen::DebugCbIntegerPositive*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x09">*(zen::DebugCbIntegerNegative*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x0a">*(zen::DebugCbFloat32*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x0b">*(zen::DebugCbFloat64*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x0c">false</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x0d">true</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x0e">*(zen::DebugCbObjectAttachment*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x0f">*(zen::DebugCbBinaryAttachment*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x10">*(zen::DebugCbHash*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x11">*(zen::DebugCbUuid*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x12">*(zen::DebugCbDateTime*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x13">*(zen::DebugCbTimeSpan*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x14">*(zen::DebugCbObjectId*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x1e">*(zen::DebugCbCustomById*)Value</Item> + <Item Name="{*Name,s8b}" Condition="FieldType() == 0x1f">*(zen::DebugCbCustomByName*)Value</Item> + <!-- Advance --> + <Exec>Field = Value + ValueSize</Exec> + </Loop> + </CustomListItems> + </Expand> + </Type> + + <Type Name="zen::DebugCbArray"> + <Intrinsic Name="ValueSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="ValueSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="FieldCountByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this + ValueSizeByteCount())"/> + <Intrinsic Name="FieldCount" Expression="DebugReadVarUInt((const uint8_t*)this + ValueSizeByteCount())"/> + <Expand HideRawView="true"> + <Item Name="[Count]">FieldCount()</Item> + <CustomListItems MaxItemsPerView="100"> + <Variable Name="Field" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + FieldCountByteCount()"/> + <Variable Name="FieldsEnd" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + ValueSize()"/> + <Variable Name="FieldType" InitialValue="(uint8_t)0"/> + <Variable Name="Value" InitialValue="(const uint8_t*)nullptr"/> + <Variable Name="ValueSize" InitialValue="(uint64_t)0"/> + <Variable Name="Index" InitialValue="0"/> + <Loop Condition="Field < FieldsEnd"> + <!-- Decode --> + <Exec>FieldType = *Field & 0x1f</Exec> + <Exec>Value = Field + 1</Exec> + <Exec>ValueSize = MeasureCompactBinaryValue(Value, FieldType)</Exec> + <!-- Display --> + <Item Name="[{Index}]" Condition="FieldType == 0x01">nullptr</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x02">*(zen::DebugCbObject*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x03">*(zen::DebugCbUniformObject*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x04">*(zen::DebugCbArray*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x05">*(zen::DebugCbUniformArray*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x06">*(zen::DebugCbBinary*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x07">*(zen::DebugCbString*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x08">*(zen::DebugCbIntegerPositive*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x09">*(zen::DebugCbIntegerNegative*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x0a">*(zen::DebugCbFloat32*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x0b">*(zen::DebugCbFloat64*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x0c">false</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x0d">true</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x0e">*(zen::DebugCbObjectAttachment*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x0f">*(zen::DebugCbBinaryAttachment*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x10">*(zen::DebugCbHash*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x11">*(zen::DebugCbUuid*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x12">*(zen::DebugCbDateTime*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x13">*(zen::DebugCbTimeSpan*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x14">*(zen::DebugCbObjectId*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x1e">*(zen::DebugCbCustomById*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType == 0x1f">*(zen::DebugCbCustomByName*)Value</Item> + <!-- Advance --> + <Exec>Field = Value + ValueSize</Exec> + <Exec>++Index</Exec> + </Loop> + </CustomListItems> + </Expand> + </Type> + + <Type Name="zen::DebugCbUniformArray"> + <Intrinsic Name="ValueSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="ValueSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="FieldCountByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this + ValueSizeByteCount())"/> + <Intrinsic Name="FieldCount" Expression="DebugReadVarUInt((const uint8_t*)this + ValueSizeByteCount())"/> + <Intrinsic Name="FieldType" Expression="(uint8_t)*((const uint8_t*)this + ValueSizeByteCount() + FieldCountByteCount()) & 0x1f"/> + <Expand HideRawView="true"> + <Item Name="[Count]">FieldCount()</Item> + <CustomListItems MaxItemsPerView="100"> + <Variable Name="Field" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + FieldCountByteCount() + 1"/> + <Variable Name="FieldsEnd" InitialValue="(const uint8_t*)this + ValueSizeByteCount() + ValueSize()"/> + <Variable Name="Value" InitialValue="(const uint8_t*)nullptr"/> + <Variable Name="ValueSize" InitialValue="(uint64_t)0"/> + <Variable Name="Index" InitialValue="0"/> + <Loop Condition="Field < FieldsEnd"> + <!-- Decode --> + <Exec>Value = Field</Exec> + <Exec>ValueSize = MeasureCompactBinaryValue(Value, FieldType())</Exec> + <!-- Display --> + <Item Name="[{Index}]" Condition="FieldType() == 0x01">nullptr</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x02">*(zen::DebugCbObject*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x03">*(zen::DebugCbUniformObject*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x04">*(zen::DebugCbArray*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x05">*(zen::DebugCbUniformArray*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x06">*(zen::DebugCbBinary*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x07">*(zen::DebugCbString*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x08">*(zen::DebugCbIntegerPositive*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x09">*(zen::DebugCbIntegerNegative*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x0a">*(zen::DebugCbFloat32*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x0b">*(zen::DebugCbFloat64*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x0c">false</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x0d">true</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x0e">*(zen::DebugCbObjectAttachment*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x0f">*(zen::DebugCbBinaryAttachment*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x10">*(zen::DebugCbHash*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x11">*(zen::DebugCbUuid*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x12">*(zen::DebugCbDateTime*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x13">*(zen::DebugCbTimeSpan*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x14">*(zen::DebugCbObjectId*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x1e">*(zen::DebugCbCustomById*)Value</Item> + <Item Name="[{Index}]" Condition="FieldType() == 0x1f">*(zen::DebugCbCustomByName*)Value</Item> + <!-- Advance --> + <Exec>Field = Value + ValueSize</Exec> + <Exec>++Index</Exec> + </Loop> + </CustomListItems> + </Expand> + </Type> + + <Type Name="zen::DebugCbBinary"> + <Intrinsic Name="ValueSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="ValueSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Expand HideRawView="true"> + <Item Name="[Size]">ValueSize()</Item> + <IndexListItems> + <Size>ValueSize()</Size> + <ValueNode>((const uint8_t*)this + ValueSizeByteCount())[$i]</ValueNode> + </IndexListItems> + </Expand> + </Type> + + <Type Name="zen::DebugCbString"> + <Intrinsic Name="ValueLenByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="ValueLen" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <DisplayString>{((const char8_t*)this + ValueLenByteCount()),[ValueLen()]s8}</DisplayString> + <StringView>((const char8_t*)this + ValueLenByteCount()),[ValueLen()]s8</StringView> + <Expand HideRawView="true"> + <Item Name="[Size]">ValueLen()</Item> + <IndexListItems> + <Size>ValueLen()</Size> + <ValueNode>((const char8_t*)this + ValueLenByteCount())[$i]</ValueNode> + </IndexListItems> + </Expand> + </Type> + + <Type Name="zen::DebugCbIntegerPositive"> + <DisplayString>{DebugReadVarUInt((const uint8_t*)this)}</DisplayString> + </Type> + + <Type Name="zen::DebugCbIntegerNegative"> + <DisplayString>{-(int64_t)DebugReadVarUInt((const uint8_t*)this) - 1}</DisplayString> + </Type> + + <Type Name="zen::DebugCbFloat32"> + <DisplayString>{Float32FromBits(ByteSwap32(*(uint32_t*)this))}</DisplayString> + </Type> + + <Type Name="zen::DebugCbFloat64"> + <DisplayString>{Float64FromBits(ByteSwap64(*(uint64_t*)this))}</DisplayString> + </Type> + + <Type Name="zen::DebugCbObjectAttachment"> + <DisplayString>{*(const zen::IoHash*)this}</DisplayString> + </Type> + + <Type Name="zen::DebugCbBinaryAttachment"> + <DisplayString>{*(const zen::IoHash*)this}</DisplayString> + </Type> + + <Type Name="zen::DebugCbHash"> + <DisplayString>{*(const zen::IoHash*)this}</DisplayString> + </Type> + + <Type Name="zen::DebugCbUuid"> + <DisplayString>{*(const zen::Guid*)this}</DisplayString> + </Type> + + <Type Name="zen::DebugCbDateTime"> + <!-- 100ns ticks since 0001-01-01 00:00:00.0000000 --> + <Intrinsic Name="RawTicks" Expression="(int64_t)ByteSwap64(*(uint64_t*)this)"/> + + <Intrinsic Name="TicksPerMillisecond" Expression="10000ll"/> + <Intrinsic Name="TicksPerSecond" Expression="10000000ll"/> + <Intrinsic Name="TicksPerMinute" Expression="600000000ll"/> + <Intrinsic Name="TicksPerHour" Expression="36000000000ll"/> + <Intrinsic Name="TicksPerDay" Expression="864000000000ll"/> + + <Intrinsic Name="DayNumber" Expression="RawTicks() / TicksPerDay()"/> + <Intrinsic Name="TimeOfDayTicks" Expression="RawTicks() % TicksPerDay()"/> + + <Intrinsic Name="Y400" Expression="DayNumber() / 146097"/> + <Intrinsic Name="R400" Expression="DayNumber() % 146097"/> + <Intrinsic Name="Y100Raw" Expression="R400() / 36524"/> + <Intrinsic Name="Y100" Expression="Y100Raw() > 3 ? 3 : Y100Raw()"/> + <Intrinsic Name="R100" Expression="R400() - (Y100() * 36524)"/> + <Intrinsic Name="Y4" Expression="R100() / 1461"/> + <Intrinsic Name="R4" Expression="R100() % 1461"/> + <Intrinsic Name="Y1Raw" Expression="R4() / 365"/> + <Intrinsic Name="Y1" Expression="Y1Raw() > 3 ? 3 : Y1Raw()"/> + + <Intrinsic Name="Year" Expression="(int)(Y400() * 400 + Y100() * 100 + Y4() * 4 + Y1() + 1)"/> + <Intrinsic Name="DayOfYear" Expression="(int)(R4() - Y1() * 365)"/> + <Intrinsic Name="IsLeapYear" Expression="((Year() % 4) == 0) && (((Year() % 100) != 0) || ((Year() % 400) == 0))"/> + + <Intrinsic Name="Month" Expression=" + IsLeapYear() + ? ( DayOfYear() < 31 ? 1 + : DayOfYear() < 60 ? 2 + : DayOfYear() < 91 ? 3 + : DayOfYear() < 121 ? 4 + : DayOfYear() < 152 ? 5 + : DayOfYear() < 182 ? 6 + : DayOfYear() < 213 ? 7 + : DayOfYear() < 244 ? 8 + : DayOfYear() < 274 ? 9 + : DayOfYear() < 305 ? 10 + : DayOfYear() < 335 ? 11 : 12) + : ( DayOfYear() < 31 ? 1 + : DayOfYear() < 59 ? 2 + : DayOfYear() < 90 ? 3 + : DayOfYear() < 120 ? 4 + : DayOfYear() < 151 ? 5 + : DayOfYear() < 181 ? 6 + : DayOfYear() < 212 ? 7 + : DayOfYear() < 243 ? 8 + : DayOfYear() < 273 ? 9 + : DayOfYear() < 304 ? 10 + : DayOfYear() < 334 ? 11 : 12)"/> + + <Intrinsic Name="Day" Expression=" + IsLeapYear() + ? ( DayOfYear() < 31 ? DayOfYear() + 1 + : DayOfYear() < 60 ? DayOfYear() - 31 + 1 + : DayOfYear() < 91 ? DayOfYear() - 60 + 1 + : DayOfYear() < 121 ? DayOfYear() - 91 + 1 + : DayOfYear() < 152 ? DayOfYear() - 121 + 1 + : DayOfYear() < 182 ? DayOfYear() - 152 + 1 + : DayOfYear() < 213 ? DayOfYear() - 182 + 1 + : DayOfYear() < 244 ? DayOfYear() - 213 + 1 + : DayOfYear() < 274 ? DayOfYear() - 244 + 1 + : DayOfYear() < 305 ? DayOfYear() - 274 + 1 + : DayOfYear() < 335 ? DayOfYear() - 305 + 1 + : DayOfYear() - 335 + 1) + : ( DayOfYear() < 31 ? DayOfYear() + 1 + : DayOfYear() < 59 ? DayOfYear() - 31 + 1 + : DayOfYear() < 90 ? DayOfYear() - 59 + 1 + : DayOfYear() < 120 ? DayOfYear() - 90 + 1 + : DayOfYear() < 151 ? DayOfYear() - 120 + 1 + : DayOfYear() < 181 ? DayOfYear() - 151 + 1 + : DayOfYear() < 212 ? DayOfYear() - 181 + 1 + : DayOfYear() < 243 ? DayOfYear() - 212 + 1 + : DayOfYear() < 273 ? DayOfYear() - 243 + 1 + : DayOfYear() < 304 ? DayOfYear() - 273 + 1 + : DayOfYear() < 334 ? DayOfYear() - 304 + 1 + : DayOfYear() - 334 + 1)"/> + + <Intrinsic Name="Hour" Expression="(int32_t)(TimeOfDayTicks() / TicksPerHour())"/> + <Intrinsic Name="Minute" Expression="(int32_t)((TimeOfDayTicks() % TicksPerHour()) / TicksPerMinute())"/> + <Intrinsic Name="Second" Expression="(int32_t)((TimeOfDayTicks() % TicksPerMinute()) / TicksPerSecond())"/> + <Intrinsic Name="Millisecond" Expression="(int32_t)((TimeOfDayTicks() % TicksPerSecond()) / TicksPerMillisecond())"/> + <Intrinsic Name="SubMs100ns" Expression="(int32_t)(TimeOfDayTicks() % TicksPerMillisecond())"/> + + <DisplayString>{Year()}-{Month()}-{Day()} {Hour()}:{Minute()}:{Second()} {Millisecond()}ms</DisplayString> + <Expand HideRawView="true"> + <Item Name="[Ticks] (100ns)">RawTicks()</Item> + <Item Name="[Year]">Year()</Item> + <Item Name="[Month]">Month()</Item> + <Item Name="[Day]">Day()</Item> + <Item Name="[Hour]">Hour()</Item> + <Item Name="[Minute]">Minute()</Item> + <Item Name="[Second]">Second()</Item> + <Item Name="[Millisecond]">Millisecond()</Item> + <Item Name="[100ns Remainder]">SubMs100ns()</Item> + </Expand> + </Type> + + <Type Name="zen::DebugCbTimeSpan"> + <Intrinsic Name="RawTicks" Expression="(int64_t)ByteSwap64(*(uint64_t*)this)"/> + <Intrinsic Name="AbsTicks" Expression="RawTicks() < 0 ? -RawTicks() : RawTicks()"/> + + <Intrinsic Name="TicksPerMillisecond" Expression="10000ll"/> + <Intrinsic Name="TicksPerSecond" Expression="10000000ll"/> + <Intrinsic Name="TicksPerMinute" Expression="600000000ll"/> + <Intrinsic Name="TicksPerHour" Expression="36000000000ll"/> + <Intrinsic Name="TicksPerDay" Expression="864000000000ll"/> + + <Intrinsic Name="Days" Expression="(int64_t)(AbsTicks() / TicksPerDay())"/> + <Intrinsic Name="Hours" Expression="(int32_t)((AbsTicks() % TicksPerDay()) / TicksPerHour())"/> + <Intrinsic Name="Minutes" Expression="(int32_t)((AbsTicks() % TicksPerHour()) / TicksPerMinute())"/> + <Intrinsic Name="Seconds" Expression="(int32_t)((AbsTicks() % TicksPerMinute()) / TicksPerSecond())"/> + <Intrinsic Name="Milliseconds" Expression="(int32_t)((AbsTicks() % TicksPerSecond()) / TicksPerMillisecond())"/> + <Intrinsic Name="SubMs100ns" Expression="(int32_t)(AbsTicks() % TicksPerMillisecond())"/> + + <DisplayString Condition="RawTicks() < 0">-{Days()}d {Hours()}h{Minutes()}m{Seconds()}s {Milliseconds()}ms</DisplayString> + <DisplayString>{Days()}d {Hours()}h{Minutes()}m{Seconds()}s {Milliseconds()}ms</DisplayString> + <Expand HideRawView="true"> + <Item Name="[Ticks] (100ns)">RawTicks()</Item> + <Item Name="[Days]">Days()</Item> + <Item Name="[Hours]">Hours()</Item> + <Item Name="[Minutes]">Minutes()</Item> + <Item Name="[Seconds]">Seconds()</Item> + <Item Name="[Milliseconds]">Milliseconds()</Item> + <Item Name="[100ns Remainder]">SubMs100ns()</Item> + <Item Name="[Total Milliseconds]">(double)RawTicks() / TicksPerMillisecond()</Item> + <Item Name="[Total Seconds]">(double)RawTicks() / TicksPerSecond()</Item> + <Item Name="[Total Minutes]">(double)RawTicks() / TicksPerMinute()</Item> + <Item Name="[Total Hours]">(double)RawTicks() / TicksPerHour()</Item> + <Item Name="[Total Days]">(double)RawTicks() / TicksPerDay()</Item> + </Expand> + </Type> + + <Type Name="zen::DebugCbObjectId"> + <DisplayString>{*(const zen::Oid*)this}</DisplayString> + </Type> + + <Type Name="zen::DebugCbCustomById"> + <Intrinsic Name="TotalSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="TotalSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="TypeIdByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this + TotalSizeByteCount())"/> + <Intrinsic Name="TypeId" Expression="DebugReadVarUInt((const uint8_t*)this + TotalSizeByteCount())"/> + <Intrinsic Name="Value" Expression="(const uint8_t*)this + TotalSizeByteCount() + TypeIdByteCount()"/> + <Intrinsic Name="ValueSize" Expression="TotalSize() - TypeIdByteCount()"/> + <Expand HideRawView="true"> + <Item Name="[TypeId]">TypeId()</Item> + <Item Name="[Value]">Value(),[ValueSize()]</Item> + </Expand> + </Type> + + <Type Name="zen::DebugCbCustomByName"> + <Intrinsic Name="TotalSizeByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="TotalSize" Expression="DebugReadVarUInt((const uint8_t*)this)"/> + <Intrinsic Name="TypeNameLenByteCount" Expression="DebugMeasureVarUInt((const uint8_t*)this + TotalSizeByteCount())"/> + <Intrinsic Name="TypeNameLen" Expression="DebugReadVarUInt((const uint8_t*)this + TotalSizeByteCount())"/> + <Intrinsic Name="TypeName" Expression="(const uint8_t*)this + TotalSizeByteCount() + TypeNameLenByteCount()"/> + <Intrinsic Name="Value" Expression="TypeName() + TypeNameLen()"/> + <Intrinsic Name="ValueSize" Expression="TotalSize() - TypeNameLenByteCount() - TypeNameLen()"/> + <Expand HideRawView="true"> + <Item Name="[TypeName]">TypeName(),[TypeNameLen()]s8</Item> + <Item Name="[Value]">Value(),[ValueSize()]</Item> + </Expand> + </Type> + + <Type Name="zen::TCbFieldIterator<*>"> + <Intrinsic Name="NameFromField" Expression="Field + (((uint8_t)Type & 0x40) ? 1 : 0)"> + <Parameter Name="Field" Type="const uint8_t*"/> + <Parameter Name="Type" Type="zen::CbFieldType"/> + </Intrinsic> + <Intrinsic Name="ValueFromName" Expression="Name + (((uint8_t)Type & 0x80) ? DebugMeasureVarUInt(Name) + DebugReadVarUInt(Name) : 0)"> + <Parameter Name="Name" Type="const uint8_t*"/> + <Parameter Name="Type" Type="zen::CbFieldType"/> + </Intrinsic> + <Intrinsic Name="ValueFromField" Expression="ValueFromName(NameFromField(Field, Type), Type)"> + <Parameter Name="Field" Type="const uint8_t*"/> + <Parameter Name="Type" Type="zen::CbFieldType"/> + </Intrinsic> + <Intrinsic Name="NextFieldFromValue" Expression="Value + MeasureCompactBinaryValue(Value, (uint8_t)Type & 0x1f)"> + <Parameter Name="Value" Type="const uint8_t*"/> + <Parameter Name="Type" Type="zen::CbFieldType"/> + </Intrinsic> + <DisplayString Condition="!FieldsEnd">Empty</DisplayString> + <Expand> + <ExpandedItem Condition="!!FieldsEnd">*($T1*)this</ExpandedItem> + <Synthetic Name="[Remaining]" Condition="!!FieldsEnd"> + <DisplayString Condition="NextFieldFromValue((const uint8_t*)Payload, Type) >= FieldsEnd">Empty</DisplayString> + <Expand> + <CustomListItems> + <Variable Name="FieldType" InitialValue="Type"/> + <Variable Name="Field" InitialValue="NextFieldFromValue((const uint8_t*)Payload, Type)"/> + <Variable Name="Value" InitialValue="(const uint8_t*)nullptr"/> + <Variable Name="Index" InitialValue="(uint32_t)0"/> + <Variable Name="NameLen" InitialValue="(uint32_t)0"/> + <Variable Name="Name" InitialValue="(const uint8_t*)nullptr"/> + <Loop Condition="Field < FieldsEnd"> + <!-- Decode --> + <Exec>FieldType = ((uint8_t)Type & 0x40) ? (zen::CbFieldType)*Field : Type</Exec> + <Exec>Name = NameFromField(Field, FieldType)</Exec> + <Exec>NameLen = ((uint8_t)FieldType & 0x80) ? (uint32_t)DebugReadVarUInt(Name) : 0</Exec> + <Exec>Value = ValueFromField(Field, FieldType)</Exec> + <!-- Display Name and Index --> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x81">nullptr</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x82">*(zen::DebugCbObject*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x83">*(zen::DebugCbUniformObject*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x84">*(zen::DebugCbArray*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x85">*(zen::DebugCbUniformArray*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x86">*(zen::DebugCbBinary*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x87">*(zen::DebugCbString*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x88">*(zen::DebugCbIntegerPositive*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x89">*(zen::DebugCbIntegerNegative*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x8a">*(zen::DebugCbFloat32*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x8b">*(zen::DebugCbFloat64*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x8c">false</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x8d">true</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x8e">*(zen::DebugCbObjectAttachment*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x8f">*(zen::DebugCbBinaryAttachment*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x90">*(zen::DebugCbHash*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x91">*(zen::DebugCbUuid*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x92">*(zen::DebugCbDateTime*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x93">*(zen::DebugCbTimeSpan*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x94">*(zen::DebugCbObjectId*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x9e">*(zen::DebugCbCustomById*)Value</Item> + <Item Name="[{Index}] {*(zen::DebugCbString*)Name}" Condition="((uint8_t)FieldType & 0x9f) == 0x9f">*(zen::DebugCbCustomByName*)Value</Item> + <!-- Display Index Only --> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x01">nullptr</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x02">*(zen::DebugCbObject*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x03">*(zen::DebugCbUniformObject*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x04">*(zen::DebugCbArray*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x05">*(zen::DebugCbUniformArray*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x06">*(zen::DebugCbBinary*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x07">*(zen::DebugCbString*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x08">*(zen::DebugCbIntegerPositive*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x09">*(zen::DebugCbIntegerNegative*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x0a">*(zen::DebugCbFloat32*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x0b">*(zen::DebugCbFloat64*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x0c">false</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x0d">true</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x0e">*(zen::DebugCbObjectAttachment*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x0f">*(zen::DebugCbBinaryAttachment*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x10">*(zen::DebugCbHash*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x11">*(zen::DebugCbUuid*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x12">*(zen::DebugCbDateTime*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x13">*(zen::DebugCbTimeSpan*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x14">*(zen::DebugCbObjectId*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x1e">*(zen::DebugCbCustomById*)Value</Item> + <Item Name="[{Index}]" Condition="((uint8_t)FieldType & 0x9f) == 0x1f">*(zen::DebugCbCustomByName*)Value</Item> + <!-- Advance --> + <Exec>++Index</Exec> + <Exec>Field = NextFieldFromValue(Value, FieldType)</Exec> + </Loop> + </CustomListItems> + </Expand> + </Synthetic> + </Expand> + </Type> + + <Type Name="zen::DateTime"> + <Intrinsic Name="RawTicks" Expression="Ticks"/> + + <Intrinsic Name="TicksPerMillisecond" Expression="10000ll"/> + <Intrinsic Name="TicksPerSecond" Expression="10000000ll"/> + <Intrinsic Name="TicksPerMinute" Expression="600000000ll"/> + <Intrinsic Name="TicksPerHour" Expression="36000000000ll"/> + <Intrinsic Name="TicksPerDay" Expression="864000000000ll"/> + + <Intrinsic Name="DayNumber" Expression="RawTicks() / TicksPerDay()"/> + <Intrinsic Name="TimeOfDayTicks" Expression="RawTicks() % TicksPerDay()"/> + + <Intrinsic Name="Y400" Expression="DayNumber() / 146097"/> + <Intrinsic Name="R400" Expression="DayNumber() % 146097"/> + <Intrinsic Name="Y100Raw" Expression="R400() / 36524"/> + <Intrinsic Name="Y100" Expression="Y100Raw() > 3 ? 3 : Y100Raw()"/> + <Intrinsic Name="R100" Expression="R400() - (Y100() * 36524)"/> + <Intrinsic Name="Y4" Expression="R100() / 1461"/> + <Intrinsic Name="R4" Expression="R100() % 1461"/> + <Intrinsic Name="Y1Raw" Expression="R4() / 365"/> + <Intrinsic Name="Y1" Expression="Y1Raw() > 3 ? 3 : Y1Raw()"/> + + <Intrinsic Name="Year" Expression="(int)(Y400() * 400 + Y100() * 100 + Y4() * 4 + Y1() + 1)"/> + <Intrinsic Name="DayOfYear" Expression="(int)(R4() - Y1() * 365)"/> + <Intrinsic Name="IsLeapYear" Expression="((Year() % 4) == 0) && (((Year() % 100) != 0) || ((Year() % 400) == 0))"/> + + <Intrinsic Name="Month" Expression=" + IsLeapYear() + ? ( DayOfYear() < 31 ? 1 + : DayOfYear() < 60 ? 2 + : DayOfYear() < 91 ? 3 + : DayOfYear() < 121 ? 4 + : DayOfYear() < 152 ? 5 + : DayOfYear() < 182 ? 6 + : DayOfYear() < 213 ? 7 + : DayOfYear() < 244 ? 8 + : DayOfYear() < 274 ? 9 + : DayOfYear() < 305 ? 10 + : DayOfYear() < 335 ? 11 : 12) + : ( DayOfYear() < 31 ? 1 + : DayOfYear() < 59 ? 2 + : DayOfYear() < 90 ? 3 + : DayOfYear() < 120 ? 4 + : DayOfYear() < 151 ? 5 + : DayOfYear() < 181 ? 6 + : DayOfYear() < 212 ? 7 + : DayOfYear() < 243 ? 8 + : DayOfYear() < 273 ? 9 + : DayOfYear() < 304 ? 10 + : DayOfYear() < 334 ? 11 : 12)"/> + + <Intrinsic Name="Day" Expression=" + IsLeapYear() + ? ( DayOfYear() < 31 ? DayOfYear() + 1 + : DayOfYear() < 60 ? DayOfYear() - 31 + 1 + : DayOfYear() < 91 ? DayOfYear() - 60 + 1 + : DayOfYear() < 121 ? DayOfYear() - 91 + 1 + : DayOfYear() < 152 ? DayOfYear() - 121 + 1 + : DayOfYear() < 182 ? DayOfYear() - 152 + 1 + : DayOfYear() < 213 ? DayOfYear() - 182 + 1 + : DayOfYear() < 244 ? DayOfYear() - 213 + 1 + : DayOfYear() < 274 ? DayOfYear() - 244 + 1 + : DayOfYear() < 305 ? DayOfYear() - 274 + 1 + : DayOfYear() < 335 ? DayOfYear() - 305 + 1 + : DayOfYear() - 335 + 1) + : ( DayOfYear() < 31 ? DayOfYear() + 1 + : DayOfYear() < 59 ? DayOfYear() - 31 + 1 + : DayOfYear() < 90 ? DayOfYear() - 59 + 1 + : DayOfYear() < 120 ? DayOfYear() - 90 + 1 + : DayOfYear() < 151 ? DayOfYear() - 120 + 1 + : DayOfYear() < 181 ? DayOfYear() - 151 + 1 + : DayOfYear() < 212 ? DayOfYear() - 181 + 1 + : DayOfYear() < 243 ? DayOfYear() - 212 + 1 + : DayOfYear() < 273 ? DayOfYear() - 243 + 1 + : DayOfYear() < 304 ? DayOfYear() - 273 + 1 + : DayOfYear() < 334 ? DayOfYear() - 304 + 1 + : DayOfYear() - 334 + 1)"/> + + <Intrinsic Name="Hour" Expression="(int32_t)(TimeOfDayTicks() / TicksPerHour())"/> + <Intrinsic Name="Minute" Expression="(int32_t)((TimeOfDayTicks() % TicksPerHour()) / TicksPerMinute())"/> + <Intrinsic Name="Second" Expression="(int32_t)((TimeOfDayTicks() % TicksPerMinute()) / TicksPerSecond())"/> + <Intrinsic Name="Millisecond" Expression="(int32_t)((TimeOfDayTicks() % TicksPerSecond()) / TicksPerMillisecond())"/> + <Intrinsic Name="SubMs100ns" Expression="(int32_t)(TimeOfDayTicks() % TicksPerMillisecond())"/> + + <DisplayString>{Year()}-{Month()}-{Day()} {Hour()}:{Minute()}:{Second()} {Millisecond()}ms</DisplayString> + <Expand> + <Item Name="Ticks">Ticks</Item> + <Item Name="[Year]">Year()</Item> + <Item Name="[Month]">Month()</Item> + <Item Name="[Day]">Day()</Item> + <Item Name="[Hour]">Hour()</Item> + <Item Name="[Minute]">Minute()</Item> + <Item Name="[Second]">Second()</Item> + <Item Name="[Millisecond]">Millisecond()</Item> + <Item Name="[100ns Remainder]">SubMs100ns()</Item> + </Expand> + </Type> + + <Type Name="zen::TimeSpan"> + <Intrinsic Name="RawTicks" Expression="Ticks"/> + <Intrinsic Name="AbsTicks" Expression="RawTicks() < 0 ? -RawTicks() : RawTicks()"/> + + <Intrinsic Name="TicksPerMillisecond" Expression="10000ll"/> + <Intrinsic Name="TicksPerSecond" Expression="10000000ll"/> + <Intrinsic Name="TicksPerMinute" Expression="600000000ll"/> + <Intrinsic Name="TicksPerHour" Expression="36000000000ll"/> + <Intrinsic Name="TicksPerDay" Expression="864000000000ll"/> + + <Intrinsic Name="Days" Expression="(int64_t)(AbsTicks() / TicksPerDay())"/> + <Intrinsic Name="Hours" Expression="(int32_t)((AbsTicks() % TicksPerDay()) / TicksPerHour())"/> + <Intrinsic Name="Minutes" Expression="(int32_t)((AbsTicks() % TicksPerHour()) / TicksPerMinute())"/> + <Intrinsic Name="Seconds" Expression="(int32_t)((AbsTicks() % TicksPerMinute()) / TicksPerSecond())"/> + <Intrinsic Name="Milliseconds" Expression="(int32_t)((AbsTicks() % TicksPerSecond()) / TicksPerMillisecond())"/> + <Intrinsic Name="SubMs100ns" Expression="(int32_t)(AbsTicks() % TicksPerMillisecond())"/> + + <DisplayString Condition="RawTicks() < 0">-{Days()}d {Hours()}h{Minutes()}m{Seconds()}s {Milliseconds()}ms</DisplayString> + <DisplayString>{Days()}d {Hours()}h{Minutes()}m{Seconds()}s {Milliseconds()}ms</DisplayString> + <Expand> + <Item Name="Ticks">Ticks</Item> + <Item Name="[Days]">Days()</Item> + <Item Name="[Hours]">Hours()</Item> + <Item Name="[Minutes]">Minutes()</Item> + <Item Name="[Seconds]">Seconds()</Item> + <Item Name="[Milliseconds]">Milliseconds()</Item> + <Item Name="[100ns Remainder]">SubMs100ns()</Item> + <Item Name="[Total Milliseconds]">(double)RawTicks() / TicksPerMillisecond()</Item> + <Item Name="[Total Seconds]">(double)RawTicks() / TicksPerSecond()</Item> + <Item Name="[Total Minutes]">(double)RawTicks() / TicksPerMinute()</Item> + <Item Name="[Total Hours]">(double)RawTicks() / TicksPerHour()</Item> + <Item Name="[Total Days]">(double)RawTicks() / TicksPerDay()</Item> + </Expand> + </Type> + + <Type Name="zen::IoHash"> + <DisplayString>{uint32_t(Hash[0]<<24 | Hash[1]<<16 | Hash[2]<<8 | Hash[3]),xb}{uint32_t(Hash[4]<<24 | Hash[5]<<16 | Hash[6]<<8 | Hash[7]),xb}{uint32_t(Hash[8]<<24 | Hash[9]<<16 | Hash[10]<<8 | Hash[11]),xb}{uint32_t(Hash[12]<<24 | Hash[13]<<16 | Hash[14]<<8 | Hash[15]),xb}{uint32_t(Hash[16]<<24 | Hash[17]<<16 | Hash[18]<<8 | Hash[19]),xb}</DisplayString> + <Expand> + <IndexListItems> + <Size>sizeof(Hash)</Size> + <ValueNode>Hash[$i]</ValueNode> + </IndexListItems> + </Expand> + </Type> + + <Type Name="zen::Guid"> + <DisplayString>{{{A,xb}-{(uint16_t)(B >> 16),xb}-{(uint16_t)B,xb}-{(uint16_t)(C >> 16),xb}-{(uint16_t)C,xb}{D,xb}}}</DisplayString> + <Expand> + <Item Name="A">A,x</Item> + <Item Name="B">B,x</Item> + <Item Name="C">C,x</Item> + <Item Name="D">D,x</Item> + </Expand> + </Type> + + <Type Name="zen::Oid"> + <DisplayString>{ByteSwap32(OidBits[0]),xb}{ByteSwap32(OidBits[1]),xb}{ByteSwap32(OidBits[2]),xb}</DisplayString> + <Expand> + <ExpandedItem>OidBits</ExpandedItem> + </Expand> + </Type> + +</AutoVisualizer> diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp index fbae9f5fe..770213738 100644 --- a/src/zenhttp/clients/httpwsclient.cpp +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -638,4 +638,34 @@ HttpWsClient::IsOpen() const return m_Impl->m_IsOpen.load(std::memory_order_relaxed); } +std::string +HttpToWsUrl(std::string_view Endpoint, std::string_view Path) +{ + std::string_view Scheme = "ws://"; + std::string_view Host = Endpoint; + + if (Endpoint.starts_with("http://")) + { + Host = Endpoint.substr(7); + } + else if (Endpoint.starts_with("https://")) + { + Scheme = "wss://"; + Host = Endpoint.substr(8); + } + + // Strip trailing slash from host to avoid double-slash when Path starts with '/' + if (!Host.empty() && Host.back() == '/') + { + Host = Host.substr(0, Host.size() - 1); + } + + std::string Url; + Url.reserve(Scheme.size() + Host.size() + Path.size()); + Url.append(Scheme); + Url.append(Host); + Url.append(Path); + return Url; +} + } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h index 2ca9b7ab1..9c3b909a2 100644 --- a/src/zenhttp/include/zenhttp/httpwsclient.h +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -80,4 +80,14 @@ private: std::unique_ptr<Impl> m_Impl; }; +/// Convert an HTTP(S) endpoint to a WebSocket URL by replacing the scheme +/// and appending the given path. If the endpoint has no recognised scheme, +/// it is treated as a plain host:port and gets the ws:// prefix. +/// +/// Examples: +/// HttpToWsUrl("http://host:8080", "/orch/ws") → "ws://host:8080/orch/ws" +/// HttpToWsUrl("https://host", "/foo") → "wss://host/foo" +/// HttpToWsUrl("host:8080", "/bar") → "ws://host:8080/bar" +std::string HttpToWsUrl(std::string_view Endpoint, std::string_view Path); + } // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h index bc3293282..710579faa 100644 --- a/src/zenhttp/include/zenhttp/websocket.h +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -43,6 +43,8 @@ public: virtual void SendBinary(std::span<const uint8_t> Data) = 0; virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; virtual bool IsOpen() const = 0; + + void SendBinary(MemoryView Data) { SendBinary(std::span<const uint8_t>(static_cast<const uint8_t*>(Data.GetData()), Data.GetSize())); } }; /** diff --git a/src/zens3-testbed/main.cpp b/src/zens3-testbed/main.cpp new file mode 100644 index 000000000..4cd6b411f --- /dev/null +++ b/src/zens3-testbed/main.cpp @@ -0,0 +1,526 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// Simple test bed for exercising the zens3 module against a real S3 bucket. +// +// Usage: +// zens3-testbed --bucket <name> --region <region> [command] [args...] +// +// Credentials are read from environment variables: +// AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY +// +// Commands: +// put <key> <file> Upload a local file +// get <key> [file] Download an object (prints to stdout if no file given) +// head <key> Check if object exists, show metadata +// delete <key> Delete an object +// list [prefix] List objects with optional prefix +// multipart-put <key> <file> [part-size-mb] Upload via multipart +// roundtrip <key> Upload test data, download, verify, delete + +#include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/s3client.h> + +#include <zencore/except_fmt.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/string.h> + +#include <zencore/memory/newdelete.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <cxxopts.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <cstdlib> +#include <fstream> +#include <iostream> + +namespace { + +using namespace zen; + +std::string +GetEnvVar(const char* Name) +{ + const char* Value = std::getenv(Name); + return Value ? std::string(Value) : std::string(); +} + +IoBuffer +ReadFileToBuffer(const std::filesystem::path& Path) +{ + return zen::ReadFile(Path).Flatten(); +} + +void +WriteBufferToFile(const IoBuffer& Buffer, const std::filesystem::path& Path) +{ + std::ofstream File(Path, std::ios::binary); + if (!File) + { + throw zen::runtime_error("failed to open '{}' for writing", Path.string()); + } + File.write(reinterpret_cast<const char*>(Buffer.GetData()), static_cast<std::streamsize>(Buffer.GetSize())); +} + +S3Client +CreateClient(const cxxopts::ParseResult& Args) +{ + S3ClientOptions Options; + Options.BucketName = Args["bucket"].as<std::string>(); + Options.Region = Args["region"].as<std::string>(); + + if (Args.count("imds")) + { + // Use IMDS credential provider for EC2 instances + ImdsCredentialProviderOptions ImdsOpts; + if (Args.count("imds-endpoint")) + { + ImdsOpts.Endpoint = Args["imds-endpoint"].as<std::string>(); + } + Options.CredentialProvider = Ref<ImdsCredentialProvider>(new ImdsCredentialProvider(ImdsOpts)); + } + else + { + std::string AccessKey = GetEnvVar("AWS_ACCESS_KEY_ID"); + std::string SecretKey = GetEnvVar("AWS_SECRET_ACCESS_KEY"); + std::string SessionToken = GetEnvVar("AWS_SESSION_TOKEN"); + + if (AccessKey.empty() || SecretKey.empty()) + { + throw zen::runtime_error("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables must be set"); + } + + Options.Credentials.AccessKeyId = std::move(AccessKey); + Options.Credentials.SecretAccessKey = std::move(SecretKey); + Options.Credentials.SessionToken = std::move(SessionToken); + } + + if (Args.count("endpoint")) + { + Options.Endpoint = Args["endpoint"].as<std::string>(); + } + + if (Args.count("path-style")) + { + Options.PathStyle = true; + } + + if (Args.count("timeout")) + { + Options.Timeout = std::chrono::milliseconds(Args["timeout"].as<int>() * 1000); + } + + return S3Client(Options); +} + +int +CmdPut(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 3) + { + fmt::print(stderr, "Usage: zens3-testbed ... put <key> <file>\n"); + return 1; + } + + const auto& Key = Positional[1]; + const auto& FilePath = Positional[2]; + + IoBuffer Content = ReadFileToBuffer(FilePath); + fmt::print("Uploading '{}' ({} bytes) to s3://{}/{}\n", FilePath, Content.GetSize(), Client.BucketName(), Key); + + S3Result Result = Client.PutObject(Key, Content); + if (!Result) + { + fmt::print(stderr, "PUT failed: {}\n", Result.Error); + return 1; + } + + fmt::print("OK\n"); + return 0; +} + +int +CmdGet(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... get <key> [file]\n"); + return 1; + } + + const auto& Key = Positional[1]; + + S3GetObjectResult Result = Client.GetObject(Key); + if (!Result) + { + fmt::print(stderr, "GET failed: {}\n", Result.Error); + return 1; + } + + if (Positional.size() >= 3) + { + const auto& FilePath = Positional[2]; + WriteBufferToFile(Result.Content, FilePath); + fmt::print("Downloaded {} bytes to '{}'\n", Result.Content.GetSize(), FilePath); + } + else + { + // Print to stdout + std::string_view Text = Result.AsText(); + std::cout.write(Text.data(), static_cast<std::streamsize>(Text.size())); + std::cout << std::endl; + } + + return 0; +} + +int +CmdHead(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... head <key>\n"); + return 1; + } + + const auto& Key = Positional[1]; + + S3HeadObjectResult Result = Client.HeadObject(Key); + + if (!Result) + { + fmt::print(stderr, "HEAD failed: {}\n", Result.Error); + return 1; + } + + if (Result.Status == HeadObjectResult::NotFound) + { + fmt::print("Object '{}' does not exist\n", Key); + return 1; + } + + fmt::print("Key: {}\n", Result.Info.Key); + fmt::print("Size: {} bytes\n", Result.Info.Size); + fmt::print("ETag: {}\n", Result.Info.ETag); + fmt::print("Last-Modified: {}\n", Result.Info.LastModified); + return 0; +} + +int +CmdDelete(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... delete <key>\n"); + return 1; + } + + const auto& Key = Positional[1]; + + S3Result Result = Client.DeleteObject(Key); + if (!Result) + { + fmt::print(stderr, "DELETE failed: {}\n", Result.Error); + return 1; + } + + fmt::print("Deleted '{}'\n", Key); + return 0; +} + +int +CmdList(S3Client& Client, const std::vector<std::string>& Positional) +{ + std::string Prefix; + if (Positional.size() >= 2) + { + Prefix = Positional[1]; + } + + S3ListObjectsResult Result = Client.ListObjects(Prefix); + if (!Result) + { + fmt::print(stderr, "LIST failed: {}\n", Result.Error); + return 1; + } + + fmt::print("{} objects found:\n", Result.Objects.size()); + for (const auto& Obj : Result.Objects) + { + fmt::print(" {:>12} {} {}\n", Obj.Size, Obj.LastModified, Obj.Key); + } + + return 0; +} + +int +CmdMultipartPut(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 3) + { + fmt::print(stderr, "Usage: zens3-testbed ... multipart-put <key> <file> [part-size-mb]\n"); + return 1; + } + + const auto& Key = Positional[1]; + const auto& FilePath = Positional[2]; + + uint64_t PartSize = 8 * 1024 * 1024; // 8 MB default + if (Positional.size() >= 4) + { + PartSize = std::stoull(Positional[3]) * 1024 * 1024; + } + + IoBuffer Content = ReadFileToBuffer(FilePath); + fmt::print("Multipart uploading '{}' ({} bytes, part size {} MB) to s3://{}/{}\n", + FilePath, + Content.GetSize(), + PartSize / (1024 * 1024), + Client.BucketName(), + Key); + + S3Result Result = Client.PutObjectMultipart(Key, Content, PartSize); + if (!Result) + { + fmt::print(stderr, "Multipart PUT failed: {}\n", Result.Error); + return 1; + } + + fmt::print("OK\n"); + return 0; +} + +int +CmdRoundtrip(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... roundtrip <key>\n"); + return 1; + } + + const auto& Key = Positional[1]; + + // Generate test data + const size_t TestSize = 1024 * 64; // 64 KB + std::vector<uint8_t> TestData(TestSize); + for (size_t i = 0; i < TestSize; ++i) + { + TestData[i] = static_cast<uint8_t>(i & 0xFF); + } + + IoBuffer UploadContent(IoBuffer::Clone, TestData.data(), TestData.size()); + + fmt::print("=== Roundtrip test for key '{}' ===\n\n", Key); + + // PUT + fmt::print("[1/4] PUT {} bytes...\n", TestSize); + S3Result Result = Client.PutObject(Key, UploadContent); + if (!Result) + { + fmt::print(stderr, " FAILED: {}\n", Result.Error); + return 1; + } + fmt::print(" OK\n"); + + // HEAD + fmt::print("[2/4] HEAD...\n"); + S3HeadObjectResult HeadResult = Client.HeadObject(Key); + if (HeadResult.Status != HeadObjectResult::Found) + { + fmt::print(stderr, " FAILED: {}\n", !HeadResult ? HeadResult.Error : "not found"); + return 1; + } + fmt::print(" OK (size={}, etag={})\n", HeadResult.Info.Size, HeadResult.Info.ETag); + + if (HeadResult.Info.Size != TestSize) + { + fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, HeadResult.Info.Size); + return 1; + } + + // GET + fmt::print("[3/4] GET and verify...\n"); + S3GetObjectResult GetResult = Client.GetObject(Key); + if (!GetResult) + { + fmt::print(stderr, " FAILED: {}\n", GetResult.Error); + return 1; + } + + if (GetResult.Content.GetSize() != TestSize) + { + fmt::print(stderr, " SIZE MISMATCH: expected {}, got {}\n", TestSize, GetResult.Content.GetSize()); + return 1; + } + + if (memcmp(GetResult.Content.GetData(), TestData.data(), TestSize) != 0) + { + fmt::print(stderr, " DATA MISMATCH\n"); + return 1; + } + fmt::print(" OK (verified {} bytes)\n", TestSize); + + // DELETE + fmt::print("[4/4] DELETE...\n"); + Result = Client.DeleteObject(Key); + if (!Result) + { + fmt::print(stderr, " FAILED: {}\n", Result.Error); + return 1; + } + fmt::print(" OK\n"); + + fmt::print("\n=== Roundtrip test PASSED ===\n"); + return 0; +} + +int +CmdPresign(S3Client& Client, const std::vector<std::string>& Positional) +{ + if (Positional.size() < 2) + { + fmt::print(stderr, "Usage: zens3-testbed ... presign <key> [method] [expires-seconds]\n"); + return 1; + } + + const auto& Key = Positional[1]; + + std::string Method = "GET"; + if (Positional.size() >= 3) + { + Method = Positional[2]; + } + + std::chrono::seconds ExpiresIn(3600); + if (Positional.size() >= 4) + { + ExpiresIn = std::chrono::seconds(std::stoul(Positional[3])); + } + + std::string Url; + if (Method == "PUT") + { + Url = Client.GeneratePresignedPutUrl(Key, ExpiresIn); + } + else + { + Url = Client.GeneratePresignedGetUrl(Key, ExpiresIn); + } + + fmt::print("{}\n", Url); + return 0; +} + +} // namespace + +int +main(int argc, char* argv[]) +{ + using namespace zen; + + logging::InitializeLogging(); + + cxxopts::Options Options("zens3-testbed", "Test bed for exercising S3 operations via the zens3 module"); + + // clang-format off + Options.add_options() + ("b,bucket", "S3 bucket name", cxxopts::value<std::string>()) + ("r,region", "AWS region", cxxopts::value<std::string>()->default_value("us-east-1")) + ("e,endpoint", "Custom S3 endpoint URL", cxxopts::value<std::string>()) + ("path-style", "Use path-style addressing (for MinIO, etc.)") + ("imds", "Use EC2 IMDS for credentials instead of env vars") + ("imds-endpoint", "Custom IMDS endpoint URL (for testing)", cxxopts::value<std::string>()) + ("timeout", "Request timeout in seconds", cxxopts::value<int>()->default_value("30")) + ("v,verbose", "Enable verbose logging") + ("h,help", "Show help") + ("positional", "Command and arguments", cxxopts::value<std::vector<std::string>>()); + // clang-format on + + Options.parse_positional({"positional"}); + Options.positional_help("<command> [args...]"); + + try + { + auto Result = Options.parse(argc, argv); + + if (Result.count("help") || !Result.count("positional")) + { + fmt::print("{}\n", Options.help()); + fmt::print("Commands:\n"); + fmt::print(" put <key> <file> Upload a local file\n"); + fmt::print(" get <key> [file] Download (to file or stdout)\n"); + fmt::print(" head <key> Show object metadata\n"); + fmt::print(" delete <key> Delete an object\n"); + fmt::print(" list [prefix] List objects\n"); + fmt::print(" multipart-put <key> <file> [part-mb] Multipart upload\n"); + fmt::print(" roundtrip <key> Upload/download/verify/delete\n"); + fmt::print(" presign <key> [method] [expires-sec] Generate pre-signed URL\n"); + fmt::print("\nCredentials via AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY env vars,\n"); + fmt::print("or use --imds to fetch from EC2 Instance Metadata Service.\n"); + return 0; + } + + if (!Result.count("bucket")) + { + fmt::print(stderr, "Error: --bucket is required\n"); + return 1; + } + + if (Result.count("verbose")) + { + logging::SetLogLevel(logging::Debug); + } + + auto Client = CreateClient(Result); + + const auto& Positional = Result["positional"].as<std::vector<std::string>>(); + const auto& Command = Positional[0]; + + if (Command == "put") + { + return CmdPut(Client, Positional); + } + else if (Command == "get") + { + return CmdGet(Client, Positional); + } + else if (Command == "head") + { + return CmdHead(Client, Positional); + } + else if (Command == "delete") + { + return CmdDelete(Client, Positional); + } + else if (Command == "list") + { + return CmdList(Client, Positional); + } + else if (Command == "multipart-put") + { + return CmdMultipartPut(Client, Positional); + } + else if (Command == "roundtrip") + { + return CmdRoundtrip(Client, Positional); + } + else if (Command == "presign") + { + return CmdPresign(Client, Positional); + } + else + { + fmt::print(stderr, "Unknown command: '{}'\n", Command); + return 1; + } + } + catch (const std::exception& Ex) + { + fmt::print(stderr, "Error: {}\n", Ex.what()); + return 1; + } +} diff --git a/src/zens3-testbed/xmake.lua b/src/zens3-testbed/xmake.lua new file mode 100644 index 000000000..168ab9de9 --- /dev/null +++ b/src/zens3-testbed/xmake.lua @@ -0,0 +1,8 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zens3-testbed") + set_kind("binary") + set_group("tools") + add_files("*.cpp") + add_deps("zenutil", "zencore") + add_deps("cxxopts", "fmt") diff --git a/src/zenserver-test/compute-tests.cpp b/src/zenserver-test/compute-tests.cpp index c90ac5d8b..021052a3b 100644 --- a/src/zenserver-test/compute-tests.cpp +++ b/src/zenserver-test/compute-tests.cpp @@ -19,6 +19,7 @@ # include <zencore/timer.h> # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> +# include <zenhttp/websocket.h> # include <zencompute/computeservice.h> # include <zenstore/zenstore.h> # include <zenutil/zenserverprocess.h> @@ -291,7 +292,9 @@ GetRot13Output(const CbPackage& ResultPackage) } // Mock orchestrator HTTP service that serves GET /orch/agents with a controllable response. -class MockOrchestratorService : public HttpService +// Also implements IWebSocketHandler so the compute session's WS subscription receives +// push notifications when the worker list changes. +class MockOrchestratorService : public HttpService, public IWebSocketHandler { public: MockOrchestratorService() @@ -318,13 +321,48 @@ public: void SetWorkerList(CbObject WorkerList) { - RwLock::ExclusiveLockScope Lock(m_Lock); - m_WorkerList = std::move(WorkerList); + { + RwLock::ExclusiveLockScope Lock(m_Lock); + m_WorkerList = std::move(WorkerList); + } + + // Broadcast a poke to all connected WebSocket clients so they + // immediately re-query the orchestrator instead of waiting for the poll. + std::vector<Ref<WebSocketConnection>> Snapshot; + m_WsLock.WithSharedLock([&] { Snapshot = m_WsConnections; }); + for (auto& Conn : Snapshot) + { + if (Conn->IsOpen()) + { + Conn->SendText("updated"sv); + } + } + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + { + m_WsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + } + + void OnWebSocketMessage(WebSocketConnection&, const WebSocketMessage&) override {} + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t, std::string_view) override + { + m_WsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); } private: RwLock m_Lock; CbObject m_WorkerList; + + RwLock m_WsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; }; // Manages in-process ASIO HTTP server lifecycle for mock orchestrator. @@ -1089,9 +1127,8 @@ TEST_CASE("function.remote.worker_sync_on_discovery") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session (stored locally, no runners yet) CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); @@ -1100,8 +1137,9 @@ TEST_CASE("function.remote.worker_sync_on_discovery") // Update mock orchestrator to advertise the real server MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri}})); - // Wait for scheduler to discover the runner (~5s throttle + margin) - Sleep(7'000); + // Trigger immediate orchestrator re-query and wait for runner setup + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Submit Rot13 action via session CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); @@ -1153,15 +1191,15 @@ TEST_CASE("function.remote.late_runner_discovery") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for W1 discovery - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Baseline: submit Rot13 action and verify it completes on W1 { @@ -1202,7 +1240,8 @@ TEST_CASE("function.remote.late_runner_discovery") MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", ServerUri1}, {"worker-2", ServerUri2}})); // Wait for W2 discovery - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Verify W2 received the worker by querying its /compute/workers endpoint directly { @@ -1274,16 +1313,16 @@ TEST_CASE("function.remote.queue_association") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a local queue and submit action to it auto QueueResult = Session.CreateQueue(); @@ -1353,16 +1392,16 @@ TEST_CASE("function.remote.queue_cancel_propagation") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a local queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1496,7 +1535,7 @@ TEST_CASE("function.session.abandon_pending") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.Ready(); CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); @@ -1515,19 +1554,29 @@ TEST_CASE("function.session.abandon_pending") REQUIRE_MESSAGE(Enqueue3, "Failed to enqueue action 3"); // Transition to Abandoned — should mark all pending actions as Abandoned - bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); CHECK(Session.GetSessionState() == zen::compute::ComputeServiceSession::SessionState::Abandoned); CHECK(!Session.IsHealthy()); - // Give the scheduler thread time to process the state changes - Sleep(2'000); - - // All three actions should now be in the results map as abandoned + // Poll until the scheduler thread has processed all abandoned actions into + // the results map. The abandon is asynchronous: AbandonAllActions() sets + // action state and posts updates, but HandleActionUpdates() on the + // scheduler thread must run before results are queryable. + Stopwatch Timer; for (int Lsn : {Enqueue1.Lsn, Enqueue2.Lsn, Enqueue3.Lsn}) { CbPackage Result; - HttpResponseCode Code = Session.GetActionResult(Lsn, Result); + HttpResponseCode Code = HttpResponseCode::Accepted; + while (Timer.GetElapsedTimeMs() < 10'000) + { + Code = Session.GetActionResult(Lsn, Result); + if (Code == HttpResponseCode::OK) + { + break; + } + Sleep(100); + } CHECK_MESSAGE(Code == HttpResponseCode::OK, fmt::format("Expected action LSN {} to be in results (got {})", Lsn, int(Code))); } @@ -1561,15 +1610,15 @@ TEST_CASE("function.session.abandon_running") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1585,7 +1634,7 @@ TEST_CASE("function.session.abandon_running") Sleep(2'000); // Transition to Abandoned — should abandon the running action - bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); CHECK(!Session.IsHealthy()); @@ -1631,16 +1680,16 @@ TEST_CASE("function.remote.abandon_propagation") InMemoryChunkResolver Resolver; ScopedTemporaryDirectory SessionBaseDir; zen::compute::ComputeServiceSession Session(Resolver); - Session.SetOrchestratorEndpoint(MockOrch.GetEndpoint()); - Session.SetOrchestratorBasePath(SessionBaseDir.Path()); - Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Ready); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); // Register worker on session CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); Session.RegisterWorker(WorkerPackage); // Wait for scheduler to discover the runner - Sleep(7'000); + Session.NotifyOrchestratorChanged(); + Sleep(2'000); // Create a local queue and submit a long-running Sleep action auto QueueResult = Session.CreateQueue(); @@ -1656,7 +1705,7 @@ TEST_CASE("function.remote.abandon_propagation") Sleep(2'000); // Transition to Abandoned — should abandon the running action and propagate - bool Transitioned = Session.RequestStateTransition(zen::compute::ComputeServiceSession::SessionState::Abandoned); + bool Transitioned = Session.Abandon(); CHECK_MESSAGE(Transitioned, "Failed to transition to Abandoned"); // Poll for the action to complete @@ -1693,6 +1742,278 @@ TEST_CASE("function.remote.abandon_propagation") Session.Shutdown(); } +TEST_CASE("function.remote.shutdown_cancels_queues") +{ + // Verify that Session.Shutdown() cancels remote queues on the compute node. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + Session.NotifyOrchestratorChanged(); + Sleep(2'000); + + // Create a queue and submit a long-running action so the remote queue is established + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create local queue"); + const int QueueId = QueueResult.QueueId; + + CbObject ActionObj = BuildSleepActionForSession("data"sv, 30'000, Resolver); + + auto EnqueueRes = Session.EnqueueActionToQueue(QueueId, ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Sleep action enqueue to queue failed"); + + // Wait for the action to start running on the remote + Sleep(2'000); + + // Verify the remote has a non-implicit queue before shutdown + HttpClient RemoteClient(Instance.GetBaseUri() + "/compute"); + { + HttpClient::Response QueuesResp = RemoteClient.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server before shutdown"); + + bool RemoteQueueFound = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueFound = true; + break; + } + } + REQUIRE_MESSAGE(RemoteQueueFound, "Expected remote queue to exist before shutdown"); + } + + // Shut down the session — this should cancel all remote queues + Session.Shutdown(); + + // Verify the remote queue is now cancelled + { + HttpClient::Response QueuesResp = RemoteClient.Get("/queues"sv); + REQUIRE_MESSAGE(QueuesResp, "Failed to list queues on remote server after shutdown"); + + bool RemoteQueueCancelled = false; + for (auto& Item : QueuesResp.AsObject()["queues"sv]) + { + if (!Item.AsObjectView()["implicit"sv].AsBool()) + { + RemoteQueueCancelled = std::string(Item.AsObjectView()["state"sv].AsString()) == "cancelled"; + break; + } + } + CHECK_MESSAGE(RemoteQueueCancelled, "Expected remote queue to be cancelled after session shutdown"); + } +} + +TEST_CASE("function.remote.shutdown_rejects_new_work") +{ + // Verify that after Shutdown() the remote runner rejects new submissions. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + REQUIRE_MESSAGE(Instance.SpawnServerAndWaitUntilReady() != 0, Instance.GetLogOutput()); + + MockOrchestratorFixture MockOrch; + MockOrch.Service.SetWorkerList(BuildAgentListResponse({{"worker-1", Instance.GetBaseUri()}})); + + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestrator(MockOrch.GetEndpoint(), SessionBaseDir.Path()); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + // Wait for runner discovery + Session.NotifyOrchestratorChanged(); + Sleep(2'000); + + // Baseline: submit an action and verify it completes + { + CbObject ActionObj = BuildRot13ActionForSession("Hello World"sv, Resolver); + + auto EnqueueRes = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(EnqueueRes, "Baseline action enqueue failed"); + + CbPackage ResultPackage; + HttpResponseCode ResultCode = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + ResultCode = Session.GetActionResult(EnqueueRes.Lsn, ResultPackage); + if (ResultCode == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(ResultCode == HttpResponseCode::OK, + fmt::format("Baseline action did not complete in time\nServer log:\n{}", Instance.GetLogOutput())); + CHECK_EQ(GetRot13Output(ResultPackage), "Uryyb Jbeyq"sv); + } + + // Shut down — the remote runner should now reject new work + Session.Shutdown(); + + // Attempting to enqueue after shutdown should fail (session is in Sunset state) + CbObject ActionObj = BuildRot13ActionForSession("rejected"sv, Resolver); + auto Rejected = Session.EnqueueAction(ActionObj, 0); + CHECK_MESSAGE(!Rejected, "Expected action submission to be rejected after shutdown"); +} + +TEST_CASE("function.session.retract_pending") +{ + // Create a session with no runners so actions stay pending + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + auto QueueResult = Session.CreateQueue(); + REQUIRE_MESSAGE(QueueResult.QueueId != 0, "Failed to create queue"); + + CbObject ActionObj = BuildRot13ActionForSession("retract-test"sv, Resolver); + + auto Enqueued = Session.EnqueueActionToQueue(QueueResult.QueueId, ActionObj, 0); + REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action"); + + // Let the scheduler process the pending action + Sleep(500); + + // Retract the pending action + auto Result = Session.RetractAction(Enqueued.Lsn); + CHECK_MESSAGE(Result.Success, fmt::format("RetractAction failed: {}", Result.Error)); + CHECK_EQ(Result.RetryCount, 0); // Retract should NOT increment retry count + + // The action should be re-enqueued as pending (still no runners, so stays pending). + // Let the scheduler process the retracted action back to pending. + Sleep(500); + + // Queue should still show 1 active (the action was rescheduled, not completed) + auto Status = Session.GetQueueStatus(QueueResult.QueueId); + CHECK_EQ(Status.ActiveCount, 1); + CHECK_EQ(Status.CompletedCount, 0); + CHECK_EQ(Status.AbandonedCount, 0); + CHECK_EQ(Status.CancelledCount, 0); + + // The action result should NOT be in the results map (it's pending again) + CbPackage ResultPackage; + HttpResponseCode Code = Session.GetActionResult(Enqueued.Lsn, ResultPackage); + CHECK(Code != HttpResponseCode::OK); + + Session.Shutdown(); +} + +TEST_CASE("function.session.retract_not_terminal") +{ + // Verify that a completed action cannot be retracted + InMemoryChunkResolver Resolver; + ScopedTemporaryDirectory SessionBaseDir; + zen::compute::ComputeServiceSession Session(Resolver); + Session.SetOrchestratorBasePath(SessionBaseDir.Path()); + Session.AddLocalRunner(Resolver, SessionBaseDir.Path()); + Session.Ready(); + + CbPackage WorkerPackage = BuildWorkerPackage(TestEnv, Resolver); + Session.RegisterWorker(WorkerPackage); + + CbObject ActionObj = BuildRot13ActionForSession("retract-completed"sv, Resolver); + + auto Enqueued = Session.EnqueueAction(ActionObj, 0); + REQUIRE_MESSAGE(Enqueued, "Failed to enqueue action"); + + // Wait for the action to complete + CbPackage ResultPackage; + HttpResponseCode Code = HttpResponseCode::Accepted; + Stopwatch Timer; + + while (Timer.GetElapsedTimeMs() < 30'000) + { + Code = Session.GetActionResult(Enqueued.Lsn, ResultPackage); + if (Code == HttpResponseCode::OK) + { + break; + } + Sleep(200); + } + + REQUIRE_MESSAGE(Code == HttpResponseCode::OK, "Action did not complete within timeout"); + + // Retract should fail — action already completed (no longer in pending/running maps) + auto RetractResult = Session.RetractAction(Enqueued.Lsn); + CHECK(!RetractResult.Success); + + Session.Shutdown(); +} + +TEST_CASE("function.retract_http") +{ + // Use max-actions=1 with a long sleep to hold the slot, then submit a second + // action that will stay pending and can be retracted via the HTTP endpoint. + ZenServerInstance Instance(TestEnv, ZenServerInstance::ServerMode::kComputeServer); + Instance.SetDataDir(TestEnv.CreateNewTestDir()); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--max-actions=1"); + REQUIRE_MESSAGE(Port != 0, Instance.GetLogOutput()); + + const std::string ComputeBaseUri = fmt::format("http://localhost:{}/compute", Port); + HttpClient Client(ComputeBaseUri); + + const IoHash WorkerId = RegisterWorker(Client, TestEnv); + + // Submit a long-running Sleep action to occupy the single execution slot + const std::string BlockerUrl = fmt::format("/jobs/{}", WorkerId.ToHexString()); + HttpClient::Response BlockerResp = Client.Post(BlockerUrl, BuildSleepActionPackage("data"sv, 30'000)); + REQUIRE_MESSAGE(BlockerResp, fmt::format("Blocker submission failed: status={}", int(BlockerResp.StatusCode))); + + // Submit a second action — it will stay pending because the slot is occupied + HttpClient::Response SubmitResp = Client.Post(BlockerUrl, BuildRot13ActionPackage("Retract HTTP Test"sv)); + REQUIRE_MESSAGE(SubmitResp, fmt::format("Job submission failed: status={}", int(SubmitResp.StatusCode))); + + const int Lsn = SubmitResp.AsObject()["lsn"sv].AsInt32(); + REQUIRE_MESSAGE(Lsn != 0, "Expected non-zero LSN from job submission"); + + // Wait for the scheduler to process the pending action into m_PendingActions + Sleep(1'000); + + // Retract the pending action via POST /jobs/{lsn}/retract + const std::string RetractUrl = fmt::format("/jobs/{}/retract", Lsn); + HttpClient::Response RetractResp = Client.Post(RetractUrl); + CHECK_MESSAGE(RetractResp.StatusCode == HttpResponseCode::OK, + fmt::format("Retract failed: status={}, body={}\nServer log:\n{}", + int(RetractResp.StatusCode), + RetractResp.ToText(), + Instance.GetLogOutput())); + + if (RetractResp.StatusCode == HttpResponseCode::OK) + { + CbObject RespObj = RetractResp.AsObject(); + CHECK(RespObj["success"sv].AsBool()); + CHECK_EQ(RespObj["lsn"sv].AsInt32(), Lsn); + } + + // A second retract should also succeed (action is back to pending) + Sleep(500); + HttpClient::Response RetractResp2 = Client.Post(RetractUrl); + CHECK_MESSAGE(RetractResp2.StatusCode == HttpResponseCode::OK, + fmt::format("Second retract failed: status={}, body={}", int(RetractResp2.StatusCode), RetractResp2.ToText())); +} + TEST_SUITE_END(); } // namespace zen::tests::compute diff --git a/src/zenserver-test/objectstore-tests.cpp b/src/zenserver-test/objectstore-tests.cpp index f3db5fdf6..1f6a7675c 100644 --- a/src/zenserver-test/objectstore-tests.cpp +++ b/src/zenserver-test/objectstore-tests.cpp @@ -2,10 +2,12 @@ #if ZEN_WITH_TESTS # include "zenserver-test.h" +# include <zencore/memoryview.h> # include <zencore/testing.h> # include <zencore/testutils.h> -# include <zenutil/zenserverprocess.h> # include <zenhttp/httpclient.h> +# include <zenutil/cloud/s3client.h> +# include <zenutil/zenserverprocess.h> ZEN_THIRD_PARTY_INCLUDES_START # include <tsl/robin_set.h> @@ -68,6 +70,94 @@ TEST_CASE("objectstore.blobs") } } +TEST_CASE("objectstore.s3client") +{ + ZenServerInstance Instance(TestEnv); + const uint16_t Port = Instance.SpawnServerAndWaitUntilReady("--objectstore-enabled"); + CHECK_MESSAGE(Port != 0, Instance.GetLogOutput()); + + // S3Client in path-style builds paths as /{bucket}/{key}. + // The objectstore routes objects at bucket/{bucket}/{key} relative to its base. + // Point the S3Client endpoint at {server}/obj/bucket so the paths line up. + S3ClientOptions Opts; + Opts.BucketName = "s3test"; + Opts.Region = "us-east-1"; + Opts.Endpoint = fmt::format("http://localhost:{}/obj/bucket", Port); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = "testkey"; + Opts.Credentials.SecretAccessKey = "testsecret"; + + S3Client Client(Opts); + + // -- PUT + GET roundtrip -- + std::string_view TestData = "hello from s3client via objectstore"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); + REQUIRE_MESSAGE(PutRes.IsSuccess(), PutRes.Error); + + S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); + REQUIRE_MESSAGE(GetRes.IsSuccess(), GetRes.Error); + CHECK(GetRes.AsText() == TestData); + + // -- PUT overwrites -- + IoBuffer Original = IoBufferBuilder::MakeFromMemory(MakeMemoryView("original"sv)); + IoBuffer Overwrite = IoBufferBuilder::MakeFromMemory(MakeMemoryView("overwritten"sv)); + REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Original)).IsSuccess()); + REQUIRE(Client.PutObject("overwrite/file.txt", std::move(Overwrite)).IsSuccess()); + + S3GetObjectResult OverwriteGet = Client.GetObject("overwrite/file.txt"); + REQUIRE(OverwriteGet.IsSuccess()); + CHECK(OverwriteGet.AsText() == "overwritten"sv); + + // -- GET not found -- + S3GetObjectResult NotFoundGet = Client.GetObject("nonexistent/file.dat"); + CHECK_FALSE(NotFoundGet.IsSuccess()); + + // -- HEAD found -- + std::string_view HeadData = "head test data"sv; + IoBuffer HeadContent = IoBufferBuilder::MakeFromMemory(MakeMemoryView(HeadData)); + REQUIRE(Client.PutObject("head/meta.txt", std::move(HeadContent)).IsSuccess()); + + S3HeadObjectResult HeadRes = Client.HeadObject("head/meta.txt"); + REQUIRE_MESSAGE(HeadRes.IsSuccess(), HeadRes.Error); + CHECK(HeadRes.Status == HeadObjectResult::Found); + CHECK(HeadRes.Info.Size == HeadData.size()); + + // -- HEAD not found -- + S3HeadObjectResult HeadNotFound = Client.HeadObject("nonexistent/file.dat"); + CHECK(HeadNotFound.IsSuccess()); + CHECK(HeadNotFound.Status == HeadObjectResult::NotFound); + + // -- LIST objects -- + for (int i = 0; i < 3; ++i) + { + std::string Key = fmt::format("listing/item-{}.txt", i); + std::string Payload = fmt::format("content-{}", i); + IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); + REQUIRE(Client.PutObject(Key, std::move(Buf)).IsSuccess()); + } + + S3ListObjectsResult ListRes = Client.ListObjects("listing/"); + REQUIRE_MESSAGE(ListRes.IsSuccess(), ListRes.Error); + REQUIRE(ListRes.Objects.size() == 3); + + std::vector<std::string> Keys; + for (const S3ObjectInfo& Obj : ListRes.Objects) + { + Keys.push_back(Obj.Key); + CHECK(Obj.Size > 0); + } + std::sort(Keys.begin(), Keys.end()); + CHECK(Keys[0] == "listing/item-0.txt"); + CHECK(Keys[1] == "listing/item-1.txt"); + CHECK(Keys[2] == "listing/item-2.txt"); + + // -- LIST empty prefix -- + S3ListObjectsResult EmptyList = Client.ListObjects("no-such-prefix/"); + REQUIRE(EmptyList.IsSuccess()); + CHECK(EmptyList.Objects.empty()); +} + TEST_SUITE_END(); } // namespace zen::tests diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp index e89812c1f..42632682b 100644 --- a/src/zenserver-test/zenserver-test.cpp +++ b/src/zenserver-test/zenserver-test.cpp @@ -15,6 +15,7 @@ # include <zencore/testutils.h> # include <zencore/thread.h> # include <zencore/timer.h> +# include <zencore/trace.h> # include <zenhttp/httpclient.h> # include <zenhttp/packageformat.h> # include <zenutil/config/commandlineoptions.h> @@ -134,8 +135,7 @@ main(int argc, char** argv) ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir); zen::testing::TestRunner Runner; - Runner.SetDefaultSuiteFilter("server.*"); - Runner.ApplyCommandLine(argc, argv); + Runner.ApplyCommandLine(argc, argv, "server.*"); return Runner.Run(); } diff --git a/src/zenserver/compute/computeserver.cpp b/src/zenserver/compute/computeserver.cpp index 724ef9ad2..d1875f41a 100644 --- a/src/zenserver/compute/computeserver.cpp +++ b/src/zenserver/compute/computeserver.cpp @@ -721,21 +721,7 @@ ZenComputeServer::InitializeOrchestratorWebSocket() return; } - // Convert http://host:port → ws://host:port/orch/ws - std::string WsUrl = m_CoordinatorEndpoint; - if (WsUrl.starts_with("http://")) - { - WsUrl = "ws://" + WsUrl.substr(7); - } - else if (WsUrl.starts_with("https://")) - { - WsUrl = "wss://" + WsUrl.substr(8); - } - if (!WsUrl.empty() && WsUrl.back() != '/') - { - WsUrl += '/'; - } - WsUrl += "orch/ws"; + std::string WsUrl = HttpToWsUrl(m_CoordinatorEndpoint, "/orch/ws"); ZEN_INFO("establishing WebSocket link to orchestrator at {}", WsUrl); diff --git a/src/zenserver/frontend/html/compute/compute.html b/src/zenserver/frontend/html/compute/compute.html index 66c20175f..c07bbb692 100644 --- a/src/zenserver/frontend/html/compute/compute.html +++ b/src/zenserver/frontend/html/compute/compute.html @@ -6,6 +6,7 @@ <title>Zen Compute Dashboard</title> <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.min.js"></script> <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../util/sanitize.js"></script> <script src="../theme.js"></script> <script src="../banner.js" defer></script> <script src="../nav.js" defer></script> @@ -456,11 +457,6 @@ }); // Helper functions - function escapeHtml(text) { - var div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; - } function formatBytes(bytes) { if (bytes === 0) return '0 B'; diff --git a/src/zenserver/frontend/html/compute/hub.html b/src/zenserver/frontend/html/compute/hub.html index 32e1b05db..620349a2b 100644 --- a/src/zenserver/frontend/html/compute/hub.html +++ b/src/zenserver/frontend/html/compute/hub.html @@ -4,6 +4,7 @@ <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../util/sanitize.js"></script> <script src="../theme.js"></script> <script src="../banner.js" defer></script> <script src="../nav.js" defer></script> @@ -62,12 +63,6 @@ const BASE_URL = window.location.origin; const REFRESH_INTERVAL = 2000; - function escapeHtml(text) { - var div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; - } - function showError(message) { document.getElementById('error-container').innerHTML = '<div class="error">Error: ' + escapeHtml(message) + '</div>'; diff --git a/src/zenserver/frontend/html/compute/orchestrator.html b/src/zenserver/frontend/html/compute/orchestrator.html index a519dee18..d1a2bb015 100644 --- a/src/zenserver/frontend/html/compute/orchestrator.html +++ b/src/zenserver/frontend/html/compute/orchestrator.html @@ -4,6 +4,7 @@ <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <link rel="stylesheet" type="text/css" href="../zen.css" /> + <script src="../util/sanitize.js"></script> <script src="../theme.js"></script> <script src="../banner.js" defer></script> <script src="../nav.js" defer></script> @@ -128,12 +129,6 @@ const BASE_URL = window.location.origin; const REFRESH_INTERVAL = 2000; - function escapeHtml(text) { - var div = document.createElement('div'); - div.textContent = text; - return div.innerHTML; - } - function showError(message) { document.getElementById('error-container').innerHTML = '<div class="error">Error: ' + escapeHtml(message) + '</div>'; diff --git a/src/zenserver/frontend/html/util/sanitize.js b/src/zenserver/frontend/html/util/sanitize.js new file mode 100644 index 000000000..1b0f32e38 --- /dev/null +++ b/src/zenserver/frontend/html/util/sanitize.js @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// Shared utility functions for compute dashboard pages. + +function escapeHtml(text) { + var div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; +} diff --git a/src/zenserver/hub/hub.cpp b/src/zenserver/hub/hub.cpp index 3c9f40eaa..2f3873884 100644 --- a/src/zenserver/hub/hub.cpp +++ b/src/zenserver/hub/hub.cpp @@ -20,6 +20,7 @@ ZEN_THIRD_PARTY_INCLUDES_END # include <zencore/testing.h> # include <zencore/testutils.h> # include <zencore/workthreadpool.h> +# include <zenhttp/httpclient.h> #endif #include <numeric> @@ -232,7 +233,8 @@ Hub::Provision(std::string_view ModuleId, HubProvisionedInstanceInfo& OutInfo, s .HydrationTempPath = m_HydrationTempPath, .FileHydrationPath = m_FileHydrationPath, .HttpThreadCount = m_Config.InstanceHttpThreadCount, - .CoreLimit = m_Config.InstanceCoreLimit}, + .CoreLimit = m_Config.InstanceCoreLimit, + .ConfigPath = m_Config.InstanceConfigPath}, ModuleId); #if ZEN_PLATFORM_WINDOWS if (m_JobObject.IsValid()) @@ -504,6 +506,45 @@ TEST_CASE("hub.provision_basic") CHECK_FALSE(HubInstance->Find("module_a")); } +TEST_CASE("hub.provision_config") +{ + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path() / "hub"); + + std::string LuaConfig = + "server = {\n" + " buildstore = {\n" + " enabled = true,\n" + " }\n" + "}\n"; + + WriteFile(TempDir.Path() / "config.lua", IoBuffer(IoBuffer::Wrap, LuaConfig.data(), LuaConfig.length())); + + std::unique_ptr<Hub> HubInstance = + hub_testutils::MakeHub(TempDir.Path() / "hub", Hub::Configuration{.InstanceConfigPath = TempDir.Path() / "config.lua"}); + + CHECK_EQ(HubInstance->GetInstanceCount(), 0); + CHECK_FALSE(HubInstance->Find("module_a")); + + HubProvisionedInstanceInfo Info; + std::string Reason; + const bool ProvisionResult = HubInstance->Provision("module_a", Info, Reason); + REQUIRE_MESSAGE(ProvisionResult, Reason); + CHECK_NE(Info.Port, 0); + CHECK_EQ(HubInstance->GetInstanceCount(), 1); + CHECK(HubInstance->Find("module_a")); + + HttpClient Client(fmt::format("http://127.0.0.1:{}{}", Info.Port, Info.BaseUri)); + HttpClient::Response TestResponse = Client.Get("/status/builds"); + CHECK(TestResponse.IsSuccess()); + CHECK(TestResponse.AsObject()["ok"].AsBool()); + + const bool DeprovisionResult = HubInstance->Deprovision("module_a", Reason); + CHECK(DeprovisionResult); + CHECK_EQ(HubInstance->GetInstanceCount(), 0); + CHECK_FALSE(HubInstance->Find("module_a")); +} + TEST_CASE("hub.provision_callbacks") { ScopedTemporaryDirectory TempDir; diff --git a/src/zenserver/hub/hub.h b/src/zenserver/hub/hub.h index 8a84a558b..78be3eda1 100644 --- a/src/zenserver/hub/hub.h +++ b/src/zenserver/hub/hub.h @@ -45,8 +45,9 @@ public: int InstanceLimit = 1000; - uint32_t InstanceHttpThreadCount = 0; // Deduce from core count - int InstanceCoreLimit = 0; // Use hardware core count + uint32_t InstanceHttpThreadCount = 0; // Deduce from core count + int InstanceCoreLimit = 0; // Use hardware core count + std::filesystem::path InstanceConfigPath; }; typedef std::function<void(std::string_view ModuleId, const HubProvisionedInstanceInfo& Info)> ProvisionModuleCallbackFunc; diff --git a/src/zenserver/hub/storageserverinstance.cpp b/src/zenserver/hub/storageserverinstance.cpp index 68de5e274..8e71e7aca 100644 --- a/src/zenserver/hub/storageserverinstance.cpp +++ b/src/zenserver/hub/storageserverinstance.cpp @@ -45,6 +45,10 @@ StorageServerInstance::SpawnServerProcess() { AdditionalOptions << " --corelimit=" << m_Config.CoreLimit; } + if (!m_Config.ConfigPath.empty()) + { + AdditionalOptions << " --config=\"" << MakeSafeAbsolutePath(m_Config.ConfigPath).string() << "\""; + } m_ServerInstance.SpawnServerAndWaitUntilReady(m_Config.BasePort, AdditionalOptions.ToView()); ZEN_DEBUG("Storage server instance for module '{}' started, listening on port {}", m_ModuleId, m_Config.BasePort); diff --git a/src/zenserver/hub/storageserverinstance.h b/src/zenserver/hub/storageserverinstance.h index 23196d835..3b3cae385 100644 --- a/src/zenserver/hub/storageserverinstance.h +++ b/src/zenserver/hub/storageserverinstance.h @@ -28,6 +28,7 @@ public: std::filesystem::path FileHydrationPath; uint32_t HttpThreadCount = 0; // Deduce from core count int CoreLimit = 0; // Use hardware core count + std::filesystem::path ConfigPath; }; StorageServerInstance(ZenServerEnvironment& RunEnvironment, const Configuration& Config, std::string_view ModuleId); diff --git a/src/zenserver/hub/zenhubserver.cpp b/src/zenserver/hub/zenhubserver.cpp index 313be977c..b36a0778e 100644 --- a/src/zenserver/hub/zenhubserver.cpp +++ b/src/zenserver/hub/zenhubserver.cpp @@ -94,6 +94,13 @@ ZenHubServerConfigurator::AddCliOptions(cxxopts::Options& Options) cxxopts::value(m_ServerOptions.HubInstanceCoreLimit), "<instance core limit>"); + Options.add_option("hub", + "", + "hub-instance-config", + "Path to Lua config file for provisioned instances", + cxxopts::value(m_ServerOptions.HubInstanceConfigPath), + "<instance config>"); + #if ZEN_PLATFORM_WINDOWS Options.add_option("hub", "", @@ -269,7 +276,8 @@ ZenHubServer::InitializeServices(const ZenHubServerConfig& ServerConfig) .BasePortNumber = ServerConfig.HubBasePortNumber, .InstanceLimit = ServerConfig.HubInstanceLimit, .InstanceHttpThreadCount = ServerConfig.HubInstanceHttpThreadCount, - .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit}, + .InstanceCoreLimit = ServerConfig.HubInstanceCoreLimit, + .InstanceConfigPath = ServerConfig.HubInstanceConfigPath}, ZenServerEnvironment(ZenServerEnvironment::Hub, ServerConfig.DataDir / "hub", ServerConfig.DataDir / "servers", diff --git a/src/zenserver/hub/zenhubserver.h b/src/zenserver/hub/zenhubserver.h index 1036598bb..7e85159f1 100644 --- a/src/zenserver/hub/zenhubserver.h +++ b/src/zenserver/hub/zenhubserver.h @@ -21,15 +21,16 @@ class HttpHubService; struct ZenHubServerConfig : public ZenServerConfig { - std::string UpstreamNotificationEndpoint; - std::string InstanceId; // For use in notifications - std::string ConsulEndpoint; // If set, enables Consul service registration - uint16_t HubBasePortNumber = 21000; - int HubInstanceLimit = 1000; - bool HubUseJobObject = true; - std::string HubInstanceHttpClass = "asio"; - uint32_t HubInstanceHttpThreadCount = 0; // Deduce from core count - int HubInstanceCoreLimit = 0; // Use hardware core count + std::string UpstreamNotificationEndpoint; + std::string InstanceId; // For use in notifications + std::string ConsulEndpoint; // If set, enables Consul service registration + uint16_t HubBasePortNumber = 21000; + int HubInstanceLimit = 1000; + bool HubUseJobObject = true; + std::string HubInstanceHttpClass = "asio"; + uint32_t HubInstanceHttpThreadCount = 0; // Deduce from core count + int HubInstanceCoreLimit = 0; // Use hardware core count + std::filesystem::path HubInstanceConfigPath; // Path to Lua config file }; class Hub; diff --git a/src/zenserver/main.cpp b/src/zenserver/main.cpp index 26ae85ae1..9d786c209 100644 --- a/src/zenserver/main.cpp +++ b/src/zenserver/main.cpp @@ -319,6 +319,11 @@ main(int argc, char* argv[]) { ServerMode = kTest; } + else if (argv[1][0] != '-') + { + fprintf(stderr, "unknown mode '%s'. Available modes: hub, store, compute, proxy, test\n", argv[1]); + return 1; + } } switch (ServerMode) diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua index fe279ebb2..aa306190f 100644 --- a/src/zenserver/xmake.lua +++ b/src/zenserver/xmake.lua @@ -19,7 +19,7 @@ target("zenserver") add_headerfiles("**.h") add_rules("utils.bin2c", {extensions = {".zip"}}) add_files("**.cpp") - add_files("frontend/html.zip") + add_files(path.join(os.projectdir(), get_config("builddir") or get_config("buildir") or "build", "frontend/html.zip")) add_files("zenserver.cpp", {unity_ignored = true }) if is_plat("linux") and not (get_config("toolchain") or ""):find("clang") then @@ -36,6 +36,7 @@ target("zenserver") add_packages("json11") add_packages("lua") add_packages("consul") + add_packages("minio") add_packages("oidctoken") add_packages("nomad") @@ -84,7 +85,8 @@ target("zenserver") on_load(function(target) local html_dir = path.join(os.projectdir(), "src/zenserver/frontend/html") - local zip_path = path.join(os.projectdir(), "src/zenserver/frontend/html.zip") + local zip_dir = path.join(os.projectdir(), get_config("buildir") or "build", "frontend") + local zip_path = path.join(zip_dir, "html.zip") -- Check if zip needs regeneration local need_update = not os.isfile(zip_path) @@ -100,18 +102,19 @@ target("zenserver") if need_update then print("Regenerating frontend zip...") + os.mkdir(zip_dir) os.tryrm(zip_path) import("detect.tools.find_7z") local cmd_7z = find_7z() if cmd_7z then - os.execv(cmd_7z, {"a", "-mx0", zip_path, path.join(html_dir, ".")}) + os.execv(cmd_7z, {"a", "-mx0", "-bso0", zip_path, path.join(html_dir, ".")}) else import("detect.tools.find_zip") local zip_cmd = find_zip() if zip_cmd then local oldir = os.cd(html_dir) - os.execv(zip_cmd, {"-r", "-0", zip_path, "."}) + os.execv(zip_cmd, {"-r", "-0", "-q", zip_path, "."}) os.cd(oldir) else raise("Unable to find a suitable zip tool (need 7z or zip)") @@ -213,6 +216,16 @@ target("zenserver") copy_if_newer(path.join(installdir, "bin", consul_bin), path.join(target:targetdir(), consul_bin), consul_bin) end + local minio_pkg = target:pkg("minio") + if minio_pkg then + local installdir = minio_pkg:installdir() + local minio_bin = "minio" + if is_plat("windows") then + minio_bin = "minio.exe" + end + copy_if_newer(path.join(installdir, "bin", minio_bin), path.join(target:targetdir(), minio_bin), minio_bin) + end + local oidctoken_pkg = target:pkg("oidctoken") if oidctoken_pkg then local installdir = oidctoken_pkg:installdir() diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp index 1cd8ed846..fe6a5a572 100644 --- a/src/zenserver/zenserver.cpp +++ b/src/zenserver/zenserver.cpp @@ -631,6 +631,10 @@ ZenServerMain::Run() uint32_t AttachSponsorProcessRetriesLeft = 3; ZenServerState::ZenServerEntry* Entry = ServerState.Lookup(m_ServerOptions.BasePort); + // NOTE: ZEN_CONSOLE_WARN/INFO is used in this block and the lock file block below + // (instead of ZEN_WARN/INFO) because InitializeLogging() has not been called yet at + // this point. ZEN_WARN/INFO would silently discard messages before the logging system + // is initialized. while (Entry) { if (m_ServerOptions.OwnerPid) @@ -640,27 +644,29 @@ ZenServerMain::Run() { if (Ec) { - ZEN_WARN(ZEN_APP_NAME - " exiting, sponsor owner pid {} can not be checked for running state, reason: '{}'. Will not add sponsor " - "to process " - "listening to port {} (pid: {})", - m_ServerOptions.OwnerPid, - Ec.message(), - m_ServerOptions.BasePort, - Entry->Pid.load()); + ZEN_CONSOLE_WARN( + ZEN_APP_NAME + " exiting, sponsor owner pid {} can not be checked for running state, reason: '{}'. Will not add sponsor " + "to process " + "listening to port {} (pid: {})", + m_ServerOptions.OwnerPid, + Ec.message(), + m_ServerOptions.BasePort, + Entry->Pid.load()); } else { - ZEN_WARN(ZEN_APP_NAME - " exiting, sponsor owner pid {} is no longer running, will not add sponsor to process listening to port " - "{} (pid: {})", - m_ServerOptions.OwnerPid, - m_ServerOptions.BasePort, - Entry->Pid.load()); + ZEN_CONSOLE_WARN( + ZEN_APP_NAME + " exiting, sponsor owner pid {} is no longer running, will not add sponsor to process listening to port " + "{} (pid: {})", + m_ServerOptions.OwnerPid, + m_ServerOptions.BasePort, + Entry->Pid.load()); } std::exit(1); } - ZEN_INFO( + ZEN_CONSOLE_INFO( "Looks like there is already a process listening to this port {} (pid: {}), attaching owner pid {} to running instance", m_ServerOptions.BasePort, Entry->Pid.load(), @@ -678,18 +684,18 @@ ZenServerMain::Run() } else { - ZEN_WARN(ZEN_APP_NAME " exiting, failed to add sponsor owner pid {} to process listening to port {} (pid: {})", - m_ServerOptions.OwnerPid, - m_ServerOptions.BasePort, - Entry->Pid.load()); + ZEN_CONSOLE_WARN(ZEN_APP_NAME " exiting, failed to add sponsor owner pid {} to process listening to port {} (pid: {})", + m_ServerOptions.OwnerPid, + m_ServerOptions.BasePort, + Entry->Pid.load()); std::exit(1); } } else { - ZEN_WARN(ZEN_APP_NAME " exiting, there is already a process listening to port {} (pid: {})", - m_ServerOptions.BasePort, - Entry->Pid.load()); + ZEN_CONSOLE_WARN(ZEN_APP_NAME " exiting, there is already a process listening to port {} (pid: {})", + m_ServerOptions.BasePort, + Entry->Pid.load()); std::exit(1); } } @@ -702,19 +708,19 @@ ZenServerMain::Run() if (Ec) { - ZEN_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message()); + ZEN_CONSOLE_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message()); Sleep(100); m_LockFile.Create(LockFilePath, MakeLockData(false), Ec); if (Ec) { - ZEN_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message()); + ZEN_CONSOLE_INFO(ZEN_APP_NAME " unable to grab lock at '{}' (reason: '{}'), retrying", LockFilePath, Ec.message()); Sleep(500); m_LockFile.Create(LockFilePath, MakeLockData(false), Ec); if (Ec) { - ZEN_WARN(ZEN_APP_NAME " exiting, unable to grab lock at '{}' (reason: '{}')", LockFilePath, Ec.message()); + ZEN_CONSOLE_WARN(ZEN_APP_NAME " exiting, unable to grab lock at '{}' (reason: '{}')", LockFilePath, Ec.message()); std::exit(99); } } @@ -736,6 +742,12 @@ ZenServerMain::Run() Entry = ServerState.Register(m_ServerOptions.BasePort); + if (!Entry) + { + throw std::runtime_error( + fmt::format("Failed to register server on port {} in shared state (all slots occupied)", m_ServerOptions.BasePort)); + } + // Publish per-instance extended info (e.g. UDS path) via a small shared memory // section keyed by SessionId so clients can discover it during Snapshot() enumeration. { @@ -762,22 +774,22 @@ ZenServerMain::Run() } catch (const AssertException& AssertEx) { - ZEN_CRITICAL(ZEN_APP_NAME " caught assert exception in main for process {}: {}", - zen::GetCurrentProcessId(), - AssertEx.FullDescription()); + ZEN_CONSOLE_CRITICAL(ZEN_APP_NAME " caught assert exception in main for process {}: {}", + zen::GetCurrentProcessId(), + AssertEx.FullDescription()); RequestApplicationExit(1); } catch (const std::system_error& e) { - ZEN_CRITICAL(ZEN_APP_NAME " caught system error exception in main for process {}: {} ({})", - zen::GetCurrentProcessId(), - e.what(), - e.code().value()); + ZEN_CONSOLE_CRITICAL(ZEN_APP_NAME " caught system error exception in main for process {}: {} ({})", + zen::GetCurrentProcessId(), + e.what(), + e.code().value()); RequestApplicationExit(1); } catch (const std::exception& e) { - ZEN_CRITICAL(ZEN_APP_NAME " caught exception in main for process {}: {}", zen::GetCurrentProcessId(), e.what()); + ZEN_CONSOLE_CRITICAL(ZEN_APP_NAME " caught exception in main for process {}: {}", zen::GetCurrentProcessId(), e.what()); RequestApplicationExit(1); } diff --git a/src/zenutil/cloud/cloudprovider.cpp b/src/zenutil/cloud/cloudprovider.cpp new file mode 100644 index 000000000..e32a50c64 --- /dev/null +++ b/src/zenutil/cloud/cloudprovider.cpp @@ -0,0 +1,23 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/cloudprovider.h> + +namespace zen::compute { + +std::string_view +ToString(CloudProvider Provider) +{ + switch (Provider) + { + case CloudProvider::AWS: + return "AWS"; + case CloudProvider::Azure: + return "Azure"; + case CloudProvider::GCP: + return "GCP"; + default: + return "None"; + } +} + +} // namespace zen::compute diff --git a/src/zenutil/cloud/imdscredentials.cpp b/src/zenutil/cloud/imdscredentials.cpp new file mode 100644 index 000000000..dde1dc019 --- /dev/null +++ b/src/zenutil/cloud/imdscredentials.cpp @@ -0,0 +1,387 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/imdscredentials.h> + +#include <zenutil/cloud/mockimds.h> + +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zenhttp/httpserver.h> + +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace { + + /// Margin before expiration at which we proactively refresh credentials. + constexpr auto kRefreshMargin = std::chrono::minutes(5); + + /// Parse an ISO 8601 UTC timestamp (e.g. "2026-03-14T20:00:00Z") into a system_clock time_point. + /// Returns epoch on failure. + std::chrono::system_clock::time_point ParseIso8601(std::string_view Timestamp) + { + // Expected format: YYYY-MM-DDTHH:MM:SSZ + if (Timestamp.size() < 19) + { + return {}; + } + + std::tm Tm = {}; + // Manual parse since std::get_time is locale-dependent + Tm.tm_year = ParseInt<int>(Timestamp.substr(0, 4)).value_or(1970) - 1900; + Tm.tm_mon = ParseInt<int>(Timestamp.substr(5, 2)).value_or(1) - 1; + Tm.tm_mday = ParseInt<int>(Timestamp.substr(8, 2)).value_or(1); + Tm.tm_hour = ParseInt<int>(Timestamp.substr(11, 2)).value_or(0); + Tm.tm_min = ParseInt<int>(Timestamp.substr(14, 2)).value_or(0); + Tm.tm_sec = ParseInt<int>(Timestamp.substr(17, 2)).value_or(0); + +#if ZEN_PLATFORM_WINDOWS + time_t EpochSeconds = _mkgmtime(&Tm); +#else + time_t EpochSeconds = timegm(&Tm); +#endif + if (EpochSeconds == -1) + { + return {}; + } + + return std::chrono::system_clock::from_time_t(EpochSeconds); + } + +} // namespace + +ImdsCredentialProvider::ImdsCredentialProvider(const ImdsCredentialProviderOptions& Options) +: m_Log(logging::Get("imds")) +, m_HttpClient(Options.Endpoint, + HttpClientSettings{ + .LogCategory = "imds", + .ConnectTimeout = Options.ConnectTimeout, + .Timeout = Options.RequestTimeout, + }) +{ + ZEN_INFO("IMDS credential provider configured (endpoint: {})", m_HttpClient.GetBaseUri()); +} + +ImdsCredentialProvider::~ImdsCredentialProvider() = default; + +SigV4Credentials +ImdsCredentialProvider::GetCredentials() +{ + // Fast path: shared lock for cache hit + { + RwLock::SharedLockScope SharedLock(m_Lock); + if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt) + { + return m_CachedCredentials; + } + } + + // Slow path: exclusive lock to refresh + RwLock::ExclusiveLockScope ExclusiveLock(m_Lock); + + // Double-check after acquiring exclusive lock + if (!m_CachedCredentials.AccessKeyId.empty() && std::chrono::steady_clock::now() < m_ExpiresAt) + { + return m_CachedCredentials; + } + + if (!FetchCredentials()) + { + ZEN_WARN("failed to fetch credentials from IMDS"); + return {}; + } + + return m_CachedCredentials; +} + +void +ImdsCredentialProvider::InvalidateCache() +{ + RwLock::ExclusiveLockScope ExclusiveLock(m_Lock); + m_CachedCredentials = {}; + m_ExpiresAt = {}; +} + +bool +ImdsCredentialProvider::FetchToken() +{ + HttpClient::KeyValueMap Headers; + Headers->emplace("X-aws-ec2-metadata-token-ttl-seconds", "21600"); + + HttpClient::Response Response = m_HttpClient.Put("/latest/api/token", Headers); + if (!Response.IsSuccess()) + { + ZEN_WARN("IMDS token request failed: {}", Response.ErrorMessage("PUT /latest/api/token")); + return false; + } + + m_ImdsToken = std::string(Response.AsText()); + if (m_ImdsToken.empty()) + { + ZEN_WARN("IMDS returned empty token"); + return false; + } + + return true; +} + +bool +ImdsCredentialProvider::FetchCredentials() +{ + // Step 1: Get IMDSv2 session token + if (!FetchToken()) + { + return false; + } + + HttpClient::KeyValueMap TokenHeader; + TokenHeader->emplace("X-aws-ec2-metadata-token", m_ImdsToken); + + // Step 2: Discover IAM role name (if not already known) + if (m_RoleName.empty()) + { + HttpClient::Response RoleResponse = m_HttpClient.Get("/latest/meta-data/iam/security-credentials/", TokenHeader); + if (!RoleResponse.IsSuccess()) + { + ZEN_WARN("IMDS role discovery failed: {}", RoleResponse.ErrorMessage("GET iam/security-credentials/")); + return false; + } + + m_RoleName = std::string(RoleResponse.AsText()); + // Trim any trailing whitespace/newlines + while (!m_RoleName.empty() && (m_RoleName.back() == '\n' || m_RoleName.back() == '\r' || m_RoleName.back() == ' ')) + { + m_RoleName.pop_back(); + } + + if (m_RoleName.empty()) + { + ZEN_WARN("IMDS returned empty IAM role name"); + return false; + } + + ZEN_INFO("IMDS discovered IAM role: {}", m_RoleName); + } + + // Step 3: Fetch credentials for the role + std::string CredentialPath = fmt::format("/latest/meta-data/iam/security-credentials/{}", m_RoleName); + + HttpClient::Response CredResponse = m_HttpClient.Get(CredentialPath, TokenHeader); + if (!CredResponse.IsSuccess()) + { + ZEN_WARN("IMDS credential fetch failed: {}", CredResponse.ErrorMessage("GET iam/security-credentials/" + m_RoleName)); + return false; + } + + // Step 4: Parse JSON response + std::string JsonError; + const json11::Json Json = json11::Json::parse(std::string(CredResponse.AsText()), JsonError); + + if (!JsonError.empty()) + { + ZEN_WARN("IMDS credential response JSON parse error: {}", JsonError); + return false; + } + + std::string AccessKeyId = Json["AccessKeyId"].string_value(); + std::string SecretAccessKey = Json["SecretAccessKey"].string_value(); + std::string SessionToken = Json["Token"].string_value(); + std::string Expiration = Json["Expiration"].string_value(); + + if (AccessKeyId.empty() || SecretAccessKey.empty()) + { + ZEN_WARN("IMDS credential response missing AccessKeyId or SecretAccessKey"); + return false; + } + + // Compute local expiration time based on the Expiration field + auto ExpirationTime = ParseIso8601(Expiration); + auto Now = std::chrono::system_clock::now(); + + std::chrono::steady_clock::time_point NewExpiresAt; + if (ExpirationTime > Now) + { + auto TimeUntilExpiry = ExpirationTime - Now; + NewExpiresAt = std::chrono::steady_clock::now() + TimeUntilExpiry - kRefreshMargin; + } + else + { + // Expiration is in the past or unparseable — force refresh next time + NewExpiresAt = std::chrono::steady_clock::now(); + } + + bool KeyChanged = (m_CachedCredentials.AccessKeyId != AccessKeyId); + + m_CachedCredentials.AccessKeyId = std::move(AccessKeyId); + m_CachedCredentials.SecretAccessKey = std::move(SecretAccessKey); + m_CachedCredentials.SessionToken = std::move(SessionToken); + m_ExpiresAt = NewExpiresAt; + + if (KeyChanged) + { + ZEN_INFO("IMDS credentials refreshed (AccessKeyId: {}...)", m_CachedCredentials.AccessKeyId.substr(0, 8)); + } + else + { + ZEN_DEBUG("IMDS credentials refreshed (unchanged key)"); + } + + return true; +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +imdscredentials_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.imdscredentials"); + +TEST_CASE("imdscredentials.parse_iso8601") +{ + // Verify basic ISO 8601 parsing + auto Tp = ParseIso8601("2026-03-14T20:00:00Z"); + CHECK(Tp != std::chrono::system_clock::time_point{}); + + auto Epoch = std::chrono::system_clock::to_time_t(Tp); + std::tm Tm; +# if ZEN_PLATFORM_WINDOWS + gmtime_s(&Tm, &Epoch); +# else + gmtime_r(&Epoch, &Tm); +# endif + CHECK(Tm.tm_year + 1900 == 2026); + CHECK(Tm.tm_mon + 1 == 3); + CHECK(Tm.tm_mday == 14); + CHECK(Tm.tm_hour == 20); + CHECK(Tm.tm_min == 0); + CHECK(Tm.tm_sec == 0); + + // Invalid input + auto Bad = ParseIso8601("bad"); + CHECK(Bad == std::chrono::system_clock::time_point{}); +} + +// --------------------------------------------------------------------------- +// Integration test with mock IMDS server +// --------------------------------------------------------------------------- + +struct TestImdsServer +{ + compute::MockImdsService Mock; + + void Start() + { + m_TmpDir.emplace(); + m_Server = CreateHttpServer(HttpServerConfig{.ServerClass = "asio"}); + m_Port = m_Server->Initialize(7576, m_TmpDir->Path() / "http"); + REQUIRE(m_Port != -1); + m_Server->RegisterService(Mock); + m_ServerThread = std::thread([this]() { m_Server->Run(false); }); + } + + std::string Endpoint() const { return fmt::format("http://127.0.0.1:{}", m_Port); } + + ~TestImdsServer() + { + if (m_Server) + { + m_Server->RequestExit(); + } + if (m_ServerThread.joinable()) + { + m_ServerThread.join(); + } + if (m_Server) + { + m_Server->Close(); + } + } + +private: + std::optional<ScopedTemporaryDirectory> m_TmpDir; + Ref<HttpServer> m_Server; + std::thread m_ServerThread; + int m_Port = -1; +}; + +TEST_CASE("imdscredentials.fetch_from_mock") +{ + TestImdsServer Imds; + Imds.Mock.ActiveProvider = compute::CloudProvider::AWS; + Imds.Start(); + + ImdsCredentialProviderOptions Opts; + Opts.Endpoint = Imds.Endpoint(); + + Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts)); + + SUBCASE("basic_credential_fetch") + { + SigV4Credentials Creds = Provider->GetCredentials(); + CHECK(!Creds.AccessKeyId.empty()); + CHECK(Creds.AccessKeyId == "ASIAIOSFODNN7EXAMPLE"); + CHECK(Creds.SecretAccessKey == "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"); + CHECK(Creds.SessionToken == "FwoGZXIvYXdzEBYaDEXAMPLETOKEN"); + } + + SUBCASE("credentials_are_cached") + { + SigV4Credentials First = Provider->GetCredentials(); + SigV4Credentials Second = Provider->GetCredentials(); + CHECK(First.AccessKeyId == Second.AccessKeyId); + CHECK(First.SecretAccessKey == Second.SecretAccessKey); + } + + SUBCASE("invalidate_forces_refresh") + { + SigV4Credentials First = Provider->GetCredentials(); + CHECK(!First.AccessKeyId.empty()); + + // Change the credentials on the mock + Imds.Mock.Aws.IamAccessKeyId = "ASIANEWKEYEXAMPLE12"; + + Provider->InvalidateCache(); + SigV4Credentials Second = Provider->GetCredentials(); + CHECK(Second.AccessKeyId == "ASIANEWKEYEXAMPLE12"); + } + + SUBCASE("custom_role_name") + { + Imds.Mock.Aws.IamRoleName = "my-custom-role"; + + Ref<ImdsCredentialProvider> Provider2(new ImdsCredentialProvider(Opts)); + SigV4Credentials Creds = Provider2->GetCredentials(); + CHECK(!Creds.AccessKeyId.empty()); + } +} + +TEST_CASE("imdscredentials.unreachable_endpoint") +{ + // Point at a non-existent server — should return empty credentials, not crash + ImdsCredentialProviderOptions Opts; + Opts.Endpoint = "http://127.0.0.1:1"; // unlikely to have anything listening + Opts.ConnectTimeout = std::chrono::milliseconds(100); + Opts.RequestTimeout = std::chrono::milliseconds(200); + + Ref<ImdsCredentialProvider> Provider(new ImdsCredentialProvider(Opts)); + SigV4Credentials Creds = Provider->GetCredentials(); + CHECK(Creds.AccessKeyId.empty()); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/cloud/minioprocess.cpp b/src/zenutil/cloud/minioprocess.cpp new file mode 100644 index 000000000..565705731 --- /dev/null +++ b/src/zenutil/cloud/minioprocess.cpp @@ -0,0 +1,174 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/minioprocess.h> + +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/process.h> +#include <zencore/timer.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +struct MinioProcess::Impl +{ + Impl(const MinioProcessOptions& Options) : m_Options(Options), m_HttpClient(fmt::format("http://localhost:{}/", Options.Port)) {} + ~Impl() = default; + + void SpawnMinioServer() + { + if (m_ProcessHandle.IsValid()) + { + return; + } + + // Create a clean temp data directory, removing any stale data from a previous run + std::error_code Ec; + m_DataDir = std::filesystem::temp_directory_path(Ec) / fmt::format("zen-minio-{}", GetCurrentProcessId()); + if (Ec) + { + ZEN_WARN("MinIO: Failed to get temp directory: {}", Ec.message()); + return; + } + std::filesystem::remove_all(m_DataDir, Ec); + Ec.clear(); + std::filesystem::create_directories(m_DataDir, Ec); + if (Ec) + { + ZEN_WARN("MinIO: Failed to create data directory '{}': {}", m_DataDir.string(), Ec.message()); + return; + } + + CreateProcOptions Options; + Options.Flags |= CreateProcOptions::Flag_Windows_NewProcessGroup; + Options.Environment.emplace_back("MINIO_ROOT_USER", m_Options.RootUser); + Options.Environment.emplace_back("MINIO_ROOT_PASSWORD", m_Options.RootPassword); + + const std::filesystem::path MinioExe = GetRunningExecutablePath().parent_path() / ("minio" ZEN_EXE_SUFFIX_LITERAL); + + std::string CommandLine = + fmt::format("minio" ZEN_EXE_SUFFIX_LITERAL " server {} --address :{} --quiet", m_DataDir.string(), m_Options.Port); + + CreateProcResult Result = CreateProc(MinioExe, CommandLine, Options); + + if (Result) + { + m_ProcessHandle.Initialize(Result); + + Stopwatch Timer; + + // Poll to check when the server is ready + do + { + Sleep(100); + HttpClient::Response Resp = m_HttpClient.Get("minio/health/live"); + if (Resp) + { + ZEN_INFO("MinIO server started successfully (waited {})", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + return; + } + } while (Timer.GetElapsedTimeMs() < 10000); + } + + // Report failure + ZEN_WARN("MinIO server failed to start within timeout period"); + } + + void StopMinioServer() + { + if (!m_ProcessHandle.IsValid()) + { + return; + } + + m_ProcessHandle.Kill(); + + // Clean up temp data directory + std::error_code Ec; + std::filesystem::remove_all(m_DataDir, Ec); + if (Ec) + { + ZEN_WARN("MinIO: Failed to clean up data directory '{}': {}", m_DataDir.string(), Ec.message()); + } + } + + void CreateBucket(std::string_view BucketName) + { + if (m_DataDir.empty()) + { + ZEN_WARN("MinIO: Cannot create bucket before data directory is initialized — call SpawnMinioServer() first"); + return; + } + + std::filesystem::path BucketDir = m_DataDir / std::string(BucketName); + std::error_code Ec; + std::filesystem::create_directories(BucketDir, Ec); + if (Ec) + { + ZEN_WARN("MinIO: Failed to create bucket directory '{}': {}", BucketDir.string(), Ec.message()); + } + } + + MinioProcessOptions m_Options; + ProcessHandle m_ProcessHandle; + HttpClient m_HttpClient; + std::filesystem::path m_DataDir; +}; + +MinioProcess::MinioProcess(const MinioProcessOptions& Options) : m_Impl(std::make_unique<Impl>(Options)) +{ +} + +MinioProcess::~MinioProcess() +{ + m_Impl->StopMinioServer(); +} + +void +MinioProcess::SpawnMinioServer() +{ + m_Impl->SpawnMinioServer(); +} + +void +MinioProcess::StopMinioServer() +{ + m_Impl->StopMinioServer(); +} + +void +MinioProcess::CreateBucket(std::string_view BucketName) +{ + m_Impl->CreateBucket(BucketName); +} + +uint16_t +MinioProcess::Port() const +{ + return m_Impl->m_Options.Port; +} + +std::string_view +MinioProcess::RootUser() const +{ + return m_Impl->m_Options.RootUser; +} + +std::string_view +MinioProcess::RootPassword() const +{ + return m_Impl->m_Options.RootPassword; +} + +std::string +MinioProcess::Endpoint() const +{ + return fmt::format("http://localhost:{}", m_Impl->m_Options.Port); +} + +} // namespace zen diff --git a/src/zenutil/cloud/mockimds.cpp b/src/zenutil/cloud/mockimds.cpp new file mode 100644 index 000000000..6919fab4d --- /dev/null +++ b/src/zenutil/cloud/mockimds.cpp @@ -0,0 +1,237 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/mockimds.h> + +#include <zencore/fmtutils.h> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +const char* +MockImdsService::BaseUri() const +{ + return "/"; +} + +void +MockImdsService::HandleRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // AWS endpoints live under /latest/ + if (Uri.starts_with("latest/")) + { + if (ActiveProvider == CloudProvider::AWS) + { + HandleAwsRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // Azure endpoints live under /metadata/ + if (Uri.starts_with("metadata/")) + { + if (ActiveProvider == CloudProvider::Azure) + { + HandleAzureRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + // GCP endpoints live under /computeMetadata/ + if (Uri.starts_with("computeMetadata/")) + { + if (ActiveProvider == CloudProvider::GCP) + { + HandleGcpRequest(Request); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// AWS +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAwsRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // IMDSv2 token acquisition (PUT only) + if (Uri == "latest/api/token" && Request.RequestVerb() == HttpVerb::kPut) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.Token); + return; + } + + // Instance identity + if (Uri == "latest/meta-data/instance-id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.InstanceId); + return; + } + + if (Uri == "latest/meta-data/placement/availability-zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AvailabilityZone); + return; + } + + if (Uri == "latest/meta-data/instance-life-cycle") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.LifeCycle); + return; + } + + // Autoscaling lifecycle state — 404 when not in an ASG + if (Uri == "latest/meta-data/autoscaling/target-lifecycle-state") + { + if (Aws.AutoscalingState.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.AutoscalingState); + return; + } + + // Spot interruption notice — 404 when no interruption pending + if (Uri == "latest/meta-data/spot/instance-action") + { + if (Aws.SpotAction.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.SpotAction); + return; + } + + // IAM role discovery — returns the role name + if (Uri == "latest/meta-data/iam/security-credentials/") + { + if (Aws.IamRoleName.empty()) + { + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Aws.IamRoleName); + return; + } + + // IAM credentials for a specific role + constexpr std::string_view kIamCredPrefix = "latest/meta-data/iam/security-credentials/"; + if (Uri.starts_with(kIamCredPrefix) && Uri.size() > kIamCredPrefix.size()) + { + std::string_view RequestedRole = Uri.substr(kIamCredPrefix.size()); + if (RequestedRole == Aws.IamRoleName) + { + std::string Json = + fmt::format(R"({{"Code":"Success","AccessKeyId":"{}","SecretAccessKey":"{}","Token":"{}","Expiration":"{}"}})", + Aws.IamAccessKeyId, + Aws.IamSecretAccessKey, + Aws.IamSessionToken, + Aws.IamExpiration); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + Request.WriteResponse(HttpResponseCode::NotFound); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// Azure +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleAzureRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + // Instance metadata (single JSON document) + if (Uri == "metadata/instance") + { + std::string Json = fmt::format(R"({{"compute":{{"vmId":"{}","location":"{}","priority":"{}","vmScaleSetName":"{}"}}}})", + Azure.VmId, + Azure.Location, + Azure.Priority, + Azure.VmScaleSetName); + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + // Scheduled events for termination monitoring + if (Uri == "metadata/scheduledevents") + { + std::string Json; + if (Azure.ScheduledEventType.empty()) + { + Json = R"({"Events":[]})"; + } + else + { + Json = fmt::format(R"({{"Events":[{{"EventType":"{}","EventStatus":"{}"}}]}})", + Azure.ScheduledEventType, + Azure.ScheduledEventStatus); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Json); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +// --------------------------------------------------------------------------- +// GCP +// --------------------------------------------------------------------------- + +void +MockImdsService::HandleGcpRequest(HttpServerRequest& Request) +{ + std::string_view Uri = Request.RelativeUri(); + + if (Uri == "computeMetadata/v1/instance/id") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.InstanceId); + return; + } + + if (Uri == "computeMetadata/v1/instance/zone") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Zone); + return; + } + + if (Uri == "computeMetadata/v1/instance/scheduling/preemptible") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.Preemptible); + return; + } + + if (Uri == "computeMetadata/v1/instance/maintenance-event") + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Gcp.MaintenanceEvent); + return; + } + + Request.WriteResponse(HttpResponseCode::NotFound); +} + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zenutil/cloud/s3client.cpp b/src/zenutil/cloud/s3client.cpp new file mode 100644 index 000000000..88d844b61 --- /dev/null +++ b/src/zenutil/cloud/s3client.cpp @@ -0,0 +1,986 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/s3client.h> + +#include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/minioprocess.h> + +#include <zencore/except_fmt.h> +#include <zencore/iobuffer.h> +#include <zencore/memoryview.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <algorithm> + +namespace zen { + +namespace { + + /// The SHA-256 hash of an empty payload, precomputed + constexpr std::string_view EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + + /// Simple XML value extractor. Finds the text content between <Tag> and </Tag>. + /// This is intentionally minimal - we only need to parse ListBucketResult responses. + /// Returns a string_view into the original XML when no entity decoding is needed. + std::string_view ExtractXmlValue(std::string_view Xml, std::string_view Tag) + { + std::string OpenTag = fmt::format("<{}>", Tag); + std::string CloseTag = fmt::format("</{}>", Tag); + + size_t Start = Xml.find(OpenTag); + if (Start == std::string_view::npos) + { + return {}; + } + Start += OpenTag.size(); + + size_t End = Xml.find(CloseTag, Start); + if (End == std::string_view::npos) + { + return {}; + } + + return Xml.substr(Start, End - Start); + } + + /// Decode the five standard XML entities (& < > " ') into a StringBuilderBase. + void DecodeXmlEntities(std::string_view Input, StringBuilderBase& Out) + { + if (Input.find('&') == std::string_view::npos) + { + Out.Append(Input); + return; + } + + for (size_t i = 0; i < Input.size(); ++i) + { + if (Input[i] == '&') + { + std::string_view Remaining = Input.substr(i); + if (Remaining.starts_with("&")) + { + Out.Append('&'); + i += 4; + } + else if (Remaining.starts_with("<")) + { + Out.Append('<'); + i += 3; + } + else if (Remaining.starts_with(">")) + { + Out.Append('>'); + i += 3; + } + else if (Remaining.starts_with(""")) + { + Out.Append('"'); + i += 5; + } + else if (Remaining.starts_with("'")) + { + Out.Append('\''); + i += 5; + } + else + { + Out.Append(Input[i]); + } + } + else + { + Out.Append(Input[i]); + } + } + } + + /// Convenience: decode XML entities and return as std::string. + std::string DecodeXmlEntities(std::string_view Input) + { + if (Input.find('&') == std::string_view::npos) + { + return std::string(Input); + } + + ExtendableStringBuilder<256> Sb; + DecodeXmlEntities(Input, Sb); + return Sb.ToString(); + } + + /// Join a path and canonical query string into a full request path for the HTTP client. + std::string BuildRequestPath(std::string_view Path, std::string_view CanonicalQS) + { + if (CanonicalQS.empty()) + { + return std::string(Path); + } + return fmt::format("{}?{}", Path, CanonicalQS); + } + + /// Case-insensitive header lookup in an HttpClient response header map. + const std::string* FindResponseHeader(const HttpClient::KeyValueMap& Headers, std::string_view Name) + { + for (const auto& [K, V] : *Headers) + { + if (StrCaseCompare(K, Name) == 0) + { + return &V; + } + } + return nullptr; + } + +} // namespace + +S3Client::S3Client(const S3ClientOptions& Options) +: m_Log(logging::Get("s3")) +, m_BucketName(Options.BucketName) +, m_Region(Options.Region) +, m_Endpoint(Options.Endpoint) +, m_PathStyle(Options.PathStyle) +, m_Credentials(Options.Credentials) +, m_CredentialProvider(Options.CredentialProvider) +, m_HttpClient(BuildEndpoint(), + HttpClientSettings{ + .LogCategory = "s3", + .ConnectTimeout = Options.ConnectTimeout, + .Timeout = Options.Timeout, + .RetryCount = Options.RetryCount, + }) +{ + m_Host = BuildHostHeader(); + ZEN_INFO("S3 client configured for bucket '{}' in region '{}' (endpoint: {}, {})", + m_BucketName, + m_Region, + m_HttpClient.GetBaseUri(), + m_PathStyle ? "path-style" : "virtual-hosted"); +} + +S3Client::~S3Client() = default; + +SigV4Credentials +S3Client::GetCurrentCredentials() +{ + if (m_CredentialProvider) + { + SigV4Credentials Creds = m_CredentialProvider->GetCredentials(); + if (!Creds.AccessKeyId.empty()) + { + // Invalidate the signing key cache when the access key changes + if (Creds.AccessKeyId != m_Credentials.AccessKeyId) + { + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); + m_CachedDateStamp.clear(); + } + m_Credentials = Creds; + } + return m_Credentials; + } + return m_Credentials; +} + +std::string +S3Client::BuildEndpoint() const +{ + if (!m_Endpoint.empty()) + { + return m_Endpoint; + } + + if (m_PathStyle) + { + // Path-style: https://s3.region.amazonaws.com + return fmt::format("https://s3.{}.amazonaws.com", m_Region); + } + + // Virtual-hosted style: https://bucket.s3.region.amazonaws.com + return fmt::format("https://{}.s3.{}.amazonaws.com", m_BucketName, m_Region); +} + +std::string +S3Client::BuildHostHeader() const +{ + if (!m_Endpoint.empty()) + { + // Extract host from custom endpoint URL (strip scheme) + std::string_view Ep = m_Endpoint; + if (size_t Pos = Ep.find("://"); Pos != std::string_view::npos) + { + Ep = Ep.substr(Pos + 3); + } + // Strip trailing slash + if (!Ep.empty() && Ep.back() == '/') + { + Ep = Ep.substr(0, Ep.size() - 1); + } + return std::string(Ep); + } + + if (m_PathStyle) + { + return fmt::format("s3.{}.amazonaws.com", m_Region); + } + + return fmt::format("{}.s3.{}.amazonaws.com", m_BucketName, m_Region); +} + +std::string +S3Client::KeyToPath(std::string_view Key) const +{ + if (m_PathStyle) + { + return fmt::format("/{}/{}", m_BucketName, Key); + } + return fmt::format("/{}", Key); +} + +std::string +S3Client::BucketRootPath() const +{ + if (m_PathStyle) + { + return fmt::format("/{}/", m_BucketName); + } + return "/"; +} + +Sha256Digest +S3Client::GetSigningKey(std::string_view DateStamp) +{ + // Fast path: shared lock for cache hit (common case — key only changes once per day) + { + RwLock::SharedLockScope SharedLock(m_SigningKeyLock); + if (m_CachedDateStamp == DateStamp) + { + return m_CachedSigningKey; + } + } + + // Slow path: exclusive lock to recompute the signing key + RwLock::ExclusiveLockScope ExclusiveLock(m_SigningKeyLock); + + // Double-check after acquiring exclusive lock (another thread may have updated it) + if (m_CachedDateStamp == DateStamp) + { + return m_CachedSigningKey; + } + + std::string SecretPrefix = fmt::format("AWS4{}", m_Credentials.SecretAccessKey); + + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, m_Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "s3"); + m_CachedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + m_CachedDateStamp = std::string(DateStamp); + + return m_CachedSigningKey; +} + +HttpClient::KeyValueMap +S3Client::SignRequest(std::string_view Method, std::string_view Path, std::string_view CanonicalQueryString, std::string_view PayloadHash) +{ + SigV4Credentials Credentials = GetCurrentCredentials(); + + std::string AmzDate = GetAmzTimestamp(); + + // Build sorted headers to sign (must be sorted by lowercase name) + std::vector<std::pair<std::string, std::string>> HeadersToSign; + HeadersToSign.emplace_back("host", m_Host); + HeadersToSign.emplace_back("x-amz-content-sha256", std::string(PayloadHash)); + HeadersToSign.emplace_back("x-amz-date", AmzDate); + if (!Credentials.SessionToken.empty()) + { + HeadersToSign.emplace_back("x-amz-security-token", Credentials.SessionToken); + } + std::sort(HeadersToSign.begin(), HeadersToSign.end()); + + std::string_view DateStamp(AmzDate.data(), 8); + Sha256Digest SigningKey = GetSigningKey(DateStamp); + + SigV4SignedHeaders Signed = + SignRequestV4(Credentials, Method, Path, CanonicalQueryString, m_Region, "s3", AmzDate, HeadersToSign, PayloadHash, &SigningKey); + + HttpClient::KeyValueMap Result; + Result->emplace("Authorization", std::move(Signed.Authorization)); + Result->emplace("x-amz-date", std::move(Signed.AmzDate)); + Result->emplace("x-amz-content-sha256", std::move(Signed.PayloadHash)); + if (!Credentials.SessionToken.empty()) + { + Result->emplace("x-amz-security-token", Credentials.SessionToken); + } + + return Result; +} + +S3Result +S3Client::PutObject(std::string_view Key, IoBuffer Content) +{ + std::string Path = KeyToPath(Key); + + // Hash the payload + std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); + + HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, "", PayloadHash); + + HttpClient::Response Response = m_HttpClient.Put(Path, Content, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 PUT failed"); + ZEN_WARN("S3 PUT '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 PUT '{}' succeeded ({} bytes)", Key, Content.GetSize()); + return {}; +} + +S3GetObjectResult +S3Client::GetObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("GET", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Get(Path, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 GET failed"); + ZEN_WARN("S3 GET '{}' failed: {}", Key, Err); + return S3GetObjectResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 GET '{}' succeeded ({} bytes)", Key, Response.ResponsePayload.GetSize()); + return S3GetObjectResult{{}, std::move(Response.ResponsePayload)}; +} + +S3Result +S3Client::DeleteObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Delete(Path, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 DELETE failed"); + ZEN_WARN("S3 DELETE '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 DELETE '{}' succeeded", Key); + return {}; +} + +S3HeadObjectResult +S3Client::HeadObject(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + + HttpClient::KeyValueMap Headers = SignRequest("HEAD", Path, "", EmptyPayloadHash); + + HttpClient::Response Response = m_HttpClient.Head(Path, Headers); + if (!Response.IsSuccess()) + { + if (Response.StatusCode == HttpResponseCode::NotFound) + { + return S3HeadObjectResult{{}, {}, HeadObjectResult::NotFound}; + } + + std::string Err = Response.ErrorMessage("S3 HEAD failed"); + ZEN_WARN("S3 HEAD '{}' failed: {}", Key, Err); + return S3HeadObjectResult{S3Result{std::move(Err)}, {}, HeadObjectResult::Error}; + } + + S3ObjectInfo Info; + Info.Key = std::string(Key); + + if (const std::string* V = FindResponseHeader(Response.Header, "content-length")) + { + Info.Size = ParseInt<uint64_t>(*V).value_or(0); + } + + if (const std::string* V = FindResponseHeader(Response.Header, "etag")) + { + Info.ETag = *V; + } + + if (const std::string* V = FindResponseHeader(Response.Header, "last-modified")) + { + Info.LastModified = *V; + } + + ZEN_DEBUG("S3 HEAD '{}' succeeded (size={})", Key, Info.Size); + return S3HeadObjectResult{{}, std::move(Info), HeadObjectResult::Found}; +} + +S3ListObjectsResult +S3Client::ListObjects(std::string_view Prefix, uint32_t MaxKeys) +{ + S3ListObjectsResult Result; + + std::string ContinuationToken; + + for (;;) + { + // Build query parameters for ListObjectsV2 + std::vector<std::pair<std::string, std::string>> QueryParams; + QueryParams.emplace_back("list-type", "2"); + if (!Prefix.empty()) + { + QueryParams.emplace_back("prefix", std::string(Prefix)); + } + if (MaxKeys > 0) + { + QueryParams.emplace_back("max-keys", fmt::format("{}", MaxKeys)); + } + if (!ContinuationToken.empty()) + { + QueryParams.emplace_back("continuation-token", ContinuationToken); + } + + std::string CanonicalQS = BuildCanonicalQueryString(std::move(QueryParams)); + std::string RootPath = BucketRootPath(); + HttpClient::KeyValueMap Headers = SignRequest("GET", RootPath, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(RootPath, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Get(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 ListObjectsV2 failed"); + ZEN_WARN("S3 ListObjectsV2 prefix='{}' failed: {}", Prefix, Err); + Result.Error = std::move(Err); + return Result; + } + + // Parse the XML response to extract object keys + std::string_view ResponseBody = Response.AsText(); + + // Find all <Contents> elements + std::string_view Remaining = ResponseBody; + while (true) + { + size_t ContentsStart = Remaining.find("<Contents>"); + if (ContentsStart == std::string_view::npos) + { + break; + } + + size_t ContentsEnd = Remaining.find("</Contents>", ContentsStart); + if (ContentsEnd == std::string_view::npos) + { + break; + } + + std::string_view ContentsXml = Remaining.substr(ContentsStart, ContentsEnd - ContentsStart + 11); + + S3ObjectInfo Info; + Info.Key = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "Key")); + Info.ETag = DecodeXmlEntities(ExtractXmlValue(ContentsXml, "ETag")); + Info.LastModified = std::string(ExtractXmlValue(ContentsXml, "LastModified")); + + std::string_view SizeStr = ExtractXmlValue(ContentsXml, "Size"); + if (!SizeStr.empty()) + { + Info.Size = ParseInt<uint64_t>(SizeStr).value_or(0); + } + + if (!Info.Key.empty()) + { + Result.Objects.push_back(std::move(Info)); + } + + Remaining = Remaining.substr(ContentsEnd + 11); + } + + // Check if there are more pages + std::string_view IsTruncated = ExtractXmlValue(ResponseBody, "IsTruncated"); + if (IsTruncated != "true") + { + break; + } + + std::string_view NextToken = ExtractXmlValue(ResponseBody, "NextContinuationToken"); + if (NextToken.empty()) + { + break; + } + + ContinuationToken = std::string(NextToken); + ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' fetching next page ({} objects so far)", Prefix, Result.Objects.size()); + } + + ZEN_DEBUG("S3 ListObjectsV2 prefix='{}' returned {} objects", Prefix, Result.Objects.size()); + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// Multipart Upload + +S3CreateMultipartUploadResult +S3Client::CreateMultipartUpload(std::string_view Key) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploads", ""}}); + + HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Post(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 CreateMultipartUpload failed"); + ZEN_WARN("S3 CreateMultipartUpload '{}' failed: {}", Key, Err); + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + + // Parse UploadId from XML response: + // <InitiateMultipartUploadResult> + // <Bucket>...</Bucket> + // <Key>...</Key> + // <UploadId>...</UploadId> + // </InitiateMultipartUploadResult> + std::string_view ResponseBody = Response.AsText(); + std::string_view UploadId = ExtractXmlValue(ResponseBody, "UploadId"); + if (UploadId.empty()) + { + std::string Err = "failed to parse UploadId from CreateMultipartUpload response"; + ZEN_WARN("S3 CreateMultipartUpload '{}': {}", Key, Err); + return S3CreateMultipartUploadResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 CreateMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + return S3CreateMultipartUploadResult{{}, std::string(UploadId)}; +} + +S3UploadPartResult +S3Client::UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({ + {"partNumber", fmt::format("{}", PartNumber)}, + {"uploadId", std::string(UploadId)}, + }); + + std::string PayloadHash = Sha256ToHex(ComputeSha256(Content.GetData(), Content.GetSize())); + + HttpClient::KeyValueMap Headers = SignRequest("PUT", Path, CanonicalQS, PayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Put(FullPath, Content, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage(fmt::format("S3 UploadPart {} failed", PartNumber)); + ZEN_WARN("S3 UploadPart '{}' part {} failed: {}", Key, PartNumber, Err); + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + + // Extract ETag from response headers + const std::string* ETag = FindResponseHeader(Response.Header, "etag"); + if (!ETag) + { + std::string Err = "S3 UploadPart response missing ETag header"; + ZEN_WARN("S3 UploadPart '{}' part {}: {}", Key, PartNumber, Err); + return S3UploadPartResult{S3Result{std::move(Err)}, {}}; + } + + ZEN_DEBUG("S3 UploadPart '{}' part {} succeeded ({} bytes, etag={})", Key, PartNumber, Content.GetSize(), *ETag); + return S3UploadPartResult{{}, *ETag}; +} + +S3Result +S3Client::CompleteMultipartUpload(std::string_view Key, + std::string_view UploadId, + const std::vector<std::pair<uint32_t, std::string>>& PartETags) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); + + // Build the CompleteMultipartUpload XML payload + ExtendableStringBuilder<1024> XmlBody; + XmlBody.Append("<CompleteMultipartUpload>"); + for (const auto& [PartNumber, ETag] : PartETags) + { + XmlBody.Append(fmt::format("<Part><PartNumber>{}</PartNumber><ETag>{}</ETag></Part>", PartNumber, ETag)); + } + XmlBody.Append("</CompleteMultipartUpload>"); + + std::string_view XmlView = XmlBody.ToView(); + std::string PayloadHash = Sha256ToHex(ComputeSha256(XmlView)); + + HttpClient::KeyValueMap Headers = SignRequest("POST", Path, CanonicalQS, PayloadHash); + Headers->emplace("Content-Type", "application/xml"); + + IoBuffer Payload(IoBuffer::Clone, XmlView.data(), XmlView.size()); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Post(FullPath, Payload, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 CompleteMultipartUpload failed"); + ZEN_WARN("S3 CompleteMultipartUpload '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + // Check for error in response body - S3 can return 200 with an error in the XML body + std::string_view ResponseBody = Response.AsText(); + if (ResponseBody.find("<Error>") != std::string_view::npos) + { + std::string_view ErrorCode = ExtractXmlValue(ResponseBody, "Code"); + std::string_view ErrorMessage = ExtractXmlValue(ResponseBody, "Message"); + std::string Err = fmt::format("S3 CompleteMultipartUpload '{}' returned error: {} - {}", Key, ErrorCode, ErrorMessage); + ZEN_WARN("{}", Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 CompleteMultipartUpload '{}' succeeded ({} parts)", Key, PartETags.size()); + return {}; +} + +S3Result +S3Client::AbortMultipartUpload(std::string_view Key, std::string_view UploadId) +{ + std::string Path = KeyToPath(Key); + std::string CanonicalQS = BuildCanonicalQueryString({{"uploadId", std::string(UploadId)}}); + + HttpClient::KeyValueMap Headers = SignRequest("DELETE", Path, CanonicalQS, EmptyPayloadHash); + + std::string FullPath = BuildRequestPath(Path, CanonicalQS); + HttpClient::Response Response = m_HttpClient.Delete(FullPath, Headers); + if (!Response.IsSuccess()) + { + std::string Err = Response.ErrorMessage("S3 AbortMultipartUpload failed"); + ZEN_WARN("S3 AbortMultipartUpload '{}' failed: {}", Key, Err); + return S3Result{std::move(Err)}; + } + + ZEN_DEBUG("S3 AbortMultipartUpload '{}' succeeded (uploadId={})", Key, UploadId); + return {}; +} + +std::string +S3Client::GeneratePresignedGetUrl(std::string_view Key, std::chrono::seconds ExpiresIn) +{ + return GeneratePresignedUrlForMethod(Key, "GET", ExpiresIn); +} + +std::string +S3Client::GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn) +{ + return GeneratePresignedUrlForMethod(Key, "PUT", ExpiresIn); +} + +std::string +S3Client::GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn) +{ + std::string Path = KeyToPath(Key); + std::string Scheme = "https"; + + if (!m_Endpoint.empty() && m_Endpoint.starts_with("http://")) + { + Scheme = "http"; + } + + SigV4Credentials Credentials = GetCurrentCredentials(); + return GeneratePresignedUrl(Credentials, Method, Scheme, m_Host, Path, m_Region, "s3", ExpiresIn); +} + +S3Result +S3Client::PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize) +{ + const uint64_t ContentSize = Content.GetSize(); + + // If the content fits in a single part, just use PutObject + if (ContentSize <= PartSize) + { + return PutObject(Key, Content); + } + + ZEN_INFO("S3 multipart upload '{}': {} bytes in ~{} parts", Key, ContentSize, (ContentSize + PartSize - 1) / PartSize); + + // Initiate multipart upload + + S3CreateMultipartUploadResult InitResult = CreateMultipartUpload(Key); + if (!InitResult) + { + return S3Result{std::move(InitResult.Error)}; + } + + const std::string& UploadId = InitResult.UploadId; + + // Upload parts sequentially + // TODO: upload parts in parallel for improved throughput on large uploads + + std::vector<std::pair<uint32_t, std::string>> PartETags; + uint64_t Offset = 0; + uint32_t PartNumber = 1; + + while (Offset < ContentSize) + { + uint64_t ThisPartSize = std::min(PartSize, ContentSize - Offset); + + // Create a sub-buffer referencing the part data within the original content + IoBuffer PartContent(Content, Offset, ThisPartSize); + + S3UploadPartResult PartResult = UploadPart(Key, UploadId, PartNumber, PartContent); + if (!PartResult) + { + // Attempt to abort the multipart upload on failure + AbortMultipartUpload(Key, UploadId); + return S3Result{std::move(PartResult.Error)}; + } + + PartETags.emplace_back(PartNumber, std::move(PartResult.ETag)); + Offset += ThisPartSize; + PartNumber++; + } + + // Complete multipart upload + S3Result CompleteResult = CompleteMultipartUpload(Key, UploadId, PartETags); + if (!CompleteResult) + { + AbortMultipartUpload(Key, UploadId); + return CompleteResult; + } + + ZEN_INFO("S3 multipart upload '{}' completed ({} parts, {} bytes)", Key, PartETags.size(), ContentSize); + return {}; +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +s3client_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.s3client"); + +TEST_CASE("s3client.xml_extract") +{ + std::string_view Xml = + "<Contents><Key>test/file.txt</Key><Size>1234</Size>" + "<ETag>\"abc123\"</ETag><LastModified>2024-01-01T00:00:00Z</LastModified></Contents>"; + + CHECK(ExtractXmlValue(Xml, "Key") == "test/file.txt"); + CHECK(ExtractXmlValue(Xml, "Size") == "1234"); + CHECK(ExtractXmlValue(Xml, "ETag") == "\"abc123\""); + CHECK(ExtractXmlValue(Xml, "LastModified") == "2024-01-01T00:00:00Z"); + CHECK(ExtractXmlValue(Xml, "NonExistent") == ""); +} + +TEST_CASE("s3client.xml_entity_decode") +{ + CHECK(DecodeXmlEntities("no entities") == "no entities"); + CHECK(DecodeXmlEntities("a&b") == "a&b"); + CHECK(DecodeXmlEntities("<tag>") == "<tag>"); + CHECK(DecodeXmlEntities(""hello'") == "\"hello'"); + CHECK(DecodeXmlEntities("&&") == "&&"); + CHECK(DecodeXmlEntities("") == ""); + + // Key with entities as S3 would return it + std::string_view Xml = "<Key>path/file&name<1>.txt</Key>"; + CHECK(DecodeXmlEntities(ExtractXmlValue(Xml, "Key")) == "path/file&name<1>.txt"); +} + +TEST_CASE("s3client.path_style_addressing") +{ + // Verify path-style builds /{bucket}/{key} paths + S3ClientOptions Opts; + Opts.BucketName = "test-bucket"; + Opts.Region = "us-east-1"; + Opts.Endpoint = "http://localhost:9000"; + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = "minioadmin"; + Opts.Credentials.SecretAccessKey = "minioadmin"; + + S3Client Client(Opts); + CHECK(Client.BucketName() == "test-bucket"); + CHECK(Client.Region() == "us-east-1"); +} + +TEST_CASE("s3client.virtual_hosted_addressing") +{ + // Verify virtual-hosted style derives endpoint from region + bucket + S3ClientOptions Opts; + Opts.BucketName = "my-bucket"; + Opts.Region = "eu-west-1"; + Opts.PathStyle = false; + Opts.Credentials.AccessKeyId = "key"; + Opts.Credentials.SecretAccessKey = "secret"; + + S3Client Client(Opts); + CHECK(Client.BucketName() == "my-bucket"); + CHECK(Client.Region() == "eu-west-1"); +} + +TEST_CASE("s3client.minio_integration") +{ + using namespace std::literals; + + // Spawn a local MinIO server + MinioProcessOptions MinioOpts; + MinioOpts.Port = 19000; + MinioOpts.RootUser = "testuser"; + MinioOpts.RootPassword = "testpassword"; + + MinioProcess Minio(MinioOpts); + Minio.SpawnMinioServer(); + + // Pre-create the test bucket (creates a subdirectory in MinIO's data dir) + Minio.CreateBucket("integration-test"); + + // Configure S3Client for the test bucket + S3ClientOptions Opts; + Opts.BucketName = "integration-test"; + Opts.Region = "us-east-1"; + Opts.Endpoint = Minio.Endpoint(); + Opts.PathStyle = true; + Opts.Credentials.AccessKeyId = std::string(Minio.RootUser()); + Opts.Credentials.SecretAccessKey = std::string(Minio.RootPassword()); + + S3Client Client(Opts); + + SUBCASE("put_get_delete") + { + // PUT + std::string_view TestData = "hello, minio integration test!"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("test/hello.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + // GET + S3GetObjectResult GetRes = Client.GetObject("test/hello.txt"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.AsText() == TestData); + + // HEAD + S3HeadObjectResult HeadRes = Client.HeadObject("test/hello.txt"); + REQUIRE(HeadRes.IsSuccess()); + CHECK(HeadRes.Status == HeadObjectResult::Found); + CHECK(HeadRes.Info.Size == TestData.size()); + + // DELETE + S3Result DelRes = Client.DeleteObject("test/hello.txt"); + REQUIRE(DelRes.IsSuccess()); + + // HEAD after delete + S3HeadObjectResult HeadRes2 = Client.HeadObject("test/hello.txt"); + REQUIRE(HeadRes2.IsSuccess()); + CHECK(HeadRes2.Status == HeadObjectResult::NotFound); + } + + SUBCASE("head_not_found") + { + S3HeadObjectResult Res = Client.HeadObject("nonexistent/key.dat"); + CHECK(Res.IsSuccess()); + CHECK(Res.Status == HeadObjectResult::NotFound); + } + + SUBCASE("list_objects") + { + // Upload several objects with a common prefix + for (int i = 0; i < 3; ++i) + { + std::string Key = fmt::format("list-test/item-{}.txt", i); + std::string Payload = fmt::format("payload-{}", i); + IoBuffer Buf = IoBufferBuilder::MakeFromMemory(MakeMemoryView(Payload)); + S3Result Res = Client.PutObject(Key, std::move(Buf)); + REQUIRE(Res.IsSuccess()); + } + + // List with prefix + S3ListObjectsResult ListRes = Client.ListObjects("list-test/"); + REQUIRE(ListRes.IsSuccess()); + CHECK(ListRes.Objects.size() == 3); + + // Verify keys are present + std::vector<std::string> Keys; + for (const S3ObjectInfo& Obj : ListRes.Objects) + { + Keys.push_back(Obj.Key); + } + std::sort(Keys.begin(), Keys.end()); + CHECK(Keys[0] == "list-test/item-0.txt"); + CHECK(Keys[1] == "list-test/item-1.txt"); + CHECK(Keys[2] == "list-test/item-2.txt"); + + // Cleanup + for (int i = 0; i < 3; ++i) + { + Client.DeleteObject(fmt::format("list-test/item-{}.txt", i)); + } + } + + SUBCASE("multipart_upload") + { + // Create a payload large enough to exercise multipart (use minimum part size) + constexpr uint64_t PartSize = 5 * 1024 * 1024; // 5 MB minimum + constexpr uint64_t PayloadSize = PartSize + 1024; // slightly over one part + + std::string LargePayload(PayloadSize, 'X'); + // Add some variation + for (uint64_t i = 0; i < PayloadSize; i += 1024) + { + LargePayload[i] = char('A' + (i / 1024) % 26); + } + + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(LargePayload)); + S3Result Res = Client.PutObjectMultipart("multipart/large.bin", std::move(Content), PartSize); + REQUIRE(Res.IsSuccess()); + + // Verify via GET + S3GetObjectResult GetRes = Client.GetObject("multipart/large.bin"); + REQUIRE(GetRes.IsSuccess()); + CHECK(GetRes.Content.GetSize() == PayloadSize); + CHECK(GetRes.AsText() == std::string_view(LargePayload)); + + // Cleanup + Client.DeleteObject("multipart/large.bin"); + } + + SUBCASE("presigned_urls") + { + // Upload an object + std::string_view TestData = "presigned-url-test-data"sv; + IoBuffer Content = IoBufferBuilder::MakeFromMemory(MakeMemoryView(TestData)); + S3Result PutRes = Client.PutObject("presigned/test.txt", std::move(Content)); + REQUIRE(PutRes.IsSuccess()); + + // Generate a pre-signed GET URL + std::string Url = Client.GeneratePresignedGetUrl("presigned/test.txt", std::chrono::seconds(60)); + CHECK(!Url.empty()); + CHECK(Url.find("X-Amz-Signature") != std::string::npos); + + // Fetch via the pre-signed URL (no auth headers needed) + HttpClient Hc(Minio.Endpoint()); + // Extract the path+query from the full URL + std::string_view UrlView = Url; + size_t PathStart = UrlView.find('/', UrlView.find("://") + 3); + std::string PathAndQuery(UrlView.substr(PathStart)); + HttpClient::Response Resp = Hc.Get(PathAndQuery); + REQUIRE(Resp.IsSuccess()); + CHECK(Resp.AsText() == TestData); + + // Cleanup + Client.DeleteObject("presigned/test.txt"); + } + + Minio.StopMinioServer(); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/cloud/sigv4.cpp b/src/zenutil/cloud/sigv4.cpp new file mode 100644 index 000000000..055ccb2ad --- /dev/null +++ b/src/zenutil/cloud/sigv4.cpp @@ -0,0 +1,531 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cloud/sigv4.h> + +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <algorithm> +#include <chrono> +#include <cstring> +#include <ctime> + +// Platform-specific crypto backends +#if ZEN_PLATFORM_WINDOWS +# define ZEN_S3_USE_BCRYPT 1 +#else +# define ZEN_S3_USE_BCRYPT 0 +#endif + +#ifndef ZEN_S3_USE_OPENSSL +# if ZEN_S3_USE_BCRYPT +# define ZEN_S3_USE_OPENSSL 0 +# else +# define ZEN_S3_USE_OPENSSL 1 +# endif +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> + +#if ZEN_S3_USE_OPENSSL +# include <openssl/evp.h> +#elif ZEN_S3_USE_BCRYPT +# include <zencore/windows.h> +# include <bcrypt.h> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// SHA-256 + +#if ZEN_S3_USE_OPENSSL + +Sha256Digest +ComputeSha256(const void* Data, size_t Size) +{ + Sha256Digest Result; + unsigned int Len = 0; + EVP_Digest(Data, Size, Result.data(), &Len, EVP_sha256(), nullptr); + ZEN_ASSERT(Len == 32); + return Result; +} + +Sha256Digest +ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) +{ + Sha256Digest Result; + + EVP_MAC* Mac = EVP_MAC_fetch(nullptr, "HMAC", nullptr); + ZEN_ASSERT(Mac != nullptr); + + EVP_MAC_CTX* Ctx = EVP_MAC_CTX_new(Mac); + ZEN_ASSERT(Ctx != nullptr); + + OSSL_PARAM Params[] = { + OSSL_PARAM_construct_utf8_string("digest", const_cast<char*>("SHA256"), 0), + OSSL_PARAM_construct_end(), + }; + + int Rc = EVP_MAC_init(Ctx, reinterpret_cast<const unsigned char*>(Key), KeySize, Params); + ZEN_ASSERT(Rc == 1); + + Rc = EVP_MAC_update(Ctx, reinterpret_cast<const unsigned char*>(Data), DataSize); + ZEN_ASSERT(Rc == 1); + + size_t OutLen = 0; + Rc = EVP_MAC_final(Ctx, Result.data(), &OutLen, Result.size()); + ZEN_ASSERT(Rc == 1); + ZEN_ASSERT(OutLen == 32); + + EVP_MAC_CTX_free(Ctx); + EVP_MAC_free(Mac); + + return Result; +} + +#elif ZEN_S3_USE_BCRYPT + +namespace { + +# define NT_SUCCESS(Status) (((NTSTATUS)(Status)) >= 0) + + Sha256Digest BcryptHash(BCRYPT_ALG_HANDLE Algorithm, const void* Data, size_t DataSize) + { + Sha256Digest Result; + BCRYPT_HASH_HANDLE HashHandle = nullptr; + NTSTATUS Status; + + Status = BCryptCreateHash(Algorithm, &HashHandle, nullptr, 0, nullptr, 0, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptHashData(HashHandle, (PUCHAR)Data, (ULONG)DataSize, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptFinishHash(HashHandle, Result.data(), (ULONG)Result.size(), 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + BCryptDestroyHash(HashHandle); + return Result; + } + + Sha256Digest BcryptHmac(BCRYPT_ALG_HANDLE Algorithm, const void* Key, size_t KeySize, const void* Data, size_t DataSize) + { + Sha256Digest Result; + BCRYPT_HASH_HANDLE HashHandle = nullptr; + NTSTATUS Status; + + Status = BCryptCreateHash(Algorithm, &HashHandle, nullptr, 0, (PUCHAR)Key, (ULONG)KeySize, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptHashData(HashHandle, (PUCHAR)Data, (ULONG)DataSize, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + Status = BCryptFinishHash(HashHandle, Result.data(), (ULONG)Result.size(), 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + + BCryptDestroyHash(HashHandle); + return Result; + } + + struct BcryptAlgorithmHandles + { + BCRYPT_ALG_HANDLE Sha256 = nullptr; + BCRYPT_ALG_HANDLE HmacSha256 = nullptr; + + BcryptAlgorithmHandles() + { + NTSTATUS Status; + Status = BCryptOpenAlgorithmProvider(&Sha256, BCRYPT_SHA256_ALGORITHM, nullptr, 0); + ZEN_ASSERT(NT_SUCCESS(Status)); + Status = BCryptOpenAlgorithmProvider(&HmacSha256, BCRYPT_SHA256_ALGORITHM, nullptr, BCRYPT_ALG_HANDLE_HMAC_FLAG); + ZEN_ASSERT(NT_SUCCESS(Status)); + } + + ~BcryptAlgorithmHandles() + { + if (Sha256) + { + BCryptCloseAlgorithmProvider(Sha256, 0); + } + if (HmacSha256) + { + BCryptCloseAlgorithmProvider(HmacSha256, 0); + } + } + }; + + BcryptAlgorithmHandles& GetBcryptHandles() + { + static BcryptAlgorithmHandles s_Handles; + return s_Handles; + } + +} // namespace + +Sha256Digest +ComputeSha256(const void* Data, size_t Size) +{ + return BcryptHash(GetBcryptHandles().Sha256, Data, Size); +} + +Sha256Digest +ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize) +{ + return BcryptHmac(GetBcryptHandles().HmacSha256, Key, KeySize, Data, DataSize); +} + +#endif + +Sha256Digest +ComputeSha256(std::string_view Data) +{ + return ComputeSha256(Data.data(), Data.size()); +} + +Sha256Digest +ComputeHmacSha256(const Sha256Digest& Key, std::string_view Data) +{ + return ComputeHmacSha256(Key.data(), Key.size(), Data.data(), Data.size()); +} + +std::string +Sha256ToHex(const Sha256Digest& Digest) +{ + std::string Result; + Result.reserve(64); + for (uint8_t Byte : Digest) + { + fmt::format_to(std::back_inserter(Result), "{:02x}", Byte); + } + return Result; +} + +void +SecureZeroSecret(void* Data, size_t Size) +{ +#if ZEN_PLATFORM_WINDOWS + SecureZeroMemory(Data, Size); +#elif ZEN_PLATFORM_LINUX + explicit_bzero(Data, Size); +#else + // Portable fallback: volatile pointer prevents the compiler from optimizing away the memset + static void* (*const volatile VolatileMemset)(void*, int, size_t) = memset; + VolatileMemset(Data, 0, Size); +#endif +} + +////////////////////////////////////////////////////////////////////////// +// SigV4 signing + +namespace { + + std::string GetDateStamp(std::string_view AmzDate) + { + // AmzDate is "YYYYMMDDTHHMMSSZ", date stamp is first 8 chars + return std::string(AmzDate.substr(0, 8)); + } + +} // namespace + +std::string +GetAmzTimestamp() +{ + auto Now = std::chrono::system_clock::now(); + std::time_t NowTime = std::chrono::system_clock::to_time_t(Now); + + struct tm Tm; +#if ZEN_PLATFORM_WINDOWS + gmtime_s(&Tm, &NowTime); +#else + gmtime_r(&NowTime, &Tm); +#endif + + char Buf[32]; + std::strftime(Buf, sizeof(Buf), "%Y%m%dT%H%M%SZ", &Tm); + return std::string(Buf); +} + +std::string +AwsUriEncode(std::string_view Input, bool EncodeSlash) +{ + ExtendableStringBuilder<256> Result; + for (char C : Input) + { + if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '_' || C == '-' || C == '~' || C == '.') + { + Result.Append(C); + } + else if (C == '/' && !EncodeSlash) + { + Result.Append(C); + } + else + { + Result.Append(fmt::format("%{:02X}", static_cast<unsigned char>(C))); + } + } + return std::string(Result.ToView()); +} + +std::string +BuildCanonicalQueryString(std::vector<std::pair<std::string, std::string>> Parameters) +{ + if (Parameters.empty()) + { + return {}; + } + + // Sort by key name, then by value (as required by SigV4) + std::sort(Parameters.begin(), Parameters.end()); + + ExtendableStringBuilder<512> Result; + for (size_t i = 0; i < Parameters.size(); ++i) + { + if (i > 0) + { + Result.Append('&'); + } + Result.Append(AwsUriEncode(Parameters[i].first)); + Result.Append('='); + Result.Append(AwsUriEncode(Parameters[i].second)); + } + return std::string(Result.ToView()); +} + +SigV4SignedHeaders +SignRequestV4(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Url, + std::string_view CanonicalQueryString, + std::string_view Region, + std::string_view Service, + std::string_view AmzDate, + const std::vector<std::pair<std::string, std::string>>& Headers, + std::string_view PayloadHash, + const Sha256Digest* SigningKeyPtr) +{ + SigV4SignedHeaders Result; + Result.AmzDate = std::string(AmzDate); + Result.PayloadHash = std::string(PayloadHash); + + std::string DateStamp = GetDateStamp(Result.AmzDate); + + // Step 1: Create canonical request + // CanonicalRequest = + // HTTPRequestMethod + '\n' + + // CanonicalURI + '\n' + + // CanonicalQueryString + '\n' + + // CanonicalHeaders + '\n' + + // SignedHeaders + '\n' + + // HexEncode(Hash(RequestPayload)) + + std::string CanonicalUri = AwsUriEncode(Url, false); + + // Build canonical headers and signed headers (headers must be sorted by lowercase name) + ExtendableStringBuilder<512> CanonicalHeadersSb; + ExtendableStringBuilder<256> SignedHeadersSb; + + for (size_t i = 0; i < Headers.size(); ++i) + { + CanonicalHeadersSb.Append(Headers[i].first); + CanonicalHeadersSb.Append(':'); + CanonicalHeadersSb.Append(Headers[i].second); + CanonicalHeadersSb.Append('\n'); + + if (i > 0) + { + SignedHeadersSb.Append(';'); + } + SignedHeadersSb.Append(Headers[i].first); + } + + std::string SignedHeaders = std::string(SignedHeadersSb.ToView()); + + std::string CanonicalRequest = fmt::format("{}\n{}\n{}\n{}\n{}\n{}", + Method, + CanonicalUri, + CanonicalQueryString, + CanonicalHeadersSb.ToView(), + SignedHeaders, + PayloadHash); + + // Step 2: Create the string to sign + std::string CredentialScope = fmt::format("{}/{}/{}/aws4_request", DateStamp, Region, Service); + + Sha256Digest CanonicalRequestHash = ComputeSha256(CanonicalRequest); + std::string CanonicalRequestHex = Sha256ToHex(CanonicalRequestHash); + + std::string StringToSign = fmt::format("AWS4-HMAC-SHA256\n{}\n{}\n{}", Result.AmzDate, CredentialScope, CanonicalRequestHex); + + // Step 3: Calculate the signing key + // kDate = HMAC("AWS4" + SecretKey, DateStamp) + // kRegion = HMAC(kDate, Region) + // kService = HMAC(kRegion, Service) + // kSigning = HMAC(kService, "aws4_request") + + Sha256Digest DerivedSigningKey; + if (!SigningKeyPtr) + { + std::string SecretPrefix = fmt::format("AWS4{}", Credentials.SecretAccessKey); + + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, Service); + DerivedSigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + SigningKeyPtr = &DerivedSigningKey; + } + + // Step 4: Calculate the signature + Sha256Digest Signature = ComputeHmacSha256(*SigningKeyPtr, StringToSign); + std::string SignatureHex = Sha256ToHex(Signature); + + // Step 5: Build the Authorization header + Result.Authorization = fmt::format("AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + Credentials.AccessKeyId, + CredentialScope, + SignedHeaders, + SignatureHex); + + return Result; +} + +std::string +GeneratePresignedUrl(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Scheme, + std::string_view Host, + std::string_view Path, + std::string_view Region, + std::string_view Service, + std::chrono::seconds ExpiresIn, + const std::vector<std::pair<std::string, std::string>>& ExtraQueryParams) +{ + // Pre-signed URLs use query string authentication: + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + + std::string AmzDate = GetAmzTimestamp(); + std::string DateStamp = GetDateStamp(AmzDate); + + std::string CredentialScope = fmt::format("{}/{}/{}/aws4_request", DateStamp, Region, Service); + std::string Credential = fmt::format("{}/{}", Credentials.AccessKeyId, CredentialScope); + + // The only signed header for pre-signed URLs is "host" + constexpr std::string_view SignedHeaders = "host"; + + // Build query parameters that will be part of the canonical request. + // These are the auth params (minus X-Amz-Signature which is added after signing). + std::vector<std::pair<std::string, std::string>> QueryParams = ExtraQueryParams; + QueryParams.emplace_back("X-Amz-Algorithm", "AWS4-HMAC-SHA256"); + QueryParams.emplace_back("X-Amz-Credential", Credential); + QueryParams.emplace_back("X-Amz-Date", AmzDate); + QueryParams.emplace_back("X-Amz-Expires", fmt::format("{}", ExpiresIn.count())); + if (!Credentials.SessionToken.empty()) + { + QueryParams.emplace_back("X-Amz-Security-Token", Credentials.SessionToken); + } + QueryParams.emplace_back("X-Amz-SignedHeaders", std::string(SignedHeaders)); + + std::string CanonicalQueryString = BuildCanonicalQueryString(QueryParams); + std::string CanonicalUri = AwsUriEncode(Path, false); + + // For pre-signed URLs, the payload is always UNSIGNED-PAYLOAD + constexpr std::string_view PayloadHash = "UNSIGNED-PAYLOAD"; + + // Build the canonical request + // Only "host" is in the canonical headers for pre-signed URLs + std::string CanonicalHeaders = fmt::format("host:{}\n", Host); + + std::string CanonicalRequest = + fmt::format("{}\n{}\n{}\n{}\n{}\n{}", Method, CanonicalUri, CanonicalQueryString, CanonicalHeaders, SignedHeaders, PayloadHash); + + // Create the string to sign + Sha256Digest CanonicalRequestHash = ComputeSha256(CanonicalRequest); + std::string CanonicalRequestHex = Sha256ToHex(CanonicalRequestHash); + + std::string StringToSign = fmt::format("AWS4-HMAC-SHA256\n{}\n{}\n{}", AmzDate, CredentialScope, CanonicalRequestHex); + + // Calculate the signing key + std::string SecretPrefix = fmt::format("AWS4{}", Credentials.SecretAccessKey); + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), DateStamp.data(), DateStamp.size()); + SecureZeroSecret(SecretPrefix.data(), SecretPrefix.size()); + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, Region); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, Service); + Sha256Digest SigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + + // Calculate the signature + std::string SignatureHex = Sha256ToHex(ComputeHmacSha256(SigningKey, StringToSign)); + + // Build the final URL (use the URI-encoded path so special characters are properly escaped) + return fmt::format("{}://{}{}?{}&X-Amz-Signature={}", Scheme, Host, CanonicalUri, CanonicalQueryString, SignatureHex); +} + +////////////////////////////////////////////////////////////////////////// +// Tests + +#if ZEN_WITH_TESTS + +void +sigv4_forcelink() +{ +} + +TEST_SUITE_BEGIN("util.cloud.sigv4"); + +TEST_CASE("sigv4.sha256") +{ + // Test with known test vector (empty string) + Sha256Digest Empty = ComputeSha256("", 0); + std::string Hex = Sha256ToHex(Empty); + CHECK(Hex == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + + // Test with "hello" + Sha256Digest Hello = ComputeSha256("hello"); + std::string HelloHex = Sha256ToHex(Hello); + CHECK(HelloHex == "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"); +} + +TEST_CASE("sigv4.hmac_sha256") +{ + // RFC 4231 Test Case 2 + std::string_view Key = "Jefe"; + std::string_view Data = "what do ya want for nothing?"; + + Sha256Digest Result = ComputeHmacSha256(Key.data(), Key.size(), Data.data(), Data.size()); + std::string Hex = Sha256ToHex(Result); + CHECK(Hex == "5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843"); +} + +TEST_CASE("sigv4.signing") +{ + // Based on the AWS SigV4 test suite example + // https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + + SigV4Credentials Creds; + Creds.AccessKeyId = "AKIDEXAMPLE"; + Creds.SecretAccessKey = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + + // We can't test with a fixed timestamp since SignRequestV4 uses current time, + // but we can verify the crypto primitives produce correct results by testing + // the signing key derivation manually. + + // Test signing key derivation: HMAC chain for "20150830" / "us-east-1" / "iam" + std::string SecretPrefix = "AWS4wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + Sha256Digest DateKey = ComputeHmacSha256(SecretPrefix.data(), SecretPrefix.size(), "20150830", 8); + Sha256Digest RegionKey = ComputeHmacSha256(DateKey, "us-east-1"); + Sha256Digest ServiceKey = ComputeHmacSha256(RegionKey, "iam"); + Sha256Digest SigningKey = ComputeHmacSha256(ServiceKey, "aws4_request"); + + std::string SigningKeyHex = Sha256ToHex(SigningKey); + CHECK(SigningKeyHex == "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"); +} + +TEST_SUITE_END(); + +#endif + +} // namespace zen diff --git a/src/zenutil/consoletui.cpp b/src/zenutil/consoletui.cpp index 4410d463d..124132aed 100644 --- a/src/zenutil/consoletui.cpp +++ b/src/zenutil/consoletui.cpp @@ -480,4 +480,69 @@ TuiPollQuit() #endif } +void +TuiSetScrollRegion(uint32_t Top, uint32_t Bottom) +{ + printf("\033[%u;%ur", Top, Bottom); +} + +void +TuiResetScrollRegion() +{ + printf("\033[r"); +} + +void +TuiMoveCursor(uint32_t Row, uint32_t Col) +{ + printf("\033[%u;%uH", Row, Col); +} + +void +TuiSaveCursor() +{ + printf( + "\033" + "7"); +} + +void +TuiRestoreCursor() +{ + printf( + "\033" + "8"); +} + +void +TuiEraseLine() +{ + printf("\033[2K"); +} + +void +TuiWrite(std::string_view Text) +{ + fwrite(Text.data(), 1, Text.size(), stdout); +} + +void +TuiFlush() +{ + fflush(stdout); +} + +void +TuiShowCursor(bool Show) +{ + if (Show) + { + printf("\033[?25h"); + } + else + { + printf("\033[?25l"); + } +} + } // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/cloudprovider.h b/src/zenutil/include/zenutil/cloud/cloudprovider.h new file mode 100644 index 000000000..5825eb308 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/cloudprovider.h @@ -0,0 +1,19 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <string_view> + +namespace zen::compute { + +enum class CloudProvider +{ + None, + AWS, + Azure, + GCP +}; + +std::string_view ToString(CloudProvider Provider); + +} // namespace zen::compute diff --git a/src/zenutil/include/zenutil/cloud/imdscredentials.h b/src/zenutil/include/zenutil/cloud/imdscredentials.h new file mode 100644 index 000000000..33df5a1e2 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/imdscredentials.h @@ -0,0 +1,58 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/sigv4.h> + +#include <zenbase/refcount.h> +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zenhttp/httpclient.h> + +#include <chrono> +#include <string> + +namespace zen { + +struct ImdsCredentialProviderOptions +{ + std::string Endpoint = "http://169.254.169.254"; // Override for testing + std::chrono::milliseconds ConnectTimeout{1000}; + std::chrono::milliseconds RequestTimeout{5000}; +}; + +/// Fetches and caches temporary AWS credentials from the EC2 Instance Metadata +/// Service (IMDSv2). Thread-safe; credentials are refreshed automatically before +/// they expire. +class ImdsCredentialProvider : public RefCounted +{ +public: + explicit ImdsCredentialProvider(const ImdsCredentialProviderOptions& Options = {}); + ~ImdsCredentialProvider(); + + /// Fetch or return cached credentials. Thread-safe. + /// Returns empty credentials (empty AccessKeyId) on failure. + SigV4Credentials GetCredentials(); + + /// Force a refresh on next GetCredentials() call. + void InvalidateCache(); + +private: + bool FetchToken(); + bool FetchCredentials(); + + LoggerRef Log() { return m_Log; } + + LoggerRef m_Log; + HttpClient m_HttpClient; + + mutable RwLock m_Lock; + std::string m_ImdsToken; + SigV4Credentials m_CachedCredentials; + std::string m_RoleName; + std::chrono::steady_clock::time_point m_ExpiresAt; +}; + +void imdscredentials_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/minioprocess.h b/src/zenutil/include/zenutil/cloud/minioprocess.h new file mode 100644 index 000000000..7af350e60 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/minioprocess.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/zenbase.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> + +namespace zen { + +struct MinioProcessOptions +{ + uint16_t Port = 9000; + std::string RootUser = "minioadmin"; + std::string RootPassword = "minioadmin"; +}; + +class MinioProcess +{ +public: + explicit MinioProcess(const MinioProcessOptions& Options = {}); + ~MinioProcess(); + + MinioProcess(const MinioProcess&) = delete; + MinioProcess& operator=(const MinioProcess&) = delete; + + void SpawnMinioServer(); + void StopMinioServer(); + + /// Pre-create a bucket by creating a subdirectory in the MinIO data directory. + /// Can be called before or after SpawnMinioServer(). MinIO discovers these at startup + /// and also picks up new directories at runtime. + void CreateBucket(std::string_view BucketName); + + uint16_t Port() const; + std::string_view RootUser() const; + std::string_view RootPassword() const; + std::string Endpoint() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/mockimds.h b/src/zenutil/include/zenutil/cloud/mockimds.h new file mode 100644 index 000000000..d0c0155b0 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/mockimds.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <zenutil/cloud/cloudprovider.h> + +#include <string> + +#if ZEN_WITH_TESTS + +namespace zen::compute { + +/** + * Mock IMDS (Instance Metadata Service) for testing cloud metadata and + * credential providers. + * + * Implements an HttpService that responds to the same URL paths as the real + * cloud provider metadata endpoints (AWS IMDSv2, Azure IMDS, GCP metadata). + * Tests configure which provider is "active" and set the desired response + * values, then pass the mock server's address as the ImdsEndpoint to the + * CloudMetadata constructor. + * + * When a request arrives for a provider that is not the ActiveProvider, the + * mock returns 404, causing CloudMetadata to write a sentinel file and move + * on to the next provider — exactly like a failed probe on bare metal. + * + * All config fields are public and can be mutated between poll cycles to + * simulate state changes (e.g. a spot interruption appearing mid-run). + * + * Usage: + * MockImdsService Mock; + * Mock.ActiveProvider = CloudProvider::AWS; + * Mock.Aws.InstanceId = "i-test"; + * // ... stand up ASIO server, register Mock, create CloudMetadata with endpoint + */ +class MockImdsService : public HttpService +{ +public: + /** AWS IMDSv2 response configuration. */ + struct AwsConfig + { + std::string Token = "mock-aws-token-v2"; + std::string InstanceId = "i-0123456789abcdef0"; + std::string AvailabilityZone = "us-east-1a"; + std::string LifeCycle = "on-demand"; // "spot" or "on-demand" + + // Empty string → endpoint returns 404 (instance not in an ASG). + // Non-empty → returned as the response body. "InService" means healthy; + // anything else (e.g. "Terminated:Wait") triggers termination detection. + std::string AutoscalingState; + + // Empty string → endpoint returns 404 (no spot interruption). + // Non-empty → returned as the response body, signalling a spot reclaim. + std::string SpotAction; + + // IAM credential fields for ImdsCredentialProvider testing + std::string IamRoleName = "test-role"; + std::string IamAccessKeyId = "ASIAIOSFODNN7EXAMPLE"; + std::string IamSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"; + std::string IamSessionToken = "FwoGZXIvYXdzEBYaDEXAMPLETOKEN"; + std::string IamExpiration = "2099-01-01T00:00:00Z"; + }; + + /** Azure IMDS response configuration. */ + struct AzureConfig + { + std::string VmId = "vm-12345678-1234-1234-1234-123456789abc"; + std::string Location = "eastus"; + std::string Priority = "Regular"; // "Spot" or "Regular" + + // Empty → instance is not in a VM Scale Set (no autoscaling). + std::string VmScaleSetName; + + // Empty → no scheduled events. Set to "Preempt", "Terminate", or + // "Reboot" to simulate a termination-class event. + std::string ScheduledEventType; + std::string ScheduledEventStatus = "Scheduled"; + }; + + /** GCP metadata response configuration. */ + struct GcpConfig + { + std::string InstanceId = "1234567890123456789"; + std::string Zone = "projects/123456/zones/us-central1-a"; + std::string Preemptible = "FALSE"; // "TRUE" or "FALSE" + std::string MaintenanceEvent = "NONE"; // "NONE" or event description + }; + + /** Which provider's endpoints respond successfully. + * Requests targeting other providers receive 404. + */ + CloudProvider ActiveProvider = CloudProvider::None; + + AwsConfig Aws; + AzureConfig Azure; + GcpConfig Gcp; + + const char* BaseUri() const override; + void HandleRequest(HttpServerRequest& Request) override; + +private: + void HandleAwsRequest(HttpServerRequest& Request); + void HandleAzureRequest(HttpServerRequest& Request); + void HandleGcpRequest(HttpServerRequest& Request); +}; + +} // namespace zen::compute + +#endif // ZEN_WITH_TESTS diff --git a/src/zenutil/include/zenutil/cloud/s3client.h b/src/zenutil/include/zenutil/cloud/s3client.h new file mode 100644 index 000000000..47501c5b5 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/s3client.h @@ -0,0 +1,215 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cloud/imdscredentials.h> +#include <zenutil/cloud/sigv4.h> + +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zenhttp/httpclient.h> + +#include <zencore/thread.h> + +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +struct S3ClientOptions +{ + std::string Region = "us-east-1"; + std::string BucketName; + std::string Endpoint; // e.g., "https://s3.us-east-1.amazonaws.com". If empty, derived from Region. + + /// Use path-style addressing (endpoint/bucket/key) instead of virtual-hosted style + /// (bucket.endpoint/key). Required for S3-compatible services like MinIO that don't + /// support virtual-hosted style. + bool PathStyle = false; + + SigV4Credentials Credentials; + + /// When set, credentials are fetched from EC2 IMDS on demand. + /// Overrides the static Credentials field. + Ref<ImdsCredentialProvider> CredentialProvider; + + std::chrono::milliseconds ConnectTimeout{5000}; + std::chrono::milliseconds Timeout{}; + uint8_t RetryCount = 3; +}; + +struct S3ObjectInfo +{ + std::string Key; + uint64_t Size = 0; + std::string ETag; + std::string LastModified; +}; + +/// Result type for S3 operations. Empty Error string indicates success. +struct S3Result +{ + std::string Error; + + bool IsSuccess() const { return Error.empty(); } + explicit operator bool() const { return IsSuccess(); } +}; + +enum class HeadObjectResult +{ + Found, + NotFound, + Error, +}; + +/// Result of GetObject — carries the downloaded content. +struct S3GetObjectResult : S3Result +{ + IoBuffer Content; + + std::string_view AsText() const { return std::string_view(reinterpret_cast<const char*>(Content.GetData()), Content.GetSize()); } +}; + +/// Result of HeadObject — carries object metadata and existence status. +struct S3HeadObjectResult : S3Result +{ + S3ObjectInfo Info; + HeadObjectResult Status = HeadObjectResult::NotFound; +}; + +/// Result of ListObjects — carries the list of matching objects. +struct S3ListObjectsResult : S3Result +{ + std::vector<S3ObjectInfo> Objects; +}; + +/// Result of CreateMultipartUpload — carries the upload ID. +struct S3CreateMultipartUploadResult : S3Result +{ + std::string UploadId; +}; + +/// Result of UploadPart — carries the part ETag. +struct S3UploadPartResult : S3Result +{ + std::string ETag; +}; + +/// Client for S3-compatible object storage. +/// +/// Supports basic object operations (GET, PUT, DELETE, HEAD), listing, multipart +/// uploads, and pre-signed URL generation. Requests are authenticated with AWS +/// Signature Version 4; the signing key is cached per day to avoid redundant HMAC +/// derivation. +/// +/// Limitations: +/// - Multipart uploads are sequential (no parallel part upload). +/// - XML responses are parsed with a minimal tag extractor that only decodes the five +/// standard XML entities; CDATA sections and nested/namespaced tags are not handled. +/// - Automatic credential refresh is supported via ImdsCredentialProvider. +class S3Client +{ +public: + explicit S3Client(const S3ClientOptions& Options); + ~S3Client(); + + /// Upload an object to S3 + S3Result PutObject(std::string_view Key, IoBuffer Content); + + /// Download an object from S3 + S3GetObjectResult GetObject(std::string_view Key); + + /// Delete an object from S3 + S3Result DeleteObject(std::string_view Key); + + /// Check if an object exists and get its metadata + S3HeadObjectResult HeadObject(std::string_view Key); + + /// List objects with the given prefix + /// @param MaxKeys Maximum number of keys to return (0 = default/1000) + S3ListObjectsResult ListObjects(std::string_view Prefix, uint32_t MaxKeys = 0); + + /// Multipart upload: initiate a multipart upload and return the upload ID + S3CreateMultipartUploadResult CreateMultipartUpload(std::string_view Key); + + /// Multipart upload: upload a single part + /// @param PartNumber Part number (1-based, 1 to 10000) + /// @param Content The part data (minimum 5 MB except for the last part) + S3UploadPartResult UploadPart(std::string_view Key, std::string_view UploadId, uint32_t PartNumber, IoBuffer Content); + + /// Multipart upload: complete a multipart upload by assembling previously uploaded parts + /// @param PartETags List of {part_number, etag} pairs from UploadPart calls + S3Result CompleteMultipartUpload(std::string_view Key, + std::string_view UploadId, + const std::vector<std::pair<uint32_t, std::string>>& PartETags); + + /// Multipart upload: abort an in-progress multipart upload, discarding all uploaded parts + S3Result AbortMultipartUpload(std::string_view Key, std::string_view UploadId); + + /// High-level multipart upload: automatically splits content into parts and uploads + /// @param PartSize Size of each part in bytes (minimum 5 MB, default 8 MB) + S3Result PutObjectMultipart(std::string_view Key, IoBuffer Content, uint64_t PartSize = 8 * 1024 * 1024); + + /// Generate a pre-signed URL for downloading an object (GET) + /// @param Key The object key + /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) + std::string GeneratePresignedGetUrl(std::string_view Key, std::chrono::seconds ExpiresIn = std::chrono::hours(1)); + + /// Generate a pre-signed URL for uploading an object (PUT) + /// @param Key The object key + /// @param ExpiresIn URL validity duration (default 1 hour, max 7 days) + std::string GeneratePresignedPutUrl(std::string_view Key, std::chrono::seconds ExpiresIn = std::chrono::hours(1)); + + std::string_view BucketName() const { return m_BucketName; } + std::string_view Region() const { return m_Region; } + +private: + /// Shared implementation for pre-signed URL generation + std::string GeneratePresignedUrlForMethod(std::string_view Key, std::string_view Method, std::chrono::seconds ExpiresIn); + + LoggerRef Log() { return m_Log; } + + /// Build the endpoint URL for the bucket + std::string BuildEndpoint() const; + + /// Build the host header value + std::string BuildHostHeader() const; + + /// Build the S3 object path from a key, accounting for path-style addressing + std::string KeyToPath(std::string_view Key) const; + + /// Build the bucket root path ("/" for virtual-hosted, "/bucket/" for path-style) + std::string BucketRootPath() const; + + /// Sign a request and return headers with Authorization, x-amz-date, x-amz-content-sha256 + HttpClient::KeyValueMap SignRequest(std::string_view Method, + std::string_view Path, + std::string_view QueryString, + std::string_view PayloadHash); + + /// Get or compute the signing key for the given date stamp, caching across requests on the same day + Sha256Digest GetSigningKey(std::string_view DateStamp); + + /// Get the current credentials, either from the provider or from static config + SigV4Credentials GetCurrentCredentials(); + + LoggerRef m_Log; + std::string m_BucketName; + std::string m_Region; + std::string m_Endpoint; + std::string m_Host; + bool m_PathStyle; + SigV4Credentials m_Credentials; + Ref<ImdsCredentialProvider> m_CredentialProvider; + HttpClient m_HttpClient; + + // Cached signing key (only changes once per day, protected by RwLock for thread safety) + mutable RwLock m_SigningKeyLock; + std::string m_CachedDateStamp; + Sha256Digest m_CachedSigningKey{}; +}; + +void s3client_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cloud/sigv4.h b/src/zenutil/include/zenutil/cloud/sigv4.h new file mode 100644 index 000000000..9ac08df76 --- /dev/null +++ b/src/zenutil/include/zenutil/cloud/sigv4.h @@ -0,0 +1,116 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <array> +#include <chrono> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +/// SHA-256 digest (32 bytes) +using Sha256Digest = std::array<uint8_t, 32>; + +/// Compute SHA-256 hash of the given data +Sha256Digest ComputeSha256(const void* Data, size_t Size); +Sha256Digest ComputeSha256(std::string_view Data); + +/// Compute HMAC-SHA256 with the given key and data +Sha256Digest ComputeHmacSha256(const void* Key, size_t KeySize, const void* Data, size_t DataSize); +Sha256Digest ComputeHmacSha256(const Sha256Digest& Key, std::string_view Data); + +/// Convert a SHA-256 digest to lowercase hex string +std::string Sha256ToHex(const Sha256Digest& Digest); + +/// Securely zero memory containing secret key material (prevents compiler from optimizing away) +void SecureZeroSecret(void* Data, size_t Size); + +/// AWS Signature Version 4 signing + +struct SigV4Credentials +{ + std::string AccessKeyId; + std::string SecretAccessKey; + std::string SessionToken; // Optional; required for temporary credentials (STS/SSO) +}; + +struct SigV4SignedHeaders +{ + /// The value for the "Authorization" header + std::string Authorization; + + /// The ISO 8601 date-time string used in signing (for x-amz-date header) + std::string AmzDate; + + /// The SHA-256 hex digest of the payload (for x-amz-content-sha256 header) + std::string PayloadHash; +}; + +/// Get the current UTC timestamp in ISO 8601 format (YYYYMMDDTHHMMSSZ) +std::string GetAmzTimestamp(); + +/// URI-encode a string per AWS requirements (RFC 3986 unreserved chars are not encoded) +/// @param EncodeSlash If false, '/' is left unencoded (use for URI paths) +std::string AwsUriEncode(std::string_view Input, bool EncodeSlash = true); + +/// Build a canonical query string from key=value pairs. +/// Parameters are URI-encoded and sorted by key name as required by SigV4. +/// Takes parameters by value to sort in-place without copying. +std::string BuildCanonicalQueryString(std::vector<std::pair<std::string, std::string>> Parameters); + +/// Sign an HTTP request using AWS Signature Version 4 +/// +/// @param Credentials AWS access key and secret key +/// @param Method HTTP method (GET, PUT, DELETE, HEAD, etc.) +/// @param Url The path portion of the URL (e.g., "/bucket/key") +/// @param CanonicalQueryString Pre-built canonical query string (use BuildCanonicalQueryString) +/// @param Region The AWS region (e.g., "us-east-1") +/// @param Service The AWS service (e.g., "s3") +/// @param AmzDate The ISO 8601 date-time string (from GetAmzTimestamp()) +/// @param Headers Sorted list of {lowercase-header-name, value} pairs to sign. +/// Must include "host" and "x-amz-content-sha256". +/// Should NOT include "authorization". +/// @param PayloadHash Hex SHA-256 hash of the request payload. Use +/// "UNSIGNED-PAYLOAD" for unsigned payloads. +/// @param SigningKey Optional pre-computed signing key. If null, derived from Credentials + date + Region + Service. +SigV4SignedHeaders SignRequestV4(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Url, + std::string_view CanonicalQueryString, + std::string_view Region, + std::string_view Service, + std::string_view AmzDate, + const std::vector<std::pair<std::string, std::string>>& Headers, + std::string_view PayloadHash, + const Sha256Digest* SigningKey = nullptr); + +/// Generate a pre-signed URL using AWS Signature Version 4 query string authentication. +/// +/// The returned URL can be used by anyone (no credentials needed) until it expires. +/// +/// @param Credentials AWS access key and secret key +/// @param Method HTTP method the URL will be used with (typically "GET" or "PUT") +/// @param Scheme URL scheme ("https" or "http") +/// @param Host The host (e.g., "bucket.s3.us-east-1.amazonaws.com") +/// @param Path The path portion (e.g., "/key") +/// @param Region The AWS region (e.g., "us-east-1") +/// @param Service The AWS service (e.g., "s3") +/// @param ExpiresIn URL validity duration +/// @param ExtraQueryParams Additional query parameters to include (e.g., response-content-type) +std::string GeneratePresignedUrl(const SigV4Credentials& Credentials, + std::string_view Method, + std::string_view Scheme, + std::string_view Host, + std::string_view Path, + std::string_view Region, + std::string_view Service, + std::chrono::seconds ExpiresIn, + const std::vector<std::pair<std::string, std::string>>& ExtraQueryParams = {}); + +void sigv4_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/consoletui.h b/src/zenutil/include/zenutil/consoletui.h index 5f74fa82b..22737589b 100644 --- a/src/zenutil/include/zenutil/consoletui.h +++ b/src/zenutil/include/zenutil/consoletui.h @@ -57,4 +57,32 @@ uint32_t TuiConsoleRows(uint32_t Default = 40); // Should only be called while in alternate screen mode. bool TuiPollQuit(); +// Set the scrollable region of the terminal to rows [Top, Bottom] (1-based, inclusive). +// Emits DECSTBM: ESC[<top>;<bottom>r +void TuiSetScrollRegion(uint32_t Top, uint32_t Bottom); + +// Reset the scroll region to the full terminal. Emits ESC[r +void TuiResetScrollRegion(); + +// Move cursor to the given 1-based row and column. Emits ESC[<row>;<col>H +void TuiMoveCursor(uint32_t Row, uint32_t Col); + +// Save cursor position. Emits ESC 7 +void TuiSaveCursor(); + +// Restore cursor position. Emits ESC 8 +void TuiRestoreCursor(); + +// Erase the entire current line. Emits ESC[2K +void TuiEraseLine(); + +// Write raw text to stdout without any formatting or newline. +void TuiWrite(std::string_view Text); + +// Flush stdout. +void TuiFlush(); + +// Show or hide the cursor. Emits ESC[?25h or ESC[?25l +void TuiShowCursor(bool Show); + } // namespace zen diff --git a/src/zenutil/include/zenutil/logging/fullformatter.h b/src/zenutil/include/zenutil/logging/fullformatter.h index 33cb94dae..0d026ed72 100644 --- a/src/zenutil/include/zenutil/logging/fullformatter.h +++ b/src/zenutil/include/zenutil/logging/fullformatter.h @@ -3,10 +3,8 @@ #pragma once #include <zencore/logging/formatter.h> -#include <zencore/logging/helpers.h> -#include <zencore/memory/llm.h> -#include <zencore/zencore.h> +#include <memory> #include <string_view> namespace zen::logging { @@ -14,195 +12,16 @@ namespace zen::logging { class FullFormatter final : public Formatter { public: - FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) - : m_Epoch(Epoch) - , m_LogId(LogId) - , m_LinePrefix(128, ' ') - , m_UseFullDate(false) - { - } + FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch); + explicit FullFormatter(std::string_view LogId); + ~FullFormatter() override; - FullFormatter(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} - - virtual std::unique_ptr<Formatter> Clone() const override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - if (m_UseFullDate) - { - return std::make_unique<FullFormatter>(m_LogId); - } - return std::make_unique<FullFormatter>(m_LogId, m_Epoch); - } - - virtual void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - // Note that the sink is responsible for ensuring there is only ever a - // single caller in here - - using namespace std::literals; - - std::chrono::seconds TimestampSeconds; - - std::chrono::milliseconds Millis; - - if (m_UseFullDate) - { - TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); - if (TimestampSeconds != m_LastLogSecs) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - m_LastLogSecs = TimestampSeconds; - - m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); - m_CachedDatetime.clear(); - m_CachedDatetime.push_back('['); - helpers::Pad2(m_CachedLocalTm.tm_year % 100, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - helpers::Pad2(m_CachedLocalTm.tm_mon + 1, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - helpers::Pad2(m_CachedLocalTm.tm_mday, m_CachedDatetime); - m_CachedDatetime.push_back(' '); - helpers::Pad2(m_CachedLocalTm.tm_hour, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(m_CachedLocalTm.tm_min, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(m_CachedLocalTm.tm_sec, m_CachedDatetime); - m_CachedDatetime.push_back('.'); - } - - Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); - } - else - { - auto ElapsedTime = Msg.GetTime() - m_Epoch; - TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime); - - if (m_CacheTimestamp.load() != TimestampSeconds) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - - m_CacheTimestamp = TimestampSeconds; - - int Count = int(TimestampSeconds.count()); - const int LogSecs = Count % 60; - Count /= 60; - const int LogMins = Count % 60; - Count /= 60; - const int LogHours = Count; - - m_CachedDatetime.clear(); - m_CachedDatetime.push_back('['); - helpers::Pad2(LogHours, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(LogMins, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - helpers::Pad2(LogSecs, m_CachedDatetime); - m_CachedDatetime.push_back('.'); - } - - Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds); - } - - { - RwLock::SharedLockScope _(m_TimestampLock); - OutBuffer.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - } - - helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - - if (!m_LogId.empty()) - { - OutBuffer.push_back('['); - helpers::AppendStringView(m_LogId, OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - - // append logger name if exists - if (Msg.GetLoggerName().size() > 0) - { - OutBuffer.push_back('['); - helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - - OutBuffer.push_back('['); - // wrap the level name with color - Msg.ColorRangeStart = OutBuffer.size(); - helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer); - Msg.ColorRangeEnd = OutBuffer.size(); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - - // add source location if present - if (Msg.GetSource()) - { - OutBuffer.push_back('['); - const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename); - helpers::AppendStringView(Filename, OutBuffer); - OutBuffer.push_back(':'); - helpers::AppendInt(Msg.GetSource().Line, OutBuffer); - OutBuffer.push_back(']'); - OutBuffer.push_back(' '); - } - - // Handle newlines in single log call by prefixing each additional line to make - // subsequent lines align with the first line in the message - - const size_t LinePrefixCount = Min<size_t>(OutBuffer.size(), m_LinePrefix.size()); - - auto MsgPayload = Msg.GetPayload(); - auto ItLineBegin = MsgPayload.begin(); - auto ItMessageEnd = MsgPayload.end(); - bool IsFirstline = true; - - { - auto ItLineEnd = ItLineBegin; - - auto EmitLine = [&] { - if (IsFirstline) - { - IsFirstline = false; - } - else - { - helpers::AppendStringView(std::string_view(m_LinePrefix.data(), LinePrefixCount), OutBuffer); - } - helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); - }; - - while (ItLineEnd != ItMessageEnd) - { - if (*ItLineEnd++ == '\n') - { - EmitLine(); - ItLineBegin = ItLineEnd; - } - } - - if (ItLineBegin != ItMessageEnd) - { - EmitLine(); - helpers::AppendStringView("\n"sv, OutBuffer); - } - } - } + std::unique_ptr<Formatter> Clone() const override; + void Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) override; private: - std::chrono::time_point<std::chrono::system_clock> m_Epoch; - std::tm m_CachedLocalTm; - std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; - std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)}; - MemoryBuffer m_CachedDatetime; - std::string m_LogId; - std::string m_LinePrefix; - bool m_UseFullDate = true; - RwLock m_TimestampLock; + struct Impl; + std::unique_ptr<Impl> m_Impl; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/jsonformatter.h b/src/zenutil/include/zenutil/logging/jsonformatter.h index 216b1b5e5..fb3193f3e 100644 --- a/src/zenutil/include/zenutil/logging/jsonformatter.h +++ b/src/zenutil/include/zenutil/logging/jsonformatter.h @@ -3,158 +3,24 @@ #pragma once #include <zencore/logging/formatter.h> -#include <zencore/logging/helpers.h> -#include <zencore/memory/llm.h> -#include <zencore/zencore.h> +#include <memory> #include <string_view> -#include <unordered_map> namespace zen::logging { -using namespace std::literals; - class JsonFormatter final : public Formatter { public: - JsonFormatter(std::string_view LogId) : m_LogId(LogId) {} - - virtual std::unique_ptr<Formatter> Clone() const override { return std::make_unique<JsonFormatter>(m_LogId); } - - virtual void Format(const LogMessage& Msg, MemoryBuffer& Dest) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - using std::chrono::duration_cast; - using std::chrono::milliseconds; - using std::chrono::seconds; - - auto Secs = std::chrono::duration_cast<seconds>(Msg.GetTime().time_since_epoch()); - if (Secs != m_LastLogSecs) - { - RwLock::ExclusiveLockScope _(m_TimestampLock); - m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); - m_LastLogSecs = Secs; - - // cache the date/time part for the next second. - m_CachedDatetime.clear(); - - helpers::AppendInt(m_CachedTm.tm_year + 1900, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - helpers::Pad2(m_CachedTm.tm_mon + 1, m_CachedDatetime); - m_CachedDatetime.push_back('-'); - - helpers::Pad2(m_CachedTm.tm_mday, m_CachedDatetime); - m_CachedDatetime.push_back(' '); - - helpers::Pad2(m_CachedTm.tm_hour, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - helpers::Pad2(m_CachedTm.tm_min, m_CachedDatetime); - m_CachedDatetime.push_back(':'); - - helpers::Pad2(m_CachedTm.tm_sec, m_CachedDatetime); + explicit JsonFormatter(std::string_view LogId); + ~JsonFormatter() override; - m_CachedDatetime.push_back('.'); - } - helpers::AppendStringView("{"sv, Dest); - helpers::AppendStringView("\"time\": \""sv, Dest); - { - RwLock::SharedLockScope _(m_TimestampLock); - Dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); - } - auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime()); - helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); - helpers::AppendStringView("\", "sv, Dest); - - helpers::AppendStringView("\"status\": \""sv, Dest); - helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); - helpers::AppendStringView("\", "sv, Dest); - - helpers::AppendStringView("\"source\": \""sv, Dest); - helpers::AppendStringView("zenserver"sv, Dest); - helpers::AppendStringView("\", "sv, Dest); - - helpers::AppendStringView("\"service\": \""sv, Dest); - helpers::AppendStringView("zencache"sv, Dest); - helpers::AppendStringView("\", "sv, Dest); - - if (!m_LogId.empty()) - { - helpers::AppendStringView("\"id\": \""sv, Dest); - helpers::AppendStringView(m_LogId, Dest); - helpers::AppendStringView("\", "sv, Dest); - } - - if (Msg.GetLoggerName().size() > 0) - { - helpers::AppendStringView("\"logger.name\": \""sv, Dest); - helpers::AppendStringView(Msg.GetLoggerName(), Dest); - helpers::AppendStringView("\", "sv, Dest); - } - - if (Msg.GetThreadId() != 0) - { - helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest); - helpers::PadUint(Msg.GetThreadId(), 0, Dest); - helpers::AppendStringView("\", "sv, Dest); - } - - if (Msg.GetSource()) - { - helpers::AppendStringView("\"file\": \""sv, Dest); - WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename)); - helpers::AppendStringView("\","sv, Dest); - - helpers::AppendStringView("\"line\": \""sv, Dest); - helpers::AppendInt(Msg.GetSource().Line, Dest); - helpers::AppendStringView("\","sv, Dest); - } - - helpers::AppendStringView("\"message\": \""sv, Dest); - WriteEscapedString(Dest, Msg.GetPayload()); - helpers::AppendStringView("\""sv, Dest); - - helpers::AppendStringView("}\n"sv, Dest); - } + std::unique_ptr<Formatter> Clone() const override; + void Format(const LogMessage& Msg, MemoryBuffer& Dest) override; private: - static inline const std::unordered_map<char, std::string_view> s_SpecialCharacterMap{{'\b', "\\b"sv}, - {'\f', "\\f"sv}, - {'\n', "\\n"sv}, - {'\r', "\\r"sv}, - {'\t', "\\t"sv}, - {'"', "\\\""sv}, - {'\\', "\\\\"sv}}; - - static void WriteEscapedString(MemoryBuffer& Dest, const std::string_view& Text) - { - const char* RangeStart = Text.data(); - const char* End = Text.data() + Text.size(); - for (const char* It = RangeStart; It != End; ++It) - { - if (auto SpecialIt = s_SpecialCharacterMap.find(*It); SpecialIt != s_SpecialCharacterMap.end()) - { - if (RangeStart != It) - { - Dest.append(RangeStart, It); - } - helpers::AppendStringView(SpecialIt->second, Dest); - RangeStart = It + 1; - } - } - if (RangeStart != End) - { - Dest.append(RangeStart, End); - } - }; - - std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; - std::chrono::seconds m_LastLogSecs{0}; - MemoryBuffer m_CachedDatetime; - std::string m_LogId; - RwLock m_TimestampLock; + struct Impl; + std::unique_ptr<Impl> m_Impl; }; } // namespace zen::logging diff --git a/src/zenutil/include/zenutil/logging/rotatingfilesink.h b/src/zenutil/include/zenutil/logging/rotatingfilesink.h index cebc5b110..e0ff7eca1 100644 --- a/src/zenutil/include/zenutil/logging/rotatingfilesink.h +++ b/src/zenutil/include/zenutil/logging/rotatingfilesink.h @@ -2,14 +2,11 @@ #pragma once -#include <zencore/basicfile.h> -#include <zencore/logging/formatter.h> -#include <zencore/logging/messageonlyformatter.h> #include <zencore/logging/sink.h> -#include <zencore/memory/llm.h> -#include <atomic> +#include <cstddef> #include <filesystem> +#include <memory> namespace zen::logging { @@ -19,230 +16,21 @@ namespace zen::logging { class RotatingFileSink : public Sink { public: - RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false) - : m_BaseFilename(BaseFilename) - , m_MaxSize(MaxSize) - , m_MaxFiles(MaxFiles) - , m_Formatter(std::make_unique<MessageOnlyFormatter>()) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - std::error_code Ec; - if (RotateOnOpen) - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - Rotate(RotateLock, Ec); - } - else - { - m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, Ec); - if (!Ec) - { - m_CurrentSize = m_CurrentFile.FileSize(Ec); - } - if (!Ec) - { - if (m_CurrentSize > m_MaxSize) - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - Rotate(RotateLock, Ec); - } - } - } - - if (Ec) - { - throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", m_BaseFilename.string())); - } - } - - virtual ~RotatingFileSink() - { - try - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - m_CurrentFile.Close(); - } - catch (const std::exception&) - { - } - } + RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen = false); + ~RotatingFileSink() override; RotatingFileSink(const RotatingFileSink&) = delete; RotatingFileSink(RotatingFileSink&&) = delete; - RotatingFileSink& operator=(const RotatingFileSink&) = delete; RotatingFileSink& operator=(RotatingFileSink&&) = delete; - virtual void Log(const LogMessage& Msg) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - MemoryBuffer Formatted; - if (TrySinkIt(Msg, Formatted)) - { - return; - } - - // This intentionally has no limit on the number of retries, see - // comment above. - for (;;) - { - { - RwLock::ExclusiveLockScope RotateLock(m_Lock); - // Only rotate if no-one else has rotated before us - if (m_CurrentSize > m_MaxSize || !m_CurrentFile.IsOpen()) - { - std::error_code Ec; - Rotate(RotateLock, Ec); - if (Ec) - { - return; - } - } - } - if (TrySinkIt(Formatted)) - { - return; - } - } - } - catch (const std::exception&) - { - // Silently eat errors - } - } - virtual void Flush() override - { - if (!m_NeedFlush) - { - return; - } - - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - RwLock::SharedLockScope Lock(m_Lock); - if (m_CurrentFile.IsOpen()) - { - m_CurrentFile.Flush(); - } - } - catch (const std::exception&) - { - // Silently eat errors - } - - m_NeedFlush = false; - } - - virtual void SetFormatter(std::unique_ptr<Formatter> InFormatter) override - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - try - { - RwLock::ExclusiveLockScope _(m_Lock); - m_Formatter = std::move(InFormatter); - } - catch (const std::exception&) - { - // Silently eat errors - } - } + void Log(const LogMessage& Msg) override; + void Flush() override; + void SetFormatter(std::unique_ptr<Formatter> InFormatter) override; private: - void Rotate(RwLock::ExclusiveLockScope&, std::error_code& OutEc) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - m_CurrentFile.Close(); - - OutEc = RotateFiles(m_BaseFilename, m_MaxFiles); - if (OutEc) - { - return; - } - - m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, OutEc); - if (OutEc) - { - return; - } - - m_CurrentSize = m_CurrentFile.FileSize(OutEc); - if (OutEc) - { - // FileSize failed but we have an open file — reset to 0 - // so we can at least attempt writes from the start - m_CurrentSize = 0; - OutEc.clear(); - } - } - - bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - RwLock::SharedLockScope Lock(m_Lock); - if (!m_CurrentFile.IsOpen()) - { - return false; - } - m_Formatter->Format(Msg, OutFormatted); - size_t AddSize = OutFormatted.size(); - size_t WritePos = m_CurrentSize.fetch_add(AddSize); - if (WritePos + AddSize > m_MaxSize) - { - return false; - } - std::error_code Ec; - m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec); - if (Ec) - { - return false; - } - m_NeedFlush = true; - return true; - } - - bool TrySinkIt(const MemoryBuffer& Formatted) - { - ZEN_MEMSCOPE(ELLMTag::Logging); - - RwLock::SharedLockScope Lock(m_Lock); - if (!m_CurrentFile.IsOpen()) - { - return false; - } - size_t AddSize = Formatted.size(); - size_t WritePos = m_CurrentSize.fetch_add(AddSize); - if (WritePos + AddSize > m_MaxSize) - { - return false; - } - - std::error_code Ec; - m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec); - if (Ec) - { - return false; - } - m_NeedFlush = true; - return true; - } - - RwLock m_Lock; - const std::filesystem::path m_BaseFilename; - const std::size_t m_MaxSize; - const std::size_t m_MaxFiles; - std::unique_ptr<Formatter> m_Formatter; - std::atomic_size_t m_CurrentSize; - BasicFile m_CurrentFile; - std::atomic<bool> m_NeedFlush = false; + struct Impl; + std::unique_ptr<Impl> m_Impl; }; } // namespace zen::logging diff --git a/src/zenutil/logging/fullformatter.cpp b/src/zenutil/logging/fullformatter.cpp new file mode 100644 index 000000000..2a4840241 --- /dev/null +++ b/src/zenutil/logging/fullformatter.cpp @@ -0,0 +1,235 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/logging/fullformatter.h> + +#include <zencore/intmath.h> +#include <zencore/logging/helpers.h> +#include <zencore/logging/memorybuffer.h> +#include <zencore/memory/llm.h> +#include <zencore/thread.h> +#include <zencore/zencore.h> + +#include <atomic> +#include <chrono> +#include <string> + +namespace zen::logging { + +struct FullFormatter::Impl +{ + Impl(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) + : m_Epoch(Epoch) + , m_LogId(LogId) + , m_LinePrefix(128, ' ') + , m_UseFullDate(false) + { + } + + explicit Impl(std::string_view LogId) : m_LogId(LogId), m_LinePrefix(128, ' '), m_UseFullDate(true) {} + + std::chrono::time_point<std::chrono::system_clock> m_Epoch; + std::tm m_CachedLocalTm{}; + std::chrono::seconds m_LastLogSecs{std::chrono::seconds(87654321)}; + std::atomic<std::chrono::seconds> m_CacheTimestamp{std::chrono::seconds(87654321)}; + MemoryBuffer m_CachedDatetime; + std::string m_LogId; + std::string m_LinePrefix; + bool m_UseFullDate = true; + RwLock m_TimestampLock; +}; + +FullFormatter::FullFormatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) +: m_Impl(std::make_unique<Impl>(LogId, Epoch)) +{ +} + +FullFormatter::FullFormatter(std::string_view LogId) : m_Impl(std::make_unique<Impl>(LogId)) +{ +} + +FullFormatter::~FullFormatter() = default; + +std::unique_ptr<Formatter> +FullFormatter::Clone() const +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + std::unique_ptr<FullFormatter> Copy; + if (m_Impl->m_UseFullDate) + { + Copy = std::make_unique<FullFormatter>(m_Impl->m_LogId); + } + else + { + Copy = std::make_unique<FullFormatter>(m_Impl->m_LogId, m_Impl->m_Epoch); + } + Copy->SetColorEnabled(IsColorEnabled()); + return Copy; +} + +void +FullFormatter::Format(const LogMessage& Msg, MemoryBuffer& OutBuffer) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + // Note that the sink is responsible for ensuring there is only ever a + // single caller in here + + using namespace std::literals; + + std::chrono::seconds TimestampSeconds; + + std::chrono::milliseconds Millis; + + if (m_Impl->m_UseFullDate) + { + TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(Msg.GetTime().time_since_epoch()); + if (TimestampSeconds != m_Impl->m_LastLogSecs) + { + RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock); + m_Impl->m_LastLogSecs = TimestampSeconds; + + m_Impl->m_CachedLocalTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + m_Impl->m_CachedDatetime.clear(); + m_Impl->m_CachedDatetime.push_back('['); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_year % 100, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_mon + 1, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_mday, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(' '); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_hour, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_min, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(m_Impl->m_CachedLocalTm.tm_sec, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('.'); + } + + Millis = helpers::TimeFraction<std::chrono::milliseconds>(Msg.GetTime()); + } + else + { + auto ElapsedTime = Msg.GetTime() - m_Impl->m_Epoch; + TimestampSeconds = std::chrono::duration_cast<std::chrono::seconds>(ElapsedTime); + + if (m_Impl->m_CacheTimestamp.load() != TimestampSeconds) + { + RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock); + + m_Impl->m_CacheTimestamp = TimestampSeconds; + + int Count = int(TimestampSeconds.count()); + const int LogSecs = Count % 60; + Count /= 60; + const int LogMins = Count % 60; + Count /= 60; + const int LogHours = Count; + + m_Impl->m_CachedDatetime.clear(); + m_Impl->m_CachedDatetime.push_back('['); + helpers::Pad2(LogHours, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(LogMins, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + helpers::Pad2(LogSecs, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('.'); + } + + Millis = std::chrono::duration_cast<std::chrono::milliseconds>(ElapsedTime - TimestampSeconds); + } + + { + RwLock::SharedLockScope _(m_Impl->m_TimestampLock); + OutBuffer.append(m_Impl->m_CachedDatetime.begin(), m_Impl->m_CachedDatetime.end()); + } + + helpers::Pad3(static_cast<uint32_t>(Millis.count()), OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + + if (!m_Impl->m_LogId.empty()) + { + OutBuffer.push_back('['); + helpers::AppendStringView(m_Impl->m_LogId, OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + } + + // append logger name if exists + if (Msg.GetLoggerName().size() > 0) + { + OutBuffer.push_back('['); + helpers::AppendStringView(Msg.GetLoggerName(), OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + } + + OutBuffer.push_back('['); + if (IsColorEnabled()) + { + helpers::AppendAnsiColor(Msg.GetLevel(), OutBuffer); + } + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), OutBuffer); + if (IsColorEnabled()) + { + helpers::AppendAnsiReset(OutBuffer); + } + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + + // add source location if present + if (Msg.GetSource()) + { + OutBuffer.push_back('['); + const char* Filename = helpers::ShortFilename(Msg.GetSource().Filename); + helpers::AppendStringView(Filename, OutBuffer); + OutBuffer.push_back(':'); + helpers::AppendInt(Msg.GetSource().Line, OutBuffer); + OutBuffer.push_back(']'); + OutBuffer.push_back(' '); + } + + // Handle newlines in single log call by prefixing each additional line to make + // subsequent lines align with the first line in the message + + size_t AnsiBytes = IsColorEnabled() ? (helpers::AnsiColorForLevel(Msg.GetLevel()).size() + helpers::kAnsiReset.size()) : 0; + const size_t LinePrefixCount = Min<size_t>(OutBuffer.size() - AnsiBytes, m_Impl->m_LinePrefix.size()); + + auto MsgPayload = Msg.GetPayload(); + auto ItLineBegin = MsgPayload.begin(); + auto ItMessageEnd = MsgPayload.end(); + bool IsFirstline = true; + + { + auto ItLineEnd = ItLineBegin; + + auto EmitLine = [&] { + if (IsFirstline) + { + IsFirstline = false; + } + else + { + helpers::AppendStringView(std::string_view(m_Impl->m_LinePrefix.data(), LinePrefixCount), OutBuffer); + } + helpers::AppendStringView(std::string_view(&*ItLineBegin, ItLineEnd - ItLineBegin), OutBuffer); + }; + + while (ItLineEnd != ItMessageEnd) + { + if (*ItLineEnd++ == '\n') + { + EmitLine(); + ItLineBegin = ItLineEnd; + } + } + + if (ItLineBegin != ItMessageEnd) + { + EmitLine(); + helpers::AppendStringView("\n"sv, OutBuffer); + } + } +} + +} // namespace zen::logging diff --git a/src/zenutil/logging/jsonformatter.cpp b/src/zenutil/logging/jsonformatter.cpp new file mode 100644 index 000000000..673a03c94 --- /dev/null +++ b/src/zenutil/logging/jsonformatter.cpp @@ -0,0 +1,198 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/logging/jsonformatter.h> + +#include <zencore/logging/helpers.h> +#include <zencore/memory/llm.h> +#include <zencore/thread.h> +#include <zencore/zencore.h> + +#include <chrono> +#include <string> +#include <unordered_map> + +namespace zen::logging { + +using namespace std::literals; + +static void +WriteEscapedString(MemoryBuffer& Dest, std::string_view Text) +{ + // Strip ANSI SGR sequences before escaping so they don't appear in JSON output + static const auto IsEscapeStart = [](char C) { return C == '\033'; }; + + const char* RangeStart = Text.data(); + const char* End = Text.data() + Text.size(); + + static const std::unordered_map<char, std::string_view> SpecialCharacterMap{ + {'\b', "\\b"sv}, + {'\f', "\\f"sv}, + {'\n', "\\n"sv}, + {'\r', "\\r"sv}, + {'\t', "\\t"sv}, + {'"', "\\\""sv}, + {'\\', "\\\\"sv}, + }; + + for (const char* It = RangeStart; It != End; ++It) + { + // Skip ANSI SGR escape sequences (\033[...m) + if (*It == '\033' && (It + 1) < End && *(It + 1) == '[') + { + if (RangeStart != It) + { + Dest.append(RangeStart, It); + } + const char* Seq = It + 2; + while (Seq < End && *Seq != 'm') + { + ++Seq; + } + if (Seq < End) + { + ++Seq; // skip 'm' + } + It = Seq - 1; // -1 because the for loop will ++It + RangeStart = Seq; + continue; + } + + if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end()) + { + if (RangeStart != It) + { + Dest.append(RangeStart, It); + } + helpers::AppendStringView(SpecialIt->second, Dest); + RangeStart = It + 1; + } + } + if (RangeStart != End) + { + Dest.append(RangeStart, End); + } +} + +struct JsonFormatter::Impl +{ + explicit Impl(std::string_view LogId) : m_LogId(LogId) {} + + std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::chrono::seconds m_LastLogSecs{0}; + MemoryBuffer m_CachedDatetime; + std::string m_LogId; + RwLock m_TimestampLock; +}; + +JsonFormatter::JsonFormatter(std::string_view LogId) : m_Impl(std::make_unique<Impl>(LogId)) +{ +} + +JsonFormatter::~JsonFormatter() = default; + +std::unique_ptr<Formatter> +JsonFormatter::Clone() const +{ + return std::make_unique<JsonFormatter>(m_Impl->m_LogId); +} + +void +JsonFormatter::Format(const LogMessage& Msg, MemoryBuffer& Dest) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + auto Secs = duration_cast<seconds>(Msg.GetTime().time_since_epoch()); + if (Secs != m_Impl->m_LastLogSecs) + { + RwLock::ExclusiveLockScope _(m_Impl->m_TimestampLock); + m_Impl->m_CachedTm = helpers::SafeLocaltime(LogClock::to_time_t(Msg.GetTime())); + m_Impl->m_LastLogSecs = Secs; + + // cache the date/time part for the next second. + m_Impl->m_CachedDatetime.clear(); + + helpers::AppendInt(m_Impl->m_CachedTm.tm_year + 1900, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_mon + 1, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back('-'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_mday, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(' '); + + helpers::Pad2(m_Impl->m_CachedTm.tm_hour, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_min, m_Impl->m_CachedDatetime); + m_Impl->m_CachedDatetime.push_back(':'); + + helpers::Pad2(m_Impl->m_CachedTm.tm_sec, m_Impl->m_CachedDatetime); + + m_Impl->m_CachedDatetime.push_back('.'); + } + helpers::AppendStringView("{"sv, Dest); + helpers::AppendStringView("\"time\": \""sv, Dest); + { + RwLock::SharedLockScope _(m_Impl->m_TimestampLock); + Dest.append(m_Impl->m_CachedDatetime.begin(), m_Impl->m_CachedDatetime.end()); + } + auto Millis = helpers::TimeFraction<milliseconds>(Msg.GetTime()); + helpers::Pad3(static_cast<uint32_t>(Millis.count()), Dest); + helpers::AppendStringView("\", "sv, Dest); + + helpers::AppendStringView("\"status\": \""sv, Dest); + helpers::AppendStringView(helpers::LevelToShortString(Msg.GetLevel()), Dest); + helpers::AppendStringView("\", "sv, Dest); + + helpers::AppendStringView("\"source\": \""sv, Dest); + helpers::AppendStringView("zenserver"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); + + helpers::AppendStringView("\"service\": \""sv, Dest); + helpers::AppendStringView("zencache"sv, Dest); + helpers::AppendStringView("\", "sv, Dest); + + if (!m_Impl->m_LogId.empty()) + { + helpers::AppendStringView("\"id\": \""sv, Dest); + helpers::AppendStringView(m_Impl->m_LogId, Dest); + helpers::AppendStringView("\", "sv, Dest); + } + + if (Msg.GetLoggerName().size() > 0) + { + helpers::AppendStringView("\"logger.name\": \""sv, Dest); + helpers::AppendStringView(Msg.GetLoggerName(), Dest); + helpers::AppendStringView("\", "sv, Dest); + } + + if (Msg.GetThreadId() != 0) + { + helpers::AppendStringView("\"logger.thread_name\": \""sv, Dest); + helpers::PadUint(Msg.GetThreadId(), 0, Dest); + helpers::AppendStringView("\", "sv, Dest); + } + + if (Msg.GetSource()) + { + helpers::AppendStringView("\"file\": \""sv, Dest); + WriteEscapedString(Dest, helpers::ShortFilename(Msg.GetSource().Filename)); + helpers::AppendStringView("\","sv, Dest); + + helpers::AppendStringView("\"line\": \""sv, Dest); + helpers::AppendInt(Msg.GetSource().Line, Dest); + helpers::AppendStringView("\","sv, Dest); + } + + helpers::AppendStringView("\"message\": \""sv, Dest); + WriteEscapedString(Dest, Msg.GetPayload()); + helpers::AppendStringView("\""sv, Dest); + + helpers::AppendStringView("}\n"sv, Dest); +} + +} // namespace zen::logging diff --git a/src/zenutil/logging.cpp b/src/zenutil/logging/logging.cpp index 1258ca155..ea2448a42 100644 --- a/src/zenutil/logging.cpp +++ b/src/zenutil/logging/logging.cpp @@ -238,8 +238,10 @@ FinishInitializeLogging(const LoggingOptions& LogOptions) const std::string StartLogTime = zen::DateTime::Now().ToIso8601(); - static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; - logging::Registry::Instance().ApplyAll([&](auto Logger) { Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); }); + logging::Registry::Instance().ApplyAll([&](auto Logger) { + static constinit logging::LogPoint LogStartPoint{{}, logging::Info, "log starting at {}"}; + Logger->Log(LogStartPoint, fmt::make_format_args(StartLogTime)); + }); } g_IsLoggingInitialized = true; diff --git a/src/zenutil/logging/rotatingfilesink.cpp b/src/zenutil/logging/rotatingfilesink.cpp new file mode 100644 index 000000000..23cf60d16 --- /dev/null +++ b/src/zenutil/logging/rotatingfilesink.cpp @@ -0,0 +1,249 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/logging/rotatingfilesink.h> + +#include <zencore/basicfile.h> +#include <zencore/filesystem.h> +#include <zencore/logging/helpers.h> +#include <zencore/logging/messageonlyformatter.h> +#include <zencore/memory/llm.h> +#include <zencore/thread.h> + +#include <atomic> + +namespace zen::logging { + +struct RotatingFileSink::Impl +{ + Impl(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen) + : m_BaseFilename(BaseFilename) + , m_MaxSize(MaxSize) + , m_MaxFiles(MaxFiles) + , m_Formatter(std::make_unique<MessageOnlyFormatter>()) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + std::error_code Ec; + if (RotateOnOpen) + { + RwLock::ExclusiveLockScope RotateLock(m_Lock); + Rotate(RotateLock, Ec); + } + else + { + m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, Ec); + if (!Ec) + { + m_CurrentSize = m_CurrentFile.FileSize(Ec); + } + if (!Ec) + { + if (m_CurrentSize > m_MaxSize) + { + RwLock::ExclusiveLockScope RotateLock(m_Lock); + Rotate(RotateLock, Ec); + } + } + } + + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", m_BaseFilename.string())); + } + } + + ~Impl() + { + try + { + RwLock::ExclusiveLockScope RotateLock(m_Lock); + m_CurrentFile.Close(); + } + catch (const std::exception&) + { + } + } + + void Rotate(RwLock::ExclusiveLockScope&, std::error_code& OutEc) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + m_CurrentFile.Close(); + + OutEc = RotateFiles(m_BaseFilename, m_MaxFiles); + if (OutEc) + { + return; + } + + m_CurrentFile.Open(m_BaseFilename, BasicFile::Mode::kWrite, OutEc); + if (OutEc) + { + return; + } + + m_CurrentSize = m_CurrentFile.FileSize(OutEc); + if (OutEc) + { + // FileSize failed but we have an open file — reset to 0 + // so we can at least attempt writes from the start + m_CurrentSize = 0; + OutEc.clear(); + } + } + + bool TrySinkIt(const LogMessage& Msg, MemoryBuffer& OutFormatted) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + RwLock::SharedLockScope Lock(m_Lock); + if (!m_CurrentFile.IsOpen()) + { + return false; + } + m_Formatter->Format(Msg, OutFormatted); + helpers::StripAnsiSgrSequences(OutFormatted); + size_t AddSize = OutFormatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) + { + return false; + } + std::error_code Ec; + m_CurrentFile.Write(OutFormatted.data(), OutFormatted.size(), WritePos, Ec); + if (Ec) + { + return false; + } + m_NeedFlush = true; + return true; + } + + bool TrySinkIt(const MemoryBuffer& Formatted) + { + ZEN_MEMSCOPE(ELLMTag::Logging); + + RwLock::SharedLockScope Lock(m_Lock); + if (!m_CurrentFile.IsOpen()) + { + return false; + } + size_t AddSize = Formatted.size(); + size_t WritePos = m_CurrentSize.fetch_add(AddSize); + if (WritePos + AddSize > m_MaxSize) + { + return false; + } + + std::error_code Ec; + m_CurrentFile.Write(Formatted.data(), Formatted.size(), WritePos, Ec); + if (Ec) + { + return false; + } + m_NeedFlush = true; + return true; + } + + RwLock m_Lock; + const std::filesystem::path m_BaseFilename; + const std::size_t m_MaxSize; + const std::size_t m_MaxFiles; + std::unique_ptr<Formatter> m_Formatter; + std::atomic_size_t m_CurrentSize; + BasicFile m_CurrentFile; + std::atomic<bool> m_NeedFlush = false; +}; + +RotatingFileSink::RotatingFileSink(const std::filesystem::path& BaseFilename, std::size_t MaxSize, std::size_t MaxFiles, bool RotateOnOpen) +: m_Impl(std::make_unique<Impl>(BaseFilename, MaxSize, MaxFiles, RotateOnOpen)) +{ +} + +RotatingFileSink::~RotatingFileSink() = default; + +void +RotatingFileSink::Log(const LogMessage& Msg) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + try + { + MemoryBuffer Formatted; + if (m_Impl->TrySinkIt(Msg, Formatted)) + { + return; + } + + // This intentionally has no limit on the number of retries, see + // comment above. + for (;;) + { + { + RwLock::ExclusiveLockScope RotateLock(m_Impl->m_Lock); + // Only rotate if no-one else has rotated before us + if (m_Impl->m_CurrentSize > m_Impl->m_MaxSize || !m_Impl->m_CurrentFile.IsOpen()) + { + std::error_code Ec; + m_Impl->Rotate(RotateLock, Ec); + if (Ec) + { + return; + } + } + } + if (m_Impl->TrySinkIt(Formatted)) + { + return; + } + } + } + catch (const std::exception&) + { + // Silently eat errors + } +} + +void +RotatingFileSink::Flush() +{ + if (!m_Impl->m_NeedFlush) + { + return; + } + + ZEN_MEMSCOPE(ELLMTag::Logging); + + try + { + RwLock::SharedLockScope Lock(m_Impl->m_Lock); + if (m_Impl->m_CurrentFile.IsOpen()) + { + m_Impl->m_CurrentFile.Flush(); + } + } + catch (const std::exception&) + { + // Silently eat errors + } + + m_Impl->m_NeedFlush = false; +} + +void +RotatingFileSink::SetFormatter(std::unique_ptr<Formatter> InFormatter) +{ + ZEN_MEMSCOPE(ELLMTag::Logging); + + try + { + RwLock::ExclusiveLockScope _(m_Impl->m_Lock); + m_Impl->m_Formatter = std::move(InFormatter); + } + catch (const std::exception&) + { + // Silently eat errors + } +} + +} // namespace zen::logging diff --git a/src/zenutil/rpcrecording.cpp b/src/zenutil/rpcrecording.cpp index 28a0091cb..a9e95b9ce 100644 --- a/src/zenutil/rpcrecording.cpp +++ b/src/zenutil/rpcrecording.cpp @@ -17,7 +17,9 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <gsl/gsl-lite.hpp> ZEN_THIRD_PARTY_INCLUDES_END +#include <condition_variable> #include <deque> +#include <mutex> #include <thread> namespace zen::cache { @@ -282,7 +284,6 @@ const uint64_t LooseFileThreshold = 5000; // Somewhat arbitrary, but we try // for performance const uint64_t SegmentByteThreshold = 16ull * 1024 * 1024 * 1024; const TimeSpan SegmentTimeThreshold{/* hours */ 1, /* minutes */ 0, /* seconds */ 0}; -const int64_t MaximumBacklogCount = 2000; std::string MakeSegmentPath(uint64_t SegmentIndex) @@ -366,10 +367,10 @@ private: }; std::unique_ptr<std::thread> m_WriterThread; - std::atomic_bool m_IsWriterReady{false}; std::atomic_bool m_IsActive{false}; - std::atomic_int64_t m_PendingRequests{0}; - RwLock m_RequestQueueLock; + std::mutex m_QueueMutex; + std::condition_variable m_QueueCondition; + bool m_IsWriterReady = false; std::deque<QueuedRequest> m_RequestQueue; void WriterThreadMain(); @@ -660,7 +661,8 @@ RecordedRequestsWriter::BeginWrite(const std::filesystem::path& BasePath) m_WriterThread.reset(new std::thread(&RecordedRequestsWriter::WriterThreadMain, this)); - m_IsWriterReady.wait(false); + std::unique_lock<std::mutex> Lock(m_QueueMutex); + m_QueueCondition.wait(Lock, [this] { return m_IsWriterReady; }); } void @@ -668,15 +670,19 @@ RecordedRequestsWriter::EndWrite() { if (m_WriterThread) { - m_IsActive = false; - const int64_t PendingCount = m_PendingRequests.fetch_add(1); - m_PendingRequests.notify_all(); - - if (PendingCount) { - ZEN_INFO("waiting for RPC recorder writing thread to drain {} pending items", PendingCount); + std::lock_guard<std::mutex> Lock(m_QueueMutex); + m_IsActive = false; + const size_t PendingCount = m_RequestQueue.size(); + + if (PendingCount) + { + ZEN_INFO("waiting for RPC recorder writing thread to drain {} pending items", PendingCount); + } } + m_QueueCondition.notify_all(); + if (m_WriterThread->joinable()) { m_WriterThread->join(); @@ -695,12 +701,11 @@ RecordedRequestsWriter::WriteRequest(const RecordedRequestInfo& RequestInfo, con OwnedRequest.MakeOwned(); { - RwLock::ExclusiveLockScope _(m_RequestQueueLock); + std::lock_guard<std::mutex> Lock(m_QueueMutex); m_RequestQueue.push_back(QueuedRequest{RequestInfo, std::move(OwnedRequest)}); - m_PendingRequests.fetch_add(1); } - m_PendingRequests.notify_all(); + m_QueueCondition.notify_one(); } } @@ -710,55 +715,36 @@ RecordedRequestsWriter::WriterThreadMain() SetCurrentThreadName("rpc_writer"); EnsureCurrentSegment(); - m_IsWriterReady.store(true); - m_IsWriterReady.notify_all(); + { + std::lock_guard<std::mutex> Lock(m_QueueMutex); + m_IsWriterReady = true; + } + m_QueueCondition.notify_all(); - while (m_IsActive) + while (true) { - m_PendingRequests.wait(0); + QueuedRequest Request; - while (m_PendingRequests) { - RwLock::ExclusiveLockScope _(m_RequestQueueLock); - if (!m_RequestQueue.empty()) - { - bool DrainBacklog = false; - - do - { - QueuedRequest Request = m_RequestQueue.front(); - - m_RequestQueue.pop_front(); - m_PendingRequests.fetch_sub(1); - - // For a sufficiently large backlog, keep blocking queueing operations - // until we get below the threshold - DrainBacklog = m_RequestQueue.size() >= MaximumBacklogCount; - - if (!DrainBacklog) - { - _.ReleaseNow(); - } - - try - { - RecordedRequestsSegmentWriter& Writer = EnsureCurrentSegment(); - Writer.WriteRequest(Request.RequestInfo, Request.RequestBuffer); - } - catch (const std::exception&) - { - // TODO: what's the right behaviour here? The most likely cause would - // be some I/O error and we probably ought to just shut down recording - // at that point - } - } while (DrainBacklog); - } - else + std::unique_lock<std::mutex> Lock(m_QueueMutex); + m_QueueCondition.wait(Lock, [this] { return !m_RequestQueue.empty() || !m_IsActive; }); + + if (m_RequestQueue.empty()) { - // shutdown increments this counter so we need to decrement it - // here even though we didn't process any request - m_PendingRequests.fetch_sub(1); + break; } + + Request = std::move(m_RequestQueue.front()); + m_RequestQueue.pop_front(); + } + + try + { + RecordedRequestsSegmentWriter& Writer = EnsureCurrentSegment(); + Writer.WriteRequest(Request.RequestInfo, Request.RequestBuffer); + } + catch (const std::exception&) + { } } diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua index 1d5be5977..1e19f7b2f 100644 --- a/src/zenutil/xmake.lua +++ b/src/zenutil/xmake.lua @@ -9,6 +9,7 @@ target('zenutil') add_deps("zencore", "zenhttp") add_deps("cxxopts") add_deps("robin-map") + add_packages("json11") if is_plat("linux") then add_includedirs("$(projectdir)/thirdparty/systemd/include") diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp index ac614f779..e0d99c981 100644 --- a/src/zenutil/zenserverprocess.cpp +++ b/src/zenutil/zenserverprocess.cpp @@ -368,14 +368,16 @@ ZenServerState::Sweep() { if (ErrorCode) { - ZEN_WARN("Sweep - can not determine running state for pid {}, skipping entry (port {}). Reason: '{}'", - Entry.Pid.load(), - Entry.DesiredListenPort.load(), - ErrorCode.message()); + ZEN_CONSOLE_WARN("Sweep - can not determine running state for pid {}, skipping entry (port {}). Reason: '{}'", + Entry.Pid.load(), + Entry.DesiredListenPort.load(), + ErrorCode.message()); } else { - ZEN_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", Entry.Pid.load(), Entry.DesiredListenPort.load()); + ZEN_CONSOLE_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", + Entry.Pid.load(), + Entry.DesiredListenPort.load()); Entry.Reset(); } } @@ -402,10 +404,10 @@ ZenServerState::Snapshot(std::function<void(const ZenServerEntry&)>&& Callback) { if (ErrorCode) { - ZEN_WARN("Snapshot - can not determine running state for pid {}, skipping entry (port {}). Reason: '{}'", - Entry.Pid.load(), - Entry.DesiredListenPort.load(), - ErrorCode.message()); + ZEN_CONSOLE_WARN("Snapshot - can not determine running state for pid {}, skipping entry (port {}). Reason: '{}'", + Entry.Pid.load(), + Entry.DesiredListenPort.load(), + ErrorCode.message()); } else { diff --git a/src/zenutil/zenutil.cpp b/src/zenutil/zenutil.cpp index 291dbeadd..734813b69 100644 --- a/src/zenutil/zenutil.cpp +++ b/src/zenutil/zenutil.cpp @@ -4,6 +4,9 @@ #if ZEN_WITH_TESTS +# include <zenutil/cloud/imdscredentials.h> +# include <zenutil/cloud/s3client.h> +# include <zenutil/cloud/sigv4.h> # include <zenutil/rpcrecording.h> # include <zenutil/config/commandlineoptions.h> # include <zenutil/wildcard.h> @@ -15,6 +18,9 @@ zenutil_forcelinktests() { cache::rpcrecord_forcelink(); commandlineoptions_forcelink(); + imdscredentials_forcelink(); + s3client_forcelink(); + sigv4_forcelink(); wildcard_forcelink(); } |