// Copyright Epic Games, Inc. All Rights Reserved. #include "rpcreplay_cmd.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #include #include ZEN_THIRD_PARTY_INCLUDES_END #include namespace zen { using namespace std::literals; RpcStartRecordingCommand::RpcStartRecordingCommand() { m_Options.add_options()("h,help", "Print help"); m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), ""); m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), ""); m_Options.parse_positional("path"); } RpcStartRecordingCommand::~RpcStartRecordingCommand() = default; void RpcStartRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { ZEN_UNUSED(GlobalOptions, argc, argv); if (!ParseOptions(argc, argv)) { return; } m_HostName = ResolveTargetHostSpec(m_HostName); if (m_HostName.empty()) { throw OptionParseException("Unable to resolve server specification", m_Options.help()); } if (m_RecordingPath.empty()) { throw OptionParseException("'--path' is required", m_Options.help()); } HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/z$/exec$/start-recording"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap({{"path", m_RecordingPath}}))) { ZEN_CONSOLE("{}", Response.ToText()); } else { Response.ThrowError("Failed to start recording"); } } //////////////////////////////////////////////////// RpcStopRecordingCommand::RpcStopRecordingCommand() { m_Options.add_options()("h,help", "Print help"); m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), ""); } RpcStopRecordingCommand::~RpcStopRecordingCommand() = default; void RpcStopRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { ZEN_UNUSED(GlobalOptions, argc, argv); if (!ParseOptions(argc, argv)) { return; } m_HostName = ResolveTargetHostSpec(m_HostName); if (m_HostName.empty()) { throw OptionParseException("Unable to resolve server specification", m_Options.help()); } HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/z$/exec$/stop-recording"sv)) { ZEN_CONSOLE("{}", Response.ToText()); } else { Response.ThrowError("Failed to stop recording"); } } //////////////////////////////////////////////////// RpcReplayCommand::RpcReplayCommand() { m_Options.add_options()("h,help", "Print help"); m_Options.add_option("", "u", "hosturl", kHostUrlHelp, cxxopts::value(m_HostName)->default_value(""), ""); m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), ""); m_Options.add_option("", "", "dry", "Do a dry run", cxxopts::value(m_DryRun), ""); m_Options.add_option("", "w", "numthreads", "Number of worker threads per process", cxxopts::value(m_ThreadCount)->default_value(fmt::format("{}", GetHardwareConcurrency())), ""); m_Options.add_option("", "", "onhost", "Replay on host, bypassing http/network layer", cxxopts::value(m_OnHost), ""); m_Options.add_option("", "", "showmethodstats", "Show statistics of which RPC methods are used", cxxopts::value(m_ShowMethodStats), ""); m_Options.add_option("", "", "offset", "Offset into request recording to start replay", cxxopts::value(m_Offset)->default_value("0"), ""); m_Options.add_option("", "", "stride", "Stride for request recording when replaying requests", cxxopts::value(m_Stride)->default_value("1"), ""); m_Options.add_option("", "", "numproc", "Number of worker processes", cxxopts::value(m_ProcessCount)->default_value("1"), ""); m_Options.add_option("", "", "forceallowlocalrefs", "Force enable local refs in requests", cxxopts::value(m_ForceAllowLocalRefs), ""); m_Options .add_option("", "", "disablelocalrefs", "Force disable local refs in requests", cxxopts::value(m_DisableLocalRefs), ""); m_Options.add_option("", "", "forceallowlocalhandlerefs", "Force enable local refs as handles in requests", cxxopts::value(m_ForceAllowLocalHandleRef), ""); m_Options.add_option("", "", "disablelocalhandlerefs", "Force disable local refs as handles in requests", cxxopts::value(m_DisableLocalHandleRefs), ""); m_Options.add_option("", "", "forceallowpartiallocalrefs", "Force enable local refs for all sizes", cxxopts::value(m_ForceAllowPartialLocalRefs), ""); m_Options.add_option("", "", "disablepartiallocalrefs", "Force disable local refs for all sizes", cxxopts::value(m_DisablePartialLocalRefs), ""); m_Options.parse_positional("path"); } RpcReplayCommand::~RpcReplayCommand() = default; void RpcReplayCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) { ZEN_UNUSED(GlobalOptions, argc, argv); if (!ParseOptions(argc, argv)) { return; } m_HostName = ResolveTargetHostSpec(m_HostName); if (m_HostName.empty()) { throw OptionParseException("Unable to resolve server specification", m_Options.help()); } if (m_RecordingPath.empty()) { throw OptionParseException("'--path' is required", m_Options.help()); } if (!IsDir(m_RecordingPath)) { throw std::runtime_error(fmt::format("could not find recording at '{}'", m_RecordingPath)); } m_ThreadCount = Max(m_ThreadCount, 1); ZEN_CONSOLE("Replay '{}' (start offset {}, stride {}) to '{}', {} threads", m_RecordingPath, m_Offset, m_Stride, m_HostName, m_ThreadCount); Stopwatch TotalTimer; if (m_OnHost) { HttpClient Http = CreateHttpClient(m_HostName); if (HttpClient::Response Response = Http.Post("/z$/exec$/replay-recording"sv, HttpClient::KeyValueMap{}, HttpClient::KeyValueMap({{"path", m_RecordingPath}, {"thread-count", fmt::format("{}", m_ThreadCount)}}))) { ZEN_CONSOLE("{}", Response.ToText()); return; } else { Response.ThrowError("Failed to start replay"); } } std::unique_ptr Replayer = cache::MakeDiskRequestReplayer(m_RecordingPath, true); uint64_t EntryCount = Replayer->GetRequestCount(); std::atomic_uint64_t EntryOffset = m_Offset; std::atomic_uint64_t BytesSent = 0; std::atomic_uint64_t BytesReceived = 0; Stopwatch Timer; if (m_ProcessCount > 1) { std::vector> WorkerProcesses; WorkerProcesses.resize(m_ProcessCount); ProcessMonitor Monitor; for (int ProcessIndex = 0; ProcessIndex < m_ProcessCount; ++ProcessIndex) { std::string CommandLine = fmt::format("{} rpc-record-replay --hosturl {} --path \"{}\" --offset {} --stride {} --numthreads {} --numproc {}"sv, argv[0], m_HostName, m_RecordingPath, m_Stride == 1 ? 0 : m_Offset + ProcessIndex, m_Stride, m_ThreadCount, 1); CreateProcResult Result(CreateProc(std::filesystem::path(std::string(argv[0])), CommandLine)); WorkerProcesses[ProcessIndex] = std::make_unique(); WorkerProcesses[ProcessIndex]->Initialize(Result); Monitor.AddPid(WorkerProcesses[ProcessIndex]->Pid()); } while (Monitor.IsRunning()) { ZEN_CONSOLE("Waiting for worker processes..."); Sleep(1000); } return; } else { std::map MethodTypes; RwLock MethodTypesLock; WorkerThreadPool WorkerPool(m_ThreadCount); Latch WorkLatch(m_ThreadCount); for (int WorkerIndex = 0; WorkerIndex < m_ThreadCount; ++WorkerIndex) { WorkerPool.ScheduleWork( [this, &WorkLatch, EntryCount, &EntryOffset, &Replayer, &BytesSent, &BytesReceived, &MethodTypes, &MethodTypesLock]() { auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); std::map LocalMethodTypes; auto ReduceTypes = MakeGuard([&] { RwLock::ExclusiveLockScope __(MethodTypesLock); for (auto& Entry : LocalMethodTypes) { MethodTypes[Entry.first] += Entry.second; } }); HttpClient Http = CreateHttpClient(m_HostName); uint64_t EntryIndex = EntryOffset.fetch_add(m_Stride); while (EntryIndex < EntryCount) { IoBuffer Payload; const zen::cache::RecordedRequestInfo RequestInfo = Replayer->GetRequest(EntryIndex, /* out */ Payload); if (RequestInfo != zen::cache::RecordedRequestInfo::NullRequest) { CbPackage RequestPackage; CbObject Request; switch (RequestInfo.ContentType) { case ZenContentType::kCbPackage: { if (ParsePackageMessageWithLegacyFallback(Payload, RequestPackage)) { Request = RequestPackage.GetObject(); } } break; case ZenContentType::kCbObject: { Request = LoadCompactBinaryObject(Payload); } break; } RpcAcceptOptions OriginalAcceptOptions = static_cast(Request["AcceptFlags"sv].AsUInt16(0u)); int OriginalProcessPid = Request["Pid"sv].AsInt32(0); int AdjustedPid = 0; RpcAcceptOptions AdjustedAcceptOptions = RpcAcceptOptions::kNone; if (!m_DisableLocalRefs) { if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowLocalReferences) || m_ForceAllowLocalRefs) { AdjustedAcceptOptions |= RpcAcceptOptions::kAllowLocalReferences; if (!m_DisablePartialLocalRefs) { if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowPartialLocalReferences) || m_ForceAllowPartialLocalRefs) { AdjustedAcceptOptions |= RpcAcceptOptions::kAllowPartialLocalReferences; } } if (!m_DisableLocalHandleRefs) { if (OriginalProcessPid != 0 || m_ForceAllowLocalHandleRef) { AdjustedPid = GetCurrentProcessId(); } } } } if (m_ShowMethodStats) { std::string MethodName = std::string(Request["Method"sv].AsString()); if (auto It = LocalMethodTypes.find(MethodName); It != LocalMethodTypes.end()) { It->second++; } else { LocalMethodTypes[MethodName] = 1; } } if (OriginalAcceptOptions != AdjustedAcceptOptions || OriginalProcessPid != AdjustedPid) { CbObjectWriter RequestCopyWriter; for (const CbFieldView& Field : Request) { if (!Field.HasName()) { RequestCopyWriter.AddField(Field); continue; } std::string_view FieldName = Field.GetName(); if (FieldName == "Pid"sv) { continue; } if (FieldName == "AcceptFlags"sv) { continue; } RequestCopyWriter.AddField(FieldName, Field); } if (AdjustedPid != 0) { RequestCopyWriter.AddInteger("Pid"sv, AdjustedPid); } if (AdjustedAcceptOptions != RpcAcceptOptions::kNone) { RequestCopyWriter.AddInteger("AcceptFlags"sv, static_cast(AdjustedAcceptOptions)); } if (RequestInfo.ContentType == ZenContentType::kCbPackage) { RequestPackage.SetObject(RequestCopyWriter.Save()); std::vector Buffers = FormatPackageMessage(RequestPackage); std::vector SharedBuffers(Buffers.begin(), Buffers.end()); Payload = CompositeBuffer(std::move(SharedBuffers)).Flatten().AsIoBuffer(); } else { RequestCopyWriter.Finalize(); Payload = IoBuffer(RequestCopyWriter.GetSaveSize()); RequestCopyWriter.Save(Payload.GetMutableView()); } } if (!m_DryRun) { Http.SetSessionId(RequestInfo.SessionId); Payload.SetContentType(RequestInfo.ContentType); HttpClient::Response Response = Http.Post("/z$/$rpc", Payload, {HttpClient::Accept(RequestInfo.AcceptType)}); BytesSent.fetch_add(Payload.GetSize()); if (!Response) { ZEN_CONSOLE_ERROR("{}", Response); break; } BytesReceived.fetch_add(Response.DownloadedBytes); } } EntryIndex = EntryOffset.fetch_add(m_Stride); } }, WorkerThreadPool::EMode::EnableBacklog); } while (!WorkLatch.Wait(1000)) { const uint64_t RequestsTotal = (EntryCount - m_Offset) / m_Stride; const uint64_t RequestsRemaining = (EntryCount - EntryOffset.load()) / m_Stride; ZEN_CONSOLE("[{:3}%] [{}] {} requests, {} remaining (sent {}, received {})", (RequestsTotal - RequestsRemaining) * 100 / RequestsTotal, NiceTimeSpanMs(Timer.GetElapsedTimeMs()), RequestsTotal, RequestsRemaining, NiceBytes(BytesSent.load()), NiceBytes(BytesReceived.load())); } if (m_ShowMethodStats) { for (const auto& It : MethodTypes) { ZEN_CONSOLE("{:18}: {:10}", It.first, It.second); } } } const uint64_t RequestsSent = (EntryOffset.load() - m_Offset) / m_Stride; const uint64_t ElapsedMS = Timer.GetElapsedTimeMs(); const uint64_t Sent = BytesSent.load(); const uint64_t Received = BytesReceived.load(); ZEN_CONSOLE("Processed requests: {} ({}), payloads sent {} ({}), payloads received {} ({}) in {}.\nTotal runtime: {}", RequestsSent, NiceRate(RequestsSent, ElapsedMS, "req"), NiceBytes(Sent), NiceByteRate(Sent, ElapsedMS), NiceBytes(Received), NiceByteRate(Received, ElapsedMS), NiceTimeSpanMs(ElapsedMS), NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs())); } } // namespace zen