diff options
| author | Dan Engelbrecht <[email protected]> | 2025-05-02 15:33:58 +0200 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2025-05-02 15:33:58 +0200 |
| commit | a9c83b3c1299692923e2c16b696f5b9e211f5737 (patch) | |
| tree | 21e9ea21386672d2b870baf3103ebde2b5fbbb76 /src/zenstore/compactcas.cpp | |
| parent | cbobject validation (#377) (diff) | |
| download | zen-a9c83b3c1299692923e2c16b696f5b9e211f5737.tar.xz zen-a9c83b3c1299692923e2c16b696f5b9e211f5737.zip | |
iterate chunks crash fix (#376)
* Bugfix: Add explicit lambda capture in CasContainer::IterateChunks to avoid accessing state data references
Diffstat (limited to 'src/zenstore/compactcas.cpp')
| -rw-r--r-- | src/zenstore/compactcas.cpp | 281 |
1 files changed, 246 insertions, 35 deletions
diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp index 184251da7..90e77e48a 100644 --- a/src/zenstore/compactcas.cpp +++ b/src/zenstore/compactcas.cpp @@ -360,7 +360,11 @@ CasContainerStrategy::IterateChunks(std::span<const IoHash> ChunkHas return true; } - auto DoOneBlock = [&](std::span<const size_t> ChunkIndexes) { + auto DoOneBlock = [this](const std::function<bool(size_t Index, const IoBuffer& Payload)>& AsyncCallback, + uint64_t LargeSizeLimit, + std::span<const size_t> FoundChunkIndexes, + std::span<const BlockStoreLocation> FoundChunkLocations, + std::span<const size_t> ChunkIndexes) { if (ChunkIndexes.size() < 4) { for (size_t ChunkIndex : ChunkIndexes) @@ -376,7 +380,7 @@ CasContainerStrategy::IterateChunks(std::span<const IoHash> ChunkHas return m_BlockStore.IterateBlock( FoundChunkLocations, ChunkIndexes, - [&](size_t ChunkIndex, const void* Data, uint64_t Size) { + [AsyncCallback, FoundChunkIndexes, LargeSizeLimit](size_t ChunkIndex, const void* Data, uint64_t Size) { if (Data == nullptr) { return AsyncCallback(FoundChunkIndexes[ChunkIndex], IoBuffer()); @@ -391,39 +395,59 @@ CasContainerStrategy::IterateChunks(std::span<const IoHash> ChunkHas Latch WorkLatch(1); std::atomic_bool AsyncContinue = true; - bool Continue = m_BlockStore.IterateChunks(FoundChunkLocations, [&](uint32_t BlockIndex, std::span<const size_t> ChunkIndexes) { - if (OptionalWorkerPool && (ChunkIndexes.size() > 3)) - { - WorkLatch.AddCount(1); - OptionalWorkerPool->ScheduleWork([&, ChunkIndexes = std::vector<size_t>(ChunkIndexes.begin(), ChunkIndexes.end())]() { - auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); - if (!AsyncContinue) - { - return; - } - try - { - bool Continue = DoOneBlock(ChunkIndexes); - if (!Continue) - { - AsyncContinue.store(false); - } - } - catch (const std::exception& Ex) - { - ZEN_WARN("Failed iterating chunks for cas root path {}, block {}. Reason: '{}'", - m_RootDirectory, - BlockIndex, - Ex.what()); - } - }); - return AsyncContinue.load(); - } - else - { - return DoOneBlock(ChunkIndexes); - } - }); + bool Continue = m_BlockStore.IterateChunks( + FoundChunkLocations, + [this, + &AsyncContinue, + &WorkLatch, + &AsyncCallback, + LargeSizeLimit, + DoOneBlock, + &FoundChunkIndexes, + &FoundChunkLocations, + OptionalWorkerPool](uint32_t BlockIndex, std::span<const size_t> ChunkIndexes) { + if (OptionalWorkerPool && (ChunkIndexes.size() > 3)) + { + std::vector<size_t> TmpChunkIndexes(ChunkIndexes.begin(), ChunkIndexes.end()); + WorkLatch.AddCount(1); + OptionalWorkerPool->ScheduleWork([this, + &AsyncContinue, + &WorkLatch, + &AsyncCallback, + LargeSizeLimit, + DoOneBlock, + BlockIndex, + &FoundChunkIndexes, + &FoundChunkLocations, + ChunkIndexes = std::move(TmpChunkIndexes)]() { + auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); + if (!AsyncContinue) + { + return; + } + try + { + bool Continue = DoOneBlock(AsyncCallback, LargeSizeLimit, FoundChunkIndexes, FoundChunkLocations, ChunkIndexes); + if (!Continue) + { + AsyncContinue.store(false); + } + } + catch (const std::exception& Ex) + { + ZEN_WARN("Failed iterating chunks for cas root path {}, block {}. Reason: '{}'", + m_RootDirectory, + BlockIndex, + Ex.what()); + } + }); + return AsyncContinue.load(); + } + else + { + return DoOneBlock(AsyncCallback, LargeSizeLimit, FoundChunkIndexes, FoundChunkLocations, ChunkIndexes); + } + }); WorkLatch.CountDown(); WorkLatch.Wait(); return AsyncContinue.load() && Continue; @@ -1573,6 +1597,193 @@ TEST_CASE("compactcas.threadedinsert") } } +TEST_CASE("compactcas.iteratechunks") +{ + std::atomic<size_t> WorkCompleted = 0; + WorkerThreadPool ThreadPool(Max(std::thread::hardware_concurrency() - 1u, 2u), "put"); + + const uint64_t kChunkSize = 1048 + 395; + const size_t kChunkCount = 63840; + + for (uint32_t N = 0; N < 4; N++) + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + ScopedTemporaryDirectory TempDir; + Cas.Initialize(TempDir.Path(), "test", 65536 * 128, 8, true); + + CHECK(Cas.IterateChunks( + {}, + [](size_t Index, const IoBuffer& Payload) { + ZEN_UNUSED(Index, Payload); + return true; + }, + &ThreadPool, + 2048u)); + + uint64_t ExpectedSize = 0; + + std::vector<IoHash> Hashes; + Hashes.reserve(kChunkCount); + + { + Latch WorkLatch(1); + tsl::robin_set<IoHash, IoHash::Hasher> ChunkHashesLookup; + ChunkHashesLookup.reserve(kChunkCount); + RwLock InsertLock; + for (size_t Offset = 0; Offset < kChunkCount;) + { + size_t BatchCount = Min<size_t>(kChunkCount - Offset, 512u); + WorkLatch.AddCount(1); + ThreadPool.ScheduleWork( + [N, &WorkLatch, &InsertLock, &ChunkHashesLookup, &ExpectedSize, &Hashes, &Cas, Offset, BatchCount]() { + auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); + + std::vector<IoBuffer> BatchBlobs; + std::vector<IoHash> BatchHashes; + BatchBlobs.reserve(BatchCount); + BatchHashes.reserve(BatchCount); + + while (BatchBlobs.size() < BatchCount) + { + IoBuffer Chunk = CreateRandomBlob( + N + kChunkSize + ((BatchHashes.size() % 100) + (BatchHashes.size() % 7) * 315u + Offset % 377)); + IoHash Hash = IoHash::HashBuffer(Chunk); + { + RwLock::ExclusiveLockScope __(InsertLock); + if (ChunkHashesLookup.contains(Hash)) + { + continue; + } + ChunkHashesLookup.insert(Hash); + ExpectedSize += Chunk.Size(); + } + + BatchBlobs.emplace_back(std::move(Chunk)); + BatchHashes.push_back(Hash); + } + + Cas.InsertChunks(BatchBlobs, BatchHashes); + { + RwLock::ExclusiveLockScope __(InsertLock); + Hashes.insert(Hashes.end(), BatchHashes.begin(), BatchHashes.end()); + } + }); + Offset += BatchCount; + } + WorkLatch.CountDown(); + WorkLatch.Wait(); + } + + WorkerThreadPool BatchWorkerPool(Max(std::thread::hardware_concurrency() - 1u, 2u), "fetch"); + { + std::vector<std::atomic<bool>> FetchedFlags(Hashes.size()); + std::atomic<uint64_t> FetchedSize = 0; + CHECK(Cas.IterateChunks( + Hashes, + [&Hashes, &FetchedFlags, &FetchedSize](size_t Index, const IoBuffer& Payload) { + CHECK(FetchedFlags[Index].load() == false); + FetchedFlags[Index].store(true); + const IoHash& Hash = Hashes[Index]; + CHECK(Hash == IoHash::HashBuffer(Payload)); + FetchedSize += Payload.GetSize(); + return true; + }, + &BatchWorkerPool, + 2048u)); + for (const auto& Flag : FetchedFlags) + { + CHECK(Flag.load()); + } + CHECK(FetchedSize == ExpectedSize); + } + + Latch WorkLatch(1); + for (size_t I = 0; I < 2; I++) + { + WorkLatch.AddCount(1); + ThreadPool.ScheduleWork([&Cas, &Hashes, &BatchWorkerPool, &WorkLatch, I]() { + auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); + std::vector<IoHash> PartialHashes; + PartialHashes.reserve(Hashes.size() / 4); + for (size_t Index = 0; Index < Hashes.size(); Index++) + { + size_t TestIndex = Index + I; + if ((TestIndex % 7 == 1) || (TestIndex % 13 == 1) || (TestIndex % 17 == 1)) + { + PartialHashes.push_back(Hashes[Index]); + } + } + std::reverse(PartialHashes.begin(), PartialHashes.end()); + + std::vector<IoHash> NoFoundHashes; + std::vector<size_t> NoFindIndexes; + + NoFoundHashes.reserve(9); + for (size_t J = 0; J < 9; J++) + { + std::string Data = fmt::format("oh no, we don't exist {}", J + 1); + NoFoundHashes.push_back(IoHash::HashBuffer(Data.data(), Data.length())); + } + + NoFindIndexes.reserve(9); + + // Sprinkle in chunks that are not found! + auto It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 0, NoFoundHashes[0]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 0 + 1, NoFoundHashes[1]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 1, NoFoundHashes[2]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 1 + 1, NoFoundHashes[3]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 2, NoFoundHashes[4]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 3, NoFoundHashes[5]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 3 + 1, NoFoundHashes[6]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.begin() + (PartialHashes.size() / 4) * 4, NoFoundHashes[7]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + It = PartialHashes.insert(PartialHashes.end(), NoFoundHashes[8]); + NoFindIndexes.push_back(std::distance(PartialHashes.begin(), It)); + + std::vector<std::atomic<bool>> FoundFlags(PartialHashes.size() + NoFoundHashes.size()); + std::vector<std::atomic<uint32_t>> FetchedCounts(PartialHashes.size() + NoFoundHashes.size()); + + CHECK(Cas.IterateChunks( + PartialHashes, + [&PartialHashes, &FoundFlags, &FetchedCounts, &NoFindIndexes](size_t Index, const IoBuffer& Payload) { + CHECK_EQ(NoFindIndexes.end(), std::find(NoFindIndexes.begin(), NoFindIndexes.end(), Index)); + uint32_t PreviousCount = FetchedCounts[Index].fetch_add(1); + CHECK(PreviousCount == 0); + FoundFlags[Index] = !!Payload; + const IoHash& Hash = PartialHashes[Index]; + CHECK(Hash == IoHash::HashBuffer(Payload)); + return true; + }, + &BatchWorkerPool, + 2048u)); + + for (size_t FoundIndex = 0; FoundIndex < PartialHashes.size(); FoundIndex++) + { + CHECK(FetchedCounts[FoundIndex].load() <= 1); + if (std::find(NoFindIndexes.begin(), NoFindIndexes.end(), FoundIndex) == NoFindIndexes.end()) + { + CHECK(FoundFlags[FoundIndex]); + } + else + { + CHECK(!FoundFlags[FoundIndex]); + } + } + }); + } + WorkLatch.CountDown(); + WorkLatch.Wait(); + } +} + #endif void |