diff options
| author | Stefan Boberg <[email protected]> | 2026-03-15 20:42:36 +0100 |
|---|---|---|
| committer | Stefan Boberg <[email protected]> | 2026-03-15 20:42:36 +0100 |
| commit | 9c724efbf6b38466a9b6bfde37236369f1e85cb8 (patch) | |
| tree | 214e1ec00c5bfca0704ce52789017ade734fd054 /src/zenhttp | |
| parent | reduced WaitForThreads time to see how it behaves with explicit thread pools (diff) | |
| parent | add buildid updates to oplog and builds test scripts (#838) (diff) | |
| download | zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.tar.xz zen-9c724efbf6b38466a9b6bfde37236369f1e85cb8.zip | |
Merge remote-tracking branch 'origin/main' into sb/threadpool
Diffstat (limited to 'src/zenhttp')
47 files changed, 10274 insertions, 879 deletions
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp index 38e7586ad..23bbc17e8 100644 --- a/src/zenhttp/auth/oidc.cpp +++ b/src/zenhttp/auth/oidc.cpp @@ -32,6 +32,25 @@ namespace details { using namespace std::literals; +static std::string +FormUrlEncode(std::string_view Input) +{ + std::string Result; + Result.reserve(Input.size()); + for (char C : Input) + { + if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~') + { + Result.push_back(C); + } + else + { + Result.append(fmt::format("%{:02X}", static_cast<uint8_t>(C))); + } + } + return Result; +} + OidcClient::OidcClient(const OidcClient::Options& Options) { m_BaseUrl = std::string(Options.BaseUrl); @@ -67,6 +86,8 @@ OidcClient::Initialize() .TokenEndpoint = Json["token_endpoint"].string_value(), .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .EndSessionEndpoint = Json["end_session_endpoint"].string_value(), + .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(), .JwksUri = Json["jwks_uri"].string_value(), .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), @@ -81,7 +102,8 @@ OidcClient::Initialize() OidcClient::RefreshTokenResult OidcClient::RefreshToken(std::string_view RefreshToken) { - const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + const std::string Body = + fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId)); HttpClient Http{m_Config.TokenEndpoint}; diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 47425e014..e4d11547a 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -142,7 +142,10 @@ namespace detail { DataSize -= CopySize; if (m_CacheBufferOffset == CacheBufferSize) { - AppendData(m_CacheBuffer, CacheBufferSize); + if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize)) + { + return Ec; + } if (DataSize > 0) { ZEN_ASSERT(DataSize < CacheBufferSize); @@ -382,6 +385,177 @@ namespace detail { return Result; } + MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {} + + bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue) + { + std::string LowerCaseValue = ToLower(ContentTypeHeaderValue); + if (LowerCaseValue.starts_with("multipart/byteranges")) + { + size_t BoundaryPos = LowerCaseValue.find("boundary="); + if (BoundaryPos != std::string::npos) + { + // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string + std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9); + size_t BoundaryEnd = std::string::npos; + while (!BoundaryName.empty() && BoundaryName[0] == ' ') + { + BoundaryName = BoundaryName.substr(1); + } + if (!BoundaryName.empty()) + { + if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') + { + BoundaryEnd = BoundaryName.find('"', 1); + if (BoundaryEnd != std::string::npos) + { + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); + return true; + } + } + else + { + BoundaryEnd = BoundaryName.find_first_of(" \r\n"); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + return true; + } + } + } + } + return false; + } + + void MultipartBoundaryParser::ParseInput(std::string_view data) + { + const char* InputPtr = data.data(); + size_t InputLength = data.length(); + size_t ScanPos = 0; + while (ScanPos < InputLength) + { + const char ScanChar = InputPtr[ScanPos]; + if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length())) + { + BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + break; + } + } + + BoundaryHeader.Append(ScanChar); + + HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + + if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset(); + const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset(); + const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset; + std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength)); + + uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1; + + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType = HttpContentType::kBinary; + + ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) { + const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line); + const std::string_view Key = KeyAndValue.first; + const std::string_view Value = KeyAndValue.second; + if (Key == "Content-Range") + { + std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value); + if (ContentRange.second != 0) + { + RangeOffset = ContentRange.first; + RangeLength = ContentRange.second; + } + } + else if (Key == "Content-Type") + { + ContentType = ParseContentType(Value); + } + + return true; + }); + + if (RangeLength > 0) + { + Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload, + .RangeOffset = RangeOffset, + .RangeLength = RangeLength, + .ContentType = ContentType}); + } + + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + } + } + else + { + BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar); + } + ScanPos++; + } + PayloadOffset += InputLength; + } + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString) + { + size_t DelimiterPos = HeaderString.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string_view Key = HeaderString.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string_view Value = HeaderString.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + return std::make_pair(Key, Value); + } + return std::make_pair(HeaderString, std::string_view{}); + } + + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value) + { + if (Value.starts_with("bytes ")) + { + size_t RangeSplitPos = Value.find('-', 6); + if (RangeSplitPos != std::string::npos) + { + size_t RangeEndLength = Value.find('/', RangeSplitPos + 1); + if (RangeEndLength == std::string::npos) + { + RangeEndLength = Value.length() - (RangeSplitPos + 1); + } + else + { + RangeEndLength = RangeEndLength - (RangeSplitPos + 1); + } + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + uint64_t RangeOffset = RequestedRangeStart.value(); + uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1; + return std::make_pair(RangeOffset, RangeLength); + } + } + } + return {0, 0}; + } + } // namespace detail } // namespace zen @@ -423,6 +597,8 @@ namespace testutil { } // namespace testutil +TEST_SUITE_BEGIN("http.httpclientcommon"); + TEST_CASE("BufferedReadFileStream") { ScopedTemporaryDirectory TmpDir; @@ -470,5 +646,207 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("ParseContentRange") +{ + SUBCASE("normal range with total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 0-99/500"); + CHECK_EQ(Offset, 0); + CHECK_EQ(Length, 100); + } + + SUBCASE("non-zero offset") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 2638-5111437/44369878"); + CHECK_EQ(Offset, 2638); + CHECK_EQ(Length, 5111437 - 2638 + 1); + } + + SUBCASE("wildcard total size") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 100-199/*"); + CHECK_EQ(Offset, 100); + CHECK_EQ(Length, 100); + } + + SUBCASE("no slash (total size omitted)") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 50-149"); + CHECK_EQ(Offset, 50); + CHECK_EQ(Length, 100); + } + + SUBCASE("malformed input returns zeros") + { + auto [Offset1, Length1] = detail::ParseContentRange("not-bytes 0-99/500"); + CHECK_EQ(Offset1, 0); + CHECK_EQ(Length1, 0); + + auto [Offset2, Length2] = detail::ParseContentRange("bytes abc-def/500"); + CHECK_EQ(Offset2, 0); + CHECK_EQ(Length2, 0); + + auto [Offset3, Length3] = detail::ParseContentRange(""); + CHECK_EQ(Offset3, 0); + CHECK_EQ(Length3, 0); + + auto [Offset4, Length4] = detail::ParseContentRange("bytes 100/500"); + CHECK_EQ(Offset4, 0); + CHECK_EQ(Length4, 0); + } + + SUBCASE("single byte range") + { + auto [Offset, Length] = detail::ParseContentRange("bytes 42-42/1000"); + CHECK_EQ(Offset, 42); + CHECK_EQ(Length, 1); + } +} + +TEST_CASE("MultipartBoundaryParser") +{ + uint64_t Range1Offset = 2638; + uint64_t Range1Length = (5111437 - Range1Offset) + 1; + + uint64_t Range2Offset = 5118199; + uint64_t Range2Length = (9147741 - Range2Offset) + 1; + + std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229"; + std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\""; + + { + std::string_view Example1 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/44369878\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample1; + ParserExample1.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 7; + for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow) + { + ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow))); + } + + CHECK(ParserExample1.Boundaries.size() == 2); + + CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example2 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample2; + ParserExample2.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow) + { + std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow)); + ParserExample2.ParseInput(Window); + } + + CHECK(ParserExample2.Boundaries.size() == 2); + + CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example3 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita"; + + detail::MultipartBoundaryParser ParserExample3; + ParserExample3.Init(ContentTypeHeaderValue2); + + const size_t InputWindow = 31; + for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow) + { + ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow))); + } + + CHECK(ParserExample3.Boundaries.size() == 2); + + CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example4 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "Not: really\r\n" + "\r\n" + "datadatadatadata" + "\r\n--000000000bait0019229\r\n" + "\r\n--00\r\n--000000000bait001922\r\n" + "\r\n\r\n\r\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "Content-Type: application/x-ue-comp\r\n" + "ditaditadita" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n---\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample4; + ParserExample4.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow) + { + std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow)); + ParserExample4.ParseInput(Window); + } + + CHECK(ParserExample4.Boundaries.size() == 2); + + CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length); + } +} + +TEST_SUITE_END(); + } // namespace zen #endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 1d0b7f9ea..e95e3a253 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/compositebuffer.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhttp/httpclient.h> @@ -35,7 +36,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) = 0; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) = 0; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) = 0; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) = 0; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, @@ -87,7 +91,7 @@ namespace detail { std::error_code Write(std::string_view DataString); IoBuffer DetachToIoBuffer(); IoBuffer BorrowIoBuffer(); - inline uint64_t GetSize() const { return m_WriteOffset; } + inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; } void ResetWritePos(uint64_t WriteOffset); private: @@ -143,6 +147,118 @@ namespace detail { uint64_t m_BytesLeftInSegment; }; + class IncrementalStringMatcher + { + public: + enum class EMatchState + { + None, + Partial, + Complete + }; + + EMatchState MatchState = EMatchState::None; + + IncrementalStringMatcher() {} + + IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) + { + RawMatchString = MatchString.data(); + } + + void Init(std::string&& InMatchString) + { + MatchString = std::move(InMatchString); + RawMatchString = MatchString.data(); + } + + inline void Reset() + { + MatchLength = 0; + MatchStartOffset = 0; + MatchState = EMatchState::None; + } + + inline uint64_t GetMatchEndOffset() const + { + if (MatchState == EMatchState::Complete) + { + return MatchStartOffset + MatchString.length(); + } + return 0; + } + + inline uint64_t GetMatchStartOffset() const + { + ZEN_ASSERT(MatchState == EMatchState::Complete); + return MatchStartOffset; + } + + void Match(uint64_t Offset, char C) + { + ZEN_ASSERT_SLOW(RawMatchString != nullptr); + + if (MatchState == EMatchState::Complete) + { + Reset(); + } + if (C == RawMatchString[MatchLength]) + { + if (MatchLength == 0) + { + MatchStartOffset = Offset; + } + MatchLength++; + if (MatchLength == MatchString.length()) + { + MatchState = EMatchState::Complete; + } + else + { + MatchState = EMatchState::Partial; + } + } + else if (MatchLength != 0) + { + Reset(); + Match(Offset, C); + } + else + { + Reset(); + } + } + inline const std::string& GetMatchString() const { return MatchString; } + + private: + std::string MatchString; + const char* RawMatchString = nullptr; + uint64_t MatchLength = 0; + + uint64_t MatchStartOffset = 0; + }; + + class MultipartBoundaryParser + { + public: + std::vector<HttpClient::Response::MultipartBoundary> Boundaries; + + MultipartBoundaryParser(); + bool Init(const std::string_view ContentTypeHeaderValue); + void ParseInput(std::string_view data); + + private: + IncrementalStringMatcher BoundaryBeginMatcher; + IncrementalStringMatcher BoundaryEndMatcher; + IncrementalStringMatcher HeaderEndMatcher; + + ExtendableStringBuilder<64> BoundaryHeader; + uint64_t PayloadOffset = 0; + }; + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString); + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value); + } // namespace detail } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 5d92b3b6b..a52b8f74b 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -7,11 +7,18 @@ #include <zencore/compactbinarypackage.h> #include <zencore/compactbinaryutil.h> #include <zencore/compress.h> +#include <zencore/filesystem.h> #include <zencore/iobuffer.h> #include <zencore/iohash.h> #include <zencore/session.h> #include <zencore/stream.h> #include <zenhttp/packageformat.h> +#include <algorithm> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/ssl_options.h> +#include <cpr/unix_socket.h> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { @@ -23,69 +30,42 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; -// If we want to support different HTTP client implementations then we'll need to make this more abstract +////////////////////////////////////////////////////////////////////////// -HttpClientError::ResponseClass -HttpClientError::GetResponseClass() const +static HttpClientErrorCode +MapCprError(cpr::ErrorCode Code) { - if ((cpr::ErrorCode)m_Error != cpr::ErrorCode::OK) - { - switch ((cpr::ErrorCode)m_Error) - { - case cpr::ErrorCode::CONNECTION_FAILURE: - return ResponseClass::kHttpCantConnectError; - case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: - case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: - return ResponseClass::kHttpNoHost; - case cpr::ErrorCode::INTERNAL_ERROR: - case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: - case cpr::ErrorCode::NETWORK_SEND_FAILURE: - case cpr::ErrorCode::OPERATION_TIMEDOUT: - return ResponseClass::kHttpTimeout; - case cpr::ErrorCode::SSL_CONNECT_ERROR: - case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: - case cpr::ErrorCode::SSL_CACERT_ERROR: - case cpr::ErrorCode::GENERIC_SSL_ERROR: - return ResponseClass::kHttpSLLError; - default: - return ResponseClass::kHttpOtherClientError; - } - } - else if (IsHttpSuccessCode(m_ResponseCode)) + switch (Code) { - return ResponseClass::kSuccess; - } - else - { - switch (m_ResponseCode) - { - case HttpResponseCode::Unauthorized: - return ResponseClass::kHttpUnauthorized; - case HttpResponseCode::NotFound: - return ResponseClass::kHttpNotFound; - case HttpResponseCode::Forbidden: - return ResponseClass::kHttpForbidden; - case HttpResponseCode::Conflict: - return ResponseClass::kHttpConflict; - case HttpResponseCode::InternalServerError: - return ResponseClass::kHttpInternalServerError; - case HttpResponseCode::ServiceUnavailable: - return ResponseClass::kHttpServiceUnavailable; - case HttpResponseCode::BadGateway: - return ResponseClass::kHttpBadGateway; - case HttpResponseCode::GatewayTimeout: - return ResponseClass::kHttpGatewayTimeout; - default: - if (m_ResponseCode >= HttpResponseCode::InternalServerError) - { - return ResponseClass::kHttpOtherServerError; - } - else - { - return ResponseClass::kHttpOtherClientError; - } - } + case cpr::ErrorCode::OK: + return HttpClientErrorCode::kOK; + case cpr::ErrorCode::CONNECTION_FAILURE: + return HttpClientErrorCode::kConnectionFailure; + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + return HttpClientErrorCode::kHostResolutionFailure; + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return HttpClientErrorCode::kProxyResolutionFailure; + case cpr::ErrorCode::INTERNAL_ERROR: + return HttpClientErrorCode::kInternalError; + case cpr::ErrorCode::NETWORK_RECEIVE_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case cpr::ErrorCode::NETWORK_SEND_FAILURE: + return HttpClientErrorCode::kNetworkSendFailure; + case cpr::ErrorCode::OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case cpr::ErrorCode::SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case cpr::ErrorCode::SSL_LOCAL_CERTIFICATE_ERROR: + case cpr::ErrorCode::SSL_REMOTE_CERTIFICATE_ERROR: + return HttpClientErrorCode::kSSLCertificateError; + case cpr::ErrorCode::SSL_CACERT_ERROR: + return HttpClientErrorCode::kSSLCACertError; + case cpr::ErrorCode::GENERIC_SSL_ERROR: + return HttpClientErrorCode::kGenericSSLError; + case cpr::ErrorCode::REQUEST_CANCELLED: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; } } @@ -149,6 +129,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri, { } +bool +CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + // Quiet + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + CprHttpClient::~CprHttpClient() { ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); @@ -162,10 +154,11 @@ CprHttpClient::~CprHttpClient() } HttpClient::Response -CprHttpClient::ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload) +CprHttpClient::ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { // This ends up doing a memcpy, would be good to get rid of it by streaming results // into buffer directly @@ -174,30 +167,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) { const HttpContentType ContentType = ParseContentType(It->second); - ResponseBuffer.SetContentType(ContentType); } - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - - if (!Quiet) + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + if (ShouldLogErrorCode(WorkResponseCode)) { ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); } } + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + .ElapsedSeconds = HttpResponse.elapsed, + .Ranges = std::move(BoundaryPositions)}; } HttpClient::Response -CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload) +CprHttpClient::CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); if (HttpResponse.error) @@ -221,8 +221,8 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), .ElapsedSeconds = HttpResponse.elapsed, - .Error = HttpClient::ErrorContext{.ErrorCode = gsl::narrow<int>(HttpResponse.error.code), - .ErrorMessage = HttpResponse.error.message}}; + .Error = + HttpClient::ErrorContext{.ErrorCode = MapCprError(HttpResponse.error.code), .ErrorMessage = HttpResponse.error.message}}; } if (WorkResponseCode == HttpResponseCode::NoContent || (HttpResponse.text.empty() && !Payload)) @@ -235,7 +235,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe } else { - return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload)); + return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); } } @@ -346,8 +346,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -385,8 +384,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -492,6 +490,10 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, { CprSession->UpdateHeader({{"UE-Session", std::string(SessionId)}}); } + if (ConnectionSettings.ForbidReuseConnection) + { + CprSession->UpdateHeader({{"Connection", "close"}}); + } if (AccessToken) { CprSession->UpdateHeader({{"Authorization", AccessToken->Value}}); @@ -510,6 +512,26 @@ CprHttpClient::AllocSession(const std::string_view BaseUrl, CprSession->SetParameters({}); } + if (!ConnectionSettings.UnixSocketPath.empty()) + { + CprSession->SetUnixSocket(cpr::UnixSocket(PathToUtf8(ConnectionSettings.UnixSocketPath))); + } + + if (ConnectionSettings.InsecureSsl || !ConnectionSettings.CaBundlePath.empty()) + { + cpr::SslOptions SslOpts; + if (ConnectionSettings.InsecureSsl) + { + SslOpts.SetOption(cpr::ssl::VerifyHost{false}); + SslOpts.SetOption(cpr::ssl::VerifyPeer{false}); + } + if (!ConnectionSettings.CaBundlePath.empty()) + { + SslOpts.SetOption(cpr::ssl::CaInfo{ConnectionSettings.CaBundlePath}); + } + CprSession->SetSslOptions(SslOpts); + } + ExtendableStringBuilder<128> UrlBuffer; UrlBuffer << BaseUrl << ResourcePath; CprSession->SetUrl(UrlBuffer.c_str()); @@ -621,7 +643,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke ResponseBuffer.SetContentType(ContentType); } - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)}; } ////////////////////////////////////////////////////////////////////////// @@ -774,22 +796,97 @@ CprHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentTyp } CprHttpClient::Response -CprHttpClient::Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader) +CprHttpClient::Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { ZEN_TRACE_CPU("CprHttpClient::PostObjectPayload"); - return CommonResponse( + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + cpr::Response Response = DoWithRetry( m_SessionId, - DoWithRetry(m_SessionId, - [&]() { - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + [&]() { + PayloadString.clear(); + PayloadFile.reset(); - Sess->SetBody(AsCprBody(Payload)); - Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); - return Sess.Post(); - }), - {}); + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + + Sess->SetBody(AsCprBody(Payload)); + Sess->UpdateHeader({HeaderContentType(ZenContentType::kCbObject)}); + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + cpr::Response Response = Sess.Post({}, cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); } CprHttpClient::Response @@ -896,236 +993,292 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF std::string PayloadString; std::unique_ptr<detail::TempPayloadFile> PayloadFile; - cpr::Response Response = DoWithRetry( - m_SessionId, - [&]() { - auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { - size_t DelimiterPos = header.find(':'); - if (DelimiterPos != std::string::npos) - { - std::string Key = header.substr(0, DelimiterPos); - constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); - Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); - Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); - - std::string Value = header.substr(DelimiterPos + 1); - Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); - Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); - - return std::make_pair(Key, Value); - } - return std::make_pair(header, ""); - }; - - auto DownloadCallback = [&](std::string data, intptr_t) { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return false; - } - if (PayloadFile) - { - ZEN_ASSERT(PayloadString.empty()); - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }; - - uint64_t RequestedContentLength = (uint64_t)-1; - if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) - { - if (RangeIt->second.starts_with("bytes")) - { - size_t RangeStartPos = RangeIt->second.find('=', 5); - if (RangeStartPos != std::string::npos) - { - RangeStartPos++; - size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); - if (RangeSplitPos != std::string::npos) - { - std::optional<size_t> RequestedRangeStart = - ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); - std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); - if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) - { - RequestedContentLength = RequestedRangeEnd.value() - 1; - } - } - } - } - } - - cpr::Response Response; - { - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (Header.first == "Content-Length"sv) - { - std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); - if (ContentLength.has_value()) - { - if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) - { - PayloadFile = std::make_unique<detail::TempPayloadFile>(); - std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - PayloadFile.reset(); - } - } - else - { - PayloadString.reserve(ContentLength.value()); - } - } - } - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - return 1; - }; - - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - } - if (m_ConnectionSettings.AllowResume) - { - auto SupportsRanges = [](const cpr::Response& Response) -> bool { - if (Response.header.find("Content-Range") != Response.header.end()) - { - return true; - } - if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) - { - return It->second == "bytes"sv; - } - return false; - }; - - auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { - if (ShouldRetry(Response)) - { - return SupportsRanges(Response); - } - return false; - }; - - if (ShouldResume(Response)) - { - auto It = Response.header.find("Content-Length"); - if (It != Response.header.end()) - { - uint64_t ContentLength = RequestedContentLength; - if (ContentLength == uint64_t(-1)) - { - if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) - { - ContentLength = ParsedContentLength.value(); - } - } - - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - - if (Header.first == "Content-Range"sv) - { - if (Header.second.starts_with("bytes "sv)) - { - size_t RangeStartEnd = Header.second.find('-', 6); - if (RangeStartEnd != std::string::npos) - { - const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); - if (Start) - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - if (Start.value() == DownloadedSize) - { - return 1; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (PayloadFile) - { - PayloadFile->ResetWritePos(Start.value()); - } - else - { - PayloadString = PayloadString.substr(0, Start.value()); - } - return 1; - } - } - } - return 0; - } - return 1; - }; - - KeyValueMap HeadersWithRange(AdditionalHeader); - do - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - - std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); - if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) - { - if (RangeIt->second == Range) - { - // If we didn't make any progress, abort - break; - } - } - HeadersWithRange.Entries.insert_or_assign("Range", Range); - - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - ReceivedHeaders.clear(); - } while (ShouldResume(Response)); - } - } - } - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - return Response; - }, - PayloadFile); - return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (IsMultiRangeResponse) + { + BoundaryParser.ParseInput(data); + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + if (RequestedContentLength != (uint64_t)-1 && RequestedContentLength > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + ZEN_DEBUG("Multirange request"); + } + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + const std::string Key(Header.first); + if (StrCaseCompare(Key.c_str(), "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (!TempFolderPath.empty() && ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + else if (StrCaseCompare(Key.c_str(), "Content-Type") == 0) + { + IsMultiRangeResponse = BoundaryParser.Init(Header.second); + if (!IsMultiRangeResponse) + { + ContentType = ParseContentType(Header.second); + } + } + else if (StrCaseCompare(Key.c_str(), "Content-Range") == 0) + { + if (!IsMultiRangeResponse) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second); + if (Range.second != 0) + { + BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = ContentType}); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (StrCaseCompare(std::string(Header.first).c_str(), "Content-Range") == 0) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Response), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); } } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 40af53b5d..009e6fb7a 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -38,7 +38,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}) override; - [[nodiscard]] virtual Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; [[nodiscard]] virtual Response Post(std::string_view Url, const CompositeBuffer& Payload, @@ -104,15 +107,27 @@ private: CprSession->SetReadCallback({}); return Result; } - inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}) + inline cpr::Response Post(std::optional<cpr::ReadCallback>&& Read = {}, + std::optional<cpr::WriteCallback>&& Write = {}, + std::optional<cpr::HeaderCallback>&& Header = {}) { ZEN_TRACE_CPU("HttpClient::Impl::Post"); if (Read) { CprSession->SetReadCallback(std::move(Read.value())); } + if (Write) + { + CprSession->SetWriteCallback(std::move(Write.value())); + } + if (Header) + { + CprSession->SetHeaderCallback(std::move(Header.value())); + } cpr::Response Result = CprSession->Post(); ZEN_TRACE("POST {}", Result); + CprSession->SetHeaderCallback({}); + CprSession->SetWriteCallback({}); CprSession->SetReadCallback({}); return Result; } @@ -155,14 +170,19 @@ private: std::function<cpr::Response()>&& Func, std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }); + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); - HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload); + HttpClient::Response CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); - HttpClient::Response ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload); + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); }; } // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.cpp b/src/zenhttp/clients/httpclientcurl.cpp new file mode 100644 index 000000000..ec9b7bac6 --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.cpp @@ -0,0 +1,1816 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpclientcurl.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zenhttp/packageformat.h> +#include <algorithm> + +namespace zen { + +HttpClientBase* +CreateCurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) +{ + return new CurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); +} + +static std::atomic<uint32_t> CurlHttpClientRequestIdCounter{0}; + +////////////////////////////////////////////////////////////////////////// + +static HttpClientErrorCode +MapCurlError(CURLcode Code) +{ + switch (Code) + { + case CURLE_OK: + return HttpClientErrorCode::kOK; + case CURLE_COULDNT_CONNECT: + return HttpClientErrorCode::kConnectionFailure; + case CURLE_COULDNT_RESOLVE_HOST: + return HttpClientErrorCode::kHostResolutionFailure; + case CURLE_COULDNT_RESOLVE_PROXY: + return HttpClientErrorCode::kProxyResolutionFailure; + case CURLE_RECV_ERROR: + return HttpClientErrorCode::kNetworkReceiveError; + case CURLE_SEND_ERROR: + return HttpClientErrorCode::kNetworkSendFailure; + case CURLE_OPERATION_TIMEDOUT: + return HttpClientErrorCode::kOperationTimedOut; + case CURLE_SSL_CONNECT_ERROR: + return HttpClientErrorCode::kSSLConnectError; + case CURLE_SSL_CERTPROBLEM: + return HttpClientErrorCode::kSSLCertificateError; + case CURLE_PEER_FAILED_VERIFICATION: + return HttpClientErrorCode::kSSLCACertError; + case CURLE_SSL_CIPHER: + case CURLE_SSL_ENGINE_NOTFOUND: + case CURLE_SSL_ENGINE_SETFAILED: + return HttpClientErrorCode::kGenericSSLError; + case CURLE_ABORTED_BY_CALLBACK: + return HttpClientErrorCode::kRequestCancelled; + default: + return HttpClientErrorCode::kOtherError; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Curl callback helpers + +struct WriteCallbackData +{ + std::string* Body = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlWriteCallback(char* Ptr, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<WriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; // Signal abort to curl + } + + Data->Body->append(Ptr, TotalBytes); + return TotalBytes; +} + +struct HeaderCallbackData +{ + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; +}; + +// Trims trailing CRLF, splits on the first colon, and trims whitespace from key and value. +// Returns nullopt for blank lines or lines without a colon (e.g. HTTP status lines). +static std::optional<std::pair<std::string_view, std::string_view>> +ParseHeaderLine(std::string_view Line) +{ + while (!Line.empty() && (Line.back() == '\r' || Line.back() == '\n')) + { + Line.remove_suffix(1); + } + + if (Line.empty()) + { + return std::nullopt; + } + + size_t ColonPos = Line.find(':'); + if (ColonPos == std::string_view::npos) + { + return std::nullopt; + } + + std::string_view Key = Line.substr(0, ColonPos); + std::string_view Value = Line.substr(ColonPos + 1); + + while (!Key.empty() && Key.back() == ' ') + { + Key.remove_suffix(1); + } + while (!Value.empty() && Value.front() == ' ') + { + Value.remove_prefix(1); + } + + return std::pair{Key, Value}; +} + +static size_t +CurlHeaderCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<HeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; +} + +struct ReadCallbackData +{ + const uint8_t* DataPtr = nullptr; + size_t DataSize = 0; + size_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<ReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->DataSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + memcpy(Buffer, Data->DataPtr + Data->Offset, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +struct StreamReadCallbackData +{ + detail::CompositeBufferReadStream* Reader = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlStreamReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<StreamReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + return Data->Reader->Read(Buffer, MaxRead); +} + +struct FileReadCallbackData +{ + detail::BufferedReadFileStream* Buffer = nullptr; + uint64_t TotalSize = 0; + uint64_t Offset = 0; + std::function<bool()>* CheckIfAbortFunction = nullptr; +}; + +static size_t +CurlFileReadCallback(char* Buffer, size_t Size, size_t Nmemb, void* UserData) +{ + auto* Data = static_cast<FileReadCallbackData*>(UserData); + size_t MaxRead = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return CURL_READFUNC_ABORT; + } + + size_t Remaining = Data->TotalSize - Data->Offset; + size_t ToRead = std::min(MaxRead, Remaining); + + if (ToRead > 0) + { + Data->Buffer->Read(Buffer, ToRead); + Data->Offset += ToRead; + } + + return ToRead; +} + +static int +CurlDebugCallback(CURL* Handle, curl_infotype Type, char* Data, size_t Size, void* UserPtr) +{ + ZEN_UNUSED(Handle); + LoggerRef LogRef = *static_cast<LoggerRef*>(UserPtr); + auto Log = [&]() -> LoggerRef { return LogRef; }; + + std::string_view DataView(Data, Size); + + // Trim trailing newlines + while (!DataView.empty() && (DataView.back() == '\r' || DataView.back() == '\n')) + { + DataView.remove_suffix(1); + } + + switch (Type) + { + case CURLINFO_TEXT: + if (DataView.find("need more data"sv) == std::string_view::npos) + { + ZEN_INFO("TEXT: {}", DataView); + } + break; + case CURLINFO_HEADER_IN: + ZEN_INFO("HIN : {}", DataView); + break; + case CURLINFO_HEADER_OUT: + if (auto TokenPos = DataView.find("Authorization: Bearer "sv); TokenPos != std::string_view::npos) + { + std::string Copy(DataView); + auto BearerStart = TokenPos + 22; + auto BearerEnd = Copy.find_first_of("\r\n", BearerStart); + if (BearerEnd == std::string::npos) + { + BearerEnd = Copy.length(); + } + Copy.replace(Copy.begin() + BearerStart, Copy.begin() + BearerEnd, fmt::format("[{} char token]", BearerEnd - BearerStart)); + ZEN_INFO("HOUT: {}", Copy); + } + else + { + ZEN_INFO("HOUT: {}", DataView); + } + break; + default: + break; + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +static std::pair<std::string, std::string> +HeaderContentType(ZenContentType ContentType) +{ + return std::make_pair("Content-Type", std::string(MapContentTypeToString(ContentType))); +} + +static curl_slist* +BuildHeaderList(const HttpClient::KeyValueMap& AdditionalHeader, + std::string_view SessionId, + const std::optional<HttpClientAccessToken>& AccessToken, + const std::vector<std::pair<std::string, std::string>>& ExtraHeaders = {}) +{ + curl_slist* Headers = nullptr; + + for (const auto& [Key, Value] : *AdditionalHeader) + { + ExtendableStringBuilder<64> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + if (!SessionId.empty()) + { + ExtendableStringBuilder<64> SessionHeader; + SessionHeader << "UE-Session: " << SessionId; + Headers = curl_slist_append(Headers, SessionHeader.c_str()); + } + + if (AccessToken) + { + ExtendableStringBuilder<128> AuthHeader; + AuthHeader << "Authorization: " << AccessToken->Value; + Headers = curl_slist_append(Headers, AuthHeader.c_str()); + } + + for (const auto& [Key, Value] : ExtraHeaders) + { + ExtendableStringBuilder<128> HeaderLine; + HeaderLine << Key << ": " << Value; + Headers = curl_slist_append(Headers, HeaderLine.c_str()); + } + + return Headers; +} + +static HttpClient::KeyValueMap +BuildHeaderMap(const std::vector<std::pair<std::string, std::string>>& Headers) +{ + HttpClient::KeyValueMap HeaderMap; + for (const auto& [Key, Value] : Headers) + { + HeaderMap->insert_or_assign(Key, Value); + } + return HeaderMap; +} + +// Scans response headers for Content-Type and applies it to the buffer. +static void +ApplyContentTypeFromHeaders(IoBuffer& Buffer, const std::vector<std::pair<std::string, std::string>>& Headers) +{ + for (const auto& [Key, Value] : Headers) + { + if (StrCaseCompare(Key, "Content-Type") == 0) + { + Buffer.SetContentType(ParseContentType(Value)); + break; + } + } +} + +static void +AppendUrlEncoded(StringBuilderBase& Out, std::string_view Input) +{ + static constexpr char HexDigits[] = "0123456789ABCDEF"; + static constexpr AsciiSet Unreserved("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"); + + for (char C : Input) + { + if (Unreserved.Contains(C)) + { + Out.Append(C); + } + else + { + uint8_t Byte = static_cast<uint8_t>(C); + char Encoded[3] = {'%', HexDigits[Byte >> 4], HexDigits[Byte & 0x0F]}; + Out.Append(std::string_view(Encoded, 3)); + } + } +} + +static void +BuildUrlWithParameters(StringBuilderBase& Url, + std::string_view BaseUrl, + std::string_view ResourcePath, + const HttpClient::KeyValueMap& Parameters) +{ + Url.Append(BaseUrl); + Url.Append(ResourcePath); + + if (!Parameters->empty()) + { + char Separator = '?'; + for (const auto& [Key, Value] : *Parameters) + { + Url.Append(Separator); + AppendUrlEncoded(Url, Key); + Url.Append('='); + AppendUrlEncoded(Url, Value); + Separator = '&'; + } + } +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::CurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction) +: HttpClientBase(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)) +{ +} + +CurlHttpClient::~CurlHttpClient() +{ + ZEN_TRACE_CPU("CurlHttpClient::~CurlHttpClient"); + m_SessionLock.WithExclusiveLock([&] { + for (auto* Handle : m_Sessions) + { + curl_easy_cleanup(Handle); + } + m_Sessions.clear(); + }); +} + +CurlHttpClient::Session::~Session() +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + Outer->ReleaseSession(Handle); +} + +void +CurlHttpClient::Session::SetHeaders(curl_slist* Headers) +{ + if (HeaderList) + { + curl_slist_free_all(HeaderList); + } + HeaderList = Headers; + curl_easy_setopt(Handle, CURLOPT_HTTPHEADER, HeaderList); +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::PerformWithResponseCallbacks() +{ + std::string Body; + WriteCallbackData WriteData{.Body = &Body, + .CheckIfAbortFunction = Outer->m_CheckIfAbortFunction ? &Outer->m_CheckIfAbortFunction : nullptr}; + HeaderCallbackData HdrData{}; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + HdrData.Headers = &ResponseHeaders; + + curl_easy_setopt(Handle, CURLOPT_WRITEFUNCTION, CurlWriteCallback); + curl_easy_setopt(Handle, CURLOPT_WRITEDATA, &WriteData); + curl_easy_setopt(Handle, CURLOPT_HEADERFUNCTION, CurlHeaderCallback); + curl_easy_setopt(Handle, CURLOPT_HEADERDATA, &HdrData); + + CurlResult Result = Perform(); + Result.Body = std::move(Body); + Result.Headers = std::move(ResponseHeaders); + + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::Session::Perform() +{ + CurlResult Result; + + char ErrorBuffer[CURL_ERROR_SIZE] = {}; + curl_easy_setopt(Handle, CURLOPT_ERRORBUFFER, ErrorBuffer); + + Result.ErrorCode = curl_easy_perform(Handle); + + if (Result.ErrorCode != CURLE_OK) + { + Result.ErrorMessage = ErrorBuffer[0] ? std::string(ErrorBuffer) : curl_easy_strerror(Result.ErrorCode); + } + + curl_easy_getinfo(Handle, CURLINFO_RESPONSE_CODE, &Result.StatusCode); + + double Elapsed = 0; + curl_easy_getinfo(Handle, CURLINFO_TOTAL_TIME, &Elapsed); + Result.ElapsedSeconds = Elapsed; + + curl_off_t UpBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_UPLOAD_T, &UpBytes); + Result.UploadedBytes = static_cast<int64_t>(UpBytes); + + curl_off_t DownBytes = 0; + curl_easy_getinfo(Handle, CURLINFO_SIZE_DOWNLOAD_T, &DownBytes); + Result.DownloadedBytes = static_cast<int64_t>(DownBytes); + + return Result; +} + +bool +CurlHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + +HttpClient::Response +CurlHttpClient::ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + IoBuffer ResponseBuffer = Payload ? std::move(Payload) : IoBuffer(IoBuffer::Clone, Result.Body.data(), Result.Body.size()); + + ApplyContentTypeFromHeaders(ResponseBuffer, Result.Headers); + + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + { + if (ShouldLogErrorCode(WorkResponseCode)) + { + ZEN_WARN("HttpClient request failed (session: {}): status={}, url={}", + SessionId, + static_cast<int>(WorkResponseCode), + m_BaseUri); + } + } + + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + + return HttpClient::Response{.StatusCode = WorkResponseCode, + .ResponsePayload = std::move(ResponseBuffer), + .Header = BuildHeaderMap(Result.Headers), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Ranges = std::move(BoundaryPositions)}; +} + +HttpClient::Response +CurlHttpClient::CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) +{ + const HttpResponseCode WorkResponseCode = HttpResponseCode(Result.StatusCode); + if (Result.ErrorCode != CURLE_OK) + { + const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); + if (!Quiet) + { + if (Result.ErrorCode != CURLE_OPERATION_TIMEDOUT && Result.ErrorCode != CURLE_COULDNT_CONNECT && + Result.ErrorCode != CURLE_ABORTED_BY_CALLBACK) + { + ZEN_WARN("HttpClient client failure (session: {}): ({}) '{}'", + SessionId, + static_cast<int>(Result.ErrorCode), + Result.ErrorMessage); + } + } + + return HttpClient::Response{ + .StatusCode = WorkResponseCode, + .ResponsePayload = IoBufferBuilder::MakeCloneFromMemory(Result.Body.data(), Result.Body.size()), + .Header = BuildHeaderMap(Result.Headers), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Error = HttpClient::ErrorContext{.ErrorCode = MapCurlError(Result.ErrorCode), .ErrorMessage = Result.ErrorMessage}}; + } + + if (WorkResponseCode == HttpResponseCode::NoContent || (Result.Body.empty() && !Payload)) + { + return HttpClient::Response{.StatusCode = WorkResponseCode, + .Header = BuildHeaderMap(Result.Headers), + .UploadedBytes = Result.UploadedBytes, + .DownloadedBytes = Result.DownloadedBytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + else + { + return ResponseWithPayload(SessionId, std::move(Result), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); + } +} + +bool +CurlHttpClient::ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + ZEN_TRACE_CPU("ValidatePayload"); + + IoBuffer ResponseBuffer = (Result.Body.empty() && PayloadFile) ? PayloadFile->BorrowIoBuffer() + : IoBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + + // Collect relevant headers in a single pass + std::string_view ContentLengthValue; + std::string_view IoHashValue; + std::string_view ContentTypeValue; + + for (const auto& [Key, Value] : Result.Headers) + { + if (ContentLengthValue.empty() && StrCaseCompare(Key, "Content-Length") == 0) + { + ContentLengthValue = Value; + } + else if (IoHashValue.empty() && StrCaseCompare(Key, "X-Jupiter-IoHash") == 0) + { + IoHashValue = Value; + } + else if (ContentTypeValue.empty() && StrCaseCompare(Key, "Content-Type") == 0) + { + ContentTypeValue = Value; + } + } + + // Validate Content-Length + if (!ContentLengthValue.empty()) + { + std::optional<uint64_t> ExpectedContentSize = ParseInt<uint64_t>(ContentLengthValue); + if (!ExpectedContentSize.has_value()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Can not parse Content-Length header. Value: '{}'", ContentLengthValue); + return false; + } + if (ExpectedContentSize.value() != ResponseBuffer.GetSize()) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = + fmt::format("Payload size {} does not match Content-Length {}", ResponseBuffer.GetSize(), ContentLengthValue); + return false; + } + } + + if (Result.StatusCode == static_cast<long>(HttpResponseCode::PartialContent)) + { + return true; + } + + // Validate X-Jupiter-IoHash + if (!IoHashValue.empty()) + { + IoHash ExpectedPayloadHash; + if (IoHash::TryParse(IoHashValue, ExpectedPayloadHash)) + { + IoHash PayloadHash = IoHash::HashBuffer(ResponseBuffer); + if (PayloadHash != ExpectedPayloadHash) + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Payload hash {} does not match X-Jupiter-IoHash {}", + PayloadHash.ToHexString(), + ExpectedPayloadHash.ToHexString()); + return false; + } + } + } + + // Validate content-type specific payload + if (ContentTypeValue == "application/x-ue-comp") + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(ResponseBuffer, + RawHash, + RawSize, + /*OutOptionalTotalCompressedSize*/ nullptr)) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = "Compressed binary failed validation"; + return false; + } + } + if (ContentTypeValue == "application/x-ue-cb") + { + if (CbValidateError Error = ValidateCompactBinary(ResponseBuffer.GetView(), CbValidateMode::Default); + Error == CbValidateError::None) + { + return true; + } + else + { + Result.ErrorCode = CURLE_RECV_ERROR; + Result.ErrorMessage = fmt::format("Compact binary failed validation: {}", ToString(Error)); + return false; + } + } + + return true; +} + +bool +CurlHttpClient::ShouldRetry(const CurlResult& Result) +{ + switch (Result.ErrorCode) + { + case CURLE_OK: + break; + case CURLE_RECV_ERROR: + case CURLE_SEND_ERROR: + case CURLE_OPERATION_TIMEDOUT: + return true; + default: + return false; + } + switch (static_cast<HttpResponseCode>(Result.StatusCode)) + { + case HttpResponseCode::RequestTimeout: + case HttpResponseCode::TooManyRequests: + case HttpResponseCode::InternalServerError: + case HttpResponseCode::BadGateway: + case HttpResponseCode::ServiceUnavailable: + case HttpResponseCode::GatewayTimeout: + return true; + default: + return false; + } +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, std::function<CurlResult()>&& Func, std::function<bool(CurlResult&)>&& Validate) +{ + uint8_t Attempt = 0; + CurlResult Result = Func(); + while (Attempt < m_ConnectionSettings.RetryCount) + { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return Result; + } + if (!ShouldRetry(Result)) + { + if (Result.ErrorCode != CURLE_OK || !IsHttpSuccessCode(Result.StatusCode)) + { + break; + } + if (Validate(Result)) + { + break; + } + } + Sleep(100 * (Attempt + 1)); + Attempt++; + if (ShouldLogErrorCode(HttpResponseCode(Result.StatusCode))) + { + if (Result.ErrorCode != CURLE_OK) + { + ZEN_INFO("Retry (session: {}): HTTP error ({}) '{}' Attempt {}/{}", + SessionId, + static_cast<int>(MapCurlError(Result.ErrorCode)), + Result.ErrorMessage, + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + else + { + ZEN_INFO("Retry (session: {}): HTTP status ({}) '{}' Attempt {}/{}", + SessionId, + Result.StatusCode, + zen::ToString(HttpResponseCode(Result.StatusCode)), + Attempt, + m_ConnectionSettings.RetryCount + 1); + } + } + Result = Func(); + } + return Result; +} + +CurlHttpClient::CurlResult +CurlHttpClient::DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile) +{ + return DoWithRetry(SessionId, std::move(Func), [&](CurlResult& Result) { return ValidatePayload(Result, PayloadFile); }); +} + +////////////////////////////////////////////////////////////////////////// + +CurlHttpClient::Session +CurlHttpClient::AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::AllocSession"); + CURL* Handle = nullptr; + m_SessionLock.WithExclusiveLock([&] { + if (!m_Sessions.empty()) + { + Handle = m_Sessions.back(); + m_Sessions.pop_back(); + } + }); + + if (Handle == nullptr) + { + Handle = curl_easy_init(); + if (Handle == nullptr) + { + ThrowOutOfMemory("curl_easy_init"); + } + } + else + { + curl_easy_reset(Handle); + } + + // Unix domain socket + if (!m_ConnectionSettings.UnixSocketPath.empty()) + { + std::string SocketPathUtf8 = PathToUtf8(m_ConnectionSettings.UnixSocketPath); + curl_easy_setopt(Handle, CURLOPT_UNIX_SOCKET_PATH, SocketPathUtf8.c_str()); + } + + // Build URL with parameters + ExtendableStringBuilder<256> Url; + BuildUrlWithParameters(Url, m_BaseUri, ResourcePath, Parameters); + curl_easy_setopt(Handle, CURLOPT_URL, Url.c_str()); + + // Timeouts + if (m_ConnectionSettings.ConnectTimeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_CONNECTTIMEOUT_MS, static_cast<long>(m_ConnectionSettings.ConnectTimeout.count())); + } + if (m_ConnectionSettings.Timeout.count() > 0) + { + curl_easy_setopt(Handle, CURLOPT_TIMEOUT_MS, static_cast<long>(m_ConnectionSettings.Timeout.count())); + } + + // HTTP/2 + if (m_ConnectionSettings.AssumeHttp2) + { + curl_easy_setopt(Handle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_2_PRIOR_KNOWLEDGE); + } + + // Verbose/debug + if (m_ConnectionSettings.Verbose) + { + curl_easy_setopt(Handle, CURLOPT_VERBOSE, 1L); + curl_easy_setopt(Handle, CURLOPT_DEBUGFUNCTION, CurlDebugCallback); + curl_easy_setopt(Handle, CURLOPT_DEBUGDATA, &m_Log); + } + + // SSL options + if (m_ConnectionSettings.InsecureSsl) + { + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYPEER, 0L); + curl_easy_setopt(Handle, CURLOPT_SSL_VERIFYHOST, 0L); + } + if (!m_ConnectionSettings.CaBundlePath.empty()) + { + curl_easy_setopt(Handle, CURLOPT_CAINFO, m_ConnectionSettings.CaBundlePath.c_str()); + } + + // Disable signal handling for thread safety + curl_easy_setopt(Handle, CURLOPT_NOSIGNAL, 1L); + + if (m_ConnectionSettings.ForbidReuseConnection) + { + curl_easy_setopt(Handle, CURLOPT_FORBID_REUSE, 1L); + } + + // Note: Headers are NOT set here. Each method builds its own header list + // (potentially adding method-specific headers like Content-Type) and passes + // ownership to the Session via SetHeaders(). + + return Session(this, Handle); +} + +void +CurlHttpClient::ReleaseSession(CURL* Handle) +{ + ZEN_TRACE_CPU("CurlHttpClient::ReleaseSession"); + m_SessionLock.WithExclusiveLock([&] { m_Sessions.push_back(Handle); }); +} + +////////////////////////////////////////////////////////////////////////// + +// TransactPackage is a two-phase protocol (offer + send) with server-side state +// between phases, so retrying individual phases would be incorrect. +CurlHttpClient::Response +CurlHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::TransactPackage"); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++CurlHttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (!Attachments.empty()) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + Writer.AddHash(Attachment.GetHash()); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + std::vector<std::pair<std::string, std::string>> OfferExtraHeaders; + OfferExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackageOffer)); + OfferExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), OfferExtraHeaders)); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(MemWriter.Data())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(MemWriter.Size())); + + CurlResult Result = Sess.PerformWithResponseCallbacks(); + + if (Result.ErrorCode == CURLE_OK && IsHttpSuccessCode(Result.StatusCode)) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, Result.Body.data(), Result.Body.size()); + CbValidateError ValidationError = CbValidateError::None; + if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(std::move(ResponseBuffer), ValidationError); + ValidationError == CbValidateError::None) + { + for (CbFieldView& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + std::vector<std::pair<std::string, std::string>> PkgExtraHeaders; + PkgExtraHeaders.emplace_back(HeaderContentType(HttpContentType::kCbPackage)); + PkgExtraHeaders.emplace_back("UE-Request", RequestIdString); + + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), PkgExtraHeaders)); + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(FlatMessage.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(FlatMessage.GetSize())); + + CurlResult Result = Sess.PerformWithResponseCallbacks(); + + return CommonResponse(m_SessionId, std::move(Result), {}, {}); +} + +////////////////////////////////////////////////////////////////////////// +// +// Standard HTTP verbs +// + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Put(std::string_view Url, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Put"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + KeyValueMap HeaderWithContentLength{std::pair<std::string_view, std::string_view>{"Content-Length", "0"}}; + Session Sess = AllocSession(Url, Parameters); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(HeaderWithContentLength, m_SessionId, GetAccessToken())); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, 0LL); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Get(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::Get"); + return CommonResponse(m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_HTTPGET, 1L); + return Sess.PerformWithResponseCallbacks(); + }, + [this](CurlResult& Result) { + std::unique_ptr<detail::TempPayloadFile> NoTempFile; + return ValidatePayload(Result, NoTempFile); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Head(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Head"); + + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_NOBODY, 1L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Delete(std::string_view Url, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Delete"); + + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_CUSTOMREQUEST, "DELETE"); + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const KeyValueMap& AdditionalHeader, const KeyValueMap& Parameters) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostNoPayload"); + + return CommonResponse(m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, Parameters); + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(Sess.Get(), CURLOPT_POST, 1L); + curl_easy_setopt(Sess.Get(), CURLOPT_POSTFIELDSIZE, 0L); + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, Payload, Payload.GetContentType(), AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostWithPayload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + } + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) +{ + ZEN_TRACE_CPU("CurlHttpClient::PostObjectPayload"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + PayloadString.clear(); + PayloadFile.reset(); + + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ZenContentType::kCbObject)})); + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDS, reinterpret_cast<const char*>(Payload.GetBuffer().GetData())); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetBuffer().GetSize())); + + struct PostHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + }; + + PostHeaderCallbackData PostHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + PostHdrData.Headers = &ResponseHeaders; + PostHdrData.PayloadFile = &PayloadFile; + PostHdrData.PayloadString = &PayloadString; + PostHdrData.TempFolderPath = &TempFolderPath; + PostHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + PostHdrData.Log = m_Log; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<PostHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [Key, Value] = *Header; + + if (StrCaseCompare(Key, "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (!Data->TempFolderPath->empty() && ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Post. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &PostHdrData); + + struct PostWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + }; + + PostWriteCallbackData PostWriteData; + PostWriteData.PayloadString = &PayloadString; + PostWriteData.PayloadFile = &PayloadFile; + PostWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + PostWriteData.TempFolderPath = &TempFolderPath; + PostWriteData.Log = m_Log; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<PostWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Post. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &PostWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, std::move(Result), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, {}); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, CbPackage Pkg, const KeyValueMap& AdditionalHeader) +{ + return Post(Url, zen::FormatPackageMessageBuffer(Pkg), ZenContentType::kCbPackage, AdditionalHeader); +} + +CurlHttpClient::Response +CurlHttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Post"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_POST, 1L); + curl_easy_setopt(H, CURLOPT_POSTFIELDSIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders( + BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(Payload.GetContentType())})); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + IoBufferFileReference FileRef = {nullptr, 0, 0}; + if (Payload.GetFileReference(FileRef)) + { + detail::BufferedReadFileStream Buffer(FileRef.FileHandle, FileRef.FileChunkOffset, FileRef.FileChunkSize, 512u * 1024u); + + FileReadCallbackData ReadData{.Buffer = &Buffer, + .TotalSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlFileReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + } + + ReadCallbackData ReadData{.DataPtr = static_cast<const uint8_t*>(Payload.GetData()), + .DataSize = Payload.GetSize(), + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Upload"); + + return CommonResponse( + m_SessionId, + DoWithRetry(m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken(), {HeaderContentType(ContentType)})); + + curl_easy_setopt(H, CURLOPT_UPLOAD, 1L); + curl_easy_setopt(H, CURLOPT_INFILESIZE_LARGE, static_cast<curl_off_t>(Payload.GetSize())); + + detail::CompositeBufferReadStream Reader(Payload, 512u * 1024u); + + StreamReadCallbackData ReadData{.Reader = &Reader, + .CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr}; + + curl_easy_setopt(H, CURLOPT_READFUNCTION, CurlStreamReadCallback); + curl_easy_setopt(H, CURLOPT_READDATA, &ReadData); + + return Sess.PerformWithResponseCallbacks(); + }), + {}); +} + +CurlHttpClient::Response +CurlHttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const KeyValueMap& AdditionalHeader) +{ + ZEN_TRACE_CPU("CurlHttpClient::Download"); + + std::string PayloadString; + std::unique_ptr<detail::TempPayloadFile> PayloadFile; + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + CurlResult Result = DoWithRetry( + m_SessionId, + [&]() -> CurlResult { + Session Sess = AllocSession(Url, {}); + CURL* H = Sess.Get(); + + Sess.SetHeaders(BuildHeaderList(AdditionalHeader, m_SessionId, GetAccessToken())); + curl_easy_setopt(H, CURLOPT_HTTPGET, 1L); + + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + // Track requested content length from Range header (sum all ranges) + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string_view::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string_view::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string_view::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + // Header callback that detects Content-Length and switches to file-backed storage when needed + struct DownloadHeaderCallbackData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + uint64_t MaxInMemorySize = 0; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + HttpContentType* ContentTypeOut = nullptr; + }; + + DownloadHeaderCallbackData DlHdrData; + std::vector<std::pair<std::string, std::string>> ResponseHeaders; + DlHdrData.Headers = &ResponseHeaders; + DlHdrData.PayloadFile = &PayloadFile; + DlHdrData.PayloadString = &PayloadString; + DlHdrData.TempFolderPath = &TempFolderPath; + DlHdrData.MaxInMemorySize = m_ConnectionSettings.MaximumInMemoryDownloadSize; + DlHdrData.Log = m_Log; + DlHdrData.BoundaryParser = &BoundaryParser; + DlHdrData.IsMultiRange = &IsMultiRangeResponse; + DlHdrData.ContentTypeOut = &ContentType; + + auto HeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadHeaderCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes))) + { + auto& [KeyView, Value] = *Header; + const std::string Key(KeyView); + + if (StrCaseCompare(Key, "Content-Length") == 0) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Value); + if (ContentLength.has_value()) + { + if (ContentLength.value() > Data->MaxInMemorySize) + { + *Data->PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = (*Data->PayloadFile)->Open(*Data->TempFolderPath, ContentLength.value()); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + Data->PayloadFile->reset(); + } + } + else + { + Data->PayloadString->reserve(ContentLength.value()); + } + } + } + else if (StrCaseCompare(Key, "Content-Type") == 0) + { + *Data->IsMultiRange = Data->BoundaryParser->Init(Value); + if (!*Data->IsMultiRange) + { + *Data->ContentTypeOut = ParseContentType(Value); + } + } + else if (StrCaseCompare(Key, "Content-Range") == 0) + { + if (!*Data->IsMultiRange) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Value); + if (Range.second != 0) + { + Data->BoundaryParser->Boundaries.push_back( + HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = *Data->ContentTypeOut}); + } + } + } + + Data->Headers->emplace_back(Key, std::string(Value)); + } + + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_HEADERFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(HeaderCb)); + curl_easy_setopt(H, CURLOPT_HEADERDATA, &DlHdrData); + + // Write callback that directs data to file or string + struct DownloadWriteCallbackData + { + std::string* PayloadString = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::function<bool()>* CheckIfAbortFunction = nullptr; + const std::filesystem::path* TempFolderPath = nullptr; + LoggerRef Log; + detail::MultipartBoundaryParser* BoundaryParser = nullptr; + bool* IsMultiRange = nullptr; + }; + + DownloadWriteCallbackData DlWriteData; + DlWriteData.PayloadString = &PayloadString; + DlWriteData.PayloadFile = &PayloadFile; + DlWriteData.CheckIfAbortFunction = m_CheckIfAbortFunction ? &m_CheckIfAbortFunction : nullptr; + DlWriteData.TempFolderPath = &TempFolderPath; + DlWriteData.Log = m_Log; + DlWriteData.BoundaryParser = &BoundaryParser; + DlWriteData.IsMultiRange = &IsMultiRangeResponse; + + auto WriteCb = [](char* Ptr, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<DownloadWriteCallbackData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + if (Data->CheckIfAbortFunction && *Data->CheckIfAbortFunction && (*Data->CheckIfAbortFunction)()) + { + return 0; + } + + if (*Data->IsMultiRange) + { + Data->BoundaryParser->ParseInput(std::string_view(Ptr, TotalBytes)); + } + + if (*Data->PayloadFile) + { + std::error_code Ec = (*Data->PayloadFile)->Write(std::string_view(Ptr, TotalBytes)); + if (Ec) + { + auto Log = [&]() -> LoggerRef { return Data->Log; }; + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + Data->TempFolderPath->string(), + Ec.message()); + return 0; + } + } + else + { + Data->PayloadString->append(Ptr, TotalBytes); + } + return TotalBytes; + }; + + curl_easy_setopt(H, CURLOPT_WRITEFUNCTION, static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(H, CURLOPT_WRITEDATA, &DlWriteData); + + CurlResult Res = Sess.Perform(); + Res.Headers = std::move(ResponseHeaders); + + // Handle resume logic + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const CurlResult& R) -> bool { + for (const auto& [K, V] : R.Headers) + { + if (StrCaseCompare(K, "Content-Range") == 0) + { + return true; + } + if (StrCaseCompare(K, "Accept-Ranges") == 0) + { + return V == "bytes"sv; + } + } + return false; + }; + + auto ShouldResumeCheck = [&SupportsRanges, &IsMultiRangeResponse](const CurlResult& R) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(R)) + { + return SupportsRanges(R); + } + return false; + }; + + if (ShouldResumeCheck(Res)) + { + // Find Content-Length + std::string ContentLengthValue; + for (const auto& [K, V] : Res.Headers) + { + if (StrCaseCompare(K, "Content-Length") == 0) + { + ContentLengthValue = V; + break; + } + } + + if (!ContentLengthValue.empty()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(ContentLengthValue); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + KeyValueMap HeadersWithRange(AdditionalHeader); + uint8_t ResumeAttempt = 0; + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + break; // No progress, abort + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session ResumeSess = AllocSession(Url, {}); + CURL* ResumeH = ResumeSess.Get(); + + ResumeSess.SetHeaders(BuildHeaderList(HeadersWithRange, m_SessionId, GetAccessToken())); + curl_easy_setopt(ResumeH, CURLOPT_HTTPGET, 1L); + + std::vector<std::pair<std::string, std::string>> ResumeHeaders; + + struct ResumeHeaderCbData + { + std::vector<std::pair<std::string, std::string>>* Headers = nullptr; + std::unique_ptr<detail::TempPayloadFile>* PayloadFile = nullptr; + std::string* PayloadString = nullptr; + }; + + ResumeHeaderCbData ResumeHdrData; + ResumeHdrData.Headers = &ResumeHeaders; + ResumeHdrData.PayloadFile = &PayloadFile; + ResumeHdrData.PayloadString = &PayloadString; + + auto ResumeHeaderCb = [](char* Buffer, size_t Size, size_t Nmemb, void* UserData) -> size_t { + auto* Data = static_cast<ResumeHeaderCbData*>(UserData); + size_t TotalBytes = Size * Nmemb; + + auto Header = ParseHeaderLine(std::string_view(Buffer, TotalBytes)); + if (!Header) + { + return TotalBytes; + } + auto& [Key, Value] = *Header; + + if (StrCaseCompare(Key, "Content-Range") == 0) + { + if (Value.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Value.find('-', 6); + if (RangeStartEnd != std::string_view::npos) + { + const std::optional<uint64_t> Start = ParseInt<uint64_t>(Value.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = + *Data->PayloadFile ? (*Data->PayloadFile)->GetSize() : Data->PayloadString->length(); + if (Start.value() == DownloadedSize) + { + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (*Data->PayloadFile) + { + (*Data->PayloadFile)->ResetWritePos(Start.value()); + } + else + { + *Data->PayloadString = Data->PayloadString->substr(0, Start.value()); + } + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + } + } + } + return 0; + } + + Data->Headers->emplace_back(std::string(Key), std::string(Value)); + return TotalBytes; + }; + + curl_easy_setopt(ResumeH, + CURLOPT_HEADERFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(ResumeHeaderCb)); + curl_easy_setopt(ResumeH, CURLOPT_HEADERDATA, &ResumeHdrData); + curl_easy_setopt(ResumeH, + CURLOPT_WRITEFUNCTION, + static_cast<size_t (*)(char*, size_t, size_t, void*)>(WriteCb)); + curl_easy_setopt(ResumeH, CURLOPT_WRITEDATA, &DlWriteData); + + Res = ResumeSess.Perform(); + Res.Headers = std::move(ResumeHeaders); + + ResumeAttempt++; + } while (ResumeAttempt < m_ConnectionSettings.RetryCount && ShouldResumeCheck(Res)); + } + } + } + + if (!PayloadString.empty()) + { + Res.Body = std::move(PayloadString); + } + + return Res; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Result), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); +} + +} // namespace zen diff --git a/src/zenhttp/clients/httpclientcurl.h b/src/zenhttp/clients/httpclientcurl.h new file mode 100644 index 000000000..b7fa52e6c --- /dev/null +++ b/src/zenhttp/clients/httpclientcurl.h @@ -0,0 +1,137 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "httpclientcommon.h" + +#include <zencore/logging.h> +#include <zenhttp/httpclient.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <curl/curl.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CurlHttpClient : public HttpClientBase +{ +public: + CurlHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); + ~CurlHttpClient(); + + // HttpClientBase + + [[nodiscard]] virtual Response Put(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Put(std::string_view Url, const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Get(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Head(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Delete(std::string_view Url, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const KeyValueMap& AdditionalHeader = {}, + const KeyValueMap& Parameters = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const IoBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Post(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, const IoBuffer& Payload, const KeyValueMap& AdditionalHeader = {}) override; + [[nodiscard]] virtual Response Upload(std::string_view Url, + const CompositeBuffer& Payload, + ZenContentType ContentType, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response Download(std::string_view Url, + const std::filesystem::path& TempFolderPath, + const KeyValueMap& AdditionalHeader = {}) override; + + [[nodiscard]] virtual Response TransactPackage(std::string_view Url, + CbPackage Package, + const KeyValueMap& AdditionalHeader = {}) override; + +private: + struct CurlResult + { + long StatusCode = 0; + std::string Body; + std::vector<std::pair<std::string, std::string>> Headers; + double ElapsedSeconds = 0; + int64_t UploadedBytes = 0; + int64_t DownloadedBytes = 0; + CURLcode ErrorCode = CURLE_OK; + std::string ErrorMessage; + }; + + struct Session + { + Session(CurlHttpClient* InOuter, CURL* InHandle) : Outer(InOuter), Handle(InHandle) {} + ~Session(); + + CURL* Get() const { return Handle; } + + // Takes ownership of the curl_slist and sets it on the handle. + // The list is freed automatically when the Session is destroyed. + void SetHeaders(curl_slist* Headers); + + // Low-level perform: executes the request and collects status/timing. + CurlResult Perform(); + + // Sets up standard write+header callbacks, performs the request, and + // moves the collected body and headers into the returned CurlResult. + CurlResult PerformWithResponseCallbacks(); + + LoggerRef Log() { return Outer->Log(); } + + private: + CurlHttpClient* Outer; + CURL* Handle; + curl_slist* HeaderList = nullptr; + + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + }; + + Session AllocSession(std::string_view ResourcePath, const KeyValueMap& Parameters); + + RwLock m_SessionLock; + std::vector<CURL*> m_Sessions; + + void ReleaseSession(CURL* Handle); + + CurlResult DoWithRetry(std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + CurlResult DoWithRetry( + std::string_view SessionId, + std::function<CurlResult()>&& Func, + std::function<bool(CurlResult&)>&& Validate = [](CurlResult&) { return true; }); + + bool ValidatePayload(CurlResult& Result, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); + + static bool ShouldRetry(const CurlResult& Result); + + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; + + HttpClient::Response CommonResponse(std::string_view SessionId, + CurlResult&& Result, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); + + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + CurlResult&& Result, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); +}; + +} // namespace zen diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp new file mode 100644 index 000000000..fbae9f5fe --- /dev/null +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -0,0 +1,641 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpwsclient.h> + +#include "../servers/wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <random> +#include <thread> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct HttpWsClient::Impl +{ + Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + { + ParseUrl(Url); + } + + Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + { + ParseUrl(Url); + } + + ~Impl() + { + // Release work guard so io_context::run() can return + m_WorkGuard.reset(); + + // Close the socket to cancel pending async ops + CloseSocket(); + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void CloseSocket() + { + asio::error_code Ec; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + m_UnixSocket->close(Ec); + return; + } +#endif + if (m_TcpSocket) + { + m_TcpSocket->close(Ec); + } + } + + template<typename Fn> + void WithSocket(Fn&& Func) + { +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixSocket) + { + Func(*m_UnixSocket); + return; + } +#endif + Func(*m_TcpSocket); + } + + void ParseUrl(std::string_view Url) + { + // Expected format: ws://host:port/path + if (Url.substr(0, 5) == "ws://") + { + Url.remove_prefix(5); + } + + auto SlashPos = Url.find('/'); + std::string_view HostPort; + if (SlashPos != std::string_view::npos) + { + HostPort = Url.substr(0, SlashPos); + m_Path = std::string(Url.substr(SlashPos)); + } + else + { + HostPort = Url; + m_Path = "/"; + } + + auto ColonPos = HostPort.find(':'); + if (ColonPos != std::string_view::npos) + { + m_Host = std::string(HostPort.substr(0, ColonPos)); + m_Port = std::string(HostPort.substr(ColonPos + 1)); + } + else + { + m_Host = std::string(HostPort); + m_Port = "80"; + } + } + + void Connect() + { + if (m_OwnedIoContext) + { + m_WorkGuard.emplace(m_IoContext.get_executor()); + m_IoThread = std::thread([this] { m_IoContext.run(); }); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!m_Settings.UnixSocketPath.empty()) + { + asio::post(m_IoContext, [this] { DoConnectUnix(); }); + return; + } +#endif + + asio::post(m_IoContext, [this] { DoResolve(); }); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + void DoConnectUnix() + { + m_UnixSocket = std::make_unique<asio::local::stream_protocol::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect timeout for {}", m_Settings.UnixSocketPath); + CloseSocket(); + } + }); + + asio::local::stream_protocol::endpoint Endpoint(PathToUtf8(m_Settings.UnixSocketPath)); + m_UnixSocket->async_connect(Endpoint, [this](const asio::error_code& Ec) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket unix connect failed for {}: {}", m_Settings.UnixSocketPath, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } +#endif + + void DoResolve() + { + m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext); + + m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "resolve failed"); + return; + } + + DoConnect(Results); + }); + } + + void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) + { + m_TcpSocket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); + CloseSocket(); + } + }); + + asio::async_connect(*m_TcpSocket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } + + void DoHandshake() + { + // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) + uint8_t KeyBytes[16]; + { + static thread_local std::mt19937 s_Rng(std::random_device{}()); + for (int i = 0; i < 4; ++i) + { + uint32_t Val = s_Rng(); + std::memcpy(KeyBytes + i * 4, &Val, 4); + } + } + + char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; + uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); + KeyBase64[KeyLen] = '\0'; + m_WebSocketKey = std::string(KeyBase64, KeyLen); + + // Build the HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << m_Path << " HTTP/1.1\r\n" + << "Host: " << m_Host << ":" << m_Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" + << "Sec-WebSocket-Version: 13\r\n"; + + // Add Authorization header if access token provider is set + if (m_Settings.AccessTokenProvider) + { + HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); + if (Token.IsValid()) + { + Request << "Authorization: Bearer " << Token.Value << "\r\n"; + } + } + + Request << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + m_HandshakeBuffer = std::make_shared<std::string>(ReqStr); + + WithSocket([this](auto& Socket) { + asio::async_write(Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + }); + } + + void DoReadHandshakeResponse() + { + WithSocket([this](auto& Socket) { + asio::async_read_until(Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); + + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } + + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } + + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } + + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } + + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); + }); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Read loop + // + + void EnqueueRead() + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + WithSocket([this](auto& Socket) { + asio::async_read(Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); + }); + } + + void OnDataReceived(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } + } + + void ProcessReceivedData() + { + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); + if (!Frame.IsValid) + { + break; + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWsMessage(Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with masked pong + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = + std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo masked close frame if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWsClose(Code, Reason); + return; + } + + default: + ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // + // Write queue + // + + void EnqueueWrite(std::vector<uint8_t> Frame) + { + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } + } + + void FlushWriteQueue() + { + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + WithSocket([this, OwnedFrame](auto& Socket) { + asio::async_write(Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + }); + } + + void OnWriteComplete(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "write error"); + } + return; + } + + FlushWriteQueue(); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Public operations + // + + void SendText(std::string_view Text) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); + } + + void SendBinary(std::span<const uint8_t> Data) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); + } + + void DoClose(uint16_t Code, std::string_view Reason) + { + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + IWsClientHandler& m_Handler; + HttpWsClientSettings m_Settings; + LoggerRef m_Log; + + std::string m_Host; + std::string m_Port; + std::string m_Path; + + // io_context: owned (standalone) or external (shared) + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + std::optional<asio::executor_work_guard<asio::io_context::executor_type>> m_WorkGuard; + std::thread m_IoThread; + + // Connection state + std::unique_ptr<asio::ip::tcp::resolver> m_Resolver; + std::unique_ptr<asio::ip::tcp::socket> m_TcpSocket; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio::local::stream_protocol::socket> m_UnixSocket; +#endif + std::unique_ptr<asio::steady_timer> m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr<std::string> m_HandshakeBuffer; + + // Write queue + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{false}; + std::atomic<bool> m_CloseSent{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, Settings)) +{ +} + +HttpWsClient::HttpWsClient(std::string_view Url, + IWsClientHandler& Handler, + asio::io_context& IoContext, + const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings)) +{ +} + +HttpWsClient::~HttpWsClient() = default; + +void +HttpWsClient::Connect() +{ + m_Impl->Connect(); +} + +void +HttpWsClient::SendText(std::string_view Text) +{ + m_Impl->SendText(Text); +} + +void +HttpWsClient::SendBinary(std::span<const uint8_t> Data) +{ + m_Impl->SendBinary(Data); +} + +void +HttpWsClient::Close(uint16_t Code, std::string_view Reason) +{ + m_Impl->DoClose(Code, Reason); +} + +bool +HttpWsClient::IsOpen() const +{ + return m_Impl->m_IsOpen.load(std::memory_order_relaxed); +} + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index d3b59df2b..9f49802a0 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -21,6 +21,8 @@ #include "clients/httpclientcommon.h" +#include <numeric> + #if ZEN_WITH_TESTS # include <zencore/scopeguard.h> # include <zencore/testing.h> @@ -34,9 +36,43 @@ namespace zen { +#if ZEN_WITH_CPR extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction); +#endif + +extern HttpClientBase* CreateCurlHttpClient(std::string_view BaseUri, + const HttpClientSettings& ConnectionSettings, + std::function<bool()>&& CheckIfAbortFunction); + +static HttpClientBackend g_DefaultHttpClientBackend = HttpClientBackend::kCurl; + +void +SetDefaultHttpClientBackend(HttpClientBackend Backend) +{ + g_DefaultHttpClientBackend = Backend; +} + +void +SetDefaultHttpClientBackend(std::string_view Backend) +{ +#if ZEN_WITH_CPR + if (Backend == "cpr") + { + g_DefaultHttpClientBackend = HttpClientBackend::kCpr; + } + else +#endif + if (Backend == "curl") + { + g_DefaultHttpClientBackend = HttpClientBackend::kCurl; + } + else + { + g_DefaultHttpClientBackend = HttpClientBackend::kDefault; + } +} using namespace std::literals; @@ -102,6 +138,109 @@ HttpClientBase::GetAccessToken() ////////////////////////////////////////////////////////////////////////// +HttpClientError::ResponseClass +HttpClientError::GetResponseClass() const +{ + if (m_Error != HttpClientErrorCode::kOK) + { + switch (m_Error) + { + case HttpClientErrorCode::kConnectionFailure: + return ResponseClass::kHttpCantConnectError; + case HttpClientErrorCode::kHostResolutionFailure: + case HttpClientErrorCode::kProxyResolutionFailure: + return ResponseClass::kHttpNoHost; + case HttpClientErrorCode::kInternalError: + case HttpClientErrorCode::kNetworkReceiveError: + case HttpClientErrorCode::kNetworkSendFailure: + case HttpClientErrorCode::kOperationTimedOut: + return ResponseClass::kHttpTimeout; + case HttpClientErrorCode::kSSLConnectError: + case HttpClientErrorCode::kSSLCertificateError: + case HttpClientErrorCode::kSSLCACertError: + case HttpClientErrorCode::kGenericSSLError: + return ResponseClass::kHttpSLLError; + default: + return ResponseClass::kHttpOtherClientError; + } + } + else if (IsHttpSuccessCode(m_ResponseCode)) + { + return ResponseClass::kSuccess; + } + else + { + switch (m_ResponseCode) + { + case HttpResponseCode::Unauthorized: + return ResponseClass::kHttpUnauthorized; + case HttpResponseCode::NotFound: + return ResponseClass::kHttpNotFound; + case HttpResponseCode::Forbidden: + return ResponseClass::kHttpForbidden; + case HttpResponseCode::Conflict: + return ResponseClass::kHttpConflict; + case HttpResponseCode::InternalServerError: + return ResponseClass::kHttpInternalServerError; + case HttpResponseCode::ServiceUnavailable: + return ResponseClass::kHttpServiceUnavailable; + case HttpResponseCode::BadGateway: + return ResponseClass::kHttpBadGateway; + case HttpResponseCode::GatewayTimeout: + return ResponseClass::kHttpGatewayTimeout; + default: + if (m_ResponseCode >= HttpResponseCode::InternalServerError) + { + return ResponseClass::kHttpOtherServerError; + } + else + { + return ResponseClass::kHttpOtherClientError; + } + } + } +} + +////////////////////////////////////////////////////////////////////////// + +std::vector<std::pair<uint64_t, uint64_t>> +HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const +{ + if (Ranges.empty()) + { + return {}; + } + + std::vector<std::pair<uint64_t, uint64_t>> Result; + Result.reserve(OffsetAndLengthPairs.size()); + + auto BoundaryIt = Ranges.begin(); + auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin(); + while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end()) + { + uint64_t Offset = OffsetAndLengthPairIt->first; + uint64_t Length = OffsetAndLengthPairIt->second; + while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength) + { + BoundaryIt++; + if (BoundaryIt == Ranges.end()) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + } + if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset; + uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange; + Result.emplace_back(std::make_pair(RangePayloadOffset, Length)); + + OffsetAndLengthPairIt++; + } + return Result; +} + CbObject HttpClient::Response::AsObject() const { @@ -182,7 +321,11 @@ HttpClient::Response::ErrorMessage(std::string_view Prefix) const { if (Error.has_value()) { - return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage); + return fmt::format("{}{}HTTP error ({}) '{}'", + Prefix, + Prefix.empty() ? ""sv : ": "sv, + static_cast<int>(Error->ErrorCode), + Error->ErrorMessage); } else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode) { @@ -205,19 +348,36 @@ HttpClient::Response::ThrowError(std::string_view ErrorPrefix) { if (!IsSuccess()) { - throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode); + throw HttpClientError(ErrorMessage(ErrorPrefix), + Error.has_value() ? Error.value().ErrorCode : HttpClientErrorCode::kOK, + StatusCode); } } ////////////////////////////////////////////////////////////////////////// HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function<bool()>&& CheckIfAbortFunction) -: m_BaseUri(BaseUri) +: m_Log(zen::logging::Get(ConnectionSettings.LogCategory)) +, m_BaseUri(BaseUri) , m_ConnectionSettings(ConnectionSettings) { m_SessionId = GetSessionIdString(); - m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + HttpClientBackend EffectiveBackend = + ConnectionSettings.Backend != HttpClientBackend::kDefault ? ConnectionSettings.Backend : g_DefaultHttpClientBackend; + + switch (EffectiveBackend) + { +#if ZEN_WITH_CPR + case HttpClientBackend::kCpr: + m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; +#endif + case HttpClientBackend::kCurl: + default: + m_Inner = CreateCurlHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); + break; + } } HttpClient::~HttpClient() @@ -287,9 +447,12 @@ HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType C } HttpClient::Response -HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader) +HttpClient::Post(std::string_view Url, + CbObject Payload, + const HttpClient::KeyValueMap& AdditionalHeader, + const std::filesystem::path& TempFolderPath) { - return m_Inner->Post(Url, Payload, AdditionalHeader); + return m_Inner->Post(Url, Payload, AdditionalHeader, TempFolderPath); } HttpClient::Response @@ -340,10 +503,55 @@ HttpClient::Authenticate() return m_Inner->Authenticate(); } +LatencyTestResult +MeasureLatency(HttpClient& Client, std::string_view Url) +{ + std::vector<double> MeasurementTimes; + std::string ErrorMessage; + + for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++) + { + HttpClient::Response MeasureResponse = Client.Get(Url); + if (MeasureResponse.IsSuccess()) + { + MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds); + Sleep(5); + } + else + { + ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); + + // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. + // Bail out immediately — retrying will just burn the connect timeout each time. + if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) + { + break; + } + } + } + + if (MeasurementTimes.empty()) + { + return {.Success = false, .FailureReason = ErrorMessage}; + } + + if (MeasurementTimes.size() > 2) + { + std::sort(MeasurementTimes.begin(), MeasurementTimes.end()); + MeasurementTimes.pop_back(); // Remove the worst time + } + + double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size(); + + return {.Success = true, .LatencySeconds = AverageLatency}; +} + ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpclient"); + TEST_CASE("responseformat") { using namespace std::literals; @@ -753,6 +961,8 @@ TEST_CASE("httpclient.password") AsioServer->RequestExit(); } } +TEST_SUITE_END(); + void httpclient_forcelink() { diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp new file mode 100644 index 000000000..5f3ad2455 --- /dev/null +++ b/src/zenhttp/httpclient_test.cpp @@ -0,0 +1,1701 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclient.h> +#include <zenhttp/httpserver.h> + +#if ZEN_WITH_TESTS + +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryutil.h> +# include <zencore/compositebuffer.h> +# include <zencore/filesystem.h> +# include <zencore/iobuffer.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/session.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include "servers/httpasio.h" + +# include <atomic> +# include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Test service + +class HttpClientTestService : public HttpService +{ +public: + HttpClientTestService() + { + m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool { + for (char C : Str) + { + if (C < '0' || C > '9') + { + return false; + } + } + return !Str.empty(); + }); + + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/headers", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + CbObjectWriter Writer; + if (!Auth.empty()) + { + Writer.AddString("Authorization", Auth); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Obj.AddString("message", "test"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "created", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created"); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "content-type/text", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}"); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/binary", + [](HttpRouterRequest& Req) { + uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF}; + IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data)); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/cbobject", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddString("type", "cbobject"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "auth/bearer", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + if (Auth.starts_with("Bearer ") && Auth.size() > 7) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated"); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized"); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "slow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Sleep(2000); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "large", + [](HttpRouterRequest& Req) { + constexpr size_t Size = 64 * 1024; + IoBuffer Buf(Size); + uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData()); + for (size_t i = 0; i < Size; ++i) + { + Ptr[i] = static_cast<uint8_t>(i & 0xFF); + } + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status/{statuscode}", + [](HttpRouterRequest& Req) { + std::string_view CodeStr = Req.GetCapture(1); + int Code = std::stoi(std::string{CodeStr}); + const HttpResponseCode ResponseCode = static_cast<HttpResponseCode>(Code); + Req.ServerRequest().WriteResponse(ResponseCode); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "attempt-counter", + [this](HttpRouterRequest& Req) { + uint32_t Count = m_AttemptCounter.fetch_add(1); + if (Count < m_FailCount) + { + Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries"); + } + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + + void ResetAttemptCounter(uint32_t FailCount) + { + m_AttemptCounter.store(0); + m_FailCount = FailCount; + } + +private: + HttpRequestRouter m_Router; + std::atomic<uint32_t> m_AttemptCounter{0}; + uint32_t m_FailCount = 2; +}; + +////////////////////////////////////////////////////////////////////////// +// Test server fixture + +struct TestServerFixture +{ + HttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + int Port = -1; + + TestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(0, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~TestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_SUITE_BEGIN("http.httpclient"); + +TEST_CASE("httpclient.verbs") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + HttpClient::Response Resp = Client.Get("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + HttpClient::Response Resp = Client.Put("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + HttpClient::Response Resp = Client.Delete("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + HttpClient::Response Resp = Client.Head("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("httpclient.get") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET with auth header via echo") + { + HttpClient::Response Resp = + Client.Get("/api/test/echo/headers", std::pair<std::string, std::string>("Authorization", "Bearer test-token-123")); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123"); + } + + SUBCASE("GET returning CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("GET large payload") + { + HttpClient::Response Resp = Client.Get("/api/test/large"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + + const uint8_t* Data = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + bool Valid = true; + for (size_t i = 0; i < 64 * 1024; ++i) + { + if (Data[i] != static_cast<uint8_t>(i & 0xFF)) + { + Valid = false; + break; + } + } + CHECK(Valid); + } +} + +TEST_CASE("httpclient.post") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("POST with IoBuffer payload echo round-trip") + { + const char* Payload = "test payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "test payload data"); + } + + SUBCASE("POST with IoBuffer and explicit content type") + { + const char* Payload = "{\"key\":\"value\"}"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}"); + } + + SUBCASE("POST with CbObject payload round-trip") + { + CbObjectWriter Writer; + Writer.AddBool("enabled", true); + Writer.AddString("name", "testobj"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj); + CHECK(Resp.IsSuccess()); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["enabled"].AsBool() == true); + CHECK_EQ(RoundTripped["name"].AsString(), "testobj"); + } + + SUBCASE("POST with CompositeBuffer payload") + { + const char* Part1 = "hello "; + const char* Part2 = "composite"; + IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1)); + IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2)); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello composite"); + } + + SUBCASE("POST with custom headers") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}); + CHECK(Resp.IsSuccess()); + } + + SUBCASE("POST with empty body to nocontent endpoint") + { + HttpClient::Response Resp = Client.Post("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("httpclient.put") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("PUT with IoBuffer payload echo round-trip") + { + const char* Payload = "put payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload data"); + } + + SUBCASE("PUT with parameters only") + { + HttpClient::Response Resp = Client.Put("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } + + SUBCASE("PUT to created endpoint") + { + const char* Payload = "new resource"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/created", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created); + CHECK_EQ(Resp.AsText(), "resource created"); + } +} + +TEST_CASE("httpclient.upload") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("Upload IoBuffer") + { + constexpr size_t Size = 128 * 1024; + IoBuffer Blob = CreateSemiRandomBlob(Size); + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), Size); + } + + SUBCASE("Upload CompositeBuffer") + { + IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024); + IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.download") +{ + TestServerFixture Fixture; + ScopedTemporaryDirectory DownloadDir; + + SUBCASE("Download small payload stays in memory") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill") + { + HttpClientSettings Settings; + Settings.MaximumInMemoryDownloadSize = 4; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.post-streaming") +{ + TestServerFixture Fixture; + ScopedTemporaryDirectory PostDir; + + SUBCASE("POST CbObject with TempFolderPath stays in memory when response is small") + { + HttpClient Client = Fixture.MakeClient(); + + CbObjectWriter Writer; + Writer.AddBool("streaming", false); + Writer.AddString("mode", "memory"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj, {}, PostDir.Path()); + CHECK(Resp.IsSuccess()); + IoBufferFileReference _; + CHECK(!Resp.ResponsePayload.GetFileReference(_)); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["streaming"].AsBool() == false); + CHECK_EQ(RoundTripped["mode"].AsString(), "memory"); + } + + SUBCASE("POST CbObject with TempFolderPath streams to file when response exceeds MaximumInMemoryDownloadSize") + { + HttpClientSettings Settings; + Settings.MaximumInMemoryDownloadSize = 4; + HttpClient Client = Fixture.MakeClient(Settings); + + CbObjectWriter Writer; + Writer.AddBool("streaming", true); + Writer.AddString("mode", "file"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj, {}, PostDir.Path()); + CHECK(Resp.IsSuccess()); + IoBufferFileReference _; + CHECK(Resp.ResponsePayload.GetFileReference(_)); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["streaming"].AsBool() == true); + CHECK_EQ(RoundTripped["mode"].AsString(), "file"); + } +} + +TEST_CASE("httpclient.status-codes") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("2xx are success") + { + CHECK(Client.Get("/api/test/status/200").IsSuccess()); + CHECK(Client.Get("/api/test/status/201").IsSuccess()); + CHECK(Client.Get("/api/test/status/204").IsSuccess()); + } + + SUBCASE("4xx are not success") + { + CHECK(!Client.Get("/api/test/status/400").IsSuccess()); + CHECK(!Client.Get("/api/test/status/401").IsSuccess()); + CHECK(!Client.Get("/api/test/status/403").IsSuccess()); + CHECK(!Client.Get("/api/test/status/404").IsSuccess()); + CHECK(!Client.Get("/api/test/status/409").IsSuccess()); + } + + SUBCASE("5xx are not success") + { + CHECK(!Client.Get("/api/test/status/500").IsSuccess()); + CHECK(!Client.Get("/api/test/status/502").IsSuccess()); + CHECK(!Client.Get("/api/test/status/503").IsSuccess()); + } + + SUBCASE("status code values match") + { + CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK); + CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created); + CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent); + CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest); + CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized); + CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden); + CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound); + CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict); + CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError); + CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway); + CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.response") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("IsSuccess and operator bool for success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(static_cast<bool>(Resp)); + } + + SUBCASE("IsSuccess and operator bool for failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + CHECK(!Resp.IsSuccess()); + CHECK(!static_cast<bool>(Resp)); + } + + SUBCASE("AsText returns body") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("AsText returns empty for no-content") + { + HttpClient::Response Resp = Client.Get("/api/test/nocontent"); + CHECK(Resp.AsText().empty()); + } + + SUBCASE("AsObject parses CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("AsObject returns empty for non-CB content") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CbObject Obj = Resp.AsObject(); + CHECK(!Obj); + } + + SUBCASE("ToText for text content") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK_EQ(Resp.ToText(), "plain text"); + } + + SUBCASE("ToText for CbObject content") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + std::string Text = Resp.ToText(); + CHECK(!Text.empty()); + // ToText for CbObject converts to JSON string representation + CHECK(Text.find("ok") != std::string::npos); + CHECK(Text.find("test") != std::string::npos); + } + + SUBCASE("ErrorMessage includes status code on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + std::string Msg = Resp.ErrorMessage("test-prefix"); + CHECK(Msg.find("test-prefix") != std::string::npos); + CHECK(Msg.find("404") != std::string::npos); + } + + SUBCASE("ThrowError throws on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/500"); + CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError); + } + + SUBCASE("ThrowError does not throw on success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_NOTHROW(Resp.ThrowError("test")); + } + + SUBCASE("HttpClientError carries response code") + { + HttpClient::Response Resp = Client.Get("/api/test/status/403"); + try + { + Resp.ThrowError("test"); + CHECK(false); // should not reach + } + catch (const HttpClientError& Err) + { + CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden); + } + } +} + +TEST_CASE("httpclient.error-handling") +{ + SUBCASE("Connection refused") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("Request timeout") + { + TestServerFixture Fixture; + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/slow"); + CHECK(!Resp.IsSuccess()); + } + + SUBCASE("Nonexistent endpoint returns failure") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/does-not-exist"); + CHECK(!Resp.IsSuccess()); + } +} + +TEST_CASE("httpclient.session") +{ + TestServerFixture Fixture; + + SUBCASE("Default session ID is non-empty") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.GetSessionId().empty()); + } + + SUBCASE("SetSessionId changes ID") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + std::string OldId = std::string(Client.GetSessionId()); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + CHECK_NE(Client.GetSessionId(), OldId); + } + + SUBCASE("SetSessionId with Zero resets") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + Client.SetSessionId(Oid::Zero); + // After resetting, should get a session string (not empty, not the custom one) + CHECK(!Client.GetSessionId().empty()); + CHECK_NE(Client.GetSessionId(), NewId.ToString()); + } +} + +TEST_CASE("httpclient.authentication") +{ + TestServerFixture Fixture; + + SUBCASE("Authenticate returns false without provider") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Authenticate returns true with valid token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "valid-token", + .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(Client.Authenticate()); + } + + SUBCASE("Authenticate returns false with expired token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "expired-token", + .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Bearer token verified by auth endpoint") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response AuthResp = + Client.Get("/api/test/auth/bearer", std::pair<std::string, std::string>("Authorization", "Bearer my-secret-token")); + CHECK(AuthResp.IsSuccess()); + CHECK_EQ(AuthResp.AsText(), "authenticated"); + } + + SUBCASE("Request without token to auth endpoint gets 401") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/auth/bearer"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized); + } +} + +TEST_CASE("httpclient.content-types") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("text content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText); + } + + SUBCASE("JSON content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/json"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON); + } + + SUBCASE("binary content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/binary"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary); + } + + SUBCASE("CbObject content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject); + } +} + +TEST_CASE("httpclient.metadata") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("ElapsedSeconds is positive") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.ElapsedSeconds > 0.0); + } + + SUBCASE("DownloadedBytes populated for GET") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.DownloadedBytes > 0); + } + + SUBCASE("UploadedBytes populated for POST with payload") + { + const char* Payload = "some upload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK(Resp.UploadedBytes > 0); + } +} + +TEST_CASE("httpclient.retry") +{ + TestServerFixture Fixture; + + SUBCASE("Retry succeeds after transient failures") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "success after retries"); + } + + SUBCASE("No retry returns 503 immediately") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 0; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.measurelatency") +{ + SUBCASE("Successful measurement against live server") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(Result.Success); + CHECK(Result.LatencySeconds > 0.0); + } + + SUBCASE("Failed measurement against unreachable port") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(!Result.Success); + CHECK(!Result.FailureReason.empty()); + } +} + +TEST_CASE("httpclient.keyvaluemap") +{ + SUBCASE("Default construction is empty") + { + HttpClient::KeyValueMap Map; + CHECK(Map->empty()); + } + + SUBCASE("Construction from pair") + { + HttpClient::KeyValueMap Map(std::pair<std::string, std::string>("key", "value")); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from string_view pair") + { + HttpClient::KeyValueMap Map(std::pair<std::string_view, std::string_view>("key"sv, "value"sv)); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from initializer list") + { + HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}}); + CHECK_EQ(Map->size(), 2u); + CHECK_EQ(Map->at("a"), "1"); + CHECK_EQ(Map->at("b"), "2"); + } +} + +////////////////////////////////////////////////////////////////////////// +// Transport fault testing + +static std::string +MakeRawHttpResponse(int StatusCode, std::string_view Body) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + StatusCode, + Body.size(), + Body); +} + +static std::string +MakeRawHttpHeaders(int StatusCode, size_t ContentLength) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Length: {}\r\n" + "\r\n", + StatusCode, + ContentLength); +} + +static void +DrainHttpRequest(asio::ip::tcp::socket& Socket) +{ + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); +} + +static void +DrainFullHttpRequest(asio::ip::tcp::socket& Socket) +{ + // Read until end of headers + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Extract headers to find Content-Length + std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + + size_t ContentLength = 0; + auto Pos = Headers.find("Content-Length: "); + if (Pos == std::string::npos) + { + Pos = Headers.find("content-length: "); + } + if (Pos != std::string::npos) + { + size_t ValStart = Pos + 16; // length of "Content-Length: " + size_t ValEnd = Headers.find("\r\n", ValStart); + if (ValEnd != std::string::npos) + { + ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart)); + } + } + + // Calculate how many body bytes were already read past the header boundary. + // asio::read_until may read past the delimiter, so Buf.data() contains everything read. + size_t HeaderEnd = Headers.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0; + size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0; + + if (Remaining > 0) + { + std::vector<char> BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +static void +DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead) +{ + // Read headers first + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Determine how many body bytes were already buffered past headers + std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + size_t HeaderEnd = All.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0; + + if (BodyBytesInBuf < BytesToRead) + { + size_t Remaining = BytesToRead - BodyBytesInBuf; + std::vector<char> BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +struct FaultTcpServer +{ + using FaultHandler = std::function<void(asio::ip::tcp::socket&)>; + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + FaultHandler m_Handler; + std::thread m_Thread; + int m_Port; + + explicit FaultTcpServer(FaultHandler Handler) + : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0)) + , m_Handler(std::move(Handler)) + { + m_Port = m_Acceptor.local_endpoint().port(); + StartAccept(); + m_Thread = std::thread([this]() { + try + { + m_IoContext.run(); + } + catch (...) + { + } + }); + } + + ~FaultTcpServer() + { + // io_context::stop() is thread-safe; do NOT call m_Acceptor.close() from this + // thread — ASIO I/O objects are not safe for concurrent access and the io_context + // thread may be touching the acceptor in StartAccept(). + m_IoContext.stop(); + if (m_Thread.joinable()) + { + m_Thread.join(); + } + } + + FaultTcpServer(const FaultTcpServer&) = delete; + FaultTcpServer& operator=(const FaultTcpServer&) = delete; + + void StartAccept() + { + m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) { + if (!Ec) + { + m_Handler(Socket); + } + if (m_Acceptor.is_open()) + { + StartAccept(); + } + }); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +TEST_CASE("httpclient.range-response") +{ + ScopedTemporaryDirectory DownloadDir; + + SUBCASE("single range 206 response populates Ranges") + { + std::string RangeBody(100, 'A'); + + FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = fmt::format( + "HTTP/1.1 206 Partial Content\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Range: bytes 200-299/1000\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + RangeBody.size(), + RangeBody); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + }); + + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent); + REQUIRE(Resp.Ranges.size() == 1); + CHECK_EQ(Resp.Ranges[0].RangeOffset, 200); + CHECK_EQ(Resp.Ranges[0].RangeLength, 100); + } + + SUBCASE("multipart byteranges 206 response populates Ranges") + { + std::string Part1Data(16, 'X'); + std::string Part2Data(12, 'Y'); + std::string Boundary = "testboundary123"; + + std::string MultipartBody = fmt::format( + "\r\n--{}\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Range: bytes 100-115/1000\r\n" + "\r\n" + "{}" + "\r\n--{}\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Range: bytes 500-511/1000\r\n" + "\r\n" + "{}" + "\r\n--{}--", + Boundary, + Part1Data, + Boundary, + Part2Data, + Boundary); + + FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = fmt::format( + "HTTP/1.1 206 Partial Content\r\n" + "Content-Type: multipart/byteranges; boundary={}\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + Boundary, + MultipartBody.size(), + MultipartBody); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + }); + + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::PartialContent); + REQUIRE(Resp.Ranges.size() == 2); + // Ranges should be sorted by RangeOffset + CHECK_EQ(Resp.Ranges[0].RangeOffset, 100); + CHECK_EQ(Resp.Ranges[0].RangeLength, 16); + CHECK_EQ(Resp.Ranges[1].RangeOffset, 500); + CHECK_EQ(Resp.Ranges[1].RangeLength, 12); + } + + SUBCASE("non-range 200 response has empty Ranges") + { + FaultTcpServer Server([&](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(200, "full content"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + }); + + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Download("/test", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK(Resp.Ranges.empty()); + } +} + +TEST_CASE("httpclient.transport-faults" * doctest::skip()) +{ + SUBCASE("connection reset before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection closed before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("partial headers then close") + { + // libcurl parses the status line (200 OK) and accepts the response even though + // headers are truncated mid-field. It reports success with an empty body instead + // of an error. Ideally this should be detected as a transport failure. + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Partial = "HTTP/1.1 200 OK\r\nContent-"; + std::error_code Ec; + asio::write(Socket, asio::buffer(Partial), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + WARN(!Resp.IsSuccess()); + WARN(Resp.Error.has_value()); + } + + SUBCASE("truncated body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(100, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection reset mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 10000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(1000, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("stalled response triggers timeout") + { + std::atomic<bool> StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("retry succeeds after transient failures") + { + std::atomic<int> ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + DrainHttpRequest(Socket); + if (N < 2) + { + // Connection reset produces NETWORK_SEND_FAILURE which is retryable + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + std::string Response = MakeRawHttpResponse(200, "recovered"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "recovered"); + } +} + +TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) +{ + constexpr size_t kPostBodySize = 256 * 1024; + + auto MakePostBody = []() -> IoBuffer { + IoBuffer Buf(kPostBodySize); + uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData()); + for (size_t i = 0; i < kPostBodySize; ++i) + { + Ptr[i] = static_cast<uint8_t>(i & 0xFF); + } + Buf.SetContentType(ZenContentType::kBinary); + return Buf; + }; + + SUBCASE("POST: server resets before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server closes before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server resets mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainPartialBody(Socket, 8 * 1024); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: early error response before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(503, "service busy"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + // With a large upload body, the server may RST the connection before the client + // reads the 503 response. Either outcome is valid: the client sees the HTTP 503 + // status, or it sees a transport-level error from the RST. + CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value())); + } + + SUBCASE("POST: stalled upload triggers timeout") + { + std::atomic<bool> StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + // Stop reading body — TCP window will fill and client send will stall + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(2000); + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("POST: retry with large body after transient failure") + { + std::atomic<int> ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + if (N < 2) + { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + DrainFullHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(200, "upload-ok"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "upload-ok"); + } +} + +TEST_CASE("httpclient.unixsocket") +{ + ScopedTemporaryDirectory TmpDir; + std::string SocketPath = (TmpDir.Path() / "zen.sock").string(); + + HttpClientTestService TestService; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != -1); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto _ = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + HttpClientSettings Settings; + Settings.UnixSocketPath = SocketPath; + + HttpClient Client("localhost", Settings, /*CheckIfAbortFunction*/ {}); + + SUBCASE("GET over unix socket") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("POST echo over unix socket") + { + const char* Payload = "unix socket payload"; + IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload)); + Body.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "unix socket payload"); + } +} + +# if ZEN_USE_OPENSSL + +TEST_CASE("httpclient.https") +{ + // Self-signed test certificate for localhost/127.0.0.1, valid until 2036 + static constexpr std::string_view TestCertPem = + "-----BEGIN CERTIFICATE-----\n" + "MIIDJTCCAg2gAwIBAgIUEtJYMSUmJmvJ157We/qXNVJ7W8gwDQYJKoZIhvcNAQEL\n" + "BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMwOTIwMjU1M1oXDTM2MDMw\n" + "NjIwMjU1M1owFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\n" + "AAOCAQ8AMIIBCgKCAQEAv9YvZ6WeBz3z/Zuxi6OIivWksDxDZZ5oAXKVwlUXaa7v\n" + "iDkm9P5ZsEhN+M5vZMe2Yb9i3cnTUaE6Avs1ddOwTAYNGrE/B5DmibrRWc23R0cv\n" + "gdnYQJ+gjsAeMvUWYLK58xW4YoMR5bmfpj1ruqobUNkG/oJYnAUcjgo4J149irW+\n" + "4n9uLJvxL+5fI/b/AIkv+4TMe70/d/BPmnixWrrzxUT6S5ghE2Mq7+XLScfpY2Sp\n" + "GQ/Xbnj9/ELYLpQnNLuVZwWZDpXj+FLbF1zxgjYdw1cCjbRcOIEW2/GJeJvGXQ6Y\n" + "Vld5pCBm9uKPPLWoFCoakK5YvP00h+8X+HghGVSscQIDAQABo28wbTAdBgNVHQ4E\n" + "FgQUgM6hjymi6g2EBUg2ENu0nIK8yhMwHwYDVR0jBBgwFoAUgM6hjymi6g2EBUg2\n" + "ENu0nIK8yhMwDwYDVR0TAQH/BAUwAwEB/zAaBgNVHREEEzARhwR/AAABgglsb2Nh\n" + "bGhvc3QwDQYJKoZIhvcNAQELBQADggEBABY1oaaWwL4RaK/epKvk/IrmVT2mlAai\n" + "uvGLfjhc6FGvXaxPGTSUPrVbFornaWZAg7bOWCexWnEm2sWd75V/usvZAPN4aIiD\n" + "H66YQipq3OD4F9Gowp01IU4AcGh7MerFpYPk76+wp2ANq71x8axtlZjVn3hSFMmN\n" + "i6m9S/eyCl9WjYBT5ZEC4fJV0nOSmNe/+gCAm11/js9zNfXKmUchJtuZpubY3A0k\n" + "X2II6qYWf1PH+JJkefNZtt2c66CrEN5eAg4/rGEgsp43zcd4ZHVkpBKFLDEls1ev\n" + "drQ45zc4Ht77pHfnHu7YsLcRZ9Wq3COMNZYx5lItqnomX2qBm1pkwjI=\n" + "-----END CERTIFICATE-----\n"; + + static constexpr std::string_view TestKeyPem = + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC/1i9npZ4HPfP9\n" + "m7GLo4iK9aSwPENlnmgBcpXCVRdpru+IOSb0/lmwSE34zm9kx7Zhv2LdydNRoToC\n" + "+zV107BMBg0asT8HkOaJutFZzbdHRy+B2dhAn6COwB4y9RZgsrnzFbhigxHluZ+m\n" + "PWu6qhtQ2Qb+glicBRyOCjgnXj2Ktb7if24sm/Ev7l8j9v8AiS/7hMx7vT938E+a\n" + "eLFauvPFRPpLmCETYyrv5ctJx+ljZKkZD9dueP38QtgulCc0u5VnBZkOleP4UtsX\n" + "XPGCNh3DVwKNtFw4gRbb8Yl4m8ZdDphWV3mkIGb24o88tagUKhqQrli8/TSH7xf4\n" + "eCEZVKxxAgMBAAECggEAILd9pDaZqfCF8SWhdQgx3Ekiii/s6qLGaCDLq7XpZUvB\n" + "bEEbBMNwNmFOcvV6B/0LfMYwLVUjZhOSGjoPlwXAVmbdy0SZVEgBGVI0LBWqgUyB\n" + "rKqjd/oBXvci71vfMiSpE+0LYjmqTryGnspw2gfy2qn4yGUgiZNRmGPjycsHweUL\n" + "V3FHm3cf0dyE4sJ0mjVqZzRT/unw2QOCE6FlY7M1XxZL88IWfn6G4lckdJTwoOP5\n" + "VPR2J3XbyhvCeXeDRCHKRXojWWR2HovWnDXQc95GRgCd0vYdHuIUM6RXVPZQvy3X\n" + "l0GwQKHNcVr1uwtYDgGKw0tNCUDvxdfQaWilTFuicQKBgQDvEYp+vL1hnF+AVdu3\n" + "elsYsHpFgExkTI8wnUMvGZrFiIQyCyVDU3jkG3kcKacI1bfwopXopaQCjrYk9epm\n" + "liOVm3/Xtr6e2ENa7w8TQbdK65PciQNOMxml6g8clRRBl0cwj+aI3nW/Kop1cdrR\n" + "A9Vo+8iPTO5gDcxTiIb45a6E3QKBgQDNbE009P6ewx9PU7Llkhb9VBgsb7oQN3EV\n" + "TCYd4taiN6FPnTuL/cdijAA8y04hiVT+Efo9TUN9NCl9HdHXQcjj7/n/eFLH0Pkw\n" + "OIK3QN49OfR88wivLMtwWxIog0tJjc9+7dR4bR4o1jTlIrasEIvUTuDJQ8MKGc9v\n" + "pBITua+SpQKBgE4raSKZqj7hd6Sp7kbnHiRLiB9znQbqtaNKuK4M7DuMsNUAKfYC\n" + "tDO5+/bGc9SCtTtcnjHM/3zKlyossrFKhGYlyz6IhXnA8v0nz8EXKsy3jMh+kHMg\n" + "aFGE394TrOTphyCM3O+B9fRE/7L5QHg5ja1fLqwUlpkXyejCaoe16kONAoGAYIz9\n" + "wN1B67cEOVG6rOI8QfdLoV8mEcctNHhlFfjvLrF89SGOwl6WX0A0QF7CK0sUEpK6\n" + "jiOJjAh/U5o3bbgyxsedNjEEn3weE0cMUTuA+UALJMtKEqO4PuffIgGL2ld35k28\n" + "ZpnK6iC8HdJyD297eV9VkeNygYXeFLgF8xV8ay0CgYEAh4fmVZt9YhgVByYny2kF\n" + "ZUIkGF5h9wxzVOPpQwpizIGFFb3i/ZdGQcuLTfIBVRKf50sT3IwJe65ATv6+Lz0f\n" + "wg/pMvosi0/F5KGbVRVdzBMQy58WyyGti4tNl+8EXGvo8+DCmjlTYwfjRoZGg/qJ\n" + "EMP3/hTN7dHDRxPK8E0Fh0Y=\n" + "-----END PRIVATE KEY-----\n"; + + ScopedTemporaryDirectory TmpDir; + + // Write cert and key to temp files + const auto CertPath = TmpDir.Path() / "test.crt"; + const auto KeyPath = TmpDir.Path() / "test.key"; + WriteFile(CertPath, IoBuffer(IoBuffer::Clone, TestCertPem.data(), TestCertPem.size())); + WriteFile(KeyPath, IoBuffer(IoBuffer::Clone, TestKeyPem.data(), TestKeyPem.size())); + + HttpClientTestService TestService; + + AsioConfig Config; + Config.CertFile = CertPath.string(); + Config.KeyFile = KeyPath.string(); + + Ref<HttpServer> Server = CreateHttpAsioServer(Config); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != -1); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto _ = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + int HttpsPort = Server->GetEffectiveHttpsPort(); + REQUIRE(HttpsPort > 0); + + HttpClientSettings Settings; + Settings.InsecureSsl = true; + + HttpClient Client(fmt::format("https://127.0.0.1:{}", HttpsPort), Settings, /*CheckIfAbortFunction*/ {}); + + SUBCASE("GET over HTTPS") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("POST echo over HTTPS") + { + const char* Payload = "https payload"; + IoBuffer Body(IoBuffer::Clone, Payload, strlen(Payload)); + Body.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "https payload"); + } + + SUBCASE("GET JSON over HTTPS") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK_EQ(Obj["ok"].AsBool(), true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("Large payload over HTTPS") + { + HttpClient::Response Resp = Client.Get("/api/test/large"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +# endif // ZEN_USE_OPENSSL + +TEST_SUITE_END(); + +void +httpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 761665c30..69000dd8e 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -2,6 +2,8 @@ #include <zenhttp/httpserver.h> +#include <zencore/filesystem.h> + #include "servers/httpasio.h" #include "servers/httpmulti.h" #include "servers/httpnull.h" @@ -23,10 +25,12 @@ #include <zencore/logging.h> #include <zencore/stream.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/testing.h> #include <zencore/thread.h> #include <zenhttp/packageformat.h> #include <zentelemetry/otlptrace.h> +#include <zentelemetry/stats.h> #include <charconv> #include <mutex> @@ -745,6 +749,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand { if (UriPattern[i] == '}') { + if (i == PatternStart) + { + throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); + } std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { @@ -910,8 +918,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) CapturedSegments.emplace_back(Uri); - for (int MatcherIndex : Matchers) + for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++) { + int MatcherIndex = Matchers[MatcherOffset]; if (UriPos >= UriLen) { IsMatch = false; @@ -921,9 +930,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (MatcherIndex < 0) { // Literal match - int LitIndex = -MatcherIndex - 1; - const std::string& LitStr = m_Literals[LitIndex]; - size_t LitLen = LitStr.length(); + int LitIndex = -MatcherIndex - 1; + std::string_view LitStr = m_Literals[LitIndex]; + size_t LitLen = LitStr.length(); if (Uri.substr(UriPos, LitLen) == LitStr) { @@ -939,9 +948,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) { // Matcher function size_t SegmentStart = UriPos; - while (UriPos < UriLen && Uri[UriPos] != '/') + + if (MatcherOffset == (Matchers.size() - 1)) { - ++UriPos; + // Last matcher, use the remaining part of the uri + UriPos = UriLen; + } + else + { + while (UriPos < UriLen && Uri[UriPos] != '/') + { + ++UriPos; + } } std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart); @@ -1014,7 +1032,31 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) int HttpServer::Initialize(int BasePort, std::filesystem::path DataDir) { - return OnInitialize(BasePort, std::move(DataDir)); + m_EffectivePort = OnInitialize(BasePort, std::move(DataDir)); + m_ExternalHost = OnGetExternalHost(); + return m_EffectivePort; +} + +std::string +HttpServer::OnGetExternalHost() const +{ + return GetMachineName(); +} + +std::string +HttpServer::GetServiceUri(const HttpService* Service) const +{ + const char* Scheme = (m_EffectiveHttpsPort > 0) ? "https" : "http"; + int Port = (m_EffectiveHttpsPort > 0) ? m_EffectiveHttpsPort : m_EffectivePort; + + if (Service) + { + return fmt::format("{}://{}:{}{}", Scheme, m_ExternalHost, Port, Service->BaseUri()); + } + else + { + return fmt::format("{}://{}:{}", Scheme, m_ExternalHost, Port); + } } void @@ -1058,6 +1100,39 @@ HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) OnSetHttpRequestFilter(RequestFilter); } +CbObject +HttpServer::CollectStats() +{ + CbObjectWriter Cbo; + + metrics::EmitSnapshot("requests", m_RequestMeter, Cbo); + + Cbo.BeginObject("bytes"); + { + Cbo << "received" << GetTotalBytesReceived(); + Cbo << "sent" << GetTotalBytesSent(); + } + Cbo.EndObject(); + + Cbo.BeginObject("websockets"); + { + Cbo << "active_connections" << GetActiveWebSocketConnectionCount(); + Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed); + Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed); + Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed); + Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed); + } + Cbo.EndObject(); + + return Cbo.Save(); +} + +void +HttpServer::HandleStatsRequest(HttpServerRequest& Request) +{ + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); +} + ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() @@ -1082,9 +1157,13 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig if (ServerClass == "asio"sv) { ZEN_INFO("using asio HTTP server implementation") - return CreateHttpAsioServer(AsioConfig{.ThreadCount = Config.ThreadCount, - .ForceLoopback = Config.ForceLoopback, - .IsDedicatedServer = Config.IsDedicatedServer}); + return CreateHttpAsioServer(AsioConfig { + .ThreadCount = Config.ThreadCount, .ForceLoopback = Config.ForceLoopback, .IsDedicatedServer = Config.IsDedicatedServer, + .NoNetwork = Config.NoNetwork, .UnixSocketPath = PathToUtf8(Config.UnixSocketPath), +#if ZEN_USE_OPENSSL + .HttpsPort = Config.HttpsPort, .CertFile = Config.CertFile, .KeyFile = Config.KeyFile, +#endif + }); } #if ZEN_WITH_HTTPSYS else if (ServerClass == "httpsys"sv) @@ -1096,7 +1175,11 @@ CreateHttpServerClass(const std::string_view ServerClass, const HttpServerConfig .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled, .IsDedicatedServer = Config.IsDedicatedServer, .ForceLoopback = Config.ForceLoopback, - .UseExplicitIoThreadPool = Config.HttpSys.UseExplicitIoThreadPool})); + .UseExplicitIoThreadPool = Config.HttpSys.UseExplicitIoThreadPool, + .HttpsPort = Config.HttpSys.HttpsPort, + .CertThumbprint = Config.HttpSys.CertThumbprint, + .CertStoreName = Config.HttpSys.CertStoreName, + .HttpsOnly = Config.HttpSys.HttpsOnly})); } #endif else if (ServerClass == "null"sv) @@ -1301,6 +1384,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpserver"); + TEST_CASE("http.common") { using namespace std::literals; @@ -1406,20 +1491,33 @@ TEST_CASE("http.common") SUBCASE("router-matcher") { - bool HandledA = false; - bool HandledAA = false; - bool HandledAB = false; - bool HandledAandB = false; + bool HandledA = false; + bool HandledAA = false; + bool HandledAB = false; + bool HandledAandB = false; + bool HandledAandPath = false; std::vector<std::string> Captures; auto Reset = [&] { - HandledA = HandledAA = HandledAB = HandledAandB = false; + HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false; Captures.clear(); }; TestHttpService Service; HttpRequestRouter r; - r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; }); - r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; }); + + r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; }); + r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; }); + static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); }); + + r.RegisterRoute( + "path/{a}/{path}", + [&](auto& Req) { + HandledAandPath = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + r.RegisterRoute( "{a}", [&](auto& Req) { @@ -1448,7 +1546,6 @@ TEST_CASE("http.common") Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); - { Reset(); TestHttpServerRequest req{Service, "ab"sv}; @@ -1456,6 +1553,7 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 1); CHECK_EQ(Captures[0], "ab"sv); @@ -1468,6 +1566,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1481,6 +1580,7 @@ TEST_CASE("http.common") CHECK(!HandledAA); CHECK(!HandledAB); CHECK(HandledAandB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1493,6 +1593,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); } { @@ -1502,6 +1603,35 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); + REQUIRE_EQ(Captures.size(), 1); + CHECK_EQ(Captures[0], "a123"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "simple_path.txt"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "directory/and/path.txt"sv); } } @@ -1519,6 +1649,8 @@ TEST_CASE("http.common") } } +TEST_SUITE_END(); + void http_forcelink() { diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index c252a5d99..3cfe652c5 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -2,17 +2,19 @@ #pragma once -#include <zencore/compactbinary.h> -#include <zencore/compactbinaryvalidation.h> -#include <zencore/iobuffer.h> -#include <zencore/string.h> -#include <zenhttp/formatters.h> -#include <zenhttp/httpclient.h> -#include <zenhttp/httpcommon.h> +#if ZEN_WITH_CPR + +# include <zencore/compactbinary.h> +# include <zencore/compactbinaryvalidation.h> +# include <zencore/iobuffer.h> +# include <zencore/string.h> +# include <zenhttp/formatters.h> +# include <zenhttp/httpclient.h> +# include <zenhttp/httpcommon.h> ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/response.h> -#include <fmt/format.h> +# include <cpr/response.h> +# include <fmt/format.h> ZEN_THIRD_PARTY_INCLUDES_END template<> @@ -92,3 +94,5 @@ struct fmt::formatter<cpr::Response> } } }; + +#endif // ZEN_WITH_CPR diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h index addb00cb8..90180391c 100644 --- a/src/zenhttp/include/zenhttp/formatters.h +++ b/src/zenhttp/include/zenhttp/formatters.h @@ -73,7 +73,7 @@ struct fmt::formatter<zen::HttpClient::Response> if (Response.IsSuccess()) { return fmt::format_to(Ctx.out(), - "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", + "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}", ToString(Response.StatusCode), Response.UploadedBytes, Response.DownloadedBytes, @@ -84,7 +84,7 @@ struct fmt::formatter<zen::HttpClient::Response> return fmt::format_to(Ctx.out(), "Failed: Elapsed: {}, Reason: ({}) '{}", NiceResponseTime, - Response.Error.value().ErrorCode, + static_cast<int>(Response.Error.value().ErrorCode), Response.Error.value().ErrorMessage); } else diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h index 0270973bf..2d384d1d8 100644 --- a/src/zenhttp/include/zenhttp/httpapiservice.h +++ b/src/zenhttp/include/zenhttp/httpapiservice.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include <zenhttp/httpserver.h> diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 9a9b74d72..e878c900f 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -10,9 +10,11 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <filesystem> #include <functional> #include <optional> #include <unordered_map> +#include <vector> namespace zen { @@ -29,6 +31,36 @@ class CompositeBuffer; */ +enum class HttpClientErrorCode : int +{ + kOK = 0, + kConnectionFailure, + kHostResolutionFailure, + kProxyResolutionFailure, + kInternalError, + kNetworkReceiveError, + kNetworkSendFailure, + kOperationTimedOut, + kSSLConnectError, + kSSLCertificateError, + kSSLCACertError, + kGenericSSLError, + kRequestCancelled, + kOtherError, +}; + +enum class HttpClientBackend : uint8_t +{ + kDefault, +#if ZEN_WITH_CPR + kCpr, +#endif + kCurl, +}; + +void SetDefaultHttpClientBackend(std::string_view Backend); +void SetDefaultHttpClientBackend(HttpClientBackend Backend); + struct HttpClientAccessToken { using Clock = std::chrono::system_clock; @@ -58,6 +90,26 @@ struct HttpClientSettings Oid SessionId = Oid::Zero; bool Verbose = false; uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u; + HttpClientBackend Backend = HttpClientBackend::kDefault; + + /// Unix domain socket path. When non-empty, the client connects via this + /// socket instead of TCP. BaseUri is still used for the Host header and URL. + std::filesystem::path UnixSocketPath; + + /// Disable HTTP keep-alive by closing the connection after each request. + /// Useful for testing per-connection overhead. + bool ForbidReuseConnection = false; + + /// Skip TLS certificate verification (for testing with self-signed certs). + bool InsecureSsl = false; + + /// CA certificate bundle path for TLS verification. When non-empty, overrides + /// the system default CA store. + std::string CaBundlePath; + + /// HTTP status codes that are expected and should not be logged as warnings. + /// 404 is always treated as expected regardless of this list. + std::vector<HttpResponseCode> ExpectedErrorCodes; }; class HttpClientError : public std::runtime_error @@ -65,22 +117,22 @@ class HttpClientError : public std::runtime_error public: using _Mybase = runtime_error; - HttpClientError(const std::string& Message, int Error, HttpResponseCode ResponseCode) + HttpClientError(const std::string& Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode) : _Mybase(Message) , m_Error(Error) , m_ResponseCode(ResponseCode) { } - HttpClientError(const char* Message, int Error, HttpResponseCode ResponseCode) + HttpClientError(const char* Message, HttpClientErrorCode Error, HttpResponseCode ResponseCode) : _Mybase(Message) , m_Error(Error) , m_ResponseCode(ResponseCode) { } - inline int GetInternalErrorCode() const { return m_Error; } - inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; } + inline HttpClientErrorCode GetInternalErrorCode() const { return m_Error; } + inline HttpResponseCode GetHttpResponseCode() const { return m_ResponseCode; } enum class ResponseClass : std::int8_t { @@ -107,24 +159,51 @@ public: ResponseClass GetResponseClass() const; private: - const int m_Error = 0; - const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot; + const HttpClientErrorCode m_Error = HttpClientErrorCode::kOK; + const HttpResponseCode m_ResponseCode = HttpResponseCode::ImATeapot; }; class HttpClientBase; +/** HTTP Client + * + * This is safe for use on multiple threads simultaneously, as each + * instance maintains an internal connection pool and will synchronize + * access to it as needed. + * + * Uses libcurl under the hood. We currently only use HTTP 1.1 features. + * + */ class HttpClient { public: - HttpClient(std::string_view BaseUri, - const HttpClientSettings& Connectionsettings = {}, - std::function<bool()>&& CheckIfAbortFunction = {}); + explicit HttpClient(std::string_view BaseUri, + const HttpClientSettings& Connectionsettings = {}, + std::function<bool()>&& CheckIfAbortFunction = {}); ~HttpClient(); + HttpClient(const HttpClient&) = delete; + HttpClient& operator=(const HttpClient&) = delete; + struct ErrorContext { - int ErrorCode; - std::string ErrorMessage; + HttpClientErrorCode ErrorCode; + std::string ErrorMessage; + + /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */ + bool IsConnectionError() const + { + switch (ErrorCode) + { + case HttpClientErrorCode::kConnectionFailure: + case HttpClientErrorCode::kOperationTimedOut: + case HttpClientErrorCode::kHostResolutionFailure: + case HttpClientErrorCode::kProxyResolutionFailure: + return true; + default: + return false; + } + } }; struct KeyValueMap @@ -171,13 +250,29 @@ public: KeyValueMap Header; // The number of bytes sent as part of the request - int64_t UploadedBytes; + int64_t UploadedBytes = 0; // The number of bytes received as part of the response - int64_t DownloadedBytes; + int64_t DownloadedBytes = 0; // The elapsed time in seconds for the request to execute - double ElapsedSeconds; + double ElapsedSeconds = 0.0; + + struct MultipartBoundary + { + uint64_t OffsetInPayload = 0; + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType; + }; + + // Ranges will map out all received ranges, both single and multi-range responses + // If no range was requested Ranges will be empty + std::vector<MultipartBoundary> Ranges; + + // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges). + // If the response was not a partial response, an empty vector will be returned + std::vector<std::pair<uint64_t, uint64_t>> GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const; // This contains any errors from the HTTP stack. It won't contain information on // why the server responded with a non-success HTTP status, that may be gleaned @@ -226,7 +321,10 @@ public: const IoBuffer& Payload, ZenContentType ContentType, const KeyValueMap& AdditionalHeader = {}); - [[nodiscard]] Response Post(std::string_view Url, CbObject Payload, const KeyValueMap& AdditionalHeader = {}); + [[nodiscard]] Response Post(std::string_view Url, + CbObject Payload, + const KeyValueMap& AdditionalHeader = {}, + const std::filesystem::path& TempFolderPath = {}); [[nodiscard]] Response Post(std::string_view Url, CbPackage Payload, const KeyValueMap& AdditionalHeader = {}); [[nodiscard]] Response Post(std::string_view Url, const CompositeBuffer& Payload, @@ -260,6 +358,16 @@ private: const HttpClientSettings m_ConnectionSettings; }; -void httpclient_forcelink(); // internal +struct LatencyTestResult +{ + bool Success = false; + std::string FailureReason; + double LatencySeconds = -1.0; +}; + +LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url); + +void httpclient_forcelink(); // internal +void httpclient_test_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index bc18549c9..8fca35ac5 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept return IsHttpSuccessCode(int(HttpCode)); } +[[nodiscard]] inline bool +IsHttpOk(HttpResponseCode HttpCode) noexcept +{ + return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted || + HttpCode == HttpResponseCode::NoContent; +} + std::string_view ToString(HttpResponseCode HttpCode); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 7887beacd..42e5b1628 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -13,6 +13,9 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zentelemetry/stats.h> + +#include <filesystem> #include <functional> #include <gsl/gsl-lite.hpp> #include <list> @@ -103,6 +106,7 @@ public: virtual bool IsLocalMachineRequest() const = 0; virtual std::string_view GetAuthorizationHeader() const = 0; + virtual std::string_view GetRemoteAddress() const { return {}; } /** Respond with payload @@ -202,12 +206,34 @@ private: int m_UriPrefixLength = 0; }; +struct IHttpStatsProvider +{ + /** Handle an HTTP stats request, writing the response directly. + * Implementations may inspect query parameters on the request + * to include optional detailed breakdowns. + */ + virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; + + /** Return the provider's current stats as a CbObject snapshot. + * Used by the WebSocket push thread to broadcast live updates + * without requiring an HttpServerRequest. Providers that do + * not override this will be skipped in WebSocket broadcasts. + */ + virtual CbObject CollectStats() { return {}; } +}; + +struct IHttpStatsService +{ + virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; + virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; +}; + /** HTTP server * * Implements the main event loop to service HTTP requests, and handles routing * requests to the appropriate handler as registered via RegisterService */ -class HttpServer : public RefCounted +class HttpServer : public RefCounted, public IHttpStatsProvider { public: void RegisterService(HttpService& Service); @@ -219,8 +245,65 @@ public: void RequestExit(); void Close(); + /** Returns a canonical http:// URI for the given service, using the external + * IP and the port the server is actually listening on. Only valid + * after Initialize() has returned successfully. + */ + std::string GetServiceUri(const HttpService* Service) const; + + /** Returns the external host string (IP or hostname) determined during Initialize(). + * Only valid after Initialize() has returned successfully. + */ + std::string_view GetExternalHost() const { return m_ExternalHost; } + + /** Returns the effective HTTPS port, or 0 if HTTPS is not enabled. Only valid after Initialize(). */ + int GetEffectiveHttpsPort() const { return m_EffectiveHttpsPort; } + + /** Returns total bytes received and sent across all connections since server start. */ + virtual uint64_t GetTotalBytesReceived() const { return 0; } + virtual uint64_t GetTotalBytesSent() const { return 0; } + + /** Mark that a request has been handled. Called by server implementations. */ + void MarkRequest() { m_RequestMeter.Mark(); } + + /** Set a default redirect path for root requests */ + void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; } + + std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } + + /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } + void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } + uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } + + /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + void OnWebSocketFrameReceived(uint64_t Bytes) + { + m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); + m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed); + } + void OnWebSocketFrameSent(uint64_t Bytes) + { + m_WsFramesSent.fetch_add(1, std::memory_order_relaxed); + m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed); + } + + // IHttpStatsProvider + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + private: std::vector<HttpService*> m_KnownServices; + int m_EffectivePort = 0; + int m_EffectiveHttpsPort = 0; + std::string m_ExternalHost; + metrics::Meter m_RequestMeter; + std::string m_DefaultRedirect; + std::atomic<uint64_t> m_ActiveWebSocketConnections{0}; + std::atomic<uint64_t> m_WsFramesReceived{0}; + std::atomic<uint64_t> m_WsFramesSent{0}; + std::atomic<uint64_t> m_WsBytesReceived{0}; + std::atomic<uint64_t> m_WsBytesSent{0}; virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; @@ -228,6 +311,10 @@ private: virtual void OnRun(bool IsInteractiveSession) = 0; virtual void OnRequestExit() = 0; virtual void OnClose() = 0; + +protected: + void SetEffectiveHttpsPort(int Port) { m_EffectiveHttpsPort = Port; } + virtual std::string OnGetExternalHost() const; }; struct HttpServerPluginConfig @@ -243,6 +330,11 @@ struct HttpServerConfig std::vector<HttpServerPluginConfig> PluginConfigs; bool ForceLoopback = false; unsigned int ThreadCount = 0; + std::filesystem::path UnixSocketPath; // Unix domain socket path (empty = disabled) + bool NoNetwork = false; // Disable TCP/HTTPS listeners; only accept connections via UnixSocketPath + int HttpsPort = 0; // HTTPS listen port (0 = disabled, ASIO backend) + std::string CertFile; // PEM certificate chain file path + std::string KeyFile; // PEM private key file path struct { @@ -250,6 +342,10 @@ struct HttpServerConfig bool IsAsyncResponseEnabled = true; bool IsRequestLoggingEnabled = false; bool UseExplicitIoThreadPool = false; + int HttpsPort = 0; // 0 = HTTPS disabled + std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding + std::string CertStoreName = "MY"; // Windows certificate store name + bool HttpsOnly = false; // When true, disable HTTP listener } HttpSys; }; @@ -420,7 +516,7 @@ public: ~HttpRpcHandler(); HttpRpcHandler(const HttpRpcHandler&) = delete; - HttpRpcHandler operator=(const HttpRpcHandler&) = delete; + HttpRpcHandler& operator=(const HttpRpcHandler&) = delete; void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction); @@ -436,17 +532,7 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); -struct IHttpStatsProvider -{ - virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; -}; - -struct IHttpStatsService -{ - virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; - virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; -}; - -void http_forcelink(); // internal +void http_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index e6fea6765..460315faf 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -3,23 +3,50 @@ #pragma once #include <zencore/logging.h> +#include <zencore/thread.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <atomic> #include <map> +#include <memory> +#include <thread> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -class HttpStatsService : public HttpService, public IHttpStatsService +class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - HttpStatsService(); + /// Construct without an io_context — optionally uses a dedicated push thread + /// for WebSocket stats broadcasting. + explicit HttpStatsService(bool EnableWebSockets = false); + + /// Construct with an external io_context — uses an asio timer instead + /// of a dedicated thread for WebSocket stats broadcasting. + /// The caller must ensure the io_context outlives this service and that + /// its run loop is active. + HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true); + ~HttpStatsService(); + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + private: LoggerRef m_Log; HttpRequestRouter m_Router; @@ -28,6 +55,22 @@ private: RwLock m_Lock; std::map<std::string, IHttpStatsProvider*> m_Providers; + + // WebSocket push + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::atomic<bool> m_PushEnabled{false}; + + void BroadcastStats(); + + // Thread-based push (when no io_context is provided) + std::thread m_PushThread; + Event m_PushEvent; + void PushThreadFunction(); + + // Timer-based push (when an io_context is provided) + std::unique_ptr<asio::steady_timer> m_PushTimer; + void EnqueuePushTimer(); }; } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h new file mode 100644 index 000000000..2ca9b7ab1 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -0,0 +1,83 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpclient.h> +#include <zenhttp/websocket.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <chrono> +#include <cstdint> +#include <functional> +#include <memory> +#include <optional> +#include <span> +#include <string> +#include <string_view> + +namespace zen { + +/** + * Callback interface for WebSocket client events + * + * Separate from the server-side IWebSocketHandler because the caller + * already owns the HttpWsClient — no Ref<WebSocketConnection> needed. + */ +class IWsClientHandler +{ +public: + virtual ~IWsClientHandler() = default; + + virtual void OnWsOpen() = 0; + virtual void OnWsMessage(const WebSocketMessage& Msg) = 0; + virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0; +}; + +struct HttpWsClientSettings +{ + std::string LogCategory = "wsclient"; + std::chrono::milliseconds ConnectTimeout{5000}; + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; + + /// Unix domain socket path. When non-empty, connects via this socket + /// instead of TCP. The URL host is still used for the Host header. + std::filesystem::path UnixSocketPath; +}; + +/** + * WebSocket client over TCP (ws:// scheme) + * + * Uses ASIO for async I/O. Two construction modes: + * - Internal io_context + background thread (standalone use) + * - External io_context (shared event loop, no internal thread) + * + * Thread-safe for SendText/SendBinary/Close. + */ +class HttpWsClient +{ +public: + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {}); + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {}); + + ~HttpWsClient(); + + HttpWsClient(const HttpWsClient&) = delete; + HttpWsClient& operator=(const HttpWsClient&) = delete; + + void Connect(); + void SendText(std::string_view Text); + void SendBinary(std::span<const uint8_t> Data); + void Close(uint16_t Code = 1000, std::string_view Reason = {}); + bool IsOpen() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index c90b840da..1a5068580 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -68,7 +68,7 @@ struct CbAttachmentEntry struct CbAttachmentReferenceHeader { uint64_t PayloadByteOffset = 0; - uint64_t PayloadByteSize = ~0u; + uint64_t PayloadByteSize = ~uint64_t(0); uint16_t AbsolutePathLength = 0; // This header will be followed by UTF8 encoded absolute path to backing file diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..bc3293282 --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> +#include <zencore/iobuffer.h> + +#include <cstdint> +#include <span> +#include <string_view> + +namespace zen { + +enum class WebSocketOpcode : uint8_t +{ + kText = 0x1, + kBinary = 0x2, + kClose = 0x8, + kPing = 0x9, + kPong = 0xA +}; + +struct WebSocketMessage +{ + WebSocketOpcode Opcode = WebSocketOpcode::kText; + IoBuffer Payload; + uint16_t CloseCode = 0; +}; + +/** + * Represents an active WebSocket connection + * + * Derived classes implement the actual transport (e.g. ASIO sockets). + * Instances are reference-counted so that both the service layer and + * the async read/write loop can share ownership. + */ +class WebSocketConnection : public RefCounted +{ +public: + virtual ~WebSocketConnection() = default; + + virtual void SendText(std::string_view Text) = 0; + virtual void SendBinary(std::span<const uint8_t> Data) = 0; + virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; + virtual bool IsOpen() const = 0; +}; + +/** + * Interface for services that accept WebSocket upgrades + * + * An HttpService may additionally implement this interface to indicate + * it supports WebSocket connections. The HTTP server checks for this + * via dynamic_cast when it sees an Upgrade: websocket request. + */ +class IWebSocketHandler +{ +public: + virtual ~IWebSocketHandler() = default; + + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; + virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index b097a0d3f..3877215a8 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -3,15 +3,57 @@ #include "zenhttp/httpstats.h" #include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/trace.h> namespace zen { -HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats")) +HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats")) { + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); + } +} + +HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats")) +{ + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushTimer = std::make_unique<asio::steady_timer>(IoContext); + EnqueuePushTimer(); + } } HttpStatsService::~HttpStatsService() { + Shutdown(); +} + +void +HttpStatsService::Shutdown() +{ + if (!m_PushEnabled.exchange(false)) + { + return; + } + + if (m_PushTimer) + { + m_PushTimer->cancel(); + m_PushTimer.reset(); + } + + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + m_PushThread.join(); + } + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); } const char* @@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro void HttpStatsService::HandleRequest(HttpServerRequest& Request) { + ZEN_TRACE_CPU("HttpStatsService::HandleRequest"); using namespace std::literals; std::string_view Key = Request.RelativeUri(); @@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); + ZEN_INFO("Stats WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Send initial state immediately + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + } +} + +void +HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/) +{ + // No client-to-server messages expected +} + +void +HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose"); + ZEN_INFO("Stats WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +////////////////////////////////////////////////////////////////////////// +// +// Stats broadcast +// + +void +HttpStatsService::BroadcastStats() +{ + ZEN_TRACE_CPU("HttpStatsService::BroadcastStats"); + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + return; + } + + // Collect stats from all providers + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + + bool First = true; + { + RwLock::SharedLockScope _(m_Lock); + for (auto& [Id, Provider] : m_Providers) + { + CbObject Stats = Provider->CollectStats(); + if (!Stats) + { + continue; + } + + if (!First) + { + JsonBuilder.Append(","); + } + First = false; + + // Emit as "provider_id": { ... } + JsonBuilder.Append("\""); + JsonBuilder.Append(Id); + JsonBuilder.Append("\":"); + + ExtendableStringBuilder<2048> StatsJson; + Stats.ToJson(StatsJson); + JsonBuilder.Append(StatsJson.ToView()); + } + } + + JsonBuilder.Append("}"); + + std::string_view Json = JsonBuilder.ToView(); + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Thread-based push (fallback when no io_context) +// + +void +HttpStatsService::PushThreadFunction() +{ + SetCurrentThreadName("stats_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(1000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + BroadcastStats(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Timer-based push (when io_context is provided) +// + +void +HttpStatsService::EnqueuePushTimer() +{ + if (!m_PushTimer) + { + return; + } + + m_PushTimer->expires_after(std::chrono::seconds(1)); + m_PushTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; + } + + BroadcastStats(); + EnqueuePushTimer(); + }); +} + } // namespace zen diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 708238224..9c62c1f2d 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -575,13 +575,21 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint } else if (AttachmentSize > 0) { - // Make a copy of the buffer so the attachments don't reference the entire payload - IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); - ZEN_ASSERT(AttachmentBufferCopy); - ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); - AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); + IoBufferFileReference TestIfFileRef; + if (AttachmentBuffer.GetFileReference(TestIfFileRef)) + { + Attachments.emplace_back(CbAttachment(SharedBuffer{std::move(AttachmentBuffer)}, Entry.AttachmentHash)); + } + else + { + // Make a copy of the buffer so the attachments don't reference the entire payload + IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); + ZEN_ASSERT(AttachmentBufferCopy); + ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); + AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); - Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); + Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash)); + } } else { @@ -805,6 +813,8 @@ CbPackageReader::Finalize() #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.packageformat"); + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -926,6 +936,8 @@ TEST_CASE("CbPackage.LocalRef") Reader.Finalize(); } +TEST_SUITE_END(); + void forcelink_packageformat() { diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp index a8fb9c3f5..0e3a743c3 100644 --- a/src/zenhttp/security/passwordsecurity.cpp +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -76,6 +76,8 @@ PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUr #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.passwordsecurity"); + TEST_CASE("passwordsecurity.allowanything") { PasswordSecurity Anything({}); @@ -162,6 +164,9 @@ TEST_CASE("passwordsecurity.conflictingunprotecteduris") "uri #1 ('/free/access')")); } } + +TEST_SUITE_END(); + void passwordsecurity_forcelink() { diff --git a/src/zenhttp/servers/asio_socket_traits.h b/src/zenhttp/servers/asio_socket_traits.h new file mode 100644 index 000000000..25aeaa24e --- /dev/null +++ b/src/zenhttp/servers/asio_socket_traits.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::asio_http { + +/** + * Traits for abstracting socket shutdown/close across plain TCP, Unix domain, and SSL sockets. + * SSL sockets need lowest_layer() access and have different shutdown semantics. + */ +template<typename SocketType> +struct SocketTraits +{ + /// SSL sockets cannot use zero-copy file send (TransmitFile/sendfile) because + /// those bypass the encryption layer. This flag lets templated code fall back + /// to reading-into-memory for SSL connections. + static constexpr bool IsSslSocket = false; + + static void ShutdownReceive(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_receive, Ec); } + + static void ShutdownBoth(SocketType& S, std::error_code& Ec) { S.shutdown(asio::socket_base::shutdown_both, Ec); } + + static void Close(SocketType& S, std::error_code& Ec) { S.close(Ec); } +}; + +#if ZEN_USE_OPENSSL +using SslSocket = asio::ssl::stream<asio::ip::tcp::socket>; + +template<> +struct SocketTraits<SslSocket> +{ + static constexpr bool IsSslSocket = true; + + static void ShutdownReceive(SslSocket& S, std::error_code& Ec) { S.lowest_layer().shutdown(asio::socket_base::shutdown_receive, Ec); } + + static void ShutdownBoth(SslSocket& S, std::error_code& Ec) + { + // Best-effort SSL close_notify, then TCP shutdown + S.shutdown(Ec); + S.lowest_layer().shutdown(asio::socket_base::shutdown_both, Ec); + } + + static void Close(SslSocket& S, std::error_code& Ec) { S.lowest_layer().close(Ec); } +}; +#endif + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 1c0ebef90..643f33618 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -1,18 +1,22 @@ // Copyright Epic Games, Inc. All Rights Reserved. #include "httpasio.h" +#include "asio_socket_traits.h" #include "httptracer.h" #include <zencore/except.h> #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/memory/llm.h> +#include <zencore/system.h> #include <zencore/thread.h> #include <zencore/trace.h> #include <zencore/windows.h> #include <zenhttp/httpserver.h> #include "httpparser.h" +#include "wsasio.h" +#include "wsframecodec.h" #include <EASTL/fixed_vector.h> @@ -32,6 +36,12 @@ ZEN_THIRD_PARTY_INCLUDES_START #endif #include <asio.hpp> #include <asio/stream_file.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif ZEN_THIRD_PARTY_INCLUDES_END #define ASIO_VERBOSE_TRACE 0 @@ -89,10 +99,10 @@ IsIPv6AvailableSysctl(void) char buf[16]; if (fgets(buf, sizeof(buf), f)) { - fclose(f); // 0 means IPv6 enabled, 1 means disabled val = atoi(buf); } + fclose(f); } return val == 0; @@ -141,13 +151,23 @@ using namespace std::literals; struct HttpAcceptor; struct HttpResponse; -struct HttpServerConnection; +template<typename SocketType> +struct HttpServerConnectionT; +using HttpServerConnection = HttpServerConnectionT<asio::ip::tcp::socket>; +#if defined(ASIO_HAS_LOCAL_SOCKETS) +struct UnixAcceptor; +using UnixServerConnection = HttpServerConnectionT<asio::local::stream_protocol::socket>; +#endif +#if ZEN_USE_OPENSSL +struct HttpsAcceptor; +using HttpsSslServerConnection = HttpServerConnectionT<SslSocket>; +#endif inline LoggerRef InitLogger() { LoggerRef Logger = logging::Get("asio"); - // Logger.SetLogLevel(logging::level::Trace); + // Logger.SetLogLevel(logging::Trace); return Logger; } @@ -173,9 +193,9 @@ Log() #endif #if ZEN_USE_TRANSMITFILE -template<typename Handler> +template<typename Handler, typename SocketType> void -TransmitFileAsync(asio::ip::tcp::socket& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb) +TransmitFileAsync(SocketType& Socket, HANDLE FileHandle, uint64_t ByteOffset, uint32_t ByteSize, Handler&& Cb) { # if ZEN_BUILD_DEBUG const uint64_t FileSize = FileSizeFromHandle(FileHandle); @@ -506,11 +526,22 @@ public: HttpService* RouteRequest(std::string_view Url); IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); - asio::io_service m_IoService; - asio::io_service::work m_Work{m_IoService}; - std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; - std::vector<std::thread> m_ThreadPool; - std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + bool IsLoopbackOnly() const; + + int GetEffectiveHttpsPort() const; + + asio::io_context m_IoService; + asio::executor_work_guard<asio::io_context::executor_type> m_Work{m_IoService.get_executor()}; + std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; +#if defined(ASIO_HAS_LOCAL_SOCKETS) + std::unique_ptr<asio_http::UnixAcceptor> m_UnixAcceptor; +#endif +#if ZEN_USE_OPENSSL + std::unique_ptr<asio::ssl::context> m_SslContext; + std::unique_ptr<asio_http::HttpsAcceptor> m_HttpsAcceptor; +#endif + std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -523,6 +554,11 @@ public: RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; + + HttpServer* m_HttpServer = nullptr; }; /** @@ -536,7 +572,8 @@ public: HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber, - bool IsLocalMachineRequest); + bool IsLocalMachineRequest, + std::string RemoteAddress); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; @@ -544,6 +581,7 @@ public: virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -561,6 +599,8 @@ public: uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; bool m_IsLocalMachineRequest; + bool m_AllowZeroCopyFileSend = true; + std::string m_RemoteAddress; std::unique_ptr<HttpResponse> m_Response; }; @@ -582,6 +622,8 @@ public: ~HttpResponse() = default; + void SetAllowZeroCopyFileSend(bool Allow) { m_AllowZeroCopyFileSend = Allow; } + /** * Initialize the response for sending a payload made up of multiple blobs * @@ -623,7 +665,7 @@ public: bool ChunkHandled = false; #if ZEN_USE_TRANSMITFILE || ZEN_USE_ASYNC_SENDFILE - if (OwnedBuffer.IsWholeFile()) + if (m_AllowZeroCopyFileSend && OwnedBuffer.IsWholeFile()) { if (IoBufferFileReference FileRef; OwnedBuffer.GetFileReference(/* out */ FileRef)) { @@ -738,7 +780,8 @@ public: return m_Headers; } - void SendResponse(asio::ip::tcp::socket& TcpSocket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) + template<typename SocketType> + void SendResponse(SocketType& Socket, std::function<void(const asio::error_code& Ec, std::size_t ByteCount)>&& Token) { ZEN_ASSERT(m_State == State::kInitialized); @@ -748,10 +791,11 @@ public: m_SendCb = std::move(Token); m_State = State::kSending; - SendNextChunk(TcpSocket); + SendNextChunk(Socket); } - void SendNextChunk(asio::ip::tcp::socket& TcpSocket) + template<typename SocketType> + void SendNextChunk(SocketType& Socket) { ZEN_ASSERT(m_State == State::kSending); @@ -768,12 +812,12 @@ public: auto CompletionToken = [Self = this, Token = std::move(m_SendCb), TotalBytes = m_TotalBytesSent] { Token({}, TotalBytes); }; - asio::defer(TcpSocket.get_executor(), std::move(CompletionToken)); + asio::defer(Socket.get_executor(), std::move(CompletionToken)); return; } - auto OnCompletion = [this, &TcpSocket](const asio::error_code& Ec, std::size_t ByteCount) { + auto OnCompletion = [this, &Socket](const asio::error_code& Ec, std::size_t ByteCount) { ZEN_ASSERT(m_State == State::kSending); m_TotalBytesSent += ByteCount; @@ -784,7 +828,7 @@ public: } else { - SendNextChunk(TcpSocket); + SendNextChunk(Socket); } }; @@ -798,25 +842,21 @@ public: Io.Ref.FileRef.FileChunkSize); #if ZEN_USE_TRANSMITFILE - TransmitFileAsync(TcpSocket, + TransmitFileAsync(Socket, Io.Ref.FileRef.FileHandle, Io.Ref.FileRef.FileChunkOffset, gsl::narrow_cast<uint32_t>(Io.Ref.FileRef.FileChunkSize), OnCompletion); + return; #elif ZEN_USE_ASYNC_SENDFILE - SendFileAsync(TcpSocket, + SendFileAsync(Socket, Io.Ref.FileRef.FileHandle, Io.Ref.FileRef.FileChunkOffset, Io.Ref.FileRef.FileChunkSize, 64 * 1024, OnCompletion); -#else - // This should never occur unless we compile with one - // of the options above - ZEN_WARN("invalid file reference in response"); -#endif - return; +#endif } // Send as many consecutive non-file references as possible in one asio operation @@ -837,7 +877,7 @@ public: ++m_IoVecCursor; } - asio::async_write(TcpSocket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion); + asio::async_write(Socket, std::move(AsioBuffers), asio::transfer_all(), OnCompletion); } private: @@ -850,12 +890,13 @@ private: kFailed }; - uint32_t m_RequestNumber = 0; - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - State m_State = State::kUninitialized; - HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; + uint32_t m_RequestNumber = 0; + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + bool m_AllowZeroCopyFileSend = true; + State m_State = State::kUninitialized; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; eastl::fixed_vector<IoBuffer, 8> m_DataBuffers; // This is here to keep the IoBuffer buffers/handles alive ExtendableStringBuilder<160> m_Headers; @@ -882,12 +923,13 @@ private: ////////////////////////////////////////////////////////////////////////// -struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection> +template<typename SocketType> +struct HttpServerConnectionT : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnectionT<SocketType>> { - HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); - ~HttpServerConnection(); + HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket); + ~HttpServerConnectionT(); - std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + std::shared_ptr<HttpServerConnectionT> AsSharedPtr() { return this->shared_from_this(); } // HttpConnectionBase implementation @@ -938,6 +980,7 @@ private: void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop); void CloseConnection(); + void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {}); HttpAsioServerImpl& m_Server; asio::streambuf m_RequestBuffer; @@ -948,12 +991,13 @@ private: RwLock m_ActiveResponsesLock; std::deque<std::unique_ptr<HttpResponse>> m_ActiveResponses; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<SocketType> m_Socket; }; std::atomic<uint32_t> g_ConnectionIdCounter{0}; -HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +template<typename SocketType> +HttpServerConnectionT<SocketType>::HttpServerConnectionT(HttpAsioServerImpl& Server, std::unique_ptr<SocketType>&& Socket) : m_Server(Server) , m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) , m_Socket(std::move(Socket)) @@ -961,21 +1005,24 @@ HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::uniq ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); } -HttpServerConnection::~HttpServerConnection() +template<typename SocketType> +HttpServerConnectionT<SocketType>::~HttpServerConnectionT() { RwLock::ExclusiveLockScope _(m_ActiveResponsesLock); ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); } +template<typename SocketType> void -HttpServerConnection::HandleNewRequest() +HttpServerConnectionT<SocketType>::HandleNewRequest() { EnqueueRead(); } +template<typename SocketType> void -HttpServerConnection::TerminateConnection() +HttpServerConnectionT<SocketType>::TerminateConnection() { if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kTerminated) { @@ -987,12 +1034,13 @@ HttpServerConnection::TerminateConnection() // Terminating, we don't care about any errors when closing socket std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); - m_Socket->close(Ec); + SocketTraits<SocketType>::ShutdownBoth(*m_Socket, Ec); + SocketTraits<SocketType>::Close(*m_Socket, Ec); } +template<typename SocketType> void -HttpServerConnection::EnqueueRead() +HttpServerConnectionT<SocketType>::EnqueueRead() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1013,8 +1061,9 @@ HttpServerConnection::EnqueueRead() [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); } +template<typename SocketType> void -HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +HttpServerConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1038,6 +1087,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } + m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed), @@ -1070,11 +1121,12 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } +template<typename SocketType> void -HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, - [[maybe_unused]] std::size_t ByteCount, - [[maybe_unused]] uint32_t RequestNumber, - HttpResponse* ResponseToPop) +HttpServerConnectionT<SocketType>::OnResponseDataSent(const asio::error_code& Ec, + [[maybe_unused]] std::size_t ByteCount, + [[maybe_unused]] uint32_t RequestNumber, + HttpResponse* ResponseToPop) { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1091,6 +1143,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, return; } + m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, RequestNumber, @@ -1126,8 +1180,9 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, } } +template<typename SocketType> void -HttpServerConnection::CloseConnection() +HttpServerConnectionT<SocketType>::CloseConnection() { ZEN_MEMSCOPE(GetHttpasioTag()); @@ -1139,29 +1194,113 @@ HttpServerConnection::CloseConnection() m_RequestState = RequestState::kDone; std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec); if (Ec) { ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); } - m_Socket->close(Ec); + SocketTraits<SocketType>::Close(*m_Socket, Ec); if (Ec) { ZEN_WARN("socket close ERROR, reason '{}'", Ec.message()); } } +template<typename SocketType> +void +HttpServerConnectionT<SocketType>::SendInlineResponse(uint32_t RequestNumber, + std::string_view StatusLine, + std::string_view Headers, + std::string_view Body) +{ + ExtendableStringBuilder<256> ResponseBuilder; + ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n"; + if (!Headers.empty()) + { + ResponseBuilder << Headers; + } + if (!m_RequestData.IsKeepAlive()) + { + ResponseBuilder << "Connection: close\r\n"; + } + ResponseBuilder << "\r\n"; + if (!Body.empty()) + { + ResponseBuilder << Body; + } + auto ResponseView = ResponseBuilder.ToView(); + IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size()); + auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize()); + asio::async_write( + *m_Socket, + Buffer, + [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); + }); +} + +template<typename SocketType> void -HttpServerConnection::HandleRequest() +HttpServerConnectionT<SocketType>::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); + // WebSocket upgrade detection must happen before the keep-alive check below, + // because Upgrade requests have "Connection: Upgrade" which the HTTP parser + // treats as non-keep-alive, causing a premature shutdown of the receive side. + if (m_RequestData.IsWebSocketUpgrade()) + { + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service); + if (WsHandler && !m_RequestData.SecWebSocketKey().empty()) + { + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey()); + + auto ResponseStr = std::make_shared<std::string>(); + ResponseStr->reserve(256); + ResponseStr->append( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "); + ResponseStr->append(AcceptKey); + ResponseStr->append("\r\n\r\n"); + + // Send the 101 response on the current socket, then hand the socket off + // to a WsAsioConnectionT for the WebSocket protocol. + asio::async_write( + *m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + using WsConnType = WsAsioConnectionT<SocketType>; + Ref<WsConnType> WsConn(new WsConnType(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); + + m_RequestState = RequestState::kDone; + return; + } + } + // Service doesn't support WebSocket or missing key — fall through to normal handling + } + if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + SocketTraits<SocketType>::ShutdownReceive(*m_Socket, Ec); if (Ec) { @@ -1179,16 +1318,45 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); - bool IsLocalConnection = m_Socket->local_endpoint().address() == m_Socket->remote_endpoint().address(); + m_Server.m_HttpServer->MarkRequest(); + + bool IsLocalConnection = true; + std::string RemoteAddress; + + if constexpr (std::is_same_v<SocketType, asio::ip::tcp::socket>) + { + auto RemoteEndpoint = m_Socket->remote_endpoint(); + IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + RemoteAddress = RemoteEndpoint.address().to_string(); + } +#if ZEN_USE_OPENSSL + else if constexpr (std::is_same_v<SocketType, SslSocket>) + { + auto RemoteEndpoint = m_Socket->lowest_layer().remote_endpoint(); + IsLocalConnection = m_Socket->lowest_layer().local_endpoint().address() == RemoteEndpoint.address(); + RemoteAddress = RemoteEndpoint.address().to_string(); + } +#endif + else + { + RemoteAddress = "unix"; + } + + HttpAsioServerRequest Request(m_RequestData, + *Service, + m_RequestData.Body(), + RequestNumber, + IsLocalConnection, + std::move(RemoteAddress)); - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber, IsLocalConnection); + Request.m_AllowZeroCopyFileSend = !SocketTraits<SocketType>::IsSslSocket; ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server.m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server.m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -1310,63 +1478,45 @@ HttpServerConnection::HandleRequest() } } - if (m_RequestData.RequestVerb() == HttpVerb::kHead) + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect(); + if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty())) { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + ExtendableStringBuilder<128> Headers; + Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n"; + SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView()); + } + else if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + SendInlineResponse(RequestNumber, "404 NOT FOUND"sv); } else { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + SendInlineResponse(RequestNumber, + "404 NOT FOUND"sv, + "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv, + "No suitable route found"sv); } } ////////////////////////////////////////////////////////////////////////// +// Base class for TCP acceptors that handles socket setup, port binding +// with probing/retry, and dual-stack (IPv6+IPv4 loopback) support. +// Subclasses only need to implement OnAccept() to handle new connections. -struct HttpAcceptor +struct TcpAcceptorBase { - HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + TcpAcceptorBase(HttpAsioServerImpl& Server, + asio::io_context& IoService, + uint16_t BasePort, + bool ForceLoopback, + bool AllowPortProbing, + std::string_view Label) : m_Server(Server) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) , m_AlternateProtocolAcceptor(m_IoService, asio::ip::tcp::v4()) + , m_Label(Label) { const bool IsUsingIPv6 = IsIPv6Capable(); if (!IsUsingIPv6) @@ -1375,93 +1525,66 @@ struct HttpAcceptor } #if ZEN_PLATFORM_WINDOWS - // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address; m_Acceptor.set_option(exclusive_address(true)); m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); #else // ZEN_PLATFORM_WINDOWS - m_Acceptor.set_option(asio::socket_base::reuse_address(false)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false)); + // Allow binding to a port in TIME_WAIT so the server can restart immediately + // after a previous instance exits. On Linux this does not allow two processes + // to actively listen on the same port simultaneously. + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true)); #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - m_AlternateProtocolAcceptor.set_option(asio::ip::tcp::no_delay(true)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - std::string BoundBaseUrl; if (IsUsingIPv6) { - BoundBaseUrl = BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing); + BindAcceptor<asio::ip::address_v6>(BasePort, ForceLoopback, AllowPortProbing); } else { - ZEN_INFO("NOTE: ipv6 support is disabled, binding to ipv4 only"); - - BoundBaseUrl = BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing); + ZEN_INFO("{}: ipv6 support is disabled, binding to ipv4 only", m_Label); + BindAcceptor<asio::ip::address_v4>(BasePort, ForceLoopback, AllowPortProbing); } + } - if (!IsValid()) - { - return; - } - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor.native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); - - if (m_UseAlternateProtocolAcceptor) - { - NativeSocket = m_AlternateProtocolAcceptor.native_handle(); - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); - } -#endif - m_Acceptor.listen(); + virtual ~TcpAcceptorBase() + { + m_Acceptor.close(); if (m_UseAlternateProtocolAcceptor) { - m_AlternateProtocolAcceptor.listen(); + m_AlternateProtocolAcceptor.close(); } - - ZEN_INFO("Started asio server at '{}", BoundBaseUrl); } - ~HttpAcceptor() + void Start() { - m_Acceptor.close(); + ZEN_ASSERT(!m_IsStopped); + InitAcceptLoop(m_Acceptor); if (m_UseAlternateProtocolAcceptor) { - m_AlternateProtocolAcceptor.close(); + InitAcceptLoop(m_AlternateProtocolAcceptor); } } + void StopAccepting() { m_IsStopped = true; } + + uint16_t GetPort() const { return m_Acceptor.local_endpoint().port(); } + bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } + bool IsValid() const { return m_IsValid; } + +protected: + /// Called for each accepted TCP socket. Subclasses create the appropriate connection type. + virtual void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) = 0; + + HttpAsioServerImpl& m_Server; + asio::io_context& m_IoService; + +private: template<typename AddressType> - std::string BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + void BindAcceptor(uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) { uint16_t EffectivePort = BasePort; @@ -1488,7 +1611,7 @@ struct HttpAcceptor if (BindErrorCode == asio::error::access_denied && !BindAddress.is_loopback()) { - ZEN_INFO("Access denied for public port {}, falling back to loopback", BasePort); + ZEN_INFO("{}: Access denied for public port {}, falling back to loopback", m_Label, BasePort); BindAddress = AddressType::loopback(); @@ -1502,7 +1625,7 @@ struct HttpAcceptor if (BindErrorCode == asio::error::address_in_use) { - ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message()); + ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message()); Sleep(500); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } @@ -1518,7 +1641,8 @@ struct HttpAcceptor if (BindErrorCode) { - ZEN_INFO("Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')", + ZEN_INFO("{}: Unable to bind to preferred port range, falling back to automatic assignment (bind returned '{}')", + m_Label, BindErrorCode.message()); EffectivePort = 0; @@ -1534,7 +1658,7 @@ struct HttpAcceptor { for (uint32_t Retries = 0; (BindErrorCode == asio::error::address_in_use) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (bind returned '{}'), retrying", EffectivePort, BindErrorCode.message()); + ZEN_INFO("{}: Desired port {} is in use (bind returned '{}'), retrying", m_Label, EffectivePort, BindErrorCode.message()); Sleep(500); m_Acceptor.bind(asio::ip::tcp::endpoint(BindAddress, EffectivePort), BindErrorCode); } @@ -1542,14 +1666,13 @@ struct HttpAcceptor if (BindErrorCode) { - ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message()); - - return 0; + ZEN_WARN("{}: Unable to bind on port {} (bind returned '{}')", m_Label, BasePort, BindErrorCode.message()); + return; } if (EffectivePort != BasePort) { - ZEN_WARN("Desired port {} is in use, remapped to port {}", BasePort, EffectivePort); + ZEN_WARN("{}: Desired port {} is in use, remapped to port {}", m_Label, BasePort, EffectivePort); } if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>) @@ -1559,54 +1682,64 @@ struct HttpAcceptor // IPv6 loopback will only respond on the IPv6 loopback address. Not everyone does // IPv6 though so we also bind to IPv4 loopback (localhost/127.0.0.1) - m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), BindErrorCode); + asio::error_code AltEc; + m_AlternateProtocolAcceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), EffectivePort), AltEc); - if (BindErrorCode) + if (AltEc) { - ZEN_WARN("Failed to register secondary IPv4 local-only handler 'http://{}:{}/'", "localhost", EffectivePort); + ZEN_WARN("{}: Failed to register secondary IPv4 local-only handler on port {}", m_Label, EffectivePort); } else { m_UseAlternateProtocolAcceptor = true; - ZEN_INFO("Registered local-only handler 'http://{}:{}/' - this is not accessible from remote hosts", - "localhost", - EffectivePort); } } } - m_IsValid = true; +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); - if constexpr (std::is_same_v<asio::ip::address_v6, AddressType>) - { - return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "[::1]" : "*", EffectivePort); - } - else + if (m_UseAlternateProtocolAcceptor) { - return fmt::format("http://{}:{}'", BindAddress.is_loopback() ? "127.0.0.1" : "*", EffectivePort); + NativeSocket = m_AlternateProtocolAcceptor.native_handle(); + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); } - } - - void Start() - { - ZEN_MEMSCOPE(GetHttpasioTag()); +#endif - ZEN_ASSERT(!m_IsStopped); - InitAcceptInternal(m_Acceptor); + m_Acceptor.listen(); if (m_UseAlternateProtocolAcceptor) { - InitAcceptInternal(m_AlternateProtocolAcceptor); + m_AlternateProtocolAcceptor.listen(); } - } - void StopAccepting() { m_IsStopped = true; } - - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } - - bool IsValid() const { return m_IsValid; } + m_IsValid = true; + ZEN_INFO("{}: Listening on port {}", m_Label, m_Acceptor.local_endpoint().port()); + } -private: - void InitAcceptInternal(asio::ip::tcp::acceptor& Acceptor) + void InitAcceptLoop(asio::ip::tcp::acceptor& Acceptor) { auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); @@ -1614,29 +1747,19 @@ private: Acceptor.async_accept(SocketRef, [this, &Acceptor, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { if (Ec) { - ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", - Acceptor.local_endpoint().address().to_string(), - Acceptor.local_endpoint().port(), - Ec.message()); + if (!m_IsStopped.load()) + { + ZEN_WARN("{}: async_accept failed: '{}'", m_Label, Ec.message()); + } } else { - // New connection established, pass socket ownership into connection object - // and initiate request handling loop. The connection lifetime is - // managed by the async read/write loop by passing the shared - // reference to the callbacks. - - Socket->set_option(asio::ip::tcp::no_delay(true)); - Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); - Conn->HandleNewRequest(); + OnAccept(std::move(Socket)); } if (!m_IsStopped.load()) { - InitAcceptInternal(Acceptor); + InitAcceptLoop(Acceptor); } else { @@ -1644,33 +1767,218 @@ private: Acceptor.close(CloseEc); if (CloseEc) { - ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + ZEN_WARN("{}: acceptor close error: '{}'", m_Label, CloseEc.message()); } } }); } - HttpAsioServerImpl& m_Server; - asio::io_service& m_IoService; asio::ip::tcp::acceptor m_Acceptor; asio::ip::tcp::acceptor m_AlternateProtocolAcceptor; bool m_UseAlternateProtocolAcceptor{false}; bool m_IsValid{false}; std::atomic<bool> m_IsStopped{false}; + std::string_view m_Label; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor final : TcpAcceptorBase +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_context& IoService, uint16_t BasePort, bool ForceLoopback, bool AllowPortProbing) + : TcpAcceptorBase(Server, IoService, BasePort, ForceLoopback, AllowPortProbing, "HTTP") + { + } + + int GetAcceptPort() const { return GetPort(); } + +protected: + void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override + { + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } +}; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + +////////////////////////////////////////////////////////////////////////// + +struct UnixAcceptor +{ + UnixAcceptor(HttpAsioServerImpl& Server, asio::io_context& IoService, const std::string& SocketPath) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService) + , m_SocketPath(SocketPath) + { + // Remove any stale socket file from a previous run + std::filesystem::remove(m_SocketPath); + + asio::local::stream_protocol::endpoint Endpoint(m_SocketPath); + + asio::error_code Ec; + m_Acceptor.open(Endpoint.protocol(), Ec); + if (Ec) + { + ZEN_WARN("failed to open unix domain socket: {}", Ec.message()); + return; + } + + m_Acceptor.bind(Endpoint, Ec); + if (Ec) + { + ZEN_WARN("failed to bind unix domain socket at '{}': {}", m_SocketPath, Ec.message()); + return; + } + + m_Acceptor.listen(asio::socket_base::max_listen_connections, Ec); + if (Ec) + { + ZEN_WARN("failed to listen on unix domain socket at '{}': {}", m_SocketPath, Ec.message()); + return; + } + + m_IsValid = true; + ZEN_INFO("Started unix domain socket listener at '{}'", m_SocketPath); + } + + ~UnixAcceptor() + { + asio::error_code Ec; + m_Acceptor.close(Ec); + std::filesystem::remove(m_SocketPath); + } + + void Start() + { + ZEN_ASSERT(!m_IsStopped); + InitAccept(); + } + + void StopAccepting() { m_IsStopped = true; } + + bool IsValid() const { return m_IsValid; } + +private: + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::local::stream_protocol::socket>(m_IoService); + asio::local::stream_protocol::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + if (!m_IsStopped.load()) + { + ZEN_WARN("unix domain socket async_accept failed: '{}'", Ec.message()); + } + } + else + { + auto Conn = std::make_shared<UnixServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + } + }); + } + + HttpAsioServerImpl& m_Server; + asio::io_context& m_IoService; + asio::local::stream_protocol::acceptor m_Acceptor; + std::string m_SocketPath; + bool m_IsValid{false}; + std::atomic<bool> m_IsStopped{false}; +}; + +#endif // ASIO_HAS_LOCAL_SOCKETS + +#if ZEN_USE_OPENSSL + +////////////////////////////////////////////////////////////////////////// + +struct HttpsAcceptor final : TcpAcceptorBase +{ + HttpsAcceptor(HttpAsioServerImpl& Server, + asio::io_context& IoService, + asio::ssl::context& SslContext, + uint16_t Port, + bool ForceLoopback, + bool AllowPortProbing) + : TcpAcceptorBase(Server, IoService, Port, ForceLoopback, AllowPortProbing, "HTTPS") + , m_SslContext(SslContext) + { + } + +protected: + void OnAccept(std::unique_ptr<asio::ip::tcp::socket> Socket) override + { + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + // Wrap accepted TCP socket in an SSL stream and perform the handshake + auto SslSocketPtr = std::make_unique<SslSocket>(std::move(*Socket), m_SslContext); + + SslSocket& SslRef = *SslSocketPtr; + SslRef.async_handshake(asio::ssl::stream_base::server, + [this, SslSocket = std::move(SslSocketPtr)](const asio::error_code& HandshakeEc) mutable { + if (HandshakeEc) + { + ZEN_WARN("SSL handshake failed: '{}'", HandshakeEc.message()); + std::error_code Ec; + SslSocket->lowest_layer().close(Ec); + return; + } + + auto Conn = std::make_shared<HttpsSslServerConnection>(m_Server, std::move(SslSocket)); + Conn->HandleNewRequest(); + }); + } + +private: + asio::ssl::context& m_SslContext; }; +#endif // ZEN_USE_OPENSSL + +int +HttpAsioServerImpl::GetEffectiveHttpsPort() const +{ +#if ZEN_USE_OPENSSL + return m_HttpsAcceptor ? m_HttpsAcceptor->GetPort() : 0; +#else + return 0; +#endif +} + ////////////////////////////////////////////////////////////////////////// HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber, - bool IsLocalMachineRequest) + bool IsLocalMachineRequest, + std::string RemoteAddress) : HttpServerRequest(Service) , m_Request(Request) , m_RequestNumber(RequestNumber) , m_PayloadBuffer(std::move(PayloadBuffer)) , m_IsLocalMachineRequest(IsLocalMachineRequest) +, m_RemoteAddress(std::move(RemoteAddress)) { const int PrefixLength = Service.UriPrefixLength(); @@ -1749,6 +2057,12 @@ HttpAsioServerRequest::IsLocalMachineRequest() const } std::string_view +HttpAsioServerRequest::GetRemoteAddress() const +{ + return m_RemoteAddress; +} + +std::string_view HttpAsioServerRequest::GetAuthorizationHeader() const { return m_Request.AuthorizationHeader(); @@ -1768,6 +2082,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(HttpContentType::kBinary, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); std::array<IoBuffer, 0> Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); @@ -1781,6 +2096,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } @@ -1791,6 +2107,7 @@ HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentT ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType, m_RequestNumber)); + m_Response->SetAllowZeroCopyFileSend(m_AllowZeroCopyFileSend); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); @@ -1840,15 +2157,63 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config) ZEN_INFO("starting asio http with {} service threads", Config.ThreadCount); - m_Acceptor.reset( - new asio_http::HttpAcceptor(*this, m_IoService, Port, Config.ForceLoopback, /*AllowPortProbing */ !Config.IsDedicatedServer)); + if (!Config.NoNetwork) + { + m_Acceptor.reset( + new asio_http::HttpAcceptor(*this, m_IoService, Port, Config.ForceLoopback, /*AllowPortProbing */ !Config.IsDedicatedServer)); + + if (!m_Acceptor->IsValid()) + { + return 0; + } - if (!m_Acceptor->IsValid()) + m_Acceptor->Start(); + } + +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (!Config.UnixSocketPath.empty()) { - return 0; + m_UnixAcceptor.reset(new asio_http::UnixAcceptor(*this, m_IoService, Config.UnixSocketPath)); + + if (m_UnixAcceptor->IsValid()) + { + m_UnixAcceptor->Start(); + } + else + { + m_UnixAcceptor.reset(); + } } +#endif + +#if ZEN_USE_OPENSSL + if (!Config.NoNetwork && !Config.CertFile.empty() && !Config.KeyFile.empty()) + { + m_SslContext = std::make_unique<asio::ssl::context>(asio::ssl::context::tlsv12_server); + m_SslContext->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 | + asio::ssl::context::no_tlsv1 | asio::ssl::context::no_tlsv1_1); + m_SslContext->use_certificate_chain_file(Config.CertFile); + m_SslContext->use_private_key_file(Config.KeyFile, asio::ssl::context::pem); - m_Acceptor->Start(); + ZEN_INFO("SSL context initialized (cert: '{}', key: '{}')", Config.CertFile, Config.KeyFile); + + m_HttpsAcceptor.reset(new asio_http::HttpsAcceptor(*this, + m_IoService, + *m_SslContext, + gsl::narrow<uint16_t>(Config.HttpsPort), + Config.ForceLoopback, + /*AllowPortProbing*/ !Config.IsDedicatedServer)); + + if (m_HttpsAcceptor->IsValid()) + { + m_HttpsAcceptor->Start(); + } + else + { + m_HttpsAcceptor.reset(); + } + } +#endif // This should consist of a set of minimum threads and grow on demand to // meet concurrency needs? Right now we end up allocating a large number @@ -1881,12 +2246,18 @@ HttpAsioServerImpl::Start(uint16_t Port, const AsioConfig& Config) }); } - ZEN_INFO("asio http started in {} mode, using {} threads on port {}", - Config.IsDedicatedServer ? "DEDICATED" : "NORMAL", - Config.ThreadCount, - m_Acceptor->GetAcceptPort()); + if (m_Acceptor) + { + ZEN_INFO("asio http started in {} mode, using {} threads on port {}", + Config.IsDedicatedServer ? "DEDICATED" : "NORMAL", + Config.ThreadCount, + m_Acceptor->GetAcceptPort()); - return m_Acceptor->GetAcceptPort(); + return m_Acceptor->GetAcceptPort(); + } + + ZEN_INFO("asio http started in no-network mode, using {} threads (unix socket only)", Config.ThreadCount); + return Port; } void @@ -1898,6 +2269,18 @@ HttpAsioServerImpl::Stop() { m_Acceptor->StopAccepting(); } +#if defined(ASIO_HAS_LOCAL_SOCKETS) + if (m_UnixAcceptor) + { + m_UnixAcceptor->StopAccepting(); + } +#endif +#if ZEN_USE_OPENSSL + if (m_HttpsAcceptor) + { + m_HttpsAcceptor->StopAccepting(); + } +#endif m_IoService.stop(); for (auto& Thread : m_ThreadPool) { @@ -1907,7 +2290,23 @@ HttpAsioServerImpl::Stop() } } m_ThreadPool.clear(); + + // Drain remaining handlers (e.g. cancellation callbacks from active WebSocket + // connections) so that their captured Ref<> pointers are released while the + // io_context and its epoll reactor are still alive. Without this, sockets + // held by external code (e.g. IWebSocketHandler connection lists) can outlive + // the reactor and crash during deregistration. + m_IoService.restart(); + m_IoService.poll(); + m_Acceptor.reset(); +#if defined(ASIO_HAS_LOCAL_SOCKETS) + m_UnixAcceptor.reset(); +#endif +#if ZEN_USE_OPENSSL + m_HttpsAcceptor.reset(); + m_SslContext.reset(); +#endif } void @@ -1975,6 +2374,12 @@ HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) return RequestFilter->FilterRequest(Request); } +bool +HttpAsioServerImpl::IsLoopbackOnly() const +{ + return m_Acceptor && m_Acceptor->IsLoopbackOnly(); +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1987,12 +2392,15 @@ public: HttpAsioServer(const AsioConfig& Config); ~HttpAsioServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; private: Event m_ShutdownEvent; @@ -2006,6 +2414,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config) : m_InitialConfig(Config) , m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) { + m_Impl->m_HttpServer = this; ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); } @@ -2064,9 +2473,51 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Config); +#if ZEN_USE_OPENSSL + if (int EffectiveHttpsPort = m_Impl->GetEffectiveHttpsPort(); EffectiveHttpsPort > 0) + { + SetEffectiveHttpsPort(EffectiveHttpsPort); + } +#endif + return m_BasePort; } +std::string +HttpAsioServer::OnGetExternalHost() const +{ + if (m_Impl->IsLoopbackOnly()) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_context IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::make_address("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpAsioServer::GetTotalBytesReceived() const +{ + return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpAsioServer::GetTotalBytesSent() const +{ + return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpAsioServer::OnRun(bool IsInteractive) { diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h index 3ec1141a7..21d10170e 100644 --- a/src/zenhttp/servers/httpasio.h +++ b/src/zenhttp/servers/httpasio.h @@ -11,6 +11,13 @@ struct AsioConfig unsigned int ThreadCount = 0; bool ForceLoopback = false; bool IsDedicatedServer = false; + bool NoNetwork = false; + std::string UnixSocketPath; +#if ZEN_USE_OPENSSL + int HttpsPort = 0; // 0 = auto-assign; set CertFile/KeyFile to enable HTTPS + std::string CertFile; // PEM certificate chain file (empty = HTTPS disabled) + std::string KeyFile; // PEM private key file +#endif }; Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config); diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 310ac9dc0..584e06cbf 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -117,6 +117,16 @@ HttpMultiServer::OnClose() } } +std::string +HttpMultiServer::OnGetExternalHost() const +{ + if (!m_Servers.empty()) + { + return std::string(m_Servers.front()->GetExternalHost()); + } + return HttpServer::OnGetExternalHost(); +} + void HttpMultiServer::AddServer(Ref<HttpServer> Server) { diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index 1897587a9..97699828a 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -15,12 +15,13 @@ public: HttpMultiServer(); ~HttpMultiServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; void AddServer(Ref<HttpServer> Server); diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index f0485aa25..918b55dc6 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -12,14 +12,17 @@ namespace zen { using namespace std::literals; -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); -static constinit uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv); +static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv); +static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv); ////////////////////////////////////////////////////////////////////////// // @@ -143,45 +146,62 @@ HttpRequestParser::ParseCurrentHeader() const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); - if (HeaderHash == HashContentLength) + switch (HeaderHash) { - m_ContentLengthHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAuthorization) - { - m_AuthorizationHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::TryFromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = CurrentHeaderIndex; + case HashContentLength: + m_ContentLengthHeaderIndex = CurrentHeaderIndex; + break; + + case HashAccept: + m_AcceptHeaderIndex = CurrentHeaderIndex; + break; + + case HashContentType: + m_ContentTypeHeaderIndex = CurrentHeaderIndex; + break; + + case HashAuthorization: + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + break; + + case HashSession: + m_SessionId = Oid::TryFromHexString(HeaderValue); + break; + + case HashRequest: + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + break; + + case HashExpect: + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + break; + + case HashRange: + m_RangeHeaderIndex = CurrentHeaderIndex; + break; + + case HashUpgrade: + m_UpgradeHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketKey: + m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketVersion: + m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex; + break; + + default: + break; } } @@ -225,13 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) NormalizedUrl.reserve(UrlLength); NormalizedUrl.append(Url, UrlIndex); } - - // NOTE: this check is redundant given the enclosing if, - // need to verify the intent of this code - if (!LastCharWasSeparator) - { - NormalizedUrl.push_back('/'); - } } else if (!NormalizedUrl.empty()) { @@ -361,14 +374,18 @@ HttpRequestParser::ResetState() m_HeaderEntries.clear(); - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_AuthorizationHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; + m_UpgradeHeaderIndex = -1; + m_SecWebSocketKeyHeaderIndex = -1; + m_SecWebSocketVersionHeaderIndex = -1; + m_RequestVerb = HttpVerb::kGet; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; m_HeaderData.clear(); m_NormalizedUrl.clear(); @@ -425,4 +442,21 @@ HttpRequestParser::OnMessageComplete() } } +bool +HttpRequestParser::IsWebSocketUpgrade() const +{ + std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex); + if (Upgrade.empty()) + { + return false; + } + + // Case-insensitive check for "websocket" + if (Upgrade.size() != 9) + { + return false; + } + return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; +} + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index ff56ca970..23ad9d8fb 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -48,6 +48,10 @@ struct HttpRequestParser std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); } + std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); } + bool IsWebSocketUpgrade() const; + private: struct HeaderRange { @@ -86,7 +90,10 @@ private: int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; int8_t m_AuthorizationHeaderIndex; - HttpVerb m_RequestVerb; + int8_t m_UpgradeHeaderIndex; + int8_t m_SecWebSocketKeyHeaderIndex; + int8_t m_SecWebSocketVersionHeaderIndex; + HttpVerb m_RequestVerb = HttpVerb::kGet; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; int m_RequestId = -1; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 8564826d6..a1bb719c8 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -123,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer bool m_IsRequestLoggingEnabled = false; LoggerRef m_RequestLog; std::atomic_uint32_t m_ConnectionIdCounter{0}; - int m_BasePort; + int m_BasePort = 0; HttpServerTracer m_RequestTracer; @@ -147,7 +147,7 @@ public: HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection - virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + bool IsLocalMachineRequest() const override { return false; } virtual std::string_view GetAuthorizationHeader() const override; virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; @@ -294,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug ConnectionName = "anonymous"; } - ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName); + ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName); } uint32_t @@ -378,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest() { ZEN_TRACE_CPU("http_plugin::HandleRequest"); + m_Server->MarkRequest(); + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server->m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -480,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest() const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), ResponseBuffers); diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 4406d0619..eaf080960 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -12,6 +12,7 @@ #include <zencore/memory/llm.h> #include <zencore/scopeguard.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/packageformat.h> @@ -25,7 +26,9 @@ # include <zencore/workthreadpool.h> # include "iothreadpool.h" +# include <atomic> # include <http.h> +# include <asio.hpp> // for resolving addresses for GetExternalHost namespace zen { @@ -85,6 +88,8 @@ class HttpSysServerRequest; class HttpSysServer : public HttpServer { friend class HttpSysTransaction; + friend class HttpMessageResponseRequest; + friend struct InitialRequestHandler; public: explicit HttpSysServer(const HttpSysConfig& Config); @@ -92,12 +97,15 @@ public: // HttpServer interface implementation - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool TestMode) override; - virtual void OnRequestExit() override; - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; - virtual void OnClose() override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool TestMode) override; + virtual void OnRequestExit() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; WorkerThreadPool& WorkPool(); @@ -108,6 +116,12 @@ public: private: int InitializeServer(int BasePort); + bool CreateSessionAndUrlGroup(); + bool RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris); + int RegisterHttpUrls(int BasePort); + bool RegisterHttpsUrls(); + bool CreateRequestQueue(int EffectivePort); + bool SetupIoCompletionPort(); void Cleanup(); void StartServer(); @@ -117,6 +131,9 @@ private: void RegisterService(const char* Endpoint, HttpService& Service); void UnregisterService(const char* Endpoint, HttpService& Service); + bool BindSslCertificate(int Port); + void UnbindSslCertificate(); + private: LoggerRef m_Log; LoggerRef m_RequestLog; @@ -130,10 +147,13 @@ private: std::unique_ptr<WinIoThreadPool> m_IoThreadPool; bool m_IoThreadPoolIsWinTp = true; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; + RwLock m_AsyncWorkPoolInitLock; + std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; - std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + std::vector<std::wstring> m_HttpsBaseUris; // eg: https://*:nnnn/ + bool m_DidAutoBindCert = false; + int m_HttpsPort = 0; HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; HANDLE m_RequestQueueHandle = 0; @@ -146,6 +166,9 @@ private: RwLock m_RequestFilterLock; std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; }; } // namespace zen @@ -153,6 +176,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include <conio.h> # include <mstcpip.h> # pragma comment(lib, "httpapi.lib") @@ -322,8 +349,9 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; - virtual bool IsLocalMachineRequest() const; + virtual bool IsLocalMachineRequest() const override; virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -339,11 +367,12 @@ public: HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -378,7 +407,7 @@ public: void StartIo(); void CancelIo(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -395,8 +424,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -436,6 +465,8 @@ public: inline uint16_t GetResponseCode() const { return m_ResponseCode; } inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } + private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes @@ -445,6 +476,7 @@ private: bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; + std::string m_LocationHeader; void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); }; @@ -585,7 +617,7 @@ HttpMessageResponseRequest::SuppressResponseBody() HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { - ZEN_UNUSED(NumberOfBytesTransferred); + Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); if (IoResult != NO_ERROR) { @@ -699,6 +731,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + // Location header (for redirects) + + if (!m_LocationHeader.empty()) + { + PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation]; + LocationHeader->pRawValue = m_LocationHeader.data(); + LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -900,7 +941,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}", + GetSystemErrorAsString(IoResult), + IoResult, + NumberOfBytesTransferred); return this; } @@ -1035,8 +1079,10 @@ HttpSysServer::~HttpSysServer() ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } - delete m_AsyncWorkPool; + auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed); m_AsyncWorkPool = nullptr; + + delete WorkPool; } void @@ -1051,36 +1097,63 @@ HttpSysServer::OnClose() } } -int -HttpSysServer::InitializeServer(int BasePort) +bool +HttpSysServer::CreateSessionAndUrlGroup() { - ZEN_MEMSCOPE(GetHttpsysTag()); - - using namespace std::literals; - - WideStringBuilder<64> WildcardUrlPath; - WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_IsOk = false; - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create server session: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create URL group: {} ({:#x})", GetSystemErrorAsString(Result), Result); - return 0; + return false; } + return true; +} + +bool +HttpSysServer::RegisterLocalUrls(std::u8string_view Scheme, int Port, std::vector<std::wstring>& OutUris) +{ + using namespace std::literals; + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrl; + LocalUrl << Scheme << u8"://"sv << Host << u8":"sv << int64_t(Port) << u8"/"sv; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrl.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (Result == NO_ERROR) + { + ZEN_WARN("Registered local-only handler '{}' - this is not accessible from remote hosts", WideToUtf8(LocalUrl)); + OutUris.push_back(LocalUrl.c_str()); + } + else + { + break; + } + } + + return !OutUris.empty(); +} + +int +HttpSysServer::RegisterHttpUrls(int BasePort) +{ + using namespace std::literals; + m_BaseUris.clear(); const bool AllowPortProbing = !m_InitialConfig.IsDedicatedServer; @@ -1088,6 +1161,11 @@ HttpSysServer::InitializeServer(int BasePort) int EffectivePort = BasePort; + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + ULONG Result; + if (m_InitialConfig.ForceLoopback) { // Force trigger of opening using local port @@ -1100,7 +1178,9 @@ HttpSysServer::InitializeServer(int BasePort) if ((Result == ERROR_SHARING_VIOLATION)) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); @@ -1122,7 +1202,9 @@ HttpSysServer::InitializeServer(int BasePort) { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } @@ -1139,11 +1221,11 @@ HttpSysServer::InitializeServer(int BasePort) { if (AllowLocalOnly) { - // If we can't register the wildcard path, we fall back to local paths - // This local paths allow requests originating locally to function, but will not allow - // remote origin requests to function. This can be remedied by using netsh + // If we can't register the wildcard path, we fall back to local paths. + // Local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh // during an install process to grant permissions to route public access to the appropriate - // port for the current user. eg: + // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> if (!m_InitialConfig.ForceLoopback) @@ -1157,17 +1239,18 @@ HttpSysServer::InitializeServer(int BasePort) const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + bool ShouldRetryNextPort = true; + for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset) { - EffectivePort = BasePort + (PortOffset * 100); + EffectivePort = BasePort + (PortOffset * 100); + ShouldRetryNextPort = false; for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { @@ -1175,11 +1258,25 @@ HttpSysServer::InitializeServer(int BasePort) m_BaseUris.push_back(LocalUrlPath.c_str()); } + else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) + { + // Port may be owned by another process's wildcard registration (access denied) + // or actively in use (sharing violation) — retry on a different port + ShouldRetryNextPort = true; + } else { - break; + ZEN_WARN("Failed to register local handler '{}': {} ({:#x})", + WideToUtf8(LocalUrlPath), + GetSystemErrorAsString(InternalResult), + InternalResult); } } + + if (!m_BaseUris.empty()) + { + break; + } } } else @@ -1193,29 +1290,123 @@ HttpSysServer::InitializeServer(int BasePort) } } - if (m_BaseUris.empty()) + if (m_BaseUris.empty() && m_InitialConfig.HttpsPort == 0) { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } + return EffectivePort; +} + +bool +HttpSysServer::RegisterHttpsUrls() +{ + using namespace std::literals; + + const bool AllowLocalOnly = !m_InitialConfig.IsDedicatedServer; + const int HttpsPort = m_InitialConfig.HttpsPort; + + // If HTTPS-only mode, remove HTTP URLs and clear base URIs + if (m_InitialConfig.HttpsOnly) + { + for (const std::wstring& Uri : m_BaseUris) + { + HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Uri.c_str(), 0); + } + m_BaseUris.clear(); + } + + // Auto-bind certificate if thumbprint is provided + if (!m_InitialConfig.CertThumbprint.empty()) + { + if (!BindSslCertificate(HttpsPort)) + { + return false; + } + } + else + { + ZEN_INFO("HTTPS port {} configured without thumbprint - assuming pre-registered SSL certificate", HttpsPort); + } + + // Register HTTPS URLs using same pattern as HTTP + + WideStringBuilder<64> HttpsWildcard; + HttpsWildcard << u8"https://*:"sv << int64_t(HttpsPort) << u8"/"sv; + + ULONG HttpsResult = NO_ERROR; + + if (m_InitialConfig.ForceLoopback) + { + HttpsResult = ERROR_ACCESS_DENIED; + } + else + { + HttpsResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, HttpsWildcard.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + if (HttpsResult == NO_ERROR) + { + m_HttpsBaseUris.push_back(HttpsWildcard.c_str()); + } + else if (HttpsResult == ERROR_ACCESS_DENIED && AllowLocalOnly) + { + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register HTTPS handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation and SSL certificate configuration is made.", + WideToUtf8(HttpsWildcard)); + } + + RegisterLocalUrls(u8"https", HttpsPort, m_HttpsBaseUris); + } + else if (HttpsResult != NO_ERROR) + { + ZEN_ERROR("Failed to register HTTPS URL '{}': {} ({:#x})", + WideToUtf8(HttpsWildcard), + GetSystemErrorAsString(HttpsResult), + HttpsResult); + return false; + } + + if (m_HttpsBaseUris.empty()) + { + ZEN_ERROR("Failed to register any HTTPS URL for port {}", HttpsPort); + return false; + } + + m_HttpsPort = HttpsPort; + return true; +} + +bool +HttpSysServer::CreateRequestQueue(int EffectivePort) +{ HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; WideStringBuilder<64> QueueName; QueueName << "zenserver_" << EffectivePort; - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, - /* Name */ QueueName.c_str(), - /* SecurityAttributes */ nullptr, - /* Flags */ 0, - &m_RequestQueueHandle); + ULONG Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ QueueName.c_str(), + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); - return 0; + return false; } HttpBindingInfo.Flags.Present = 1; @@ -1225,9 +1416,12 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); - return 0; + return false; } // Configure rejection method. Default is to drop the connection, it's better if we @@ -1257,42 +1451,82 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result); } } - // Create I/O completion port + return true; +} +bool +HttpSysServer::SetupIoCompletionPort() +{ std::error_code ErrorCode; m_IoThreadPool->CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); if (ErrorCode) { - ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + ZEN_ERROR("Failed to create IOCP: {}", ErrorCode.message()); + return false; + } + + m_IsOk = true; + + if (!m_BaseUris.empty()) + { + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + if (!m_HttpsBaseUris.empty()) + { + ZEN_INFO("Started http.sys HTTPS server at '{}'", WideToUtf8(m_HttpsBaseUris.front())); + } + + return true; +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + ZEN_MEMSCOPE(GetHttpsysTag()); + + m_IsOk = false; + if (!CreateSessionAndUrlGroup()) + { return 0; } - else + + int EffectivePort = RegisterHttpUrls(BasePort); + + if (m_InitialConfig.HttpsPort > 0) { - m_IsOk = true; + if (!RegisterHttpsUrls()) + { + return 0; + } + } - ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + if (m_BaseUris.empty() && m_HttpsBaseUris.empty()) + { + ZEN_ERROR("No HTTP or HTTPS listeners could be registered"); + return 0; } - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build + if (!CreateRequestQueue(EffectivePort)) + { + return 0; + } -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) + if (!SetupIoCompletionPort()) { - ZEN_DEBUG("HTTP3 is available"); + return 0; } - else + + // When HTTPS-only, return the HTTPS port as the effective port + if (m_InitialConfig.HttpsOnly && m_HttpsPort > 0) { - ZEN_DEBUG("HTTP3 is NOT available"); + return m_HttpsPort; } -# endif return EffectivePort; } @@ -1302,6 +1536,8 @@ HttpSysServer::Cleanup() { ++m_IsShuttingDown; + UnbindSslCertificate(); + if (m_RequestQueueHandle) { HttpCloseRequestQueue(m_RequestQueueHandle); @@ -1321,23 +1557,122 @@ HttpSysServer::Cleanup() } } +// {7E3F4B2A-1C8D-4A6E-B5F0-9D2E8C7A3B1F} - Fixed GUID for zenserver SSL bindings +static constexpr GUID ZenServerSslAppId = {0x7E3F4B2A, 0x1C8D, 0x4A6E, {0xB5, 0xF0, 0x9D, 0x2E, 0x8C, 0x7A, 0x3B, 0x1F}}; + +bool +HttpSysServer::BindSslCertificate(int Port) +{ + const std::string& Thumbprint = m_InitialConfig.CertThumbprint; + if (Thumbprint.size() != 40) + { + ZEN_ERROR("SSL certificate thumbprint must be exactly 40 hex characters, got {}", Thumbprint.size()); + return false; + } + + BYTE CertHash[20] = {}; + if (!ParseHexBytes(Thumbprint, CertHash)) + { + ZEN_ERROR("SSL certificate thumbprint contains invalid hex characters"); + return false; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(Port)); + Address.sin_addr.s_addr = INADDR_ANY; + + const std::wstring StoreNameW = UTF8_to_UTF16(m_InitialConfig.CertStoreName.c_str()); + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + SslConfig.ParamDesc.pSslHash = CertHash; + SslConfig.ParamDesc.SslHashLength = sizeof(CertHash); + SslConfig.ParamDesc.pSslCertStoreName = const_cast<PWSTR>(StoreNameW.c_str()); + SslConfig.ParamDesc.AppId = ZenServerSslAppId; + + ULONG Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result == ERROR_ALREADY_EXISTS) + { + // Remove existing binding and retry + HTTP_SERVICE_CONFIG_SSL_SET DeleteConfig = {}; + DeleteConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &DeleteConfig, sizeof(DeleteConfig), nullptr); + + Result = HttpSetServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + } + + if (Result != NO_ERROR) + { + ZEN_ERROR( + "Failed to bind SSL certificate to port {}: {} ({:#x}). " + "This operation may require running as administrator.", + Port, + GetSystemErrorAsString(Result), + Result); + return false; + } + + m_DidAutoBindCert = true; + m_HttpsPort = Port; + + ZEN_INFO("SSL certificate auto-bound for 0.0.0.0:{} (thumbprint: {}..., store: {})", + Port, + Thumbprint.substr(0, 8), + m_InitialConfig.CertStoreName); + + return true; +} + +void +HttpSysServer::UnbindSslCertificate() +{ + if (!m_DidAutoBindCert) + { + return; + } + + SOCKADDR_IN Address = {}; + Address.sin_family = AF_INET; + Address.sin_port = htons(static_cast<USHORT>(m_HttpsPort)); + Address.sin_addr.s_addr = INADDR_ANY; + + HTTP_SERVICE_CONFIG_SSL_SET SslConfig = {}; + SslConfig.KeyDesc.pIpPort = reinterpret_cast<SOCKADDR*>(&Address); + + ULONG Result = HttpDeleteServiceConfiguration(0, HttpServiceConfigSSLCertInfo, &SslConfig, sizeof(SslConfig), nullptr); + + if (Result != NO_ERROR) + { + ZEN_WARN("Failed to remove SSL certificate binding from port {}: {} ({:#x})", m_HttpsPort, GetSystemErrorAsString(Result), Result); + } + else + { + ZEN_INFO("SSL certificate binding removed from port {}", m_HttpsPort); + } + + m_DidAutoBindCert = false; +} + WorkerThreadPool& HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_acquire)) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_relaxed)) { m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async", m_InitialConfig.UseExplicitIoThreadPool); } } - return *m_AsyncWorkPool; + return *m_AsyncWorkPool.load(std::memory_order_relaxed); } void @@ -1449,19 +1784,23 @@ HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) // Convert to wide string - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) + auto RegisterWithBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) { - ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + std::wstring Url16 = BaseUri + PathUtf16; - return; + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + return; + } } - } + }; + + RegisterWithBaseUris(m_BaseUris); + RegisterWithBaseUris(m_HttpsBaseUris); } void @@ -1476,19 +1815,22 @@ HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; + auto UnregisterFromBaseUris = [&](const std::vector<std::wstring>& BaseUris) { + for (const std::wstring& BaseUri : BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } } - } + }; + + UnregisterFromBaseUris(m_BaseUris); + UnregisterFromBaseUris(m_HttpsBaseUris); } ////////////////////////////////////////////////////////////////////////// @@ -1551,7 +1893,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); // Assign names to threads for context (only needed when using Windows' native // thread pool) @@ -1675,6 +2033,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + m_HttpServer.MarkRequest(); + // Default request handling # if ZEN_WITH_OTEL @@ -1884,6 +2244,17 @@ HttpSysServerRequest::IsLocalMachineRequest() const } std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + +std::string_view HttpSysServerRequest::GetAuthorizationHeader() const { const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); @@ -2111,6 +2482,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT break; } + Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); + ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request @@ -2119,64 +2492,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders; + + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Transaction().Server().OnWebSocketConnectionOpened(); + Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle, + RequestId, + *WsHandler, + Transaction().Iocp(), + &Transaction().Server())); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } + + ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); - ZEN_INFO(""); + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); @@ -2242,6 +2673,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } + else + { + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect(); + std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength); + if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty())) + { + auto* Response = new HttpMessageResponseRequest(Transaction(), 302); + Response->SetLocationHeader(DefaultRedirect); + return Response; + } + } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); @@ -2285,6 +2728,11 @@ HttpSysServer::OnInitialize(int BasePort, std::filesystem::path DataDir) ZEN_UNUSED(DataDir); if (int EffectivePort = InitializeServer(BasePort)) { + if (m_HttpsPort > 0) + { + SetEffectiveHttpsPort(m_HttpsPort); + } + StartServer(); return EffectivePort; @@ -2301,6 +2749,52 @@ HttpSysServer::OnRequestExit() m_ShutdownEvent.Set(); } +std::string +HttpSysServer::OnGetExternalHost() const +{ + // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback + bool IsPublic = false; + for (const auto& Uri : m_BaseUris) + { + if (Uri.find(L'*') != std::wstring::npos) + { + IsPublic = true; + break; + } + } + + if (!IsPublic) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_context IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::make_address("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpSysServer::GetTotalBytesReceived() const +{ + return m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpSysServer::GetTotalBytesSent() const +{ + return m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpSysServer::OnRegisterService(HttpService& Service) { diff --git a/src/zenhttp/servers/httpsys.h b/src/zenhttp/servers/httpsys.h index 4ff6df1fa..0685b42b2 100644 --- a/src/zenhttp/servers/httpsys.h +++ b/src/zenhttp/servers/httpsys.h @@ -23,6 +23,10 @@ struct HttpSysConfig bool IsDedicatedServer = false; bool ForceLoopback = false; bool UseExplicitIoThreadPool = false; + int HttpsPort = 0; // 0 = HTTPS disabled + std::string CertThumbprint; // Hex SHA-1 (40 chars) for auto SSL binding + std::string CertStoreName = "MY"; // Windows certificate store name + bool HttpsOnly = false; // When true, disable HTTP listener }; Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config); diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h new file mode 100644 index 000000000..4f8a97012 --- /dev/null +++ b/src/zenhttp/servers/httpsys_iocontext.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> + +# include <cstdint> + +namespace zen { + +/** + * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch + * + * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection + * (for WebSocket read/write) embed this struct. The single IoCompletionCallback + * bound to the request queue uses the ContextType tag to dispatch to the correct + * handler. + * + * The Overlapped member must be first so that CONTAINING_RECORD works to recover + * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool. + */ +struct HttpSysIoContext +{ + OVERLAPPED Overlapped{}; + + enum class Type : uint8_t + { + kTransaction, + kWebSocketRead, + kWebSocketWrite, + } ContextType = Type::kTransaction; + + void* Owner = nullptr; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h index da72c79c9..a9a45f162 100644 --- a/src/zenhttp/servers/httptracer.h +++ b/src/zenhttp/servers/httptracer.h @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenhttp/httpserver.h> - #pragma once +#include <zenhttp/httpserver.h> + namespace zen { /** Helper class for HTTP server implementations diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp new file mode 100644 index 000000000..5ae48f5b3 --- /dev/null +++ b/src/zenhttp/servers/wsasio.cpp @@ -0,0 +1,339 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsasio.h" +#include "asio_socket_traits.h" +#include "wsframecodec.h" + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +namespace zen::asio_http { + +static LoggerRef +WsLog() +{ + static LoggerRef g_Logger = logging::Get("ws"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +template<typename SocketType> +WsAsioConnectionT<SocketType>::WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server) +: m_Socket(std::move(Socket)) +, m_Handler(Handler) +, m_HttpServer(Server) +{ +} + +template<typename SocketType> +WsAsioConnectionT<SocketType>::~WsAsioConnectionT() +{ + m_IsOpen.store(false); + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::Start() +{ + EnqueueRead(); +} + +template<typename SocketType> +bool +WsAsioConnectionT<SocketType>::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Read loop +// + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::EnqueueRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + Ref<WsAsioConnectionT> Self(this); + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { + Self->OnDataReceived(Ec, ByteCount); + }); +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::ProcessReceivedData() +{ + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* Data = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Shut down the socket + std::error_code ShutdownEc; + SocketTraits<SocketType>::ShutdownBoth(*m_Socket, ShutdownEc); + SocketTraits<SocketType>::Close(*m_Socket, ShutdownEc); + return; + } + + default: + ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Write queue +// + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::SendBinary(std::span<const uint8_t> Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::EnqueueWrite(std::vector<uint8_t> Frame) +{ + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::FlushWriteQueue() +{ + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + Ref<WsAsioConnectionT> Self(this); + + // Move Frame into a shared_ptr so we can create the buffer and capture ownership + // in the same async_write call without evaluation order issues. + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); +} + +template<typename SocketType> +void +WsAsioConnectionT<SocketType>::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + return; + } + + FlushWriteQueue(); +} + +////////////////////////////////////////////////////////////////////////// +// Explicit template instantiations + +template class WsAsioConnectionT<asio::ip::tcp::socket>; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +template class WsAsioConnectionT<asio::local::stream_protocol::socket>; +#endif + +#if ZEN_USE_OPENSSL +template class WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>; +#endif + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h new file mode 100644 index 000000000..64602ee46 --- /dev/null +++ b/src/zenhttp/servers/wsasio.h @@ -0,0 +1,94 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include <zencore/thread.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +#if defined(ASIO_HAS_LOCAL_SOCKETS) +# include <asio/local/stream_protocol.hpp> +#endif +#if ZEN_USE_OPENSSL +# include <asio/ssl.hpp> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <memory> +#include <vector> + +namespace zen { +class HttpServer; +} // namespace zen + +namespace zen::asio_http { + +/** + * WebSocket connection over an ASIO stream socket + * + * Templated on SocketType to support both TCP and Unix domain sockets. + * Owns the socket (moved from HttpServerConnection after the 101 handshake) + * and runs an async read/write loop to exchange WebSocket frames. + * + * Lifetime is managed solely through intrusive reference counting (RefCounted). + * The async read/write callbacks capture Ref<> to keep the connection alive + * for the duration of the async operation. The service layer also holds a + * Ref<WebSocketConnection>. + */ +template<typename SocketType> +class WsAsioConnectionT : public WebSocketConnection +{ +public: + WsAsioConnectionT(std::unique_ptr<SocketType> Socket, IWebSocketHandler& Handler, HttpServer* Server); + ~WsAsioConnectionT() override; + + /** + * Start the async read loop. Must be called once after construction + * and the 101 response has been sent. + */ + void Start(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span<const uint8_t> Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void ProcessReceivedData(); + + void EnqueueWrite(std::vector<uint8_t> Frame); + void FlushWriteQueue(); + void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount); + + void DoClose(uint16_t Code, std::string_view Reason); + + std::unique_ptr<SocketType> m_Socket; + IWebSocketHandler& m_Handler; + zen::HttpServer* m_HttpServer; + asio::streambuf m_ReadBuffer; + + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{true}; + std::atomic<bool> m_CloseSent{false}; +}; + +using WsAsioConnection = WsAsioConnectionT<asio::ip::tcp::socket>; + +#if defined(ASIO_HAS_LOCAL_SOCKETS) +using WsAsioUnixConnection = WsAsioConnectionT<asio::local::stream_protocol::socket>; +#endif + +#if ZEN_USE_OPENSSL +using WsAsioSslConnection = WsAsioConnectionT<asio::ssl::stream<asio::ip::tcp::socket>>; +#endif + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp new file mode 100644 index 000000000..e452141fe --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -0,0 +1,236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/sha1.h> + +#include <cstring> +#include <random> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +WsFrameParseResult +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +{ + // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) + if (Size < 2) + { + return {}; + } + + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + uint64_t PayloadLen = Data[1] & 0x7F; + + size_t HeaderSize = 2; + + if (PayloadLen == 126) + { + if (Size < 4) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); + HeaderSize = 4; + } + else if (PayloadLen == 127) + { + if (Size < 10) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | + (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); + HeaderSize = 10; + } + + // Reject frames with unreasonable payload sizes to prevent OOM + static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB + if (PayloadLen > kMaxPayloadSize) + { + return {}; + } + + const size_t MaskSize = Masked ? 4 : 0; + const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; + + if (Size < TotalFrame) + { + return {}; + } + + const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; + const uint8_t* PayloadData = Data + HeaderSize + MaskSize; + + WsFrameParseResult Result; + Result.IsValid = true; + Result.BytesConsumed = TotalFrame; + Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw); + Result.Fin = Fin; + + Result.Payload.resize(static_cast<size_t>(PayloadLen)); + if (PayloadLen > 0) + { + std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen)); + + if (Masked) + { + for (size_t i = 0; i < Result.Payload.size(); ++i) + { + Result.Payload[i] ^= MaskKey[i & 3]; + } + } + } + + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (server-to-client, no masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length (no mask bit for server frames) + if (PayloadLen < 126) + { + Frame.push_back(static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + Frame.insert(Frame.end(), Payload.begin(), Payload.end()); + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (client-to-server, with masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (PayloadLen < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + // Generate random 4-byte mask key + static thread_local std::mt19937 s_Rng(std::random_device{}()); + uint32_t MaskValue = s_Rng(); + uint8_t MaskKey[4]; + std::memcpy(MaskKey, &MaskValue, 4); + + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < PayloadLen; ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) +// + +static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +std::string +WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) +{ + // Concatenate client key with the magic GUID + std::string Combined; + Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); + Combined.append(ClientKey); + Combined.append(kWebSocketMagicGuid); + + // SHA1 hash + SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); + + // Base64 encode the 20-byte hash + char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; + uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); + Base64Buf[EncodedLen] = '\0'; + + return std::string(Base64Buf, EncodedLen); +} + +} // namespace zen diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h new file mode 100644 index 000000000..2d90b6fa1 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include <cstddef> +#include <cstdint> +#include <optional> +#include <span> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +/** + * Result of attempting to parse a single WebSocket frame from a byte buffer + */ +struct WsFrameParseResult +{ + bool IsValid = false; // true if a complete frame was successfully parsed + size_t BytesConsumed = 0; // number of bytes consumed from the input buffer + WebSocketOpcode Opcode = WebSocketOpcode::kText; + bool Fin = false; + std::vector<uint8_t> Payload; +}; + +/** + * RFC 6455 WebSocket frame codec + * + * Provides static helpers for parsing client-to-server frames (which are + * always masked) and building server-to-client frames (which are never masked). + */ +struct WsFrameCodec +{ + /** + * Try to parse one complete frame from the front of the buffer. + * + * Returns a result with IsValid == false and BytesConsumed == 0 when + * there is not enough data yet. The caller should accumulate more data + * and retry. + */ + static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size); + + /** + * Build a server-to-client frame (no masking) + */ + static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload); + + /** + * Build a close frame with a status code and optional reason string + */ + static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Build a client-to-server frame (with masking per RFC 6455) + */ + static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload); + + /** + * Build a masked close frame with status code and optional reason + */ + static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2 + * + * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + */ + static std::string ComputeAcceptKey(std::string_view ClientKey); +}; + +} // namespace zen diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp new file mode 100644 index 000000000..af320172d --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -0,0 +1,485 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wshttpsys.h" + +#if ZEN_WITH_HTTPSYS + +# include "wsframecodec.h" + +# include <zencore/logging.h> +# include <zenhttp/httpserver.h> + +namespace zen { + +static LoggerRef +WsHttpSysLog() +{ + static LoggerRef g_Logger = logging::Get("ws_httpsys"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, + HTTP_REQUEST_ID RequestId, + IWebSocketHandler& Handler, + PTP_IO Iocp, + HttpServer* Server) +: m_RequestQueueHandle(RequestQueueHandle) +, m_RequestId(RequestId) +, m_Handler(Handler) +, m_Iocp(Iocp) +, m_HttpServer(Server) +, m_ReadBuffer(8192) +{ + m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; + m_ReadIoContext.Owner = this; + m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite; + m_WriteIoContext.Owner = this; +} + +WsHttpSysConnection::~WsHttpSysConnection() +{ + ZEN_ASSERT(m_OutstandingOps.load() == 0); + + if (m_IsOpen.exchange(false)) + { + Disconnect(); + } + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } +} + +void +WsHttpSysConnection::Start() +{ + m_SelfRef = Ref<WsHttpSysConnection>(this); + IssueAsyncRead(); +} + +void +WsHttpSysConnection::Shutdown() +{ + m_ShutdownRequested.store(true, std::memory_order_relaxed); + + if (!m_IsOpen.exchange(false)) + { + return; + } + + // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +bool +WsHttpSysConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Async read path +// + +void +WsHttpSysConnection::IssueAsyncRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed)) + { + MaybeReleaseSelfRef(); + return; + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle, + m_RequestId, + 0, // Flags + m_ReadBuffer.data(), + (ULONG)m_ReadBuffer.size(), + nullptr, // BytesRead (ignored for async) + &m_ReadIoContext.Overlapped); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "read issue failed"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef + Ref<WsHttpSysConnection> Guard(this); + + if (IoResult != NO_ERROR) + { + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + if (IoResult == ERROR_HANDLE_EOF) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection closed"); + } + else if (IoResult != ERROR_OPERATION_ABORTED) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + } + + MaybeReleaseSelfRef(); + return; + } + + if (NumberOfBytesTransferred > 0) + { + m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred); + ProcessReceivedData(); + } + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + IssueAsyncRead(); + } + else + { + MaybeReleaseSelfRef(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +void +WsHttpSysConnection::ProcessReceivedData() +{ + while (!m_Accumulated.empty()) + { + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + // Remove consumed bytes + m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent.exchange(true)) + { + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + Disconnect(); + return; + } + + default: + ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Async write path +// + +void +WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame) +{ + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + + bool ShouldFlush = false; + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.push_back(std::move(Frame)); + + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + } + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsHttpSysConnection::FlushWriteQueue() +{ + { + RwLock::ExclusiveLockScope _(m_WriteLock); + + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + + m_CurrentWriteBuffer = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk)); + m_WriteChunk.DataChunkType = HttpDataChunkFromMemory; + m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data(); + m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size(); + + ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA, + 1, + &m_WriteChunk, + nullptr, + nullptr, + 0, + &m_WriteIoContext.Overlapped, + nullptr); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + m_CurrentWriteBuffer.clear(); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + // Hold a transient ref to prevent mid-callback destruction + Ref<WsHttpSysConnection> Guard(this); + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + m_CurrentWriteBuffer.clear(); + + if (IoResult != NO_ERROR) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + return; + } + + FlushWriteQueue(); +} + +////////////////////////////////////////////////////////////////////////// +// +// Send interface +// + +void +WsHttpSysConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent.exchange(true)) + { + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +////////////////////////////////////////////////////////////////////////// +// +// Lifetime management +// + +void +WsHttpSysConnection::MaybeReleaseSelfRef() +{ + if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed)) + { + m_SelfRef = nullptr; + } +} + +void +WsHttpSysConnection::Disconnect() +{ + // Send final empty body with DISCONNECT to tell http.sys the connection is done + HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT, + 0, + nullptr, + nullptr, + nullptr, + 0, + nullptr, + nullptr); +} + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h new file mode 100644 index 000000000..6015e3873 --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include "httpsys_iocontext.h" + +#include <zencore/thread.h> + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <http.h> + +# include <atomic> +# include <deque> +# include <vector> + +namespace zen { + +class HttpServer; + +/** + * WebSocket connection over an http.sys opaque-mode connection + * + * After the 101 Switching Protocols response is sent with + * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the + * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody / + * HttpSendResponseEntityBody using the original RequestId. + * + * All I/O is performed asynchronously via the same IOCP threadpool used + * for normal http.sys traffic, eliminating per-connection threads. + * + * Lifetime is managed through intrusive reference counting (RefCounted). + * A self-reference (m_SelfRef) is held from Start() until all outstanding + * async operations have drained, preventing premature destruction. + */ +class WsHttpSysConnection : public WebSocketConnection +{ +public: + WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server); + ~WsHttpSysConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and after the 101 response has been sent. + */ + void Start(); + + /** + * Shut down the connection. Cancels pending I/O; IOCP completions + * will fire with ERROR_OPERATION_ABORTED and drain naturally. + */ + void Shutdown(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span<const uint8_t> Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + + // Called from IoCompletionCallback via tagged dispatch + void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + +private: + void IssueAsyncRead(); + void ProcessReceivedData(); + void EnqueueWrite(std::vector<uint8_t> Frame); + void FlushWriteQueue(); + void DoClose(uint16_t Code, std::string_view Reason); + void Disconnect(); + void MaybeReleaseSelfRef(); + + HANDLE m_RequestQueueHandle; + HTTP_REQUEST_ID m_RequestId; + IWebSocketHandler& m_Handler; + PTP_IO m_Iocp; + HttpServer* m_HttpServer; + + // Tagged OVERLAPPED contexts for concurrent read and write + HttpSysIoContext m_ReadIoContext{}; + HttpSysIoContext m_WriteIoContext{}; + + // Read state + std::vector<uint8_t> m_ReadBuffer; + std::vector<uint8_t> m_Accumulated; + + // Write state + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + std::vector<uint8_t> m_CurrentWriteBuffer; + HTTP_DATA_CHUNK m_WriteChunk{}; + bool m_IsWriting = false; + + // Lifetime management + std::atomic<int32_t> m_OutstandingOps{0}; + Ref<WsHttpSysConnection> m_SelfRef; + std::atomic<bool> m_ShutdownRequested{false}; + std::atomic<bool> m_IsOpen{true}; + std::atomic<bool> m_CloseSent{false}; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp new file mode 100644 index 000000000..59c46a418 --- /dev/null +++ b/src/zenhttp/servers/wstest.cpp @@ -0,0 +1,994 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS + +# include <zencore/scopeguard.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include <zenhttp/httpserver.h> +# include <zenhttp/httpwsclient.h> +# include <zenhttp/websocket.h> + +# include "httpasio.h" +# include "wsframecodec.h" + +ZEN_THIRD_PARTY_INCLUDES_START +# if ZEN_PLATFORM_WINDOWS +# include <winsock2.h> +# else +# include <poll.h> +# include <sys/socket.h> +# endif +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +# include <atomic> +# include <chrono> +# include <cstring> +# include <random> +# include <string> +# include <string_view> +# include <thread> +# include <vector> + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests: WsFrameCodec +// + +TEST_SUITE_BEGIN("http.wstest"); + +TEST_CASE("websocket.framecodec") +{ + SUBCASE("ComputeAcceptKey RFC 6455 test vector") + { + // RFC 6455 section 4.2.2 example + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - text") + { + std::string_view Text = "Hello, WebSocket!"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + + // Server frames are unmasked — TryParseFrame should handle them + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - binary") + { + std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildFrame - medium payload (126-65535 bytes)") + { + std::vector<uint8_t> Payload(300, 0x42); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildFrame - large payload (>65535 bytes)") + { + std::vector<uint8_t> Payload(70000, 0xAB); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildCloseFrame roundtrip") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure"); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } + + SUBCASE("TryParseFrame - partial data returns invalid") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); + + // Pass only 1 byte — not enough for a frame header + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); + CHECK_FALSE(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, 0u); + } + + SUBCASE("TryParseFrame - empty payload") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK(Result.Payload.empty()); + } + + SUBCASE("TryParseFrame - masked client frame") + { + // Build a masked frame manually as a client would send + // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello" + uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D}; + uint8_t MaskedPayload[5] = {}; + const char* Original = "Hello"; + for (int i = 0; i < 5; ++i) + { + MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4]; + } + + std::vector<uint8_t> Frame; + Frame.push_back(0x81); // FIN + text + Frame.push_back(0x85); // MASK + len=5 + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), 5u); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv); + } + + SUBCASE("BuildMaskedFrame roundtrip - text") + { + std::string_view Text = "Hello, masked WebSocket!"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + + // Verify mask bit is set + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildMaskedFrame roundtrip - binary") + { + std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)") + { + std::vector<uint8_t> Payload(300, 0x42); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)") + { + std::vector<uint8_t> Payload(70000, 0xAB); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildMaskedCloseFrame roundtrip") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: WebSocket over ASIO +// + +namespace { + + /** + * Helper: Build a masked client-to-server frame per RFC 6455 + */ + std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) + { + std::vector<uint8_t> Frame; + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (Payload.size() < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size())); + } + else if (Payload.size() <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF)); + } + } + + // Mask key (use a fixed key for deterministic tests) + uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < Payload.size(); ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; + } + + std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text) + { + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + return BuildMaskedFrame(WebSocketOpcode::kText, Payload); + } + + std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code) + { + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); + } + + /** + * Test service that implements IWebSocketHandler + */ + struct WsTestService : public HttpService, public IWebSocketHandler + { + const char* BaseUri() const override { return "/wstest/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest"); + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + { + m_OpenCount.fetch_add(1); + + m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); + } + + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + + // Echo the message back + Conn.SendText(Text); + } + } + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + + m_ConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_Connections.erase(It, m_Connections.end()); + }); + } + + void SendToAll(std::string_view Text) + { + RwLock::SharedLockScope _(m_ConnectionsLock); + for (auto& Conn : m_Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Text); + } + } + } + + std::atomic<int> m_OpenCount{0}; + std::atomic<int> m_MessageCount{0}; + std::atomic<int> m_CloseCount{0}; + std::atomic<uint16_t> m_LastCloseCode{0}; + std::string m_LastMessage; + + RwLock m_ConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_Connections; + }; + + /** + * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket + * + * Returns true on success (101 response), false otherwise. + */ + bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port) + { + // Send HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << Path << " HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + // Read the response (look for "101") + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + return Response.find("101") != std::string::npos; + } + + /** + * Helper: Read a single server-to-client frame from a socket + * + * Uses a background thread with a synchronous ASIO read and a timeout. + */ + WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000) + { + std::vector<uint8_t> Buffer; + WsFrameParseResult Result; + std::atomic<bool> Done{false}; + + std::thread Reader([&] { + while (!Done.load()) + { + uint8_t Tmp[4096]; + asio::error_code Ec; + size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); + if (Ec || BytesRead == 0) + { + break; + } + + Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (Frame.IsValid) + { + Result = std::move(Frame); + Done.store(true); + return; + } + } + }); + + auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs); + while (!Done.load() && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + if (!Done.load()) + { + // Timeout — cancel the read + asio::error_code Ec; + Sock.cancel(Ec); + } + + if (Reader.joinable()) + { + Reader.join(); + } + + return Result; + } + +} // anonymous namespace + +TEST_CASE("websocket.integration") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + // Give server a moment to start accepting + Sleep(100); + + SUBCASE("handshake succeeds with 101") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + CHECK(Ok); + + Sleep(50); + CHECK_EQ(TestService.m_OpenCount.load(), 1); + + Sock.close(); + } + + SUBCASE("normal HTTP still works alongside WebSocket service") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + // Send a normal HTTP GET (not upgrade) + std::string HttpReq = fmt::format( + "GET /wstest/hello HTTP/1.1\r\n" + "Host: 127.0.0.1:{}\r\n" + "Connection: close\r\n" + "\r\n", + Port); + + asio::write(Sock, asio::buffer(HttpReq)); + + asio::streambuf ResponseBuf; + asio::error_code Ec; + asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + CHECK(Response.find("200") != std::string::npos); + } + + SUBCASE("echo message roundtrip") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a text message (masked, as client) + std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test"); + asio::write(Sock, asio::buffer(Frame)); + + // Read the echo reply + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, "ping test"sv); + CHECK_EQ(TestService.m_MessageCount.load(), 1); + CHECK_EQ(TestService.m_LastMessage, "ping test"); + + Sock.close(); + } + + SUBCASE("server push to client") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Server pushes a message + TestService.SendToAll("server says hello"); + + WsFrameParseResult Frame = ReadOneFrame(Sock); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "server says hello"sv); + + Sock.close(); + } + + SUBCASE("client close handshake") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send close frame + std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000); + asio::write(Sock, asio::buffer(CloseFrame)); + + // Server should echo close back + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + + Sleep(50); + CHECK_EQ(TestService.m_CloseCount.load(), 1); + CHECK_EQ(TestService.m_LastCloseCode.load(), 1000); + + Sock.close(); + } + + SUBCASE("multiple concurrent connections") + { + constexpr int NumClients = 5; + + asio::io_context IoCtx; + std::vector<asio::ip::tcp::socket> Sockets; + + for (int i = 0; i < NumClients; ++i) + { + Sockets.emplace_back(IoCtx); + Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port); + REQUIRE(Ok); + } + + Sleep(100); + CHECK_EQ(TestService.m_OpenCount.load(), NumClients); + + // Broadcast from server + TestService.SendToAll("broadcast"); + + // Each client should receive the message + for (int i = 0; i < NumClients; ++i) + { + WsFrameParseResult Frame = ReadOneFrame(Sockets[i]); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "broadcast"sv); + } + + // Close all + for (auto& S : Sockets) + { + S.close(); + } + } + + SUBCASE("service without IWebSocketHandler rejects upgrade") + { + // Register a plain HTTP service (no WebSocket) + struct PlainService : public HttpService + { + const char* BaseUri() const override { return "/plain/"; } + void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); } + }; + + PlainService Plain; + Server->RegisterService(Plain); + + Sleep(50); + + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + // Attempt WebSocket upgrade on the plain service + ExtendableStringBuilder<512> Request; + Request << "GET /plain/ws HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + // Should NOT get 101 — should fall through to normal request handling + CHECK(Response.find("101") == std::string::npos); + + Sock.close(); + } + + SUBCASE("ping/pong auto-response") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a ping frame with payload "test" + std::string_view PingPayload = "test"; + std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size()); + std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData); + asio::write(Sock, asio::buffer(PingFrame)); + + // Should receive a pong with the same payload + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong); + CHECK_EQ(Reply.Payload.size(), 4u); + std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(PongText, "test"sv); + + Sock.close(); + } + + SUBCASE("multiple messages in sequence") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::make_address("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + for (int i = 0; i < 10; ++i) + { + std::string Msg = fmt::format("message {}", i); + std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg); + asio::write(Sock, asio::buffer(Frame)); + + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, Msg); + } + + CHECK_EQ(TestService.m_MessageCount.load(), 10); + + Sock.close(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: HttpWsClient +// + +namespace { + + struct TestWsClientHandler : public IWsClientHandler + { + void OnWsOpen() override { m_OpenCount.fetch_add(1); } + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + } + m_MessageCount.fetch_add(1); + } + + void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + } + + std::atomic<int> m_OpenCount{0}; + std::atomic<int> m_MessageCount{0}; + std::atomic<int> m_CloseCount{0}; + std::atomic<uint16_t> m_LastCloseCode{0}; + std::string m_LastMessage; + }; + +} // anonymous namespace + +TEST_CASE("websocket.client") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello from client"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello from client"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + // The server echoes the close frame, which triggers OnWsClose on the client side + // with the server's close code. Allow the connection to settle. + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } + + SUBCASE("connect to bad port") + { + TestWsClientHandler Handler; + std::string Url = "ws://127.0.0.1:1/wstest/ws"; + + HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1006); + CHECK_EQ(Handler.m_OpenCount.load(), 0); + } + + SUBCASE("server-initiated close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + + // Copy connections then close them outside the lock to avoid deadlocking + // with OnWebSocketClose which acquires an exclusive lock + std::vector<Ref<WebSocketConnection>> Conns; + TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; }); + for (auto& Conn : Conns) + { + Conn->Close(1001, "going away"); + } + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1001); + CHECK_FALSE(Client.IsOpen()); + } +} + +TEST_CASE("websocket.client.unixsocket") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + std::string SocketPath = (TmpDir.Path() / "ws.sock").string(); + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{.UnixSocketPath = SocketPath}); + + int Port = Server->Initialize(0, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close over unix socket") + { + TestWsClientHandler Handler; + HttpWsClientSettings Settings; + Settings.UnixSocketPath = SocketPath; + + HttpWsClient Client("ws://localhost/wstest/ws", Handler, Settings); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello over unix socket"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello over unix socket"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } +} + +TEST_SUITE_END(); + +void +websocket_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_WITH_TESTS diff --git a/src/zenhttp/transports/asiotransport.cpp b/src/zenhttp/transports/asiotransport.cpp index 23ac1bc8b..d5413b9af 100644 --- a/src/zenhttp/transports/asiotransport.cpp +++ b/src/zenhttp/transports/asiotransport.cpp @@ -47,10 +47,10 @@ private: uint16_t m_BasePort = 8558; int m_ThreadCount = 0; - asio::io_service m_IoService; - asio::io_service::work m_Work{m_IoService}; - std::unique_ptr<AsioTransportAcceptor> m_Acceptor; - std::vector<std::thread> m_ThreadPool; + asio::io_context m_IoService; + asio::executor_work_guard<asio::io_context::executor_type> m_Work{m_IoService.get_executor()}; + std::unique_ptr<AsioTransportAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; }; struct AsioTransportConnection : public TransportConnection, std::enable_shared_from_this<AsioTransportConnection> @@ -85,7 +85,7 @@ private: struct AsioTransportAcceptor { - AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort); + AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_context& IoService, uint16_t BasePort); ~AsioTransportAcceptor(); void Start(); @@ -95,7 +95,7 @@ struct AsioTransportAcceptor private: TransportServer* m_ServerInterface = nullptr; - asio::io_service& m_IoService; + asio::io_context& m_IoService; asio::ip::tcp::acceptor m_Acceptor; std::atomic<bool> m_IsStopped{false}; @@ -104,7 +104,7 @@ private: ////////////////////////////////////////////////////////////////////////// -AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort) +AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_context& IoService, uint16_t BasePort) : m_ServerInterface(ServerInterface) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp index 9135d5425..489324aba 100644 --- a/src/zenhttp/transports/dlltransport.cpp +++ b/src/zenhttp/transports/dlltransport.cpp @@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa void DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message) { - logging::level::LogLevel Level; - // clang-format off switch (PluginLogLevel) { - case LogLevel::Trace: Level = logging::level::Trace; break; - case LogLevel::Debug: Level = logging::level::Debug; break; - case LogLevel::Info: Level = logging::level::Info; break; - case LogLevel::Warn: Level = logging::level::Warn; break; - case LogLevel::Err: Level = logging::level::Err; break; - case LogLevel::Critical: Level = logging::level::Critical; break; - default: Level = logging::level::Off; break; + case LogLevel::Trace: + ZEN_TRACE("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Debug: + ZEN_DEBUG("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Info: + ZEN_INFO("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Warn: + ZEN_WARN("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Err: + ZEN_ERROR("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Critical: + ZEN_CRITICAL("[{}] {}", m_PluginName, Message); + return; + + default: + ZEN_UNUSED(Message); + break; } - // clang-format on - ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message) } uint32_t diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 78876d21b..b4c65ea96 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -6,11 +6,22 @@ target('zenhttp') add_headerfiles("**.h") add_files("**.cpp") add_files("servers/httpsys.cpp", {unity_ignored=true}) + add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) - add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") + add_deps("zencore", "zentelemetry", "transport-sdk", "asio") + if has_config("zencpr") then + add_deps("cpr") + else + remove_files("clients/httpclientcpr.cpp") + end add_packages("http_parser", "json11") add_options("httpsys") + if is_plat("linux", "macosx") then + add_packages("openssl3") + end + if is_plat("linux") then add_syslinks("dl") -- TODO: is libdl needed? end + diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index 0b5408453..3ac8eea8d 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -16,8 +16,10 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpclient_test_forcelink(); forcelink_packageformat(); passwordsecurity_forcelink(); + websocket_forcelink(); } } // namespace zen |