aboutsummaryrefslogtreecommitdiff
path: root/src/zencompute/remotehttprunner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zencompute/remotehttprunner.cpp')
-rw-r--r--src/zencompute/remotehttprunner.cpp457
1 files changed, 457 insertions, 0 deletions
diff --git a/src/zencompute/remotehttprunner.cpp b/src/zencompute/remotehttprunner.cpp
new file mode 100644
index 000000000..98ced5fe8
--- /dev/null
+++ b/src/zencompute/remotehttprunner.cpp
@@ -0,0 +1,457 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "remotehttprunner.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/scopeguard.h>
+# include <zenhttp/httpcommon.h>
+# include <zenstore/cidstore.h>
+
+# include <span>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen::compute {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName)
+: FunctionRunner(BaseDir)
+, m_Log(logging::Get("http_exec"))
+, m_ChunkResolver{InChunkResolver}
+, m_BaseUrl{fmt::format("{}/apply", HostName)}
+, m_Http(m_BaseUrl)
+{
+ m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this};
+}
+
+RemoteHttpRunner::~RemoteHttpRunner()
+{
+ Shutdown();
+}
+
+void
+RemoteHttpRunner::Shutdown()
+{
+ // TODO: should cleanly drain/cancel pending work
+
+ m_MonitorThreadEnabled = false;
+ m_MonitorThreadEvent.Set();
+ if (m_MonitorThread.joinable())
+ {
+ m_MonitorThread.join();
+ }
+}
+
+void
+RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage)
+{
+ const IoHash WorkerId = WorkerPackage.GetObjectHash();
+ CbPackage WorkerDesc = WorkerPackage;
+
+ std::string WorkerUrl = fmt::format("/workers/{}", WorkerId);
+
+ HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl);
+
+ if (WorkerResponse.StatusCode == HttpResponseCode::NotFound)
+ {
+ HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject());
+
+ if (DescResponse.StatusCode == HttpResponseCode::NotFound)
+ {
+ CbPackage Pkg = WorkerDesc;
+
+ // Build response package by sending only the attachments
+ // the other end needs. We start with the full package and
+ // remove the attachments which are not needed.
+
+ {
+ std::unordered_set<IoHash> Needed;
+
+ CbObject Response = DescResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ Needed.insert(NeedHash);
+ }
+
+ std::unordered_set<IoHash> ToRemove;
+
+ for (const CbAttachment& Attachment : Pkg.GetAttachments())
+ {
+ const IoHash& Hash = Attachment.GetHash();
+
+ if (Needed.find(Hash) == Needed.end())
+ {
+ ToRemove.insert(Hash);
+ }
+ }
+
+ for (const IoHash& Hash : ToRemove)
+ {
+ int RemovedCount = Pkg.RemoveAttachment(Hash);
+
+ ZEN_ASSERT(RemovedCount == 1);
+ }
+ }
+
+ // Post resulting package
+
+ HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg);
+
+ if (!IsHttpSuccessCode(PayloadResponse.StatusCode))
+ {
+ ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
+
+ // TODO: propagate error
+ }
+ }
+ else if (!IsHttpSuccessCode(DescResponse.StatusCode))
+ {
+ ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl);
+
+ // TODO: propagate error
+ }
+ else
+ {
+ ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent);
+ }
+ }
+ else if (WorkerResponse.StatusCode == HttpResponseCode::OK)
+ {
+ // Already known from a previous run
+ }
+ else if (!IsHttpSuccessCode(WorkerResponse.StatusCode))
+ {
+ ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})",
+ WorkerId,
+ m_Http.GetBaseUri(),
+ WorkerUrl,
+ (int)WorkerResponse.StatusCode,
+ ToString(WorkerResponse.StatusCode));
+
+ // TODO: propagate error
+ }
+}
+
+size_t
+RemoteHttpRunner::QueryCapacity()
+{
+ // Estimate how much more work we're ready to accept
+
+ RwLock::SharedLockScope _{m_RunningLock};
+
+ size_t RunningCount = m_RemoteRunningMap.size();
+
+ if (RunningCount >= size_t(m_MaxRunningActions))
+ {
+ return 0;
+ }
+
+ return m_MaxRunningActions - RunningCount;
+}
+
+std::vector<SubmitResult>
+RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions)
+{
+ std::vector<SubmitResult> Results;
+
+ for (const Ref<RunnerAction>& Action : Actions)
+ {
+ Results.push_back(SubmitAction(Action));
+ }
+
+ return Results;
+}
+
+SubmitResult
+RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action)
+{
+ // Verify whether we can accept more work
+
+ {
+ RwLock::SharedLockScope _{m_RunningLock};
+ if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions))
+ {
+ return SubmitResult{.IsAccepted = false};
+ }
+ }
+
+ using namespace std::literals;
+
+ // Each enqueued action is assigned an integer index (logical sequence number),
+ // which we use as a key for tracking data structures and as an opaque id which
+ // may be used by clients to reference the scheduled action
+
+ const int32_t ActionLsn = Action->ActionLsn;
+ const CbObject& ActionObj = Action->ActionObj;
+ const IoHash ActionId = ActionObj.GetHash();
+
+ MaybeDumpAction(ActionLsn, ActionObj);
+
+ // Enqueue job
+
+ CbObject Result;
+
+ HttpClient::Response WorkResponse = m_Http.Post("/jobs", ActionObj);
+ HttpResponseCode WorkResponseCode = WorkResponse.StatusCode;
+
+ if (WorkResponseCode == HttpResponseCode::OK)
+ {
+ Result = WorkResponse.AsObject();
+ }
+ else if (WorkResponseCode == HttpResponseCode::NotFound)
+ {
+ // Not all attachments are present
+
+ // Build response package including all required attachments
+
+ CbPackage Pkg;
+ Pkg.SetObject(ActionObj);
+
+ CbObject Response = WorkResponse.AsObject();
+
+ for (auto& Item : Response["need"sv])
+ {
+ const IoHash NeedHash = Item.AsHash();
+
+ if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash))
+ {
+ uint64_t DataRawSize = 0;
+ IoHash DataRawHash;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize);
+
+ ZEN_ASSERT(DataRawHash == NeedHash);
+
+ Pkg.AddAttachment(CbAttachment(Compressed, NeedHash));
+ }
+ else
+ {
+ // No such attachment
+
+ return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)};
+ }
+ }
+
+ // Post resulting package
+
+ HttpClient::Response PayloadResponse = m_Http.Post("/jobs", Pkg);
+
+ if (!PayloadResponse)
+ {
+ ZEN_WARN("unable to register payloads for action {} at {}/jobs", ActionId, m_Http.GetBaseUri());
+
+ // TODO: include more information about the failure in the response
+
+ return {.IsAccepted = false, .Reason = "HTTP request failed"};
+ }
+ else if (PayloadResponse.StatusCode == HttpResponseCode::OK)
+ {
+ Result = PayloadResponse.AsObject();
+ }
+ else
+ {
+ // Unexpected response
+
+ const int ResponseStatusCode = (int)PayloadResponse.StatusCode;
+
+ ZEN_WARN("unable to register payloads for action {} at {}/jobs (error: {} {})",
+ ActionId,
+ m_Http.GetBaseUri(),
+ ResponseStatusCode,
+ ToString(ResponseStatusCode));
+
+ return {.IsAccepted = false,
+ .Reason = fmt::format("unexpected response code {} {} from {}/jobs",
+ ResponseStatusCode,
+ ToString(ResponseStatusCode),
+ m_Http.GetBaseUri())};
+ }
+ }
+
+ if (Result)
+ {
+ if (const int32_t LsnField = Result["lsn"].AsInt32(0))
+ {
+ HttpRunningAction NewAction;
+ NewAction.Action = Action;
+ NewAction.RemoteActionLsn = LsnField;
+
+ {
+ RwLock::ExclusiveLockScope _(m_RunningLock);
+
+ m_RemoteRunningMap[LsnField] = std::move(NewAction);
+ }
+
+ ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn);
+
+ Action->SetActionState(RunnerAction::State::Running);
+
+ return SubmitResult{.IsAccepted = true};
+ }
+ }
+
+ return {};
+}
+
+bool
+RemoteHttpRunner::IsHealthy()
+{
+ if (HttpClient::Response Ready = m_Http.Get("/ready"))
+ {
+ return true;
+ }
+ else
+ {
+ // TODO: use response to propagate context
+ return false;
+ }
+}
+
+size_t
+RemoteHttpRunner::GetSubmittedActionCount()
+{
+ RwLock::SharedLockScope _(m_RunningLock);
+ return m_RemoteRunningMap.size();
+}
+
+void
+RemoteHttpRunner::MonitorThreadFunction()
+{
+ SetCurrentThreadName("RemoteHttpRunner_Monitor");
+
+ do
+ {
+ const int NormalWaitingTime = 1000;
+ int WaitTimeMs = NormalWaitingTime;
+ auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); };
+ auto SweepOnce = [&] {
+ const size_t RetiredCount = SweepRunningActions();
+
+ m_RunningLock.WithSharedLock([&] {
+ if (m_RemoteRunningMap.size() > 16)
+ {
+ WaitTimeMs = NormalWaitingTime / 4;
+ }
+ else
+ {
+ if (RetiredCount)
+ {
+ WaitTimeMs = NormalWaitingTime / 2;
+ }
+ else
+ {
+ WaitTimeMs = NormalWaitingTime;
+ }
+ }
+ });
+ };
+
+ while (!WaitOnce())
+ {
+ SweepOnce();
+ }
+
+ // Signal received - this may mean we should quit
+
+ SweepOnce();
+ } while (m_MonitorThreadEnabled);
+}
+
+size_t
+RemoteHttpRunner::SweepRunningActions()
+{
+ std::vector<HttpRunningAction> CompletedActions;
+
+ // Poll remote for list of completed actions
+
+ HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv);
+
+ if (CbObject Completed = ResponseCompleted.AsObject())
+ {
+ for (auto& FieldIt : Completed["completed"sv])
+ {
+ const int32_t CompleteLsn = FieldIt.AsInt32();
+
+ if (HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn)))
+ {
+ m_RunningLock.WithExclusiveLock([&] {
+ if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end())
+ {
+ HttpRunningAction CompletedAction = std::move(CompleteIt->second);
+ CompletedAction.ActionResults = ResponseJob.AsPackage();
+ CompletedAction.Success = true;
+
+ CompletedActions.push_back(std::move(CompletedAction));
+ m_RemoteRunningMap.erase(CompleteIt);
+ }
+ else
+ {
+ // we received a completion notice for an action we don't know about,
+ // this can happen if the runner is used by multiple upstream schedulers,
+ // or if this compute node was recently restarted and lost track of
+ // previously scheduled actions
+ }
+ });
+ }
+ }
+
+ if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView())
+ {
+ // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0))
+ if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0))
+ {
+ const int32_t NewCap = zen::Max(4, CpuCount);
+
+ if (m_MaxRunningActions > NewCap)
+ {
+ ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions);
+
+ m_MaxRunningActions = NewCap;
+ }
+ }
+ }
+ }
+
+ // Notify outer. Note that this has to be done without holding any local locks
+ // otherwise we may end up with deadlocks.
+
+ for (HttpRunningAction& HttpAction : CompletedActions)
+ {
+ const int ActionLsn = HttpAction.Action->ActionLsn;
+
+ if (HttpAction.Success)
+ {
+ ZEN_DEBUG("completed: {} LSN {} (remote LSN {})", HttpAction.Action->ActionId, ActionLsn, HttpAction.RemoteActionLsn);
+
+ HttpAction.Action->SetActionState(RunnerAction::State::Completed);
+
+ HttpAction.Action->SetResult(std::move(HttpAction.ActionResults));
+ }
+ else
+ {
+ HttpAction.Action->SetActionState(RunnerAction::State::Failed);
+ }
+ }
+
+ return CompletedActions.size();
+}
+
+} // namespace zen::compute
+
+#endif