From dec3f27c488a1dda8a2f1133361e2fda9315e0d2 Mon Sep 17 00:00:00 2001 From: Dan Engelbrecht Date: Fri, 11 Apr 2025 19:22:21 +0200 Subject: fix race condition in multipart download (#358) --- src/zen/cmds/builds_cmd.cpp | 26 +++++++-------- src/zenutil/filebuildstorage.cpp | 26 +++++++++------ src/zenutil/include/zenutil/buildstorage.h | 12 +++---- .../include/zenutil/jupiter/jupitersession.h | 15 +++++---- src/zenutil/jupiter/jupiterbuildstorage.cpp | 20 ++++++++---- src/zenutil/jupiter/jupitersession.cpp | 38 ++++++++++++++-------- 6 files changed, 81 insertions(+), 56 deletions(-) (limited to 'src') diff --git a/src/zen/cmds/builds_cmd.cpp b/src/zen/cmds/builds_cmd.cpp index f7f9e3abb..cdcd79f58 100644 --- a/src/zen/cmds/builds_cmd.cpp +++ b/src/zen/cmds/builds_cmd.cpp @@ -1843,25 +1843,25 @@ namespace { BuildId, ChunkHash, PreferredMultipartChunkSize, - [Workload, &DownloadStats, OnDownloadComplete = std::move(OnDownloadComplete)](uint64_t Offset, - const IoBuffer& Chunk, - uint64_t BytesRemaining) { + [Workload, &DownloadStats](uint64_t Offset, const IoBuffer& Chunk) { DownloadStats.DownloadedChunkByteCount += Chunk.GetSize(); if (!AbortFlag.load()) { ZEN_TRACE_CPU("DownloadLargeBlob_Save"); Workload->TempFile.Write(Chunk.GetView(), Offset); - if (Chunk.GetSize() == BytesRemaining) - { - DownloadStats.DownloadedChunkCount++; - uint64_t PayloadSize = Workload->TempFile.FileSize(); - void* FileHandle = Workload->TempFile.Detach(); - ZEN_ASSERT(FileHandle != nullptr); - IoBuffer Payload(IoBuffer::File, FileHandle, 0, PayloadSize, true); - Payload.SetDeleteOnClose(true); - OnDownloadComplete(std::move(Payload)); - } + } + }, + [Workload, &DownloadStats, OnDownloadComplete = std::move(OnDownloadComplete)]() { + DownloadStats.DownloadedChunkCount++; + if (!AbortFlag.load()) + { + uint64_t PayloadSize = Workload->TempFile.FileSize(); + void* FileHandle = Workload->TempFile.Detach(); + ZEN_ASSERT(FileHandle != nullptr); + IoBuffer Payload(IoBuffer::File, FileHandle, 0, PayloadSize, true); + Payload.SetDeleteOnClose(true); + OnDownloadComplete(std::move(Payload)); } }); if (!WorkItems.empty()) diff --git a/src/zenutil/filebuildstorage.cpp b/src/zenutil/filebuildstorage.cpp index 7aa252e44..f335a03a3 100644 --- a/src/zenutil/filebuildstorage.cpp +++ b/src/zenutil/filebuildstorage.cpp @@ -369,11 +369,11 @@ public: return IoBuffer{}; } - virtual std::vector> GetLargeBuildBlob( - const Oid& BuildId, - const IoHash& RawHash, - uint64_t ChunkSize, - std::function&& Receiver) override + virtual std::vector> GetLargeBuildBlob(const Oid& BuildId, + const IoHash& RawHash, + uint64_t ChunkSize, + std::function&& OnReceive, + std::function&& OnComplete) override { ZEN_TRACE_CPU("FileBuildStorage::GetLargeBuildBlob"); ZEN_UNUSED(BuildId); @@ -387,16 +387,18 @@ public: { struct WorkloadData { - std::atomic BytesRemaining; - BasicFile BlobFile; - std::function Receiver; + std::atomic BytesRemaining; + BasicFile BlobFile; + std::function OnReceive; + std::function OnComplete; }; std::shared_ptr Workload(std::make_shared()); Workload->BlobFile.Open(BlockPath, BasicFile::Mode::kRead); const uint64_t BlobSize = Workload->BlobFile.FileSize(); - Workload->Receiver = std::move(Receiver); + Workload->OnReceive = std::move(OnReceive); + Workload->OnComplete = std::move(OnComplete); Workload->BytesRemaining = BlobSize; std::vector> WorkItems; @@ -410,8 +412,12 @@ public: IoBuffer PartPayload(Size); Workload->BlobFile.Read(PartPayload.GetMutableView().GetData(), Size, Offset); m_Stats.TotalBytesRead += PartPayload.GetSize(); + Workload->OnReceive(Offset, PartPayload); uint64_t ByteRemaning = Workload->BytesRemaining.fetch_sub(Size); - Workload->Receiver(Offset, PartPayload, ByteRemaning); + if (ByteRemaning == Size) + { + Workload->OnComplete(); + } SimulateLatency(Size, PartPayload.GetSize()); }); diff --git a/src/zenutil/include/zenutil/buildstorage.h b/src/zenutil/include/zenutil/buildstorage.h index b0665dbf8..05e3ca22d 100644 --- a/src/zenutil/include/zenutil/buildstorage.h +++ b/src/zenutil/include/zenutil/buildstorage.h @@ -47,12 +47,12 @@ public: virtual IoBuffer GetBuildBlob(const Oid& BuildId, const IoHash& RawHash, uint64_t RangeOffset = 0, - uint64_t RangeBytes = (uint64_t)-1) = 0; - virtual std::vector> GetLargeBuildBlob( - const Oid& BuildId, - const IoHash& RawHash, - uint64_t ChunkSize, - std::function&& Receiver) = 0; + uint64_t RangeBytes = (uint64_t)-1) = 0; + virtual std::vector> GetLargeBuildBlob(const Oid& BuildId, + const IoHash& RawHash, + uint64_t ChunkSize, + std::function&& OnReceive, + std::function&& OnComplete) = 0; virtual void PutBlockMetadata(const Oid& BuildId, const IoHash& BlockRawHash, const CbObject& MetaData) = 0; virtual CbObject FindBlocks(const Oid& BuildId, uint64_t MaxBlockCount) = 0; diff --git a/src/zenutil/include/zenutil/jupiter/jupitersession.h b/src/zenutil/include/zenutil/jupiter/jupitersession.h index 417ed7384..c2886ca4c 100644 --- a/src/zenutil/include/zenutil/jupiter/jupitersession.h +++ b/src/zenutil/include/zenutil/jupiter/jupitersession.h @@ -135,13 +135,14 @@ public: uint64_t PayloadSize, std::function&& Transmitter, std::vector>& OutWorkItems); - JupiterResult GetMultipartBuildBlob(std::string_view Namespace, - std::string_view BucketId, - const Oid& BuildId, - const IoHash& Hash, - uint64_t ChunkSize, - std::function&& Receiver, - std::vector>& OutWorkItems); + JupiterResult GetMultipartBuildBlob(std::string_view Namespace, + std::string_view BucketId, + const Oid& BuildId, + const IoHash& Hash, + uint64_t ChunkSize, + std::function&& OnReceive, + std::function&& OnComplete, + std::vector>& OutWorkItems); JupiterResult PutBlockMetadata(std::string_view Namespace, std::string_view BucketId, const Oid& BuildId, diff --git a/src/zenutil/jupiter/jupiterbuildstorage.cpp b/src/zenutil/jupiter/jupiterbuildstorage.cpp index f2d190408..24e062c7b 100644 --- a/src/zenutil/jupiter/jupiterbuildstorage.cpp +++ b/src/zenutil/jupiter/jupiterbuildstorage.cpp @@ -235,19 +235,25 @@ public: return std::move(GetBuildBlobResult.Response); } - virtual std::vector> GetLargeBuildBlob( - const Oid& BuildId, - const IoHash& RawHash, - uint64_t ChunkSize, - std::function&& Receiver) override + virtual std::vector> GetLargeBuildBlob(const Oid& BuildId, + const IoHash& RawHash, + uint64_t ChunkSize, + std::function&& OnReceive, + std::function&& OnComplete) override { ZEN_TRACE_CPU("Jupiter::GetLargeBuildBlob"); Stopwatch ExecutionTimer; auto _ = MakeGuard([&]() { m_Stats.TotalExecutionTimeUs += ExecutionTimer.GetElapsedTimeUs(); }); std::vector> WorkItems; - JupiterResult GetMultipartBlobResult = - m_Session.GetMultipartBuildBlob(m_Namespace, m_Bucket, BuildId, RawHash, ChunkSize, std::move(Receiver), WorkItems); + JupiterResult GetMultipartBlobResult = m_Session.GetMultipartBuildBlob(m_Namespace, + m_Bucket, + BuildId, + RawHash, + ChunkSize, + std::move(OnReceive), + std::move(OnComplete), + WorkItems); AddStatistic(GetMultipartBlobResult); if (!GetMultipartBlobResult.Success) diff --git a/src/zenutil/jupiter/jupitersession.cpp b/src/zenutil/jupiter/jupitersession.cpp index fde86a478..1f71c29b7 100644 --- a/src/zenutil/jupiter/jupitersession.cpp +++ b/src/zenutil/jupiter/jupitersession.cpp @@ -626,13 +626,14 @@ JupiterSession::PutMultipartBuildBlob(std::string_view Namespace, } JupiterResult -JupiterSession::GetMultipartBuildBlob(std::string_view Namespace, - std::string_view BucketId, - const Oid& BuildId, - const IoHash& Hash, - uint64_t ChunkSize, - std::function&& Receiver, - std::vector>& OutWorkItems) +JupiterSession::GetMultipartBuildBlob(std::string_view Namespace, + std::string_view BucketId, + const Oid& BuildId, + const IoHash& Hash, + uint64_t ChunkSize, + std::function&& OnReceive, + std::function&& OnComplete, + std::vector>& OutWorkItems) { std::string RequestUrl = fmt::format("/api/v2/builds/{}/{}/{}/blobs/{}", Namespace, BucketId, BuildId, Hash.ToHexString()); HttpClient::Response Response = @@ -649,18 +650,20 @@ JupiterSession::GetMultipartBuildBlob(std::string_view Namespa uint64_t TotalSize = TotalSizeMaybe.value(); uint64_t PayloadSize = Response.ResponsePayload.GetSize(); - Receiver(0, Response.ResponsePayload, TotalSize); + OnReceive(0, Response.ResponsePayload); if (TotalSize > PayloadSize) { struct WorkloadData { - std::function Receiver; - std::atomic BytesRemaining; + std::function OnReceive; + std::function OnComplete; + std::atomic BytesRemaining; }; std::shared_ptr Workload(std::make_shared()); - Workload->Receiver = std::move(Receiver); + Workload->OnReceive = std::move(OnReceive); + Workload->OnComplete = std::move(OnComplete); Workload->BytesRemaining = TotalSize - PayloadSize; uint64_t Offset = PayloadSize; @@ -676,19 +679,28 @@ JupiterSession::GetMultipartBuildBlob(std::string_view Namespa HttpClient::KeyValueMap({{"Range", fmt::format("bytes={}-{}", Offset, Offset + PartSize - 1)}})); if (Response.IsSuccess()) { + Workload->OnReceive(Offset, Response.ResponsePayload); uint64_t ByteRemaning = Workload->BytesRemaining.fetch_sub(Response.ResponsePayload.GetSize()); - Workload->Receiver(Offset, Response.ResponsePayload, ByteRemaning); + if (ByteRemaning == Response.ResponsePayload.GetSize()) + { + Workload->OnComplete(); + } } return detail::ConvertResponse(Response, "JupiterSession::GetMultipartBuildBlob"sv); }); Offset += PartSize; } } + else + { + OnComplete(); + } return detail::ConvertResponse(Response, "JupiterSession::GetMultipartBuildBlob"sv); } } } - Receiver(0, Response.ResponsePayload, Response.ResponsePayload.GetSize()); + OnReceive(0, Response.ResponsePayload); + OnComplete(); } return detail::ConvertResponse(Response, "JupiterSession::GetMultipartBuildBlob"sv); } -- cgit v1.2.3