diff options
Diffstat (limited to 'src/zencompute/remotehttprunner.cpp')
| -rw-r--r-- | src/zencompute/remotehttprunner.cpp | 457 |
1 files changed, 457 insertions, 0 deletions
diff --git a/src/zencompute/remotehttprunner.cpp b/src/zencompute/remotehttprunner.cpp new file mode 100644 index 000000000..98ced5fe8 --- /dev/null +++ b/src/zencompute/remotehttprunner.cpp @@ -0,0 +1,457 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remotehttprunner.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/scopeguard.h> +# include <zenhttp/httpcommon.h> +# include <zenstore/cidstore.h> + +# include <span> + +////////////////////////////////////////////////////////////////////////// + +namespace zen::compute { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +RemoteHttpRunner::RemoteHttpRunner(ChunkResolver& InChunkResolver, const std::filesystem::path& BaseDir, std::string_view HostName) +: FunctionRunner(BaseDir) +, m_Log(logging::Get("http_exec")) +, m_ChunkResolver{InChunkResolver} +, m_BaseUrl{fmt::format("{}/apply", HostName)} +, m_Http(m_BaseUrl) +{ + m_MonitorThread = std::thread{&RemoteHttpRunner::MonitorThreadFunction, this}; +} + +RemoteHttpRunner::~RemoteHttpRunner() +{ + Shutdown(); +} + +void +RemoteHttpRunner::Shutdown() +{ + // TODO: should cleanly drain/cancel pending work + + m_MonitorThreadEnabled = false; + m_MonitorThreadEvent.Set(); + if (m_MonitorThread.joinable()) + { + m_MonitorThread.join(); + } +} + +void +RemoteHttpRunner::RegisterWorker(const CbPackage& WorkerPackage) +{ + const IoHash WorkerId = WorkerPackage.GetObjectHash(); + CbPackage WorkerDesc = WorkerPackage; + + std::string WorkerUrl = fmt::format("/workers/{}", WorkerId); + + HttpClient::Response WorkerResponse = m_Http.Get(WorkerUrl); + + if (WorkerResponse.StatusCode == HttpResponseCode::NotFound) + { + HttpClient::Response DescResponse = m_Http.Post(WorkerUrl, WorkerDesc.GetObject()); + + if (DescResponse.StatusCode == HttpResponseCode::NotFound) + { + CbPackage Pkg = WorkerDesc; + + // Build response package by sending only the attachments + // the other end needs. We start with the full package and + // remove the attachments which are not needed. + + { + std::unordered_set<IoHash> Needed; + + CbObject Response = DescResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + Needed.insert(NeedHash); + } + + std::unordered_set<IoHash> ToRemove; + + for (const CbAttachment& Attachment : Pkg.GetAttachments()) + { + const IoHash& Hash = Attachment.GetHash(); + + if (Needed.find(Hash) == Needed.end()) + { + ToRemove.insert(Hash); + } + } + + for (const IoHash& Hash : ToRemove) + { + int RemovedCount = Pkg.RemoveAttachment(Hash); + + ZEN_ASSERT(RemovedCount == 1); + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post(WorkerUrl, Pkg); + + if (!IsHttpSuccessCode(PayloadResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register payloads for worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + } + else if (!IsHttpSuccessCode(DescResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to register worker {} at {}{}", WorkerId, m_Http.GetBaseUri(), WorkerUrl); + + // TODO: propagate error + } + else + { + ZEN_ASSERT(DescResponse.StatusCode == HttpResponseCode::NoContent); + } + } + else if (WorkerResponse.StatusCode == HttpResponseCode::OK) + { + // Already known from a previous run + } + else if (!IsHttpSuccessCode(WorkerResponse.StatusCode)) + { + ZEN_ERROR("ERROR: unable to look up worker {} at {}{} (error: {} {})", + WorkerId, + m_Http.GetBaseUri(), + WorkerUrl, + (int)WorkerResponse.StatusCode, + ToString(WorkerResponse.StatusCode)); + + // TODO: propagate error + } +} + +size_t +RemoteHttpRunner::QueryCapacity() +{ + // Estimate how much more work we're ready to accept + + RwLock::SharedLockScope _{m_RunningLock}; + + size_t RunningCount = m_RemoteRunningMap.size(); + + if (RunningCount >= size_t(m_MaxRunningActions)) + { + return 0; + } + + return m_MaxRunningActions - RunningCount; +} + +std::vector<SubmitResult> +RemoteHttpRunner::SubmitActions(const std::vector<Ref<RunnerAction>>& Actions) +{ + std::vector<SubmitResult> Results; + + for (const Ref<RunnerAction>& Action : Actions) + { + Results.push_back(SubmitAction(Action)); + } + + return Results; +} + +SubmitResult +RemoteHttpRunner::SubmitAction(Ref<RunnerAction> Action) +{ + // Verify whether we can accept more work + + { + RwLock::SharedLockScope _{m_RunningLock}; + if (m_RemoteRunningMap.size() >= size_t(m_MaxRunningActions)) + { + return SubmitResult{.IsAccepted = false}; + } + } + + using namespace std::literals; + + // Each enqueued action is assigned an integer index (logical sequence number), + // which we use as a key for tracking data structures and as an opaque id which + // may be used by clients to reference the scheduled action + + const int32_t ActionLsn = Action->ActionLsn; + const CbObject& ActionObj = Action->ActionObj; + const IoHash ActionId = ActionObj.GetHash(); + + MaybeDumpAction(ActionLsn, ActionObj); + + // Enqueue job + + CbObject Result; + + HttpClient::Response WorkResponse = m_Http.Post("/jobs", ActionObj); + HttpResponseCode WorkResponseCode = WorkResponse.StatusCode; + + if (WorkResponseCode == HttpResponseCode::OK) + { + Result = WorkResponse.AsObject(); + } + else if (WorkResponseCode == HttpResponseCode::NotFound) + { + // Not all attachments are present + + // Build response package including all required attachments + + CbPackage Pkg; + Pkg.SetObject(ActionObj); + + CbObject Response = WorkResponse.AsObject(); + + for (auto& Item : Response["need"sv]) + { + const IoHash NeedHash = Item.AsHash(); + + if (IoBuffer Chunk = m_ChunkResolver.FindChunkByCid(NeedHash)) + { + uint64_t DataRawSize = 0; + IoHash DataRawHash; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer{Chunk}, /* out */ DataRawHash, /* out */ DataRawSize); + + ZEN_ASSERT(DataRawHash == NeedHash); + + Pkg.AddAttachment(CbAttachment(Compressed, NeedHash)); + } + else + { + // No such attachment + + return {.IsAccepted = false, .Reason = fmt::format("missing attachment {}", NeedHash)}; + } + } + + // Post resulting package + + HttpClient::Response PayloadResponse = m_Http.Post("/jobs", Pkg); + + if (!PayloadResponse) + { + ZEN_WARN("unable to register payloads for action {} at {}/jobs", ActionId, m_Http.GetBaseUri()); + + // TODO: include more information about the failure in the response + + return {.IsAccepted = false, .Reason = "HTTP request failed"}; + } + else if (PayloadResponse.StatusCode == HttpResponseCode::OK) + { + Result = PayloadResponse.AsObject(); + } + else + { + // Unexpected response + + const int ResponseStatusCode = (int)PayloadResponse.StatusCode; + + ZEN_WARN("unable to register payloads for action {} at {}/jobs (error: {} {})", + ActionId, + m_Http.GetBaseUri(), + ResponseStatusCode, + ToString(ResponseStatusCode)); + + return {.IsAccepted = false, + .Reason = fmt::format("unexpected response code {} {} from {}/jobs", + ResponseStatusCode, + ToString(ResponseStatusCode), + m_Http.GetBaseUri())}; + } + } + + if (Result) + { + if (const int32_t LsnField = Result["lsn"].AsInt32(0)) + { + HttpRunningAction NewAction; + NewAction.Action = Action; + NewAction.RemoteActionLsn = LsnField; + + { + RwLock::ExclusiveLockScope _(m_RunningLock); + + m_RemoteRunningMap[LsnField] = std::move(NewAction); + } + + ZEN_DEBUG("scheduled action {} with remote LSN {} (local LSN {})", ActionId, LsnField, ActionLsn); + + Action->SetActionState(RunnerAction::State::Running); + + return SubmitResult{.IsAccepted = true}; + } + } + + return {}; +} + +bool +RemoteHttpRunner::IsHealthy() +{ + if (HttpClient::Response Ready = m_Http.Get("/ready")) + { + return true; + } + else + { + // TODO: use response to propagate context + return false; + } +} + +size_t +RemoteHttpRunner::GetSubmittedActionCount() +{ + RwLock::SharedLockScope _(m_RunningLock); + return m_RemoteRunningMap.size(); +} + +void +RemoteHttpRunner::MonitorThreadFunction() +{ + SetCurrentThreadName("RemoteHttpRunner_Monitor"); + + do + { + const int NormalWaitingTime = 1000; + int WaitTimeMs = NormalWaitingTime; + auto WaitOnce = [&] { return m_MonitorThreadEvent.Wait(WaitTimeMs); }; + auto SweepOnce = [&] { + const size_t RetiredCount = SweepRunningActions(); + + m_RunningLock.WithSharedLock([&] { + if (m_RemoteRunningMap.size() > 16) + { + WaitTimeMs = NormalWaitingTime / 4; + } + else + { + if (RetiredCount) + { + WaitTimeMs = NormalWaitingTime / 2; + } + else + { + WaitTimeMs = NormalWaitingTime; + } + } + }); + }; + + while (!WaitOnce()) + { + SweepOnce(); + } + + // Signal received - this may mean we should quit + + SweepOnce(); + } while (m_MonitorThreadEnabled); +} + +size_t +RemoteHttpRunner::SweepRunningActions() +{ + std::vector<HttpRunningAction> CompletedActions; + + // Poll remote for list of completed actions + + HttpClient::Response ResponseCompleted = m_Http.Get("/jobs/completed"sv); + + if (CbObject Completed = ResponseCompleted.AsObject()) + { + for (auto& FieldIt : Completed["completed"sv]) + { + const int32_t CompleteLsn = FieldIt.AsInt32(); + + if (HttpClient::Response ResponseJob = m_Http.Get(fmt::format("/jobs/{}"sv, CompleteLsn))) + { + m_RunningLock.WithExclusiveLock([&] { + if (auto CompleteIt = m_RemoteRunningMap.find(CompleteLsn); CompleteIt != m_RemoteRunningMap.end()) + { + HttpRunningAction CompletedAction = std::move(CompleteIt->second); + CompletedAction.ActionResults = ResponseJob.AsPackage(); + CompletedAction.Success = true; + + CompletedActions.push_back(std::move(CompletedAction)); + m_RemoteRunningMap.erase(CompleteIt); + } + else + { + // we received a completion notice for an action we don't know about, + // this can happen if the runner is used by multiple upstream schedulers, + // or if this compute node was recently restarted and lost track of + // previously scheduled actions + } + }); + } + } + + if (CbObjectView Metrics = Completed["metrics"sv].AsObjectView()) + { + // if (const size_t CpuCount = Metrics["core_count"].AsInt32(0)) + if (const int32_t CpuCount = Metrics["lp_count"].AsInt32(0)) + { + const int32_t NewCap = zen::Max(4, CpuCount); + + if (m_MaxRunningActions > NewCap) + { + ZEN_DEBUG("capping {} to {} actions (was {})", m_BaseUrl, NewCap, m_MaxRunningActions); + + m_MaxRunningActions = NewCap; + } + } + } + } + + // Notify outer. Note that this has to be done without holding any local locks + // otherwise we may end up with deadlocks. + + for (HttpRunningAction& HttpAction : CompletedActions) + { + const int ActionLsn = HttpAction.Action->ActionLsn; + + if (HttpAction.Success) + { + ZEN_DEBUG("completed: {} LSN {} (remote LSN {})", HttpAction.Action->ActionId, ActionLsn, HttpAction.RemoteActionLsn); + + HttpAction.Action->SetActionState(RunnerAction::State::Completed); + + HttpAction.Action->SetResult(std::move(HttpAction.ActionResults)); + } + else + { + HttpAction.Action->SetActionState(RunnerAction::State::Failed); + } + } + + return CompletedActions.size(); +} + +} // namespace zen::compute + +#endif |