diff options
Diffstat (limited to 'src/zenhttp/httpclient.cpp')
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 192 |
1 files changed, 166 insertions, 26 deletions
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 1cfddb366..13c86e9ae 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -36,9 +36,43 @@ namespace zen { +#if ZEN_WITH_CPR extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); +#endif + +extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction); + +static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCurl; + +void +SetDefaultHttpClientBackend(HttpClientBackend Backend) +{ + g_DefaultHttpClientBackend = Backend; +} + +void +SetDefaultHttpClientBackend(std::string_view Backend) +{ +#if ZEN_WITH_CPR + if (Backend == "cpr") + { + g_DefaultHttpClientBackend = HttpClientBackend::kCpr; + } + else +#endif + if (Backend == "curl") + { + g_DefaultHttpClientBackend = HttpClientBackend::kCurl; + } + else + { + g_DefaultHttpClientBackend = HttpClientBackend::kDefault; + } +} using namespace std::literals; @@ -70,15 +104,10 @@ bool HttpClientBase::Authenticate() { ZEN_TRACE_CPU("HttpClientBase::Authenticate"); - std::optional<HttpClientAccessToken> Token = GetAccessToken(); - if (!Token) - { - return false; - } - return Token->IsValid(); + return GetAccessToken().has_value(); } -const std::optional<HttpClientAccessToken> +std::optional<std::string> HttpClientBase::GetAccessToken() { ZEN_TRACE_CPU("HttpClientBase::GetAccessToken"); @@ -88,18 +117,104 @@ HttpClientBase::GetAccessToken() } { RwLock::SharedLockScope _(m_AccessTokenLock); - if (m_CachedAccessToken.IsValid()) + if (!m_CachedAccessToken.NeedsRefresh()) { - return m_CachedAccessToken; + return m_CachedAccessToken.GetValue(); } } RwLock::ExclusiveLockScope _(m_AccessTokenLock); - if (m_CachedAccessToken.IsValid()) + if (!m_CachedAccessToken.NeedsRefresh()) + { + return m_CachedAccessToken.GetValue(); + } + HttpClientAccessToken NewAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); + if (!NewAccessToken.IsValid()) + { + ZEN_WARN("HttpClient failed to refresh access token, retrying once"); + NewAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); + } + if (NewAccessToken.IsValid()) + { + m_CachedAccessToken = NewAccessToken; + } + else + { + if (m_CachedAccessToken.HasExpired()) + { + ZEN_WARN("HttpClient refreshed access token is not valid, clearing the cached token as it has expired"); + m_CachedAccessToken = {}; + } + else + { + ZEN_WARN("HttpClient refreshed access token is not valid, keeping existing token, it will expire soon"); + } + } + return m_CachedAccessToken.GetValue(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if (m_Error != HttpClientErrorCode::kOK) + { + switch (m_Error) + { + case HttpClientErrorCode::kConnectionFailure: + return ResponseClass::kHttpCantConnectError; + case HttpClientErrorCode::kHostResolutionFailure: + case HttpClientErrorCode::kProxyResolutionFailure: + return ResponseClass::kHttpNoHost; + case HttpClientErrorCode::kInternalError: + case HttpClientErrorCode::kNetworkReceiveError: + case HttpClientErrorCode::kNetworkSendFailure: + case HttpClientErrorCode::kOperationTimedOut: + return ResponseClass::kHttpTimeout; + case HttpClientErrorCode::kSSLConnectError: + case HttpClientErrorCode::kSSLCertificateError: + case HttpClientErrorCode::kSSLCACertError: + case HttpClientErrorCode::kGenericSSLError: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) { - return m_CachedAccessToken; + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } } - m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); - return m_CachedAccessToken; } ////////////////////////////////////////////////////////////////////////// @@ -107,17 +222,14 @@ HttpClientBase::GetAccessToken() std::vector<std::pair<uint64_t, uint64_t>> HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const { - std::vector<std::pair<uint64_t, uint64_t>> Result; - Result.reserve(OffsetAndLengthPairs.size()); if (Ranges.empty()) { - for (const std::pair<uint64_t, uint64_t>& Range : OffsetAndLengthPairs) - { - Result.emplace_back(std::make_pair(Range.first, Range.second)); - } - return Result; + return {}; } + std::vector<std::pair<uint64_t, uint64_t>> Result; + Result.reserve(OffsetAndLengthPairs.size()); + auto BoundaryIt = Ranges.begin(); auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin(); while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end()) @@ -225,7 +337,11 @@ HttpClient::Response::ErrorMessage(std::string_view Prefix) const { if (Error.has_value()) { - return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage); + return fmt::format("{}{}HTTP error ({}) '{}'", + Prefix, + Prefix.empty() ? ""sv : ": "sv, + static_cast<int>(Error->ErrorCode), + Error->ErrorMessage); } else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode) { @@ -248,19 +364,36 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix) { if (!IsSuccess()) { - throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode); + throw HttpClientError(ErrorMessage(ErrorPrefix), + Error.has_value() ? Error.value().ErrorCode : HttpClientErrorCode::kOK, + StatusCode); } } ////////////////////////////////////////////////////////////////////////// HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) -: m_BaseUri(BaseUri) +: m_Log(zen::logging::Get(ConnectionSettings.LogCategory)) +, m_BaseUri(BaseUri) , m_ConnectionSettings(ConnectionSettings) { m_SessionId = GetSessionIdString(); - m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + HttpClientBackend EffectiveBackend = + ConnectionSettings.Backend != HttpClientBackend::kDefault ? ConnectionSettings.Backend : g_DefaultHttpClientBackend; + + switch (EffectiveBackend) + { +#if ZEN_WITH_CPR + case HttpClientBackend::kCpr: + m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; +#endif + case HttpClientBackend::kCurl: + default: + m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; + } } HttpClient::~HttpClient() @@ -330,9 +463,12 @@ HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType C } HttpClient::Response -HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, + CbObject Payload, + const HttpClient::KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { - return m_Inner->Post(Url, Payload, AdditionalHeader); + return m_Inner->Post(Url, Payload, AdditionalHeader, TempFolderPath); } HttpClient::Response @@ -430,6 +566,8 @@ MeasureLatency(HttpClient& Client, std::string_view Url) #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpclient"); + TEST_CASE("responseformat") { using namespace std::literals; @@ -839,6 +977,8 @@ TEST_CASE("httpclient.password") AsioServer->RequestExit(); } } +TEST_SUITE_END(); + void httpclient_forcelink() { |