diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 10:01:47 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 10:01:47 +0200 |
| commit | 075d17f8ada47e990fe94606c3d21df409223465 (patch) | |
| tree | e50549b766a2f3c354798a54ff73404217b4c9af /src/zenserver/upstream/upstreamcache.cpp | |
| parent | fix: bundle shouldn't append content zip to zen (diff) | |
| download | zen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz zen-075d17f8ada47e990fe94606c3d21df409223465.zip | |
moved source directories into `/src` (#264)
* moved source directories into `/src`
* updated bundle.lua for new `src` path
* moved some docs, icon
* removed old test trees
Diffstat (limited to 'src/zenserver/upstream/upstreamcache.cpp')
| -rw-r--r-- | src/zenserver/upstream/upstreamcache.cpp | 2112 |
1 files changed, 2112 insertions, 0 deletions
diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp new file mode 100644 index 000000000..e838b5fe2 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.cpp @@ -0,0 +1,2112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamcache.h" +#include "jupiter.h" +#include "zen.h" + +#include <zencore/blockingqueue.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/stats.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpshared.h> + +#include <zenstore/cidstore.h> + +#include <auth/authmgr.h> +#include "cache/structuredcache.h" +#include "cache/structuredcachestore.h" +#include "diag/logging.h" + +#include <fmt/format.h> + +#include <algorithm> +#include <atomic> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +namespace zen { + +using namespace std::literals; + +namespace detail { + + class UpstreamStatus + { + public: + UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); } + + UpstreamEndpointStatus EndpointStatus() const + { + const UpstreamEndpointState State = EndpointState(); + { + std::unique_lock _(m_Mutex); + return {.Reason = m_ErrorText, .State = State}; + } + } + + void Set(UpstreamEndpointState NewState) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText.clear(); + } + } + + void Set(UpstreamEndpointState NewState, std::string ErrorText) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText = std::move(ErrorText); + } + } + + void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText) + { + if (ErrorCode != 0) + { + Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText)); + } + } + + private: + mutable std::mutex m_Mutex; + std::string m_ErrorText; + std::atomic_uint32_t m_State; + }; + + class JupiterUpstreamEndpoint final : public UpstreamEndpoint + { + public: + JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) + : m_AuthMgr(Mgr) + , m_Log(zen::logging::Get("upstream")) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + m_Info.Url = Options.ServiceUrl; + + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (AuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials( + {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret}); + } + else if (AuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(Options, std::move(TokenProvider)); + } + + virtual ~JupiterUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + CloudCacheSession Session(m_Client); + const CloudCacheResult Result = Session.Authenticate(); + + if (Result.Success) + { + m_Status.Set(UpstreamEndpointState::kOk); + } + else if (Result.ErrorCode != 0) + { + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + } + else + { + m_Status.Set(UpstreamEndpointState::kUnauthorized); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultDdcNamespace(); + } + return Namespace; + } + + std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultBlobStoreNamespace(); + } + return Namespace; + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord"); + + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + + if (Type == ZenContentType::kCompressedBinary) + { + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + + if (Result.Success) + { + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + IoBuffer ContentBuffer; + int NumAttachments = 0; + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize)) + { + Result.Response = AttachmentResult.Response; + ++NumAttachments; + } + else + { + Result.Success = false; + } + }); + if (NumAttachments != 1) + { + Result.Success = false; + } + } + } + } + else + { + const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type; + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType); + + if (Result.Success && Type == ZenContentType::kCbPackage) + { + CbPackage Package; + + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash())); + } + else + { + Result.Success = false; + } + }); + + Package.SetObject(CacheRecord); + } + + if (Result.Success) + { + BinaryWriter MemStream; + Package.Save(MemStream); + + Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + } + } + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheKeyRequest* Request : Requests) + { + const CacheKey& CacheKey = Request->Key; + CbPackage Package; + CbObject Record; + + double ElapsedSeconds = 0.0; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + CloudCacheResult RefResult = + Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + AppendResult(RefResult, Result); + ElapsedSeconds = RefResult.ElapsedSeconds; + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (RefResult.ErrorCode == 0) + { + const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All); + if (ValidationResult == CbValidateError::None) + { + Record = LoadCompactBinaryObject(RefResult.Response); + Record.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (BlobResult.ErrorCode == 0) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize)) + { + if (RawHash == AttachmentHash.AsHash()) + { + Package.AddAttachment(CbAttachment(Chunk, RawHash)); + } + } + } + }); + } + } + } + + OnComplete( + {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info}); + } + + return Result; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey&, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk"); + + try + { + CloudCacheSession Session(m_Client); + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + CacheChunkRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult BlobResult = + Request.ChunkId == IoHash::Zero + ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId) + : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload && IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize); + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + CacheValueRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + IoHash PayloadHash; + const CloudCacheResult BlobResult = + Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload) + { + if (IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash; + } + else + { + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload)); + RawHash = Compressed.DecodeRawHash(); + if (RawHash == PayloadHash) + { + IsCompressed = true; + } + else + { + ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header", + Namespace, + Request.Key.Bucket, + Request.Key.Hash.ToHexString(), + RawHash.ToHexString(), + PayloadHash.ToHexString()); + } + } + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + CloudCacheSession Session(m_Client); + + if (CacheRecord.Type == ZenContentType::kBinary) + { + CloudCacheResult Result; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace); + Result = Session.PutRef(BlobStoreNamespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + ZenContentType::kBinary); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + return {.Reason = std::move(Result.Reason), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Success = Result.Success}; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize)) + { + return {.Reason = std::string("Invalid compressed value buffer"), .Success = false}; + } + + CbObjectWriter ReferencingObject; + ReferencingObject.AddBinaryAttachment("RawHash", RawHash); + ReferencingObject.AddInteger("RawSize", RawSize); + + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + ReferencingObject.Save().GetBuffer().AsIoBuffer(), + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + if (ValueContentId != RawHash) + { + OutReason = + fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash); + return false; + } + + OutBuffer = RecordValue; + return true; + }); + } + else + { + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + RecordValue, + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + const auto It = + std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId); + + if (It == std::end(CacheRecord.ValueContentIds)) + { + OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId); + return false; + } + + const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It); + + OutBuffer = Values[Idx]; + return true; + }); + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out) + { + Out.Success &= Result.Success; + Out.Bytes += Result.Bytes; + Out.ElapsedSeconds += Result.ElapsedSeconds; + + if (Result.ErrorCode) + { + Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}; + } + }; + + PutUpstreamCacheResult PerformStructuredPut( + CloudCacheSession& Session, + std::string_view Namespace, + const CacheKey& Key, + IoBuffer ObjectBuffer, + const int32_t MaxAttempts, + std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn) + { + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool { + for (const IoHash& ValueContentId : ValueContentIds) + { + IoBuffer BlobBuffer; + if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason)) + { + return false; + } + + CloudCacheResult BlobResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++) + { + BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer); + } + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (!BlobResult.Success) + { + OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason); + return false; + } + + TotalBytes += BlobResult.Bytes; + TotalElapsedSeconds += BlobResult.ElapsedSeconds; + } + + return true; + }; + + PutRefResult RefResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++) + { + RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject); + } + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (!RefResult.Success) + { + return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason), + .Success = false}; + } + + TotalBytes += RefResult.Bytes; + TotalElapsedSeconds += RefResult.ElapsedSeconds; + + std::string Reason; + if (!PutBlobs(RefResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer); + FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return { + .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + if (!PutBlobs(FinalizeResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + ExtendableStringBuilder<256> Sb; + for (const IoHash& MissingHash : FinalizeResult.Needs) + { + Sb << MissingHash.ToHexString() << ","; + } + + return { + .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()), + .Success = false}; + } + } + + TotalBytes += FinalizeResult.Bytes; + TotalElapsedSeconds += FinalizeResult.ElapsedSeconds; + + return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true}; + } + + spdlog::logger& Log() { return m_Log; } + + AuthMgr& m_AuthMgr; + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + RefPtr<CloudCacheClient> m_Client; + }; + + class ZenUpstreamEndpoint final : public UpstreamEndpoint + { + struct ZenEndpoint + { + std::string Url; + std::string Reason; + double Latency{}; + bool Ok = false; + + bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; } + }; + + public: + ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options) + : m_Log(zen::logging::Get("upstream")) + , m_ConnectTimeout(Options.ConnectTimeout) + , m_Timeout(Options.Timeout) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + + for (const auto& Url : Options.Urls) + { + m_Endpoints.push_back({.Url = Url}); + } + } + + ~ZenUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + const ZenEndpoint& Ep = GetEndpoint(); + + if (m_Info.Url != Ep.Url) + { + ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url); + m_Info.Url = Ep.Url; + } + + if (Ep.Ok) + { + RwLock::ExclusiveLockScope _(m_ClientLock); + m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout}); + m_Status.Set(UpstreamEndpointState::kOk); + } + else + { + m_Status.Set(UpstreamEndpointState::kError, Ep.Reason); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords"); + ZEN_ASSERT(Requests.size() > 0); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheRecords"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy(); + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy); + + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + for (CacheKeyRequest* Request : Requests) + { + BatchRequest.BeginObject(); + { + const CacheKey& Key = Request->Key; + BatchRequest.BeginObject("Key"sv); + { + BatchRequest << "Bucket"sv << Key.Bucket; + BatchRequest << "Hash"sv << Key.Hash; + } + BatchRequest.EndObject(); + if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy) + { + BatchRequest.SetName("Policy"sv); + Request->Policy.Save(BatchRequest); + } + } + BatchRequest.EndObject(); + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (Results.Num() != Requests.size()) + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream."); + } + else + { + for (size_t Index = 0; CbFieldView Record : Results) + { + CacheKeyRequest* Request = Requests[Index++]; + OnComplete({.Request = *Request, + .Record = Record.AsObjectView(), + .Package = BatchResponse, + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream."); + } + } + + for (CacheKeyRequest* Request : Requests) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues"); + ZEN_ASSERT(!CacheValueRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheValues"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + { + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + const CacheValueRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheValueRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheValueRequest& Request = *CacheValueRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream."); + } + } + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks"); + ZEN_ASSERT(!CacheChunkRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheChunks"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("ChunkRequests"sv); + { + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + const CacheChunkRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.ValueId) + { + BatchRequest.AddObjectId("ValueId"sv, Request.ValueId); + } + if (Request.ChunkId != Request.ChunkId.Zero) + { + BatchRequest << "ChunkId"sv << Request.ChunkId; + } + if (Request.RawOffset != 0) + { + BatchRequest << "RawOffset"sv << Request.RawOffset; + } + if (Request.RawSize != UINT64_MAX) + { + BatchRequest << "RawSize"sv << Request.RawSize; + } + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheChunkRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream."); + } + } + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + ZenCacheResult Result; + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + if (CacheRecord.Type == ZenContentType::kCbPackage) + { + CbPackage Package; + Package.SetObject(CbObject(SharedBuffer(RecordValue))); + + for (const IoBuffer& Value : Values) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash)); + } + else + { + return {.Reason = std::string("Invalid value buffer"), .Success = false}; + } + } + + BinaryWriter MemStream; + Package.Save(MemStream); + IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size()); + + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + PackagePayload, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes = Result.Bytes; + TotalElapsedSeconds = Result.ElapsedSeconds; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize); + if (!Compressed) + { + return {.Reason = std::string("Invalid value compressed buffer"), .Success = false}; + } + + CbPackage BatchPackage; + CbObjectWriter BatchWriter; + BatchWriter << "Method"sv + << "PutCacheValues"sv; + BatchWriter << "Accept"sv << kCbPkgMagic; + + BatchWriter.BeginObject("Params"sv); + { + // DefaultPolicy unspecified and expected to be Default + + BatchWriter << "Namespace"sv << CacheRecord.Namespace; + + BatchWriter.BeginArray("Requests"sv); + { + BatchWriter.BeginObject(); + { + const CacheKey& Key = CacheRecord.Key; + BatchWriter.BeginObject("Key"sv); + { + BatchWriter << "Bucket"sv << Key.Bucket; + BatchWriter << "Hash"sv << Key.Hash; + } + BatchWriter.EndObject(); + // Policy unspecified and expected to be Default + BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash); + BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash)); + } + BatchWriter.EndObject(); + } + BatchWriter.EndArray(); + } + BatchWriter.EndObject(); + BatchPackage.SetObject(BatchWriter.Save()); + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.InvokeRpc(BatchPackage); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + else + { + for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++) + { + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheValue(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + CacheRecord.ValueContentIds[Idx], + Values[Idx]); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + + if (!Result.Success) + { + return {.Reason = "Failed to upload value", + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = false}; + } + } + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + + return {.Reason = std::move(Result.Reason), + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = Result.Success}; + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + Ref<ZenStructuredCacheClient> GetClientRef() + { + // m_Client can be modified at any time by a different thread. + // Make sure we safely bump the refcount inside a scope lock + RwLock::SharedLockScope _(m_ClientLock); + ZEN_ASSERT(m_Client); + Ref<ZenStructuredCacheClient> ClientRef(m_Client); + _.ReleaseNow(); + return ClientRef; + } + + const ZenEndpoint& GetEndpoint() + { + for (ZenEndpoint& Ep : m_Endpoints) + { + Ref<ZenStructuredCacheClient> Client( + new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)})); + ZenStructuredCacheSession Session(std::move(Client)); + const int32_t SampleCount = 2; + + Ep.Ok = false; + Ep.Latency = {}; + + for (int32_t Sample = 0; Sample < SampleCount; ++Sample) + { + ZenCacheResult Result = Session.CheckHealth(); + Ep.Ok = Result.Success; + Ep.Reason = std::move(Result.Reason); + Ep.Latency += Result.ElapsedSeconds; + } + Ep.Latency /= double(SampleCount); + } + + std::sort(std::begin(m_Endpoints), std::end(m_Endpoints)); + + for (const auto& Ep : m_Endpoints) + { + ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason); + } + + return m_Endpoints.front(); + } + + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + std::vector<ZenEndpoint> m_Endpoints; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + RwLock m_ClientLock; + RefPtr<ZenStructuredCacheClient> m_Client; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +class UpstreamCacheImpl final : public UpstreamCache +{ +public: + UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) + : m_Log(logging::Get("upstream")) + , m_Options(Options) + , m_CacheStore(CacheStore) + , m_CidStore(CidStore) + { + } + + virtual ~UpstreamCacheImpl() { Shutdown(); } + + virtual void Initialize() override + { + for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++) + { + m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this); + } + + m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this); + m_RunState.IsRunning = true; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override + { + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + else + { + ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + + // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s) + std::unique_lock<std::shared_mutex> _(m_EndpointsMutex); + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Ep : m_Endpoints) + { + if (!Fn(*Ep)) + { + break; + } + } + } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::GetCacheRecord"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheRecords"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheKeyRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + + Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheKeyRequest* Request : RemainingKeys) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + } + + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheChunks"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheChunkRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheChunkGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheChunkRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::GetCacheChunk"); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheValues"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheValueRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheValueGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheValueRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override + { + if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0) + { + if (!m_UpstreamThreads.empty()) + { + m_UpstreamQueue.Enqueue(std::move(CacheRecord)); + } + else + { + ProcessCacheRecord(std::move(CacheRecord)); + } + } + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "reading" << m_Options.ReadUpstream; + Status << "writing" << m_Options.WriteUpstream; + Status << "worker_threads" << m_Options.ThreadCount; + Status << "queue_count" << m_UpstreamQueue.Size(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo(); + const UpstreamEndpointStatus EpStatus = Ep->GetStatus(); + UpstreamEndpointStats& EpStats = Ep->Stats(); + + Status.BeginObject(); + Status << "name" << EpInfo.Name; + Status << "url" << EpInfo.Url; + Status << "state" << ToString(EpStatus.State); + Status << "reason" << EpStatus.Reason; + + Status.BeginObject("cache"sv); + { + const int64_t GetCount = EpStats.CacheGetCount.Value(); + const int64_t HitCount = EpStats.CacheHitCount.Value(); + const int64_t ErrorCount = EpStats.CacheErrorCount.Value(); + const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0; + const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0; + + metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status); + Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value(); + Status << "get_count" << GetCount; + Status << "hit_count" << HitCount; + Status << "hit_ratio" << HitRatio; + Status << "error_count" << ErrorCount; + Status << "error_ratio" << ErrorRatio; + metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status); + Status << "put_bytes" << EpStats.CachePutTotalBytes.Value(); + } + Status.EndObject(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + void ProcessCacheRecord(UpstreamCacheRecord CacheRecord) + { + ZEN_TRACE_CPU("Upstream::ProcessCacheRecord"); + + ZenCacheValue CacheValue; + std::vector<IoBuffer> Payloads; + + if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue)) + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash); + return; + } + + for (const IoHash& ValueContentId : CacheRecord.ValueContentIds) + { + if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId)) + { + Payloads.push_back(Payload); + } + else + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + ValueContentId); + return; + } + } + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + PutUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming); + Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads)); + } + + Stats.CachePutTotalBytes.Increment(Result.Bytes); + + if (!Result.Success) + { + ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Endpoint->GetEndpointInfo().Url, + Result.Reason); + } + } + } + + void ProcessUpstreamQueue() + { + for (;;) + { + UpstreamCacheRecord CacheRecord; + if (m_UpstreamQueue.WaitAndDequeue(CacheRecord)) + { + try + { + ProcessCacheRecord(std::move(CacheRecord)); + } + catch (std::exception& Err) + { + ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Err.what()); + } + } + + if (!m_RunState.IsRunning) + { + break; + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock lk(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + try + { + std::vector<UpstreamEndpoint*> Endpoints; + + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + UpstreamEndpointState State = Endpoint->GetState(); + if (State == UpstreamEndpointState::kError) + { + Endpoints.push_back(Endpoint.get()); + ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'", + Endpoint->GetEndpointInfo().Name, + Endpoint->GetEndpointInfo().Url, + Endpoint->GetStatus().Reason); + } + if (State == UpstreamEndpointState::kUnauthorized) + { + Endpoints.push_back(Endpoint.get()); + } + } + } + + for (auto& Endpoint : Endpoints) + { + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url); + } + else + { + const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason); + ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason); + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what()); + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_UpstreamQueue.CompleteAdding(); + for (std::thread& Thread : m_UpstreamThreads) + { + Thread.join(); + } + + m_EndpointMonitorThread.join(); + m_UpstreamThreads.clear(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>; + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::lock_guard _(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamCacheOptions m_Options; + ZenCacheStore& m_CacheStore; + CidStore& m_CidStore; + UpstreamQueue m_UpstreamQueue; + std::shared_mutex m_EndpointsMutex; + std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints; + std::vector<std::thread> m_UpstreamThreads; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options) +{ + return std::make_unique<detail::ZenUpstreamEndpoint>(Options); +} + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) +{ + return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr); +} + +std::unique_ptr<UpstreamCache> +UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) +{ + return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore); +} + +} // namespace zen |