aboutsummaryrefslogtreecommitdiff
path: root/src/zenserver/upstream/upstreamcache.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src/zenserver/upstream/upstreamcache.cpp
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src/zenserver/upstream/upstreamcache.cpp')
-rw-r--r--src/zenserver/upstream/upstreamcache.cpp2112
1 files changed, 2112 insertions, 0 deletions
diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp
new file mode 100644
index 000000000..e838b5fe2
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.cpp
@@ -0,0 +1,2112 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamcache.h"
+#include "jupiter.h"
+#include "zen.h"
+
+#include <zencore/blockingqueue.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/stats.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+
+#include <zenhttp/httpshared.h>
+
+#include <zenstore/cidstore.h>
+
+#include <auth/authmgr.h>
+#include "cache/structuredcache.h"
+#include "cache/structuredcachestore.h"
+#include "diag/logging.h"
+
+#include <fmt/format.h>
+
+#include <algorithm>
+#include <atomic>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace detail {
+
+ class UpstreamStatus
+ {
+ public:
+ UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); }
+
+ UpstreamEndpointStatus EndpointStatus() const
+ {
+ const UpstreamEndpointState State = EndpointState();
+ {
+ std::unique_lock _(m_Mutex);
+ return {.Reason = m_ErrorText, .State = State};
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText.clear();
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState, std::string ErrorText)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText = std::move(ErrorText);
+ }
+ }
+
+ void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText)
+ {
+ if (ErrorCode != 0)
+ {
+ Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText));
+ }
+ }
+
+ private:
+ mutable std::mutex m_Mutex;
+ std::string m_ErrorText;
+ std::atomic_uint32_t m_State;
+ };
+
+ class JupiterUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ public:
+ JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+ : m_AuthMgr(Mgr)
+ , m_Log(zen::logging::Get("upstream"))
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+ m_Info.Url = Options.ServiceUrl;
+
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (AuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials(
+ {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret});
+ }
+ else if (AuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(Options, std::move(TokenProvider));
+ }
+
+ virtual ~JupiterUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ CloudCacheSession Session(m_Client);
+ const CloudCacheResult Result = Session.Authenticate();
+
+ if (Result.Success)
+ {
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else if (Result.ErrorCode != 0)
+ {
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kUnauthorized);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultDdcNamespace();
+ }
+ return Namespace;
+ }
+
+ std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultBlobStoreNamespace();
+ }
+ return Namespace;
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+
+ if (Type == ZenContentType::kCompressedBinary)
+ {
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+
+ if (Result.Success)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+ IoBuffer ContentBuffer;
+ int NumAttachments = 0;
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize))
+ {
+ Result.Response = AttachmentResult.Response;
+ ++NumAttachments;
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+ if (NumAttachments != 1)
+ {
+ Result.Success = false;
+ }
+ }
+ }
+ }
+ else
+ {
+ const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type;
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType);
+
+ if (Result.Success && Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash()));
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+
+ Package.SetObject(CacheRecord);
+ }
+
+ if (Result.Success)
+ {
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ }
+ }
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ const CacheKey& CacheKey = Request->Key;
+ CbPackage Package;
+ CbObject Record;
+
+ double ElapsedSeconds = 0.0;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ CloudCacheResult RefResult =
+ Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+ AppendResult(RefResult, Result);
+ ElapsedSeconds = RefResult.ElapsedSeconds;
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (RefResult.ErrorCode == 0)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All);
+ if (ValidationResult == CbValidateError::None)
+ {
+ Record = LoadCompactBinaryObject(RefResult.Response);
+ Record.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (BlobResult.ErrorCode == 0)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize))
+ {
+ if (RawHash == AttachmentHash.AsHash())
+ {
+ Package.AddAttachment(CbAttachment(Chunk, RawHash));
+ }
+ }
+ }
+ });
+ }
+ }
+ }
+
+ OnComplete(
+ {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info});
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey&,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ CacheChunkRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult BlobResult =
+ Request.ChunkId == IoHash::Zero
+ ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId)
+ : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload && IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize);
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ CacheValueRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ IoHash PayloadHash;
+ const CloudCacheResult BlobResult =
+ Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload)
+ {
+ if (IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash;
+ }
+ else
+ {
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload));
+ RawHash = Compressed.DecodeRawHash();
+ if (RawHash == PayloadHash)
+ {
+ IsCompressed = true;
+ }
+ else
+ {
+ ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header",
+ Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash.ToHexString(),
+ RawHash.ToHexString(),
+ PayloadHash.ToHexString());
+ }
+ }
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+
+ if (CacheRecord.Type == ZenContentType::kBinary)
+ {
+ CloudCacheResult Result;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace);
+ Result = Session.PutRef(BlobStoreNamespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ ZenContentType::kBinary);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Success = Result.Success};
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize))
+ {
+ return {.Reason = std::string("Invalid compressed value buffer"), .Success = false};
+ }
+
+ CbObjectWriter ReferencingObject;
+ ReferencingObject.AddBinaryAttachment("RawHash", RawHash);
+ ReferencingObject.AddInteger("RawSize", RawSize);
+
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ ReferencingObject.Save().GetBuffer().AsIoBuffer(),
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ if (ValueContentId != RawHash)
+ {
+ OutReason =
+ fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash);
+ return false;
+ }
+
+ OutBuffer = RecordValue;
+ return true;
+ });
+ }
+ else
+ {
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ RecordValue,
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ const auto It =
+ std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId);
+
+ if (It == std::end(CacheRecord.ValueContentIds))
+ {
+ OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId);
+ return false;
+ }
+
+ const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It);
+
+ OutBuffer = Values[Idx];
+ return true;
+ });
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out)
+ {
+ Out.Success &= Result.Success;
+ Out.Bytes += Result.Bytes;
+ Out.ElapsedSeconds += Result.ElapsedSeconds;
+
+ if (Result.ErrorCode)
+ {
+ Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)};
+ }
+ };
+
+ PutUpstreamCacheResult PerformStructuredPut(
+ CloudCacheSession& Session,
+ std::string_view Namespace,
+ const CacheKey& Key,
+ IoBuffer ObjectBuffer,
+ const int32_t MaxAttempts,
+ std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn)
+ {
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool {
+ for (const IoHash& ValueContentId : ValueContentIds)
+ {
+ IoBuffer BlobBuffer;
+ if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason))
+ {
+ return false;
+ }
+
+ CloudCacheResult BlobResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++)
+ {
+ BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer);
+ }
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (!BlobResult.Success)
+ {
+ OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason);
+ return false;
+ }
+
+ TotalBytes += BlobResult.Bytes;
+ TotalElapsedSeconds += BlobResult.ElapsedSeconds;
+ }
+
+ return true;
+ };
+
+ PutRefResult RefResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++)
+ {
+ RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject);
+ }
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (!RefResult.Success)
+ {
+ return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason),
+ .Success = false};
+ }
+
+ TotalBytes += RefResult.Bytes;
+ TotalElapsedSeconds += RefResult.ElapsedSeconds;
+
+ std::string Reason;
+ if (!PutBlobs(RefResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer);
+ FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {
+ .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ if (!PutBlobs(FinalizeResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ ExtendableStringBuilder<256> Sb;
+ for (const IoHash& MissingHash : FinalizeResult.Needs)
+ {
+ Sb << MissingHash.ToHexString() << ",";
+ }
+
+ return {
+ .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()),
+ .Success = false};
+ }
+ }
+
+ TotalBytes += FinalizeResult.Bytes;
+ TotalElapsedSeconds += FinalizeResult.ElapsedSeconds;
+
+ return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true};
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthMgr& m_AuthMgr;
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ RefPtr<CloudCacheClient> m_Client;
+ };
+
+ class ZenUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ struct ZenEndpoint
+ {
+ std::string Url;
+ std::string Reason;
+ double Latency{};
+ bool Ok = false;
+
+ bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; }
+ };
+
+ public:
+ ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options)
+ : m_Log(zen::logging::Get("upstream"))
+ , m_ConnectTimeout(Options.ConnectTimeout)
+ , m_Timeout(Options.Timeout)
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+
+ for (const auto& Url : Options.Urls)
+ {
+ m_Endpoints.push_back({.Url = Url});
+ }
+ }
+
+ ~ZenUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ const ZenEndpoint& Ep = GetEndpoint();
+
+ if (m_Info.Url != Ep.Url)
+ {
+ ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url);
+ m_Info.Url = Ep.Url;
+ }
+
+ if (Ep.Ok)
+ {
+ RwLock::ExclusiveLockScope _(m_ClientLock);
+ m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout});
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Ep.Reason);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords");
+ ZEN_ASSERT(Requests.size() > 0);
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheRecords"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy();
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy);
+
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ for (CacheKeyRequest* Request : Requests)
+ {
+ BatchRequest.BeginObject();
+ {
+ const CacheKey& Key = Request->Key;
+ BatchRequest.BeginObject("Key"sv);
+ {
+ BatchRequest << "Bucket"sv << Key.Bucket;
+ BatchRequest << "Hash"sv << Key.Hash;
+ }
+ BatchRequest.EndObject();
+ if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy)
+ {
+ BatchRequest.SetName("Policy"sv);
+ Request->Policy.Save(BatchRequest);
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (Results.Num() != Requests.size())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t Index = 0; CbFieldView Record : Results)
+ {
+ CacheKeyRequest* Request = Requests[Index++];
+ OnComplete({.Request = *Request,
+ .Record = Record.AsObjectView(),
+ .Package = BatchResponse,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues");
+ ZEN_ASSERT(!CacheValueRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheValues"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ {
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ const CacheValueRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheValueRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheValueRequest& Request = *CacheValueRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks");
+ ZEN_ASSERT(!CacheChunkRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheChunks"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("ChunkRequests"sv);
+ {
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ const CacheChunkRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.ValueId)
+ {
+ BatchRequest.AddObjectId("ValueId"sv, Request.ValueId);
+ }
+ if (Request.ChunkId != Request.ChunkId.Zero)
+ {
+ BatchRequest << "ChunkId"sv << Request.ChunkId;
+ }
+ if (Request.RawOffset != 0)
+ {
+ BatchRequest << "RawOffset"sv << Request.RawOffset;
+ }
+ if (Request.RawSize != UINT64_MAX)
+ {
+ BatchRequest << "RawSize"sv << Request.RawSize;
+ }
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheChunkRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ ZenCacheResult Result;
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ if (CacheRecord.Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ Package.SetObject(CbObject(SharedBuffer(RecordValue)));
+
+ for (const IoBuffer& Value : Values)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash));
+ }
+ else
+ {
+ return {.Reason = std::string("Invalid value buffer"), .Success = false};
+ }
+ }
+
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+ IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size());
+
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ PackagePayload,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes = Result.Bytes;
+ TotalElapsedSeconds = Result.ElapsedSeconds;
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize);
+ if (!Compressed)
+ {
+ return {.Reason = std::string("Invalid value compressed buffer"), .Success = false};
+ }
+
+ CbPackage BatchPackage;
+ CbObjectWriter BatchWriter;
+ BatchWriter << "Method"sv
+ << "PutCacheValues"sv;
+ BatchWriter << "Accept"sv << kCbPkgMagic;
+
+ BatchWriter.BeginObject("Params"sv);
+ {
+ // DefaultPolicy unspecified and expected to be Default
+
+ BatchWriter << "Namespace"sv << CacheRecord.Namespace;
+
+ BatchWriter.BeginArray("Requests"sv);
+ {
+ BatchWriter.BeginObject();
+ {
+ const CacheKey& Key = CacheRecord.Key;
+ BatchWriter.BeginObject("Key"sv);
+ {
+ BatchWriter << "Bucket"sv << Key.Bucket;
+ BatchWriter << "Hash"sv << Key.Hash;
+ }
+ BatchWriter.EndObject();
+ // Policy unspecified and expected to be Default
+ BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash);
+ BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash));
+ }
+ BatchWriter.EndObject();
+ }
+ BatchWriter.EndArray();
+ }
+ BatchWriter.EndObject();
+ BatchPackage.SetObject(BatchWriter.Save());
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.InvokeRpc(BatchPackage);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+ else
+ {
+ for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++)
+ {
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheValue(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ CacheRecord.ValueContentIds[Idx],
+ Values[Idx]);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+
+ if (!Result.Success)
+ {
+ return {.Reason = "Failed to upload value",
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = false};
+ }
+ }
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ Ref<ZenStructuredCacheClient> GetClientRef()
+ {
+ // m_Client can be modified at any time by a different thread.
+ // Make sure we safely bump the refcount inside a scope lock
+ RwLock::SharedLockScope _(m_ClientLock);
+ ZEN_ASSERT(m_Client);
+ Ref<ZenStructuredCacheClient> ClientRef(m_Client);
+ _.ReleaseNow();
+ return ClientRef;
+ }
+
+ const ZenEndpoint& GetEndpoint()
+ {
+ for (ZenEndpoint& Ep : m_Endpoints)
+ {
+ Ref<ZenStructuredCacheClient> Client(
+ new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)}));
+ ZenStructuredCacheSession Session(std::move(Client));
+ const int32_t SampleCount = 2;
+
+ Ep.Ok = false;
+ Ep.Latency = {};
+
+ for (int32_t Sample = 0; Sample < SampleCount; ++Sample)
+ {
+ ZenCacheResult Result = Session.CheckHealth();
+ Ep.Ok = Result.Success;
+ Ep.Reason = std::move(Result.Reason);
+ Ep.Latency += Result.ElapsedSeconds;
+ }
+ Ep.Latency /= double(SampleCount);
+ }
+
+ std::sort(std::begin(m_Endpoints), std::end(m_Endpoints));
+
+ for (const auto& Ep : m_Endpoints)
+ {
+ ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason);
+ }
+
+ return m_Endpoints.front();
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ std::vector<ZenEndpoint> m_Endpoints;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+ RwLock m_ClientLock;
+ RefPtr<ZenStructuredCacheClient> m_Client;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamCacheImpl final : public UpstreamCache
+{
+public:
+ UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+ : m_Log(logging::Get("upstream"))
+ , m_Options(Options)
+ , m_CacheStore(CacheStore)
+ , m_CidStore(CidStore)
+ {
+ }
+
+ virtual ~UpstreamCacheImpl() { Shutdown(); }
+
+ virtual void Initialize() override
+ {
+ for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++)
+ {
+ m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this);
+ }
+
+ m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this);
+ m_RunState.IsRunning = true;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override
+ {
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+ else
+ {
+ ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+
+ // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s)
+ std::unique_lock<std::shared_mutex> _(m_EndpointsMutex);
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Ep : m_Endpoints)
+ {
+ if (!Fn(*Ep))
+ {
+ break;
+ }
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecord");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecords");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheKeyRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) {
+ if (Params.Record)
+ {
+ OnComplete(std::forward<CacheRecordGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheKeyRequest* Request : RemainingKeys)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+ }
+
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunks");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheChunkRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheChunkGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheChunkRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunk");
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheValues");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheValueRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheValueGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheValueRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override
+ {
+ if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0)
+ {
+ if (!m_UpstreamThreads.empty())
+ {
+ m_UpstreamQueue.Enqueue(std::move(CacheRecord));
+ }
+ else
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ }
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "reading" << m_Options.ReadUpstream;
+ Status << "writing" << m_Options.WriteUpstream;
+ Status << "worker_threads" << m_Options.ThreadCount;
+ Status << "queue_count" << m_UpstreamQueue.Size();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo();
+ const UpstreamEndpointStatus EpStatus = Ep->GetStatus();
+ UpstreamEndpointStats& EpStats = Ep->Stats();
+
+ Status.BeginObject();
+ Status << "name" << EpInfo.Name;
+ Status << "url" << EpInfo.Url;
+ Status << "state" << ToString(EpStatus.State);
+ Status << "reason" << EpStatus.Reason;
+
+ Status.BeginObject("cache"sv);
+ {
+ const int64_t GetCount = EpStats.CacheGetCount.Value();
+ const int64_t HitCount = EpStats.CacheHitCount.Value();
+ const int64_t ErrorCount = EpStats.CacheErrorCount.Value();
+ const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0;
+ const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0;
+
+ metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status);
+ Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value();
+ Status << "get_count" << GetCount;
+ Status << "hit_count" << HitCount;
+ Status << "hit_ratio" << HitRatio;
+ Status << "error_count" << ErrorCount;
+ Status << "error_ratio" << ErrorRatio;
+ metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status);
+ Status << "put_bytes" << EpStats.CachePutTotalBytes.Value();
+ }
+ Status.EndObject();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ void ProcessCacheRecord(UpstreamCacheRecord CacheRecord)
+ {
+ ZEN_TRACE_CPU("Upstream::ProcessCacheRecord");
+
+ ZenCacheValue CacheValue;
+ std::vector<IoBuffer> Payloads;
+
+ if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue))
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash);
+ return;
+ }
+
+ for (const IoHash& ValueContentId : CacheRecord.ValueContentIds)
+ {
+ if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId))
+ {
+ Payloads.push_back(Payload);
+ }
+ else
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ ValueContentId);
+ return;
+ }
+ }
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ PutUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming);
+ Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads));
+ }
+
+ Stats.CachePutTotalBytes.Increment(Result.Bytes);
+
+ if (!Result.Success)
+ {
+ ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Endpoint->GetEndpointInfo().Url,
+ Result.Reason);
+ }
+ }
+ }
+
+ void ProcessUpstreamQueue()
+ {
+ for (;;)
+ {
+ UpstreamCacheRecord CacheRecord;
+ if (m_UpstreamQueue.WaitAndDequeue(CacheRecord))
+ {
+ try
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Err.what());
+ }
+ }
+
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock lk(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ try
+ {
+ std::vector<UpstreamEndpoint*> Endpoints;
+
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ UpstreamEndpointState State = Endpoint->GetState();
+ if (State == UpstreamEndpointState::kError)
+ {
+ Endpoints.push_back(Endpoint.get());
+ ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'",
+ Endpoint->GetEndpointInfo().Name,
+ Endpoint->GetEndpointInfo().Url,
+ Endpoint->GetStatus().Reason);
+ }
+ if (State == UpstreamEndpointState::kUnauthorized)
+ {
+ Endpoints.push_back(Endpoint.get());
+ }
+ }
+ }
+
+ for (auto& Endpoint : Endpoints)
+ {
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url);
+ }
+ else
+ {
+ const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason);
+ ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason);
+ }
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what());
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_UpstreamQueue.CompleteAdding();
+ for (std::thread& Thread : m_UpstreamThreads)
+ {
+ Thread.join();
+ }
+
+ m_EndpointMonitorThread.join();
+ m_UpstreamThreads.clear();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>;
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::lock_guard _(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamCacheOptions m_Options;
+ ZenCacheStore& m_CacheStore;
+ CidStore& m_CidStore;
+ UpstreamQueue m_UpstreamQueue;
+ std::shared_mutex m_EndpointsMutex;
+ std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints;
+ std::vector<std::thread> m_UpstreamThreads;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options)
+{
+ return std::make_unique<detail::ZenUpstreamEndpoint>(Options);
+}
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+{
+ return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr);
+}
+
+std::unique_ptr<UpstreamCache>
+UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore);
+}
+
+} // namespace zen