diff options
| author | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
|---|---|---|
| committer | Liam Mitchell <[email protected]> | 2026-03-09 19:06:36 -0700 |
| commit | d1abc50ee9d4fb72efc646e17decafea741caa34 (patch) | |
| tree | e4288e00f2f7ca0391b83d986efcb69d3ba66a83 /src/zenhttp | |
| parent | Allow requests with invalid content-types unless specified in command line or... (diff) | |
| parent | updated chunk–block analyser (#818) (diff) | |
| download | zen-d1abc50ee9d4fb72efc646e17decafea741caa34.tar.xz zen-d1abc50ee9d4fb72efc646e17decafea741caa34.zip | |
Merge branch 'main' into lm/restrict-content-type
Diffstat (limited to 'src/zenhttp')
49 files changed, 7522 insertions, 673 deletions
diff --git a/src/zenhttp/auth/oidc.cpp b/src/zenhttp/auth/oidc.cpp index 38e7586ad..23bbc17e8 100644 --- a/src/zenhttp/auth/oidc.cpp +++ b/src/zenhttp/auth/oidc.cpp @@ -32,6 +32,25 @@ namespace details { using namespace std::literals; +static std::string +FormUrlEncode(std::string_view Input) +{ + std::string Result; + Result.reserve(Input.size()); + for (char C : Input) + { + if ((C >= 'A' && C <= 'Z') || (C >= 'a' && C <= 'z') || (C >= '0' && C <= '9') || C == '-' || C == '_' || C == '.' || C == '~') + { + Result.push_back(C); + } + else + { + Result.append(fmt::format("%{:02X}", static_cast<uint8_t>(C))); + } + } + return Result; +} + OidcClient::OidcClient(const OidcClient::Options& Options) { m_BaseUrl = std::string(Options.BaseUrl); @@ -67,6 +86,8 @@ OidcClient::Initialize() .TokenEndpoint = Json["token_endpoint"].string_value(), .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .EndSessionEndpoint = Json["end_session_endpoint"].string_value(), + .DeviceAuthorizationEndpoint = Json["device_authorization_endpoint"].string_value(), .JwksUri = Json["jwks_uri"].string_value(), .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), @@ -81,7 +102,8 @@ OidcClient::Initialize() OidcClient::RefreshTokenResult OidcClient::RefreshToken(std::string_view RefreshToken) { - const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + const std::string Body = + fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", FormUrlEncode(RefreshToken), FormUrlEncode(m_ClientId)); HttpClient Http{m_Config.TokenEndpoint}; diff --git a/src/zenhttp/clients/httpclientcommon.cpp b/src/zenhttp/clients/httpclientcommon.cpp index 47425e014..6f4c67dd0 100644 --- a/src/zenhttp/clients/httpclientcommon.cpp +++ b/src/zenhttp/clients/httpclientcommon.cpp @@ -142,7 +142,10 @@ namespace detail { DataSize -= CopySize; if (m_CacheBufferOffset == CacheBufferSize) { - AppendData(m_CacheBuffer, CacheBufferSize); + if (std::error_code Ec = AppendData(m_CacheBuffer, CacheBufferSize)) + { + return Ec; + } if (DataSize > 0) { ZEN_ASSERT(DataSize < CacheBufferSize); @@ -382,6 +385,177 @@ namespace detail { return Result; } + MultipartBoundaryParser::MultipartBoundaryParser() : BoundaryEndMatcher("--"), HeaderEndMatcher("\r\n\r\n") {} + + bool MultipartBoundaryParser::Init(const std::string_view ContentTypeHeaderValue) + { + std::string LowerCaseValue = ToLower(ContentTypeHeaderValue); + if (LowerCaseValue.starts_with("multipart/byteranges")) + { + size_t BoundaryPos = LowerCaseValue.find("boundary="); + if (BoundaryPos != std::string::npos) + { + // Yes, we do a substring of the non-lowercase value string as we want the exact boundary string + std::string_view BoundaryName = std::string_view(ContentTypeHeaderValue).substr(BoundaryPos + 9); + size_t BoundaryEnd = std::string::npos; + while (!BoundaryName.empty() && BoundaryName[0] == ' ') + { + BoundaryName = BoundaryName.substr(1); + } + if (!BoundaryName.empty()) + { + if (BoundaryName.size() > 2 && BoundaryName.front() == '"' && BoundaryName.back() == '"') + { + BoundaryEnd = BoundaryName.find('"', 1); + if (BoundaryEnd != std::string::npos) + { + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(1, BoundaryEnd - 1))); + return true; + } + } + else + { + BoundaryEnd = BoundaryName.find_first_of(" \r\n"); + BoundaryBeginMatcher.Init(fmt::format("\r\n--{}", BoundaryName.substr(0, BoundaryEnd))); + return true; + } + } + } + } + return false; + } + + void MultipartBoundaryParser::ParseInput(std::string_view data) + { + const char* InputPtr = data.data(); + size_t InputLength = data.length(); + size_t ScanPos = 0; + while (ScanPos < InputLength) + { + const char ScanChar = InputPtr[ScanPos]; + if (BoundaryBeginMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + if (PayloadOffset + ScanPos < (BoundaryBeginMatcher.GetMatchEndOffset() + BoundaryEndMatcher.GetMatchString().length())) + { + BoundaryEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + if (BoundaryEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + break; + } + } + + BoundaryHeader.Append(ScanChar); + + HeaderEndMatcher.Match(PayloadOffset + ScanPos, ScanChar); + + if (HeaderEndMatcher.MatchState == IncrementalStringMatcher::EMatchState::Complete) + { + const uint64_t HeaderStartOffset = BoundaryBeginMatcher.GetMatchEndOffset(); + const uint64_t HeaderEndOffset = HeaderEndMatcher.GetMatchStartOffset(); + const uint64_t HeaderLength = HeaderEndOffset - HeaderStartOffset; + std::string_view HeaderText(BoundaryHeader.ToView().substr(0, HeaderLength)); + + uint64_t OffsetInPayload = PayloadOffset + ScanPos + 1; + + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType = HttpContentType::kBinary; + + ForEachStrTok(HeaderText, "\r\n", [&](std::string_view Line) { + const std::pair<std::string_view, std::string_view> KeyAndValue = GetHeaderKeyAndValue(Line); + const std::string_view Key = KeyAndValue.first; + const std::string_view Value = KeyAndValue.second; + if (Key == "Content-Range") + { + std::pair<uint64_t, uint64_t> ContentRange = ParseContentRange(Value); + if (ContentRange.second != 0) + { + RangeOffset = ContentRange.first; + RangeLength = ContentRange.second; + } + } + else if (Key == "Content-Type") + { + ContentType = ParseContentType(Value); + } + + return true; + }); + + if (RangeLength > 0) + { + Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = OffsetInPayload, + .RangeOffset = RangeOffset, + .RangeLength = RangeLength, + .ContentType = ContentType}); + } + + BoundaryBeginMatcher.Reset(); + HeaderEndMatcher.Reset(); + BoundaryEndMatcher.Reset(); + BoundaryHeader.Reset(); + } + } + else + { + BoundaryBeginMatcher.Match(PayloadOffset + ScanPos, ScanChar); + } + ScanPos++; + } + PayloadOffset += InputLength; + } + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString) + { + size_t DelimiterPos = HeaderString.find(':'); + if (DelimiterPos != std::string::npos) + { + std::string_view Key = HeaderString.substr(0, DelimiterPos); + constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); + Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); + + std::string_view Value = HeaderString.substr(DelimiterPos + 1); + Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); + Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); + return std::make_pair(Key, Value); + } + return std::make_pair(HeaderString, std::string_view{}); + } + + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value) + { + if (Value.starts_with("bytes ")) + { + size_t RangeSplitPos = Value.find('-', 6); + if (RangeSplitPos != std::string::npos) + { + size_t RangeEndLength = Value.find('/', RangeSplitPos + 1); + if (RangeEndLength == std::string::npos) + { + RangeEndLength = Value.length() - (RangeSplitPos + 1); + } + else + { + RangeEndLength = RangeEndLength - (RangeSplitPos + 1); + } + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(Value.substr(6, RangeSplitPos - 6)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(Value.substr(RangeSplitPos + 1, RangeEndLength)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + uint64_t RangeOffset = RequestedRangeStart.value(); + uint64_t RangeLength = RequestedRangeEnd.value() - RangeOffset + 1; + return std::make_pair(RangeOffset, RangeLength); + } + } + } + return {0, 0}; + } + } // namespace detail } // namespace zen @@ -423,6 +597,8 @@ namespace testutil { } // namespace testutil +TEST_SUITE_BEGIN("http.httpclientcommon"); + TEST_CASE("BufferedReadFileStream") { ScopedTemporaryDirectory TmpDir; @@ -470,5 +646,150 @@ TEST_CASE("CompositeBufferReadStream") CHECK_EQ(IoHash::HashBuffer(Data), testutil::HashComposite(Data)); } +TEST_CASE("MultipartBoundaryParser") +{ + uint64_t Range1Offset = 2638; + uint64_t Range1Length = (5111437 - Range1Offset) + 1; + + uint64_t Range2Offset = 5118199; + uint64_t Range2Length = (9147741 - Range2Offset) + 1; + + std::string_view ContentTypeHeaderValue1 = "multipart/byteranges; boundary=00000000000000019229"; + std::string_view ContentTypeHeaderValue2 = "multipart/byteranges; boundary=\"00000000000000019229\""; + + { + std::string_view Example1 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/44369878\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample1; + ParserExample1.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 7; + for (size_t Offset = 0; Offset < Example1.length(); Offset += InputWindow) + { + ParserExample1.ParseInput(Example1.substr(Offset, Min(Example1.length() - Offset, InputWindow))); + } + + CHECK(ParserExample1.Boundaries.size() == 2); + + CHECK(ParserExample1.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample1.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample1.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample1.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example2 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample2; + ParserExample2.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example2.length(); Offset += InputWindow) + { + std::string_view Window = Example2.substr(Offset, Min(Example2.length() - Offset, InputWindow)); + ParserExample2.ParseInput(Window); + } + + CHECK(ParserExample2.Boundaries.size() == 2); + + CHECK(ParserExample2.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample2.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample2.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample2.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example3 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "\r\n" + "datadatadatadata" + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita"; + + detail::MultipartBoundaryParser ParserExample3; + ParserExample3.Init(ContentTypeHeaderValue2); + + const size_t InputWindow = 31; + for (size_t Offset = 0; Offset < Example3.length(); Offset += InputWindow) + { + ParserExample3.ParseInput(Example3.substr(Offset, Min(Example3.length() - Offset, InputWindow))); + } + + CHECK(ParserExample3.Boundaries.size() == 2); + + CHECK(ParserExample3.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample3.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample3.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample3.Boundaries[1].RangeLength == Range2Length); + } + + { + std::string_view Example4 = + "\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 2638-5111437/*\r\n" + "Not: really\r\n" + "\r\n" + "datadatadatadata" + "\r\n--000000000bait0019229\r\n" + "\r\n--00\r\n--000000000bait001922\r\n" + "\r\n\r\n\r\r\n--00000000000000019229\r\n" + "Content-Type: application/x-ue-comp\r\n" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n" + "ditaditadita" + "Content-Type: application/x-ue-comp\r\n" + "ditaditadita" + "Content-Range: bytes 5118199-9147741/44369878\r\n" + "\r\n---\r\n--00000000000000019229--"; + + detail::MultipartBoundaryParser ParserExample4; + ParserExample4.Init(ContentTypeHeaderValue1); + + const size_t InputWindow = 3; + for (size_t Offset = 0; Offset < Example4.length(); Offset += InputWindow) + { + std::string_view Window = Example4.substr(Offset, Min(Example4.length() - Offset, InputWindow)); + ParserExample4.ParseInput(Window); + } + + CHECK(ParserExample4.Boundaries.size() == 2); + + CHECK(ParserExample4.Boundaries[0].RangeOffset == Range1Offset); + CHECK(ParserExample4.Boundaries[0].RangeLength == Range1Length); + CHECK(ParserExample4.Boundaries[1].RangeOffset == Range2Offset); + CHECK(ParserExample4.Boundaries[1].RangeLength == Range2Length); + } +} + +TEST_SUITE_END(); + } // namespace zen #endif diff --git a/src/zenhttp/clients/httpclientcommon.h b/src/zenhttp/clients/httpclientcommon.h index 1d0b7f9ea..5ed946541 100644 --- a/src/zenhttp/clients/httpclientcommon.h +++ b/src/zenhttp/clients/httpclientcommon.h @@ -3,6 +3,7 @@ #pragma once #include <zencore/compositebuffer.h> +#include <zencore/string.h> #include <zencore/trace.h> #include <zenhttp/httpclient.h> @@ -87,7 +88,7 @@ namespace detail { std::error_code Write(std::string_view DataString); IoBuffer DetachToIoBuffer(); IoBuffer BorrowIoBuffer(); - inline uint64_t GetSize() const { return m_WriteOffset; } + inline uint64_t GetSize() const { return m_WriteOffset + m_CacheBufferOffset; } void ResetWritePos(uint64_t WriteOffset); private: @@ -143,6 +144,118 @@ namespace detail { uint64_t m_BytesLeftInSegment; }; + class IncrementalStringMatcher + { + public: + enum class EMatchState + { + None, + Partial, + Complete + }; + + EMatchState MatchState = EMatchState::None; + + IncrementalStringMatcher() {} + + IncrementalStringMatcher(std::string&& InMatchString) : MatchString(std::move(InMatchString)) + { + RawMatchString = MatchString.data(); + } + + void Init(std::string&& InMatchString) + { + MatchString = std::move(InMatchString); + RawMatchString = MatchString.data(); + } + + inline void Reset() + { + MatchLength = 0; + MatchStartOffset = 0; + MatchState = EMatchState::None; + } + + inline uint64_t GetMatchEndOffset() const + { + if (MatchState == EMatchState::Complete) + { + return MatchStartOffset + MatchString.length(); + } + return 0; + } + + inline uint64_t GetMatchStartOffset() const + { + ZEN_ASSERT(MatchState == EMatchState::Complete); + return MatchStartOffset; + } + + void Match(uint64_t Offset, char C) + { + ZEN_ASSERT_SLOW(RawMatchString != nullptr); + + if (MatchState == EMatchState::Complete) + { + Reset(); + } + if (C == RawMatchString[MatchLength]) + { + if (MatchLength == 0) + { + MatchStartOffset = Offset; + } + MatchLength++; + if (MatchLength == MatchString.length()) + { + MatchState = EMatchState::Complete; + } + else + { + MatchState = EMatchState::Partial; + } + } + else if (MatchLength != 0) + { + Reset(); + Match(Offset, C); + } + else + { + Reset(); + } + } + inline const std::string& GetMatchString() const { return MatchString; } + + private: + std::string MatchString; + const char* RawMatchString = nullptr; + uint64_t MatchLength = 0; + + uint64_t MatchStartOffset = 0; + }; + + class MultipartBoundaryParser + { + public: + std::vector<HttpClient::Response::MultipartBoundary> Boundaries; + + MultipartBoundaryParser(); + bool Init(const std::string_view ContentTypeHeaderValue); + void ParseInput(std::string_view data); + + private: + IncrementalStringMatcher BoundaryBeginMatcher; + IncrementalStringMatcher BoundaryEndMatcher; + IncrementalStringMatcher HeaderEndMatcher; + + ExtendableStringBuilder<64> BoundaryHeader; + uint64_t PayloadOffset = 0; + }; + + std::pair<std::string_view, std::string_view> GetHeaderKeyAndValue(std::string_view HeaderString); + std::pair<uint64_t, uint64_t> ParseContentRange(std::string_view Value); + } // namespace detail } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.cpp b/src/zenhttp/clients/httpclientcpr.cpp index 5d92b3b6b..14e40b02a 100644 --- a/src/zenhttp/clients/httpclientcpr.cpp +++ b/src/zenhttp/clients/httpclientcpr.cpp @@ -12,6 +12,7 @@ #include <zencore/session.h> #include <zencore/stream.h> #include <zenhttp/packageformat.h> +#include <algorithm> namespace zen { @@ -23,6 +24,21 @@ CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& Connecti static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; +bool +HttpClient::ErrorContext::IsConnectionError() const +{ + switch (static_cast<cpr::ErrorCode>(ErrorCode)) + { + case cpr::ErrorCode::CONNECTION_FAILURE: + case cpr::ErrorCode::OPERATION_TIMEDOUT: + case cpr::ErrorCode::HOST_RESOLUTION_FAILURE: + case cpr::ErrorCode::PROXY_RESOLUTION_FAILURE: + return true; + default: + return false; + } +} + // If we want to support different HTTP client implementations then we'll need to make this more abstract HttpClientError::ResponseClass @@ -149,6 +165,18 @@ CprHttpClient::CprHttpClient(std::string_view BaseUri, { } +bool +CprHttpClient::ShouldLogErrorCode(HttpResponseCode ResponseCode) const +{ + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + // Quiet + return false; + } + const auto& Expected = m_ConnectionSettings.ExpectedErrorCodes; + return std::find(Expected.begin(), Expected.end(), ResponseCode) == Expected.end(); +} + CprHttpClient::~CprHttpClient() { ZEN_TRACE_CPU("CprHttpClient::~CprHttpClient"); @@ -162,10 +190,11 @@ CprHttpClient::~CprHttpClient() } HttpClient::Response -CprHttpClient::ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload) +CprHttpClient::ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { // This ends up doing a memcpy, would be good to get rid of it by streaming results // into buffer directly @@ -174,30 +203,37 @@ CprHttpClient::ResponseWithPayload(std::string_view SessionId, if (auto It = HttpResponse.header.find("Content-Type"); It != HttpResponse.header.end()) { const HttpContentType ContentType = ParseContentType(It->second); - ResponseBuffer.SetContentType(ContentType); } - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - - if (!Quiet) + if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) { - if (!IsHttpSuccessCode(WorkResponseCode) && WorkResponseCode != HttpResponseCode::NotFound) + if (ShouldLogErrorCode(WorkResponseCode)) { ZEN_WARN("HttpClient request failed (session: {}): {}", SessionId, HttpResponse); } } + std::sort(BoundaryPositions.begin(), + BoundaryPositions.end(), + [](const HttpClient::Response::MultipartBoundary& Lhs, const HttpClient::Response::MultipartBoundary& Rhs) { + return Lhs.RangeOffset < Rhs.RangeOffset; + }); + return HttpClient::Response{.StatusCode = WorkResponseCode, .ResponsePayload = std::move(ResponseBuffer), .Header = HttpClient::KeyValueMap(HttpResponse.header.begin(), HttpResponse.header.end()), .UploadedBytes = gsl::narrow<int64_t>(HttpResponse.uploaded_bytes), .DownloadedBytes = gsl::narrow<int64_t>(HttpResponse.downloaded_bytes), - .ElapsedSeconds = HttpResponse.elapsed}; + .ElapsedSeconds = HttpResponse.elapsed, + .Ranges = std::move(BoundaryPositions)}; } HttpClient::Response -CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload) +CprHttpClient::CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions) { const HttpResponseCode WorkResponseCode = HttpResponseCode(HttpResponse.status_code); if (HttpResponse.error) @@ -235,7 +271,7 @@ CprHttpClient::CommonResponse(std::string_view SessionId, cpr::Response&& HttpRe } else { - return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload)); + return ResponseWithPayload(SessionId, std::move(HttpResponse), WorkResponseCode, std::move(Payload), std::move(BoundaryPositions)); } } @@ -346,8 +382,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -385,8 +420,7 @@ CprHttpClient::DoWithRetry(std::string_view SessionId, } Sleep(100 * (Attempt + 1)); Attempt++; - const bool Quiet = m_CheckIfAbortFunction && m_CheckIfAbortFunction(); - if (!Quiet) + if (ShouldLogErrorCode(HttpResponseCode(Result.status_code))) { ZEN_INFO("{} Attempt {}/{}", CommonResponse(SessionId, std::move(Result), {}).ErrorMessage("Retry"), @@ -621,7 +655,7 @@ CprHttpClient::TransactPackage(std::string_view Url, CbPackage Package, const Ke ResponseBuffer.SetContentType(ContentType); } - return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; + return {.StatusCode = HttpResponseCode(FilterResponse.status_code), .ResponsePayload = std::move(ResponseBuffer)}; } ////////////////////////////////////////////////////////////////////////// @@ -896,236 +930,287 @@ CprHttpClient::Download(std::string_view Url, const std::filesystem::path& TempF std::string PayloadString; std::unique_ptr<detail::TempPayloadFile> PayloadFile; - cpr::Response Response = DoWithRetry( - m_SessionId, - [&]() { - auto GetHeader = [&](std::string header) -> std::pair<std::string, std::string> { - size_t DelimiterPos = header.find(':'); - if (DelimiterPos != std::string::npos) - { - std::string Key = header.substr(0, DelimiterPos); - constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); - Key = AsciiSet::TrimSuffixWith(Key, WhitespaceCharacters); - Key = AsciiSet::TrimPrefixWith(Key, WhitespaceCharacters); - - std::string Value = header.substr(DelimiterPos + 1); - Value = AsciiSet::TrimSuffixWith(Value, WhitespaceCharacters); - Value = AsciiSet::TrimPrefixWith(Value, WhitespaceCharacters); - - return std::make_pair(Key, Value); - } - return std::make_pair(header, ""); - }; - - auto DownloadCallback = [&](std::string data, intptr_t) { - if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) - { - return false; - } - if (PayloadFile) - { - ZEN_ASSERT(PayloadString.empty()); - std::error_code Ec = PayloadFile->Write(data); - if (Ec) - { - ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - return false; - } - } - else - { - PayloadString.append(data); - } - return true; - }; - - uint64_t RequestedContentLength = (uint64_t)-1; - if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) - { - if (RangeIt->second.starts_with("bytes")) - { - size_t RangeStartPos = RangeIt->second.find('=', 5); - if (RangeStartPos != std::string::npos) - { - RangeStartPos++; - size_t RangeSplitPos = RangeIt->second.find('-', RangeStartPos); - if (RangeSplitPos != std::string::npos) - { - std::optional<size_t> RequestedRangeStart = - ParseInt<size_t>(RangeIt->second.substr(RangeStartPos, RangeSplitPos - RangeStartPos)); - std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeIt->second.substr(RangeStartPos + 1)); - if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) - { - RequestedContentLength = RequestedRangeEnd.value() - 1; - } - } - } - } - } - - cpr::Response Response; - { - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (Header.first == "Content-Length"sv) - { - std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); - if (ContentLength.has_value()) - { - if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) - { - PayloadFile = std::make_unique<detail::TempPayloadFile>(); - std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); - if (Ec) - { - ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", - TempFolderPath.string(), - Ec.message()); - PayloadFile.reset(); - } - } - else - { - PayloadString.reserve(ContentLength.value()); - } - } - } - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - return 1; - }; - - Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - } - if (m_ConnectionSettings.AllowResume) - { - auto SupportsRanges = [](const cpr::Response& Response) -> bool { - if (Response.header.find("Content-Range") != Response.header.end()) - { - return true; - } - if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) - { - return It->second == "bytes"sv; - } - return false; - }; - - auto ShouldResume = [&SupportsRanges](const cpr::Response& Response) -> bool { - if (ShouldRetry(Response)) - { - return SupportsRanges(Response); - } - return false; - }; - - if (ShouldResume(Response)) - { - auto It = Response.header.find("Content-Length"); - if (It != Response.header.end()) - { - uint64_t ContentLength = RequestedContentLength; - if (ContentLength == uint64_t(-1)) - { - if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) - { - ContentLength = ParsedContentLength.value(); - } - } - - std::vector<std::pair<std::string, std::string>> ReceivedHeaders; - - auto HeaderCallback = [&](std::string header, intptr_t) { - std::pair<std::string, std::string> Header = GetHeader(header); - if (!Header.first.empty()) - { - ReceivedHeaders.emplace_back(std::move(Header)); - } - - if (Header.first == "Content-Range"sv) - { - if (Header.second.starts_with("bytes "sv)) - { - size_t RangeStartEnd = Header.second.find('-', 6); - if (RangeStartEnd != std::string::npos) - { - const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); - if (Start) - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - if (Start.value() == DownloadedSize) - { - return 1; - } - else if (Start.value() > DownloadedSize) - { - return 0; - } - if (PayloadFile) - { - PayloadFile->ResetWritePos(Start.value()); - } - else - { - PayloadString = PayloadString.substr(0, Start.value()); - } - return 1; - } - } - } - return 0; - } - return 1; - }; - - KeyValueMap HeadersWithRange(AdditionalHeader); - do - { - uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); - - std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); - if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) - { - if (RangeIt->second == Range) - { - // If we didn't make any progress, abort - break; - } - } - HeadersWithRange.Entries.insert_or_assign("Range", Range); - - Session Sess = - AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); - Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); - for (const std::pair<std::string, std::string>& H : ReceivedHeaders) - { - Response.header.insert_or_assign(H.first, H.second); - } - ReceivedHeaders.clear(); - } while (ShouldResume(Response)); - } - } - } - - if (!PayloadString.empty()) - { - Response.text = std::move(PayloadString); - } - return Response; - }, - PayloadFile); - - return CommonResponse(m_SessionId, std::move(Response), PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}); + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + detail::MultipartBoundaryParser BoundaryParser; + bool IsMultiRangeResponse = false; + + cpr::Response Response = DoWithRetry( + m_SessionId, + [&]() { + // Reset state from any previous attempt + PayloadString.clear(); + PayloadFile.reset(); + BoundaryParser.Boundaries.clear(); + ContentType = HttpContentType::kUnknownContentType; + IsMultiRangeResponse = false; + + auto DownloadCallback = [&](std::string data, intptr_t) { + if (m_CheckIfAbortFunction && m_CheckIfAbortFunction()) + { + return false; + } + + if (IsMultiRangeResponse) + { + BoundaryParser.ParseInput(data); + } + + if (PayloadFile) + { + ZEN_ASSERT(PayloadString.empty()); + std::error_code Ec = PayloadFile->Write(data); + if (Ec) + { + ZEN_WARN("Failed to write to temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + return false; + } + } + else + { + PayloadString.append(data); + } + return true; + }; + + uint64_t RequestedContentLength = (uint64_t)-1; + if (auto RangeIt = AdditionalHeader.Entries.find("Range"); RangeIt != AdditionalHeader.Entries.end()) + { + if (RangeIt->second.starts_with("bytes")) + { + std::string_view RangeValue(RangeIt->second); + size_t RangeStartPos = RangeValue.find('=', 5); + if (RangeStartPos != std::string::npos) + { + RangeStartPos++; + while (RangeStartPos < RangeValue.length() && RangeValue[RangeStartPos] == ' ') + { + RangeStartPos++; + } + RequestedContentLength = 0; + + while (RangeStartPos < RangeValue.length()) + { + size_t RangeEnd = RangeValue.find_first_of(", \r\n", RangeStartPos); + if (RangeEnd == std::string::npos) + { + RangeEnd = RangeValue.length(); + } + + std::string_view RangeString = RangeValue.substr(RangeStartPos, RangeEnd - RangeStartPos); + size_t RangeSplitPos = RangeString.find('-'); + if (RangeSplitPos != std::string::npos) + { + std::optional<size_t> RequestedRangeStart = ParseInt<size_t>(RangeString.substr(0, RangeSplitPos)); + std::optional<size_t> RequestedRangeEnd = ParseInt<size_t>(RangeString.substr(RangeSplitPos + 1)); + if (RequestedRangeStart.has_value() && RequestedRangeEnd.has_value()) + { + RequestedContentLength += RequestedRangeEnd.value() - RequestedRangeStart.value() + 1; + } + } + RangeStartPos = RangeEnd; + while (RangeStartPos != RangeValue.length() && + (RangeValue[RangeStartPos] == ',' || RangeValue[RangeStartPos] == ' ')) + { + RangeStartPos++; + } + } + } + } + } + + cpr::Response Response; + { + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (Header.first == "Content-Length"sv) + { + std::optional<size_t> ContentLength = ParseInt<size_t>(Header.second); + if (ContentLength.has_value()) + { + if (ContentLength.value() > m_ConnectionSettings.MaximumInMemoryDownloadSize) + { + PayloadFile = std::make_unique<detail::TempPayloadFile>(); + std::error_code Ec = PayloadFile->Open(TempFolderPath, ContentLength.value()); + if (Ec) + { + ZEN_WARN("Failed to create temp file in '{}' for HttpClient::Download. Reason: {}", + TempFolderPath.string(), + Ec.message()); + PayloadFile.reset(); + } + } + else + { + PayloadString.reserve(ContentLength.value()); + } + } + } + else if (Header.first == "Content-Type") + { + IsMultiRangeResponse = BoundaryParser.Init(Header.second); + if (!IsMultiRangeResponse) + { + ContentType = ParseContentType(Header.second); + } + } + else if (Header.first == "Content-Range") + { + if (!IsMultiRangeResponse) + { + std::pair<uint64_t, uint64_t> Range = detail::ParseContentRange(Header.second); + if (Range.second != 0) + { + BoundaryParser.Boundaries.push_back(HttpClient::Response::MultipartBoundary{.OffsetInPayload = 0, + .RangeOffset = Range.first, + .RangeLength = Range.second, + .ContentType = ContentType}); + } + } + } + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + return 1; + }; + + Session Sess = AllocSession(m_BaseUri, Url, m_ConnectionSettings, AdditionalHeader, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + } + if (m_ConnectionSettings.AllowResume) + { + auto SupportsRanges = [](const cpr::Response& Response) -> bool { + if (Response.header.find("Content-Range") != Response.header.end()) + { + return true; + } + if (auto It = Response.header.find("Accept-Ranges"); It != Response.header.end()) + { + return It->second == "bytes"sv; + } + return false; + }; + + auto ShouldResume = [&SupportsRanges, &IsMultiRangeResponse](const cpr::Response& Response) -> bool { + if (IsMultiRangeResponse) + { + return false; + } + if (ShouldRetry(Response)) + { + return SupportsRanges(Response); + } + return false; + }; + + if (ShouldResume(Response)) + { + auto It = Response.header.find("Content-Length"); + if (It != Response.header.end()) + { + uint64_t ContentLength = RequestedContentLength; + if (ContentLength == uint64_t(-1)) + { + if (auto ParsedContentLength = ParseInt<int64_t>(It->second); ParsedContentLength.has_value()) + { + ContentLength = ParsedContentLength.value(); + } + } + + std::vector<std::pair<std::string, std::string>> ReceivedHeaders; + + auto HeaderCallback = [&](std::string header, intptr_t) { + const std::pair<std::string_view, std::string_view> Header = detail::GetHeaderKeyAndValue(header); + if (!Header.first.empty()) + { + ReceivedHeaders.emplace_back(std::move(Header)); + } + + if (Header.first == "Content-Range"sv) + { + if (Header.second.starts_with("bytes "sv)) + { + size_t RangeStartEnd = Header.second.find('-', 6); + if (RangeStartEnd != std::string::npos) + { + const auto Start = ParseInt<uint64_t>(Header.second.substr(6, RangeStartEnd - 6)); + if (Start) + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + if (Start.value() == DownloadedSize) + { + return 1; + } + else if (Start.value() > DownloadedSize) + { + return 0; + } + if (PayloadFile) + { + PayloadFile->ResetWritePos(Start.value()); + } + else + { + PayloadString = PayloadString.substr(0, Start.value()); + } + return 1; + } + } + } + return 0; + } + return 1; + }; + + KeyValueMap HeadersWithRange(AdditionalHeader); + do + { + uint64_t DownloadedSize = PayloadFile ? PayloadFile->GetSize() : PayloadString.length(); + + std::string Range = fmt::format("bytes={}-{}", DownloadedSize, DownloadedSize + ContentLength - 1); + if (auto RangeIt = HeadersWithRange.Entries.find("Range"); RangeIt != HeadersWithRange.Entries.end()) + { + if (RangeIt->second == Range) + { + // If we didn't make any progress, abort + break; + } + } + HeadersWithRange.Entries.insert_or_assign("Range", Range); + + Session Sess = + AllocSession(m_BaseUri, Url, m_ConnectionSettings, HeadersWithRange, {}, m_SessionId, GetAccessToken()); + Response = Sess.Download(cpr::WriteCallback{DownloadCallback}, cpr::HeaderCallback{HeaderCallback}); + for (const std::pair<std::string, std::string>& H : ReceivedHeaders) + { + Response.header.insert_or_assign(H.first, H.second); + } + ReceivedHeaders.clear(); + } while (ShouldResume(Response)); + } + } + } + + if (!PayloadString.empty()) + { + Response.text = std::move(PayloadString); + } + return Response; + }, + PayloadFile); + + return CommonResponse(m_SessionId, + std::move(Response), + PayloadFile ? PayloadFile->DetachToIoBuffer() : IoBuffer{}, + std::move(BoundaryParser.Boundaries)); } } // namespace zen diff --git a/src/zenhttp/clients/httpclientcpr.h b/src/zenhttp/clients/httpclientcpr.h index 40af53b5d..752d91add 100644 --- a/src/zenhttp/clients/httpclientcpr.h +++ b/src/zenhttp/clients/httpclientcpr.h @@ -155,14 +155,19 @@ private: std::function<cpr::Response()>&& Func, std::function<bool(cpr::Response& Result)>&& Validate = [](cpr::Response&) { return true; }); + bool ShouldLogErrorCode(HttpResponseCode ResponseCode) const; bool ValidatePayload(cpr::Response& Response, std::unique_ptr<detail::TempPayloadFile>& PayloadFile); - HttpClient::Response CommonResponse(std::string_view SessionId, cpr::Response&& HttpResponse, IoBuffer&& Payload); + HttpClient::Response CommonResponse(std::string_view SessionId, + cpr::Response&& HttpResponse, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions = {}); - HttpClient::Response ResponseWithPayload(std::string_view SessionId, - cpr::Response&& HttpResponse, - const HttpResponseCode WorkResponseCode, - IoBuffer&& Payload); + HttpClient::Response ResponseWithPayload(std::string_view SessionId, + cpr::Response&& HttpResponse, + const HttpResponseCode WorkResponseCode, + IoBuffer&& Payload, + std::vector<HttpClient::Response::MultipartBoundary>&& BoundaryPositions); }; } // namespace zen diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp new file mode 100644 index 000000000..9497dadb8 --- /dev/null +++ b/src/zenhttp/clients/httpwsclient.cpp @@ -0,0 +1,566 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpwsclient.h> + +#include "../servers/wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/logging.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <random> +#include <thread> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct HttpWsClient::Impl +{ + Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_OwnedIoContext(std::make_unique<asio::io_context>()) + , m_IoContext(*m_OwnedIoContext) + { + ParseUrl(Url); + } + + Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings) + : m_Handler(Handler) + , m_Settings(Settings) + , m_Log(logging::Get(Settings.LogCategory)) + , m_IoContext(IoContext) + { + ParseUrl(Url); + } + + ~Impl() + { + // Release work guard so io_context::run() can return + m_WorkGuard.reset(); + + // Close the socket to cancel pending async ops + if (m_Socket) + { + asio::error_code Ec; + m_Socket->close(Ec); + } + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + + void ParseUrl(std::string_view Url) + { + // Expected format: ws://host:port/path + if (Url.substr(0, 5) == "ws://") + { + Url.remove_prefix(5); + } + + auto SlashPos = Url.find('/'); + std::string_view HostPort; + if (SlashPos != std::string_view::npos) + { + HostPort = Url.substr(0, SlashPos); + m_Path = std::string(Url.substr(SlashPos)); + } + else + { + HostPort = Url; + m_Path = "/"; + } + + auto ColonPos = HostPort.find(':'); + if (ColonPos != std::string_view::npos) + { + m_Host = std::string(HostPort.substr(0, ColonPos)); + m_Port = std::string(HostPort.substr(ColonPos + 1)); + } + else + { + m_Host = std::string(HostPort); + m_Port = "80"; + } + } + + void Connect() + { + if (m_OwnedIoContext) + { + m_WorkGuard = std::make_unique<asio::io_context::work>(m_IoContext); + m_IoThread = std::thread([this] { m_IoContext.run(); }); + } + + asio::post(m_IoContext, [this] { DoResolve(); }); + } + + void DoResolve() + { + m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext); + + m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) { + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "resolve failed"); + return; + } + + DoConnect(Results); + }); + } + + void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints) + { + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext); + + // Start connect timeout timer + m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout); + m_Timer->async_wait([this](const asio::error_code& Ec) { + if (!Ec && !m_IsOpen.load(std::memory_order_relaxed)) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port); + if (m_Socket) + { + asio::error_code CloseEc; + m_Socket->close(CloseEc); + } + } + }); + + asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message()); + m_Handler.OnWsClose(1006, "connect failed"); + return; + } + + DoHandshake(); + }); + } + + void DoHandshake() + { + // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded) + uint8_t KeyBytes[16]; + { + static thread_local std::mt19937 s_Rng(std::random_device{}()); + for (int i = 0; i < 4; ++i) + { + uint32_t Val = s_Rng(); + std::memcpy(KeyBytes + i * 4, &Val, 4); + } + } + + char KeyBase64[Base64::GetEncodedDataSize(16) + 1]; + uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64); + KeyBase64[KeyLen] = '\0'; + m_WebSocketKey = std::string(KeyBase64, KeyLen); + + // Build the HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << m_Path << " HTTP/1.1\r\n" + << "Host: " << m_Host << ":" << m_Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n" + << "Sec-WebSocket-Version: 13\r\n"; + + // Add Authorization header if access token provider is set + if (m_Settings.AccessTokenProvider) + { + HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)(); + if (Token.IsValid()) + { + Request << "Authorization: Bearer " << Token.Value << "\r\n"; + } + } + + Request << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + m_HandshakeBuffer = std::make_shared<std::string>(ReqStr); + + asio::async_write(*m_Socket, + asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()), + [this](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + m_Timer->cancel(); + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake write failed"); + return; + } + + DoReadHandshakeResponse(); + }); + } + + void DoReadHandshakeResponse() + { + asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) { + m_Timer->cancel(); + + if (Ec) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message()); + m_Handler.OnWsClose(1006, "handshake read failed"); + return; + } + + // Parse the response + const auto& Data = m_ReadBuffer.data(); + std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data)); + + // Consume the headers from the read buffer (any extra data stays for frame parsing) + auto HeaderEnd = Response.find("\r\n\r\n"); + if (HeaderEnd != std::string::npos) + { + m_ReadBuffer.consume(HeaderEnd + 4); + } + + // Validate 101 response + if (Response.find("101") == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80)); + m_Handler.OnWsClose(1006, "handshake rejected"); + return; + } + + // Validate Sec-WebSocket-Accept + std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey); + if (Response.find(ExpectedAccept) == std::string::npos) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept"); + m_Handler.OnWsClose(1006, "invalid accept key"); + return; + } + + m_IsOpen.store(true); + m_Handler.OnWsOpen(); + EnqueueRead(); + }); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Read loop + // + + void EnqueueRead() + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) { + OnDataReceived(Ec); + }); + } + + void OnDataReceived(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } + } + + void ProcessReceivedData() + { + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size); + if (!Frame.IsValid) + { + break; + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWsMessage(Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with masked pong + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = + std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo masked close frame if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWsClose(Code, Reason); + return; + } + + default: + ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // + // Write queue + // + + void EnqueueWrite(std::vector<uint8_t> Frame) + { + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } + } + + void FlushWriteQueue() + { + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); }); + } + + void OnWriteComplete(const asio::error_code& Ec) + { + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWsClose(1006, "write error"); + } + return; + } + + FlushWriteQueue(); + } + + ////////////////////////////////////////////////////////////////////////// + // + // Public operations + // + + void SendText(std::string_view Text) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); + } + + void SendBinary(std::span<const uint8_t> Data) + { + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); + } + + void DoClose(uint16_t Code, std::string_view Reason) + { + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + IWsClientHandler& m_Handler; + HttpWsClientSettings m_Settings; + LoggerRef m_Log; + + std::string m_Host; + std::string m_Port; + std::string m_Path; + + // io_context: owned (standalone) or external (shared) + std::unique_ptr<asio::io_context> m_OwnedIoContext; + asio::io_context& m_IoContext; + std::unique_ptr<asio::io_context::work> m_WorkGuard; + std::thread m_IoThread; + + // Connection state + std::unique_ptr<asio::ip::tcp::resolver> m_Resolver; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<asio::steady_timer> m_Timer; + asio::streambuf m_ReadBuffer; + std::string m_WebSocketKey; + std::shared_ptr<std::string> m_HandshakeBuffer; + + // Write queue + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{false}; + std::atomic<bool> m_CloseSent{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, Settings)) +{ +} + +HttpWsClient::HttpWsClient(std::string_view Url, + IWsClientHandler& Handler, + asio::io_context& IoContext, + const HttpWsClientSettings& Settings) +: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings)) +{ +} + +HttpWsClient::~HttpWsClient() = default; + +void +HttpWsClient::Connect() +{ + m_Impl->Connect(); +} + +void +HttpWsClient::SendText(std::string_view Text) +{ + m_Impl->SendText(Text); +} + +void +HttpWsClient::SendBinary(std::span<const uint8_t> Data) +{ + m_Impl->SendBinary(Data); +} + +void +HttpWsClient::Close(uint16_t Code, std::string_view Reason) +{ + m_Impl->DoClose(Code, Reason); +} + +bool +HttpWsClient::IsOpen() const +{ + return m_Impl->m_IsOpen.load(std::memory_order_relaxed); +} + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index 43e9fb468..281d512cf 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -21,9 +21,17 @@ #include "clients/httpclientcommon.h" +#include <numeric> + #if ZEN_WITH_TESTS +# include <zencore/scopeguard.h> # include <zencore/testing.h> # include <zencore/testutils.h> +# include <zenhttp/security/passwordsecurityfilter.h> +# include "servers/httpasio.h" +# include "servers/httpsys.h" + +# include <thread> #endif // ZEN_WITH_TESTS namespace zen { @@ -96,6 +104,44 @@ HttpClientBase::GetAccessToken() ////////////////////////////////////////////////////////////////////////// +std::vector<std::pair<uint64_t, uint64_t>> +HttpClient::Response::GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const +{ + if (Ranges.empty()) + { + return {}; + } + + std::vector<std::pair<uint64_t, uint64_t>> Result; + Result.reserve(OffsetAndLengthPairs.size()); + + auto BoundaryIt = Ranges.begin(); + auto OffsetAndLengthPairIt = OffsetAndLengthPairs.begin(); + while (OffsetAndLengthPairIt != OffsetAndLengthPairs.end()) + { + uint64_t Offset = OffsetAndLengthPairIt->first; + uint64_t Length = OffsetAndLengthPairIt->second; + while (Offset >= BoundaryIt->RangeOffset + BoundaryIt->RangeLength) + { + BoundaryIt++; + if (BoundaryIt == Ranges.end()) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + } + if (Offset + Length > BoundaryIt->RangeOffset + BoundaryIt->RangeLength || Offset < BoundaryIt->RangeOffset) + { + throw std::runtime_error("HttpClient::Response can not fulfill requested range"); + } + uint64_t OffsetIntoRange = Offset - BoundaryIt->RangeOffset; + uint64_t RangePayloadOffset = BoundaryIt->OffsetInPayload + OffsetIntoRange; + Result.emplace_back(std::make_pair(RangePayloadOffset, Length)); + + OffsetAndLengthPairIt++; + } + return Result; +} + CbObject HttpClient::Response::AsObject() const { @@ -334,10 +380,55 @@ HttpClient::Authenticate() return m_Inner->Authenticate(); } +LatencyTestResult +MeasureLatency(HttpClient& Client, std::string_view Url) +{ + std::vector<double> MeasurementTimes; + std::string ErrorMessage; + + for (uint32_t AttemptCount = 0; AttemptCount < 20 && MeasurementTimes.size() < 5; AttemptCount++) + { + HttpClient::Response MeasureResponse = Client.Get(Url); + if (MeasureResponse.IsSuccess()) + { + MeasurementTimes.push_back(MeasureResponse.ElapsedSeconds); + Sleep(5); + } + else + { + ErrorMessage = MeasureResponse.ErrorMessage(fmt::format("Unable to measure latency using {}", Url)); + + // Connection-level failures (timeout, refused, DNS) mean the endpoint is unreachable. + // Bail out immediately — retrying will just burn the connect timeout each time. + if (MeasureResponse.Error && MeasureResponse.Error->IsConnectionError()) + { + break; + } + } + } + + if (MeasurementTimes.empty()) + { + return {.Success = false, .FailureReason = ErrorMessage}; + } + + if (MeasurementTimes.size() > 2) + { + std::sort(MeasurementTimes.begin(), MeasurementTimes.end()); + MeasurementTimes.pop_back(); // Remove the worst time + } + + double AverageLatency = std::accumulate(MeasurementTimes.begin(), MeasurementTimes.end(), 0.0) / MeasurementTimes.size(); + + return {.Success = true, .LatencySeconds = AverageLatency}; +} + ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpclient"); + TEST_CASE("responseformat") { using namespace std::literals; @@ -388,8 +479,366 @@ TEST_CASE("httpclient") { using namespace std::literals; - SUBCASE("client") {} + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + if (HttpServiceRequest.IsLocalMachineRequest()) + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + else + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey stranger"); + } + } + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK); + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + SUBCASE("asio") + { + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + { + HttpClient Client(fmt::format("127.0.0.1:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + if (IsIPv6Capable()) + { + HttpClient Client(fmt::format("[::1]:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + { + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } +# if 0 + { + HttpClient Client(fmt::format("10.24.101.77:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + Sleep(20000); +# endif // 0 + AsioServer->RequestExit(); + } + } + +# if ZEN_PLATFORM_WINDOWS + SUBCASE("httpsys") + { + Ref<HttpServer> HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = false}); + + int Port = HttpSysServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + HttpSysServer->RegisterService(TestService); + + std::thread ServerThread([&]() { HttpSysServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + HttpSysServer->Close(); + }); + + if (true) + { + HttpClient Client(fmt::format("127.0.0.1:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + if (IsIPv6Capable()) + { + HttpClient Client(fmt::format("[::1]:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + + { + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } +# if 0 + { + HttpClient Client(fmt::format("10.24.101.77:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response TestResponse = Client.Get("/test/yo"); + CHECK(TestResponse.IsSuccess()); + CHECK_EQ(TestResponse.AsText(), "hey family"); + } + Sleep(20000); +# endif // 0 + HttpSysServer->RequestExit(); + } + } +# endif // ZEN_PLATFORM_WINDOWS +} + +TEST_CASE("httpclient.requestfilter") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + class MyFilterImpl : public IHttpRequestFilter + { + public: + virtual Result FilterRequest(HttpServerRequest& Request) + { + if (Request.RelativeUri() == "should_filter") + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you"); + return Result::ResponseSent; + } + else if (Request.RelativeUri() == "should_forbid") + { + return Result::Forbidden; + } + return Result::Accepted; + } + }; + + MyFilterImpl MyFilter; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response YoResponse = Client.Get("/test/yo"); + CHECK(YoResponse.IsSuccess()); + CHECK_EQ(YoResponse.AsText(), "hey family"); + + HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter"); + CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed); + CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you"); + + HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid"); + CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden); + + AsioServer->RequestExit(); + } +} + +TEST_CASE("httpclient.password") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + SUBCASE("usernamepassword") + { + CbObjectWriter Writer; + { + Writer.BeginObject("basic"); + { + Writer << "username"sv + << "me"; + Writer << "password"sv + << "456123789"; + } + Writer.EndObject(); + Writer << "protect-machine-local-requests" << true; + } + + PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save()); + + PasswordHttpFilter MyFilter(PasswordFilterOptions); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response ForbiddenResponse = Client.Get("/test/yo"); + CHECK(!ForbiddenResponse.IsSuccess()); + CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden); + + HttpClient::Response WithBasicResponse = + Client.Get("/test/yo", + std::pair<std::string, std::string>("Authorization", + fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password))); + CHECK(WithBasicResponse.IsSuccess()); + AsioServer->SetHttpRequestFilter(nullptr); + } + AsioServer->RequestExit(); + } } +TEST_SUITE_END(); void httpclient_forcelink() diff --git a/src/zenhttp/httpclient_test.cpp b/src/zenhttp/httpclient_test.cpp new file mode 100644 index 000000000..52bf149a7 --- /dev/null +++ b/src/zenhttp/httpclient_test.cpp @@ -0,0 +1,1366 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclient.h> +#include <zenhttp/httpserver.h> + +#if ZEN_WITH_TESTS + +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryutil.h> +# include <zencore/compositebuffer.h> +# include <zencore/iobuffer.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/session.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include "servers/httpasio.h" + +# include <atomic> +# include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// Test service + +class HttpClientTestService : public HttpService +{ +public: + HttpClientTestService() + { + m_Router.AddMatcher("statuscode", [](std::string_view Str) -> bool { + for (char C : Str) + { + if (C < '0' || C > '9') + { + return false; + } + } + return !Str.empty(); + }); + + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello world"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + IoBuffer Body = HttpReq.ReadPayload(); + HttpContentType CT = HttpReq.RequestContentType(); + HttpReq.WriteResponse(HttpResponseCode::OK, CT, Body); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "echo/headers", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + CbObjectWriter Writer; + if (!Auth.empty()) + { + Writer.AddString("Authorization", Auth); + } + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "echo/method", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Method = ToString(HttpReq.RequestVerb()); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Method); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "json", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Obj.AddString("message", "test"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "nocontent", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "created", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::Created, HttpContentType::kText, "resource created"); + }, + HttpVerb::kPost | HttpVerb::kPut); + + m_Router.RegisterRoute( + "content-type/text", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "plain text"); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/json", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, "{\"key\":\"value\"}"); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/binary", + [](HttpRouterRequest& Req) { + uint8_t Data[] = {0xDE, 0xAD, 0xBE, 0xEF}; + IoBuffer Buf(IoBuffer::Clone, Data, sizeof(Data)); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "content-type/cbobject", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddString("type", "cbobject"); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "auth/bearer", + [](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + std::string_view Auth = HttpReq.GetAuthorizationHeader(); + if (Auth.starts_with("Bearer ") && Auth.size() > 7) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "authenticated"); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::Unauthorized, HttpContentType::kText, "unauthorized"); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "slow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Sleep(2000); + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "slow response"); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "large", + [](HttpRouterRequest& Req) { + constexpr size_t Size = 64 * 1024; + IoBuffer Buf(Size); + uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData()); + for (size_t i = 0; i < Size; ++i) + { + Ptr[i] = static_cast<uint8_t>(i & 0xFF); + } + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buf); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "status/{statuscode}", + [](HttpRouterRequest& Req) { + std::string_view CodeStr = Req.GetCapture(1); + int Code = std::stoi(std::string{CodeStr}); + const HttpResponseCode ResponseCode = static_cast<HttpResponseCode>(Code); + Req.ServerRequest().WriteResponse(ResponseCode); + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kPut | HttpVerb::kDelete | HttpVerb::kHead); + + m_Router.RegisterRoute( + "attempt-counter", + [this](HttpRouterRequest& Req) { + uint32_t Count = m_AttemptCounter.fetch_add(1); + if (Count < m_FailCount) + { + Req.ServerRequest().WriteResponse(HttpResponseCode::ServiceUnavailable); + } + else + { + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "success after retries"); + } + }, + HttpVerb::kGet); + } + + virtual const char* BaseUri() const override { return "/api/test/"; } + virtual void HandleRequest(HttpServerRequest& Request) override { m_Router.HandleRequest(Request); } + + void ResetAttemptCounter(uint32_t FailCount) + { + m_AttemptCounter.store(0); + m_FailCount = FailCount; + } + +private: + HttpRequestRouter m_Router; + std::atomic<uint32_t> m_AttemptCounter{0}; + uint32_t m_FailCount = 2; +}; + +////////////////////////////////////////////////////////////////////////// +// Test server fixture + +struct TestServerFixture +{ + HttpClientTestService TestService; + ScopedTemporaryDirectory TmpDir; + Ref<HttpServer> Server; + std::thread ServerThread; + int Port = -1; + + TestServerFixture() + { + Server = CreateHttpAsioServer(AsioConfig{}); + Port = Server->Initialize(7600, TmpDir.Path()); + ZEN_ASSERT(Port != -1); + Server->RegisterService(TestService); + ServerThread = std::thread([this]() { Server->Run(false); }); + } + + ~TestServerFixture() + { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +////////////////////////////////////////////////////////////////////////// +// Tests + +TEST_SUITE_BEGIN("http.httpclient"); + +TEST_CASE("httpclient.verbs") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("GET returns 200 with expected body") + { + HttpClient::Response Resp = Client.Get("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "GET"); + } + + SUBCASE("POST dispatches correctly") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "POST"); + } + + SUBCASE("PUT dispatches correctly") + { + HttpClient::Response Resp = Client.Put("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "PUT"); + } + + SUBCASE("DELETE dispatches correctly") + { + HttpClient::Response Resp = Client.Delete("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "DELETE"); + } + + SUBCASE("HEAD returns 200 with empty body") + { + HttpClient::Response Resp = Client.Head("/api/test/echo/method"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), ""sv); + } +} + +TEST_CASE("httpclient.get") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("simple GET with text response") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::OK); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("GET with auth header via echo") + { + HttpClient::Response Resp = + Client.Get("/api/test/echo/headers", std::pair<std::string, std::string>("Authorization", "Bearer test-token-123")); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK_EQ(Obj["Authorization"].AsString(), "Bearer test-token-123"); + } + + SUBCASE("GET returning CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CHECK(Resp.IsSuccess()); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("GET large payload") + { + HttpClient::Response Resp = Client.Get("/api/test/large"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + + const uint8_t* Data = static_cast<const uint8_t*>(Resp.ResponsePayload.GetData()); + bool Valid = true; + for (size_t i = 0; i < 64 * 1024; ++i) + { + if (Data[i] != static_cast<uint8_t>(i & 0xFF)) + { + Valid = false; + break; + } + } + CHECK(Valid); + } +} + +TEST_CASE("httpclient.post") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("POST with IoBuffer payload echo round-trip") + { + const char* Payload = "test payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "test payload data"); + } + + SUBCASE("POST with IoBuffer and explicit content type") + { + const char* Payload = "{\"key\":\"value\"}"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf, ZenContentType::kJSON); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "{\"key\":\"value\"}"); + } + + SUBCASE("POST with CbObject payload round-trip") + { + CbObjectWriter Writer; + Writer.AddBool("enabled", true); + Writer.AddString("name", "testobj"); + CbObject Obj = Writer.Save(); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Obj); + CHECK(Resp.IsSuccess()); + CbObject RoundTripped = Resp.AsObject(); + CHECK(RoundTripped["enabled"].AsBool() == true); + CHECK_EQ(RoundTripped["name"].AsString(), "testobj"); + } + + SUBCASE("POST with CompositeBuffer payload") + { + const char* Part1 = "hello "; + const char* Part2 = "composite"; + IoBuffer Buf1(IoBuffer::Clone, Part1, strlen(Part1)); + IoBuffer Buf2(IoBuffer::Clone, Part2, strlen(Part2)); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Post("/api/test/echo", Composite, ZenContentType::kText); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello composite"); + } + + SUBCASE("POST with custom headers") + { + HttpClient::Response Resp = Client.Post("/api/test/echo/headers", HttpClient::KeyValueMap{}, HttpClient::KeyValueMap{}); + CHECK(Resp.IsSuccess()); + } + + SUBCASE("POST with empty body to nocontent endpoint") + { + HttpClient::Response Resp = Client.Post("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } +} + +TEST_CASE("httpclient.put") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("PUT with IoBuffer payload echo round-trip") + { + const char* Payload = "put payload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "put payload data"); + } + + SUBCASE("PUT with parameters only") + { + HttpClient::Response Resp = Client.Put("/api/test/nocontent"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::NoContent); + } + + SUBCASE("PUT to created endpoint") + { + const char* Payload = "new resource"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Put("/api/test/created", Buf); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Created); + CHECK_EQ(Resp.AsText(), "resource created"); + } +} + +TEST_CASE("httpclient.upload") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("Upload IoBuffer") + { + constexpr size_t Size = 128 * 1024; + IoBuffer Blob = CreateSemiRandomBlob(Size); + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Blob); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), Size); + } + + SUBCASE("Upload CompositeBuffer") + { + IoBuffer Buf1 = CreateSemiRandomBlob(32 * 1024); + IoBuffer Buf2 = CreateSemiRandomBlob(32 * 1024); + + SharedBuffer Seg1{Buf1}; + SharedBuffer Seg2{Buf2}; + CompositeBuffer Composite{std::move(Seg1), std::move(Seg2)}; + + HttpClient::Response Resp = Client.Upload("/api/test/echo", Composite, ZenContentType::kBinary); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.download") +{ + TestServerFixture Fixture; + ScopedTemporaryDirectory DownloadDir; + + SUBCASE("Download small payload stays in memory") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Download("/api/test/hello", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("Download with reduced MaximumInMemoryDownloadSize forces file spill") + { + HttpClientSettings Settings; + Settings.MaximumInMemoryDownloadSize = 4; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Download("/api/test/large", DownloadDir.Path()); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetSize(), 64u * 1024u); + } +} + +TEST_CASE("httpclient.status-codes") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("2xx are success") + { + CHECK(Client.Get("/api/test/status/200").IsSuccess()); + CHECK(Client.Get("/api/test/status/201").IsSuccess()); + CHECK(Client.Get("/api/test/status/204").IsSuccess()); + } + + SUBCASE("4xx are not success") + { + CHECK(!Client.Get("/api/test/status/400").IsSuccess()); + CHECK(!Client.Get("/api/test/status/401").IsSuccess()); + CHECK(!Client.Get("/api/test/status/403").IsSuccess()); + CHECK(!Client.Get("/api/test/status/404").IsSuccess()); + CHECK(!Client.Get("/api/test/status/409").IsSuccess()); + } + + SUBCASE("5xx are not success") + { + CHECK(!Client.Get("/api/test/status/500").IsSuccess()); + CHECK(!Client.Get("/api/test/status/502").IsSuccess()); + CHECK(!Client.Get("/api/test/status/503").IsSuccess()); + } + + SUBCASE("status code values match") + { + CHECK_EQ(Client.Get("/api/test/status/200").StatusCode, HttpResponseCode::OK); + CHECK_EQ(Client.Get("/api/test/status/201").StatusCode, HttpResponseCode::Created); + CHECK_EQ(Client.Get("/api/test/status/204").StatusCode, HttpResponseCode::NoContent); + CHECK_EQ(Client.Get("/api/test/status/400").StatusCode, HttpResponseCode::BadRequest); + CHECK_EQ(Client.Get("/api/test/status/401").StatusCode, HttpResponseCode::Unauthorized); + CHECK_EQ(Client.Get("/api/test/status/403").StatusCode, HttpResponseCode::Forbidden); + CHECK_EQ(Client.Get("/api/test/status/404").StatusCode, HttpResponseCode::NotFound); + CHECK_EQ(Client.Get("/api/test/status/409").StatusCode, HttpResponseCode::Conflict); + CHECK_EQ(Client.Get("/api/test/status/500").StatusCode, HttpResponseCode::InternalServerError); + CHECK_EQ(Client.Get("/api/test/status/502").StatusCode, HttpResponseCode::BadGateway); + CHECK_EQ(Client.Get("/api/test/status/503").StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.response") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("IsSuccess and operator bool for success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(static_cast<bool>(Resp)); + } + + SUBCASE("IsSuccess and operator bool for failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + CHECK(!Resp.IsSuccess()); + CHECK(!static_cast<bool>(Resp)); + } + + SUBCASE("AsText returns body") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_EQ(Resp.AsText(), "hello world"); + } + + SUBCASE("AsText returns empty for no-content") + { + HttpClient::Response Resp = Client.Get("/api/test/nocontent"); + CHECK(Resp.AsText().empty()); + } + + SUBCASE("AsObject parses CbObject") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + CbObject Obj = Resp.AsObject(); + CHECK(Obj["ok"].AsBool() == true); + CHECK_EQ(Obj["message"].AsString(), "test"); + } + + SUBCASE("AsObject returns empty for non-CB content") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CbObject Obj = Resp.AsObject(); + CHECK(!Obj); + } + + SUBCASE("ToText for text content") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK_EQ(Resp.ToText(), "plain text"); + } + + SUBCASE("ToText for CbObject content") + { + HttpClient::Response Resp = Client.Get("/api/test/json"); + std::string Text = Resp.ToText(); + CHECK(!Text.empty()); + // ToText for CbObject converts to JSON string representation + CHECK(Text.find("ok") != std::string::npos); + CHECK(Text.find("test") != std::string::npos); + } + + SUBCASE("ErrorMessage includes status code on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/404"); + std::string Msg = Resp.ErrorMessage("test-prefix"); + CHECK(Msg.find("test-prefix") != std::string::npos); + CHECK(Msg.find("404") != std::string::npos); + } + + SUBCASE("ThrowError throws on failure") + { + HttpClient::Response Resp = Client.Get("/api/test/status/500"); + CHECK_THROWS_AS(Resp.ThrowError("test"), HttpClientError); + } + + SUBCASE("ThrowError does not throw on success") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK_NOTHROW(Resp.ThrowError("test")); + } + + SUBCASE("HttpClientError carries response code") + { + HttpClient::Response Resp = Client.Get("/api/test/status/403"); + try + { + Resp.ThrowError("test"); + CHECK(false); // should not reach + } + catch (const HttpClientError& Err) + { + CHECK_EQ(Err.GetHttpResponseCode(), HttpResponseCode::Forbidden); + } + } +} + +TEST_CASE("httpclient.error-handling") +{ + SUBCASE("Connection refused") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("Request timeout") + { + TestServerFixture Fixture; + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/slow"); + CHECK(!Resp.IsSuccess()); + } + + SUBCASE("Nonexistent endpoint returns failure") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/does-not-exist"); + CHECK(!Resp.IsSuccess()); + } +} + +TEST_CASE("httpclient.session") +{ + TestServerFixture Fixture; + + SUBCASE("Default session ID is non-empty") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.GetSessionId().empty()); + } + + SUBCASE("SetSessionId changes ID") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + std::string OldId = std::string(Client.GetSessionId()); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + CHECK_NE(Client.GetSessionId(), OldId); + } + + SUBCASE("SetSessionId with Zero resets") + { + HttpClient Client = Fixture.MakeClient(); + Oid NewId = Oid::NewOid(); + Client.SetSessionId(NewId); + CHECK_EQ(Client.GetSessionId(), NewId.ToString()); + Client.SetSessionId(Oid::Zero); + // After resetting, should get a session string (not empty, not the custom one) + CHECK(!Client.GetSessionId().empty()); + CHECK_NE(Client.GetSessionId(), NewId.ToString()); + } +} + +TEST_CASE("httpclient.authentication") +{ + TestServerFixture Fixture; + + SUBCASE("Authenticate returns false without provider") + { + HttpClient Client = Fixture.MakeClient(); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Authenticate returns true with valid token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "valid-token", + .ExpireTime = HttpClientAccessToken::Clock::now() + std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(Client.Authenticate()); + } + + SUBCASE("Authenticate returns false with expired token") + { + HttpClientSettings Settings; + Settings.AccessTokenProvider = []() -> HttpClientAccessToken { + return HttpClientAccessToken{ + .Value = "expired-token", + .ExpireTime = HttpClientAccessToken::Clock::now() - std::chrono::hours(1), + }; + }; + HttpClient Client = Fixture.MakeClient(Settings); + CHECK(!Client.Authenticate()); + } + + SUBCASE("Bearer token verified by auth endpoint") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response AuthResp = + Client.Get("/api/test/auth/bearer", std::pair<std::string, std::string>("Authorization", "Bearer my-secret-token")); + CHECK(AuthResp.IsSuccess()); + CHECK_EQ(AuthResp.AsText(), "authenticated"); + } + + SUBCASE("Request without token to auth endpoint gets 401") + { + HttpClient Client = Fixture.MakeClient(); + + HttpClient::Response Resp = Client.Get("/api/test/auth/bearer"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::Unauthorized); + } +} + +TEST_CASE("httpclient.content-types") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("text content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/text"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kText); + } + + SUBCASE("JSON content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/json"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kJSON); + } + + SUBCASE("binary content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/binary"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kBinary); + } + + SUBCASE("CbObject content type") + { + HttpClient::Response Resp = Client.Get("/api/test/content-type/cbobject"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.ResponsePayload.GetContentType(), ZenContentType::kCbObject); + } +} + +TEST_CASE("httpclient.metadata") +{ + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + SUBCASE("ElapsedSeconds is positive") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.ElapsedSeconds > 0.0); + } + + SUBCASE("DownloadedBytes populated for GET") + { + HttpClient::Response Resp = Client.Get("/api/test/hello"); + CHECK(Resp.IsSuccess()); + CHECK(Resp.DownloadedBytes > 0); + } + + SUBCASE("UploadedBytes populated for POST with payload") + { + const char* Payload = "some upload data"; + IoBuffer Buf(IoBuffer::Clone, Payload, strlen(Payload)); + Buf.SetContentType(ZenContentType::kText); + + HttpClient::Response Resp = Client.Post("/api/test/echo", Buf); + CHECK(Resp.IsSuccess()); + CHECK(Resp.UploadedBytes > 0); + } +} + +TEST_CASE("httpclient.retry") +{ + TestServerFixture Fixture; + + SUBCASE("Retry succeeds after transient failures") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "success after retries"); + } + + SUBCASE("No retry returns 503 immediately") + { + Fixture.TestService.ResetAttemptCounter(2); + + HttpClientSettings Settings; + Settings.RetryCount = 0; + HttpClient Client = Fixture.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/api/test/attempt-counter"); + CHECK(!Resp.IsSuccess()); + CHECK_EQ(Resp.StatusCode, HttpResponseCode::ServiceUnavailable); + } +} + +TEST_CASE("httpclient.measurelatency") +{ + SUBCASE("Successful measurement against live server") + { + TestServerFixture Fixture; + HttpClient Client = Fixture.MakeClient(); + + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(Result.Success); + CHECK(Result.LatencySeconds > 0.0); + } + + SUBCASE("Failed measurement against unreachable port") + { + HttpClient Client("127.0.0.1:19999", HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); + LatencyTestResult Result = MeasureLatency(Client, "/api/test/hello"); + CHECK(!Result.Success); + CHECK(!Result.FailureReason.empty()); + } +} + +TEST_CASE("httpclient.keyvaluemap") +{ + SUBCASE("Default construction is empty") + { + HttpClient::KeyValueMap Map; + CHECK(Map->empty()); + } + + SUBCASE("Construction from pair") + { + HttpClient::KeyValueMap Map(std::pair<std::string, std::string>("key", "value")); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from string_view pair") + { + HttpClient::KeyValueMap Map(std::pair<std::string_view, std::string_view>("key"sv, "value"sv)); + CHECK_EQ(Map->size(), 1u); + CHECK_EQ(Map->at("key"), "value"); + } + + SUBCASE("Construction from initializer list") + { + HttpClient::KeyValueMap Map({{"a"sv, "1"sv}, {"b"sv, "2"sv}}); + CHECK_EQ(Map->size(), 2u); + CHECK_EQ(Map->at("a"), "1"); + CHECK_EQ(Map->at("b"), "2"); + } +} + +////////////////////////////////////////////////////////////////////////// +// Transport fault testing + +static std::string +MakeRawHttpResponse(int StatusCode, std::string_view Body) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: {}\r\n" + "\r\n" + "{}", + StatusCode, + Body.size(), + Body); +} + +static std::string +MakeRawHttpHeaders(int StatusCode, size_t ContentLength) +{ + return fmt::format( + "HTTP/1.1 {} OK\r\n" + "Content-Type: application/octet-stream\r\n" + "Content-Length: {}\r\n" + "\r\n", + StatusCode, + ContentLength); +} + +static void +DrainHttpRequest(asio::ip::tcp::socket& Socket) +{ + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); +} + +static void +DrainFullHttpRequest(asio::ip::tcp::socket& Socket) +{ + // Read until end of headers + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Extract headers to find Content-Length + std::string Headers(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + + size_t ContentLength = 0; + auto Pos = Headers.find("Content-Length: "); + if (Pos == std::string::npos) + { + Pos = Headers.find("content-length: "); + } + if (Pos != std::string::npos) + { + size_t ValStart = Pos + 16; // length of "Content-Length: " + size_t ValEnd = Headers.find("\r\n", ValStart); + if (ValEnd != std::string::npos) + { + ContentLength = std::stoull(Headers.substr(ValStart, ValEnd - ValStart)); + } + } + + // Calculate how many body bytes were already read past the header boundary. + // asio::read_until may read past the delimiter, so Buf.data() contains everything read. + size_t HeaderEnd = Headers.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = Headers.size() > HeaderEnd ? Headers.size() - HeaderEnd : 0; + size_t Remaining = ContentLength > BodyBytesInBuf ? ContentLength - BodyBytesInBuf : 0; + + if (Remaining > 0) + { + std::vector<char> BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +static void +DrainPartialBody(asio::ip::tcp::socket& Socket, size_t BytesToRead) +{ + // Read headers first + asio::streambuf Buf; + std::error_code Ec; + asio::read_until(Socket, Buf, "\r\n\r\n", Ec); + if (Ec) + { + return; + } + + // Determine how many body bytes were already buffered past headers + std::string All(asio::buffers_begin(Buf.data()), asio::buffers_end(Buf.data())); + size_t HeaderEnd = All.find("\r\n\r\n") + 4; + size_t BodyBytesInBuf = All.size() > HeaderEnd ? All.size() - HeaderEnd : 0; + + if (BodyBytesInBuf < BytesToRead) + { + size_t Remaining = BytesToRead - BodyBytesInBuf; + std::vector<char> BodyBuf(Remaining); + asio::read(Socket, asio::buffer(BodyBuf), Ec); + } +} + +struct FaultTcpServer +{ + using FaultHandler = std::function<void(asio::ip::tcp::socket&)>; + + asio::io_context m_IoContext; + asio::ip::tcp::acceptor m_Acceptor; + FaultHandler m_Handler; + std::thread m_Thread; + int m_Port; + + explicit FaultTcpServer(FaultHandler Handler) + : m_Acceptor(m_IoContext, asio::ip::tcp::endpoint(asio::ip::address_v4::loopback(), 0)) + , m_Handler(std::move(Handler)) + { + m_Port = m_Acceptor.local_endpoint().port(); + StartAccept(); + m_Thread = std::thread([this]() { m_IoContext.run(); }); + } + + ~FaultTcpServer() + { + std::error_code Ec; + m_Acceptor.close(Ec); + m_IoContext.stop(); + if (m_Thread.joinable()) + { + m_Thread.join(); + } + } + + FaultTcpServer(const FaultTcpServer&) = delete; + FaultTcpServer& operator=(const FaultTcpServer&) = delete; + + void StartAccept() + { + m_Acceptor.async_accept([this](std::error_code Ec, asio::ip::tcp::socket Socket) { + if (!Ec) + { + m_Handler(Socket); + } + if (m_Acceptor.is_open()) + { + StartAccept(); + } + }); + } + + HttpClient MakeClient(HttpClientSettings Settings = {}) + { + return HttpClient(fmt::format("127.0.0.1:{}", m_Port), Settings, /*CheckIfAbortFunction*/ {}); + } +}; + +TEST_CASE("httpclient.transport-faults" * doctest::skip()) +{ + SUBCASE("connection reset before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection closed before response") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("partial headers then close") + { + // libcurl parses the status line (200 OK) and accepts the response even though + // headers are truncated mid-field. It reports success with an empty body instead + // of an error. Ideally this should be detected as a transport failure. + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Partial = "HTTP/1.1 200 OK\r\nContent-"; + std::error_code Ec; + asio::write(Socket, asio::buffer(Partial), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + WARN(!Resp.IsSuccess()); + WARN(Resp.Error.has_value()); + } + + SUBCASE("truncated body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(100, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("connection reset mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 10000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + std::string PartialBody(1000, 'x'); + asio::write(Socket, asio::buffer(PartialBody), Ec); + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("stalled response triggers timeout") + { + std::atomic<bool> StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Headers = MakeRawHttpHeaders(200, 1000); + std::error_code Ec; + asio::write(Socket, asio::buffer(Headers), Ec); + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(500); + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("retry succeeds after transient failures") + { + std::atomic<int> ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + DrainHttpRequest(Socket); + if (N < 2) + { + // Connection reset produces NETWORK_SEND_FAILURE which is retryable + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + std::string Response = MakeRawHttpResponse(200, "recovered"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + HttpClient::Response Resp = Client.Get("/test"); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "recovered"); + } +} + +TEST_CASE("httpclient.transport-faults-post" * doctest::skip()) +{ + constexpr size_t kPostBodySize = 256 * 1024; + + auto MakePostBody = []() -> IoBuffer { + IoBuffer Buf(kPostBodySize); + uint8_t* Ptr = static_cast<uint8_t*>(Buf.MutableData()); + for (size_t i = 0; i < kPostBodySize; ++i) + { + Ptr[i] = static_cast<uint8_t>(i & 0xFF); + } + Buf.SetContentType(ZenContentType::kBinary); + return Buf; + }; + + SUBCASE("POST: server resets before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server closes before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: server resets mid-body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainPartialBody(Socket, 8 * 1024); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + } + + SUBCASE("POST: early error response before consuming body") + { + FaultTcpServer Server([](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(503, "service busy"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec); + Socket.close(Ec); + }); + HttpClient Client = Server.MakeClient(); + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + // With a large upload body, the server may RST the connection before the client + // reads the 503 response. Either outcome is valid: the client sees the HTTP 503 + // status, or it sees a transport-level error from the RST. + CHECK((Resp.StatusCode == HttpResponseCode::ServiceUnavailable || Resp.Error.has_value())); + } + + SUBCASE("POST: stalled upload triggers timeout") + { + std::atomic<bool> StallActive{true}; + FaultTcpServer Server([&StallActive](asio::ip::tcp::socket& Socket) { + DrainHttpRequest(Socket); + // Stop reading body — TCP window will fill and client send will stall + while (StallActive.load()) + { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + }); + + HttpClientSettings Settings; + Settings.Timeout = std::chrono::milliseconds(2000); + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(!Resp.IsSuccess()); + CHECK(Resp.Error.has_value()); + StallActive.store(false); + } + + SUBCASE("POST: retry with large body after transient failure") + { + std::atomic<int> ConnCount{0}; + FaultTcpServer Server([&ConnCount](asio::ip::tcp::socket& Socket) { + int N = ConnCount.fetch_add(1); + if (N < 2) + { + DrainHttpRequest(Socket); + std::error_code Ec; + Socket.set_option(asio::socket_base::linger(true, 0), Ec); + Socket.close(Ec); + } + else + { + DrainFullHttpRequest(Socket); + std::string Response = MakeRawHttpResponse(200, "upload-ok"); + std::error_code Ec; + asio::write(Socket, asio::buffer(Response), Ec); + } + }); + + HttpClientSettings Settings; + Settings.RetryCount = 3; + HttpClient Client = Server.MakeClient(Settings); + + IoBuffer Body = MakePostBody(); + HttpClient::Response Resp = Client.Post("/test", Body); + CHECK(Resp.IsSuccess()); + CHECK_EQ(Resp.AsText(), "upload-ok"); + } +} + +TEST_SUITE_END(); + +void +httpclient_test_forcelink() +{ +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/httpclientauth.cpp b/src/zenhttp/httpclientauth.cpp index 72df12d02..02e1b57e2 100644 --- a/src/zenhttp/httpclientauth.cpp +++ b/src/zenhttp/httpclientauth.cpp @@ -170,7 +170,7 @@ namespace zen { namespace httpclientauth { time_t UTCTime = timegm(&Time); HttpClientAccessToken::TimePoint ExpireTime = std::chrono::system_clock::from_time_t(UTCTime); - ExpireTime += std::chrono::microseconds(Millisecond); + ExpireTime += std::chrono::milliseconds(Millisecond); return HttpClientAccessToken{.Value = fmt::format("Bearer {}"sv, Token), .ExpireTime = ExpireTime}; } diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index c4e67d4ed..9bae95690 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -23,10 +23,12 @@ #include <zencore/logging.h> #include <zencore/stream.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/testing.h> #include <zencore/thread.h> #include <zenhttp/packageformat.h> #include <zentelemetry/otlptrace.h> +#include <zentelemetry/stats.h> #include <charconv> #include <mutex> @@ -463,7 +465,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) ////////////////////////////////////////////////////////////////////////// -HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri()) +HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) { } @@ -745,6 +747,10 @@ HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::Hand { if (UriPattern[i] == '}') { + if (i == PatternStart) + { + throw std::runtime_error(fmt::format("matcher pattern is empty in URI pattern '{}'", UriPattern)); + } std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) { @@ -910,8 +916,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) CapturedSegments.emplace_back(Uri); - for (int MatcherIndex : Matchers) + for (size_t MatcherOffset = 0; MatcherOffset < Matchers.size(); MatcherOffset++) { + int MatcherIndex = Matchers[MatcherOffset]; if (UriPos >= UriLen) { IsMatch = false; @@ -921,9 +928,9 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (MatcherIndex < 0) { // Literal match - int LitIndex = -MatcherIndex - 1; - const std::string& LitStr = m_Literals[LitIndex]; - size_t LitLen = LitStr.length(); + int LitIndex = -MatcherIndex - 1; + std::string_view LitStr = m_Literals[LitIndex]; + size_t LitLen = LitStr.length(); if (Uri.substr(UriPos, LitLen) == LitStr) { @@ -939,9 +946,18 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) { // Matcher function size_t SegmentStart = UriPos; - while (UriPos < UriLen && Uri[UriPos] != '/') + + if (MatcherOffset == (Matchers.size() - 1)) + { + // Last matcher, use the remaining part of the uri + UriPos = UriLen; + } + else { - ++UriPos; + while (UriPos < UriLen && Uri[UriPos] != '/') + { + ++UriPos; + } } std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart); @@ -970,7 +986,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -994,7 +1010,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -1014,7 +1030,28 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) int HttpServer::Initialize(int BasePort, std::filesystem::path DataDir) { - return OnInitialize(BasePort, std::move(DataDir)); + m_EffectivePort = OnInitialize(BasePort, std::move(DataDir)); + m_ExternalHost = OnGetExternalHost(); + return m_EffectivePort; +} + +std::string +HttpServer::OnGetExternalHost() const +{ + return GetMachineName(); +} + +std::string +HttpServer::GetServiceUri(const HttpService* Service) const +{ + if (Service) + { + return fmt::format("http://{}:{}{}", m_ExternalHost, m_EffectivePort, Service->BaseUri()); + } + else + { + return fmt::format("http://{}:{}", m_ExternalHost, m_EffectivePort); + } } void @@ -1052,6 +1089,45 @@ HttpServer::EnumerateServices(std::function<void(HttpService& Service)>&& Callba } } +void +HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + OnSetHttpRequestFilter(RequestFilter); +} + +CbObject +HttpServer::CollectStats() +{ + CbObjectWriter Cbo; + + metrics::EmitSnapshot("requests", m_RequestMeter, Cbo); + + Cbo.BeginObject("bytes"); + { + Cbo << "received" << GetTotalBytesReceived(); + Cbo << "sent" << GetTotalBytesSent(); + } + Cbo.EndObject(); + + Cbo.BeginObject("websockets"); + { + Cbo << "active_connections" << GetActiveWebSocketConnectionCount(); + Cbo << "frames_received" << m_WsFramesReceived.load(std::memory_order_relaxed); + Cbo << "frames_sent" << m_WsFramesSent.load(std::memory_order_relaxed); + Cbo << "bytes_received" << m_WsBytesReceived.load(std::memory_order_relaxed); + Cbo << "bytes_sent" << m_WsBytesSent.load(std::memory_order_relaxed); + } + Cbo.EndObject(); + + return Cbo.Save(); +} + +void +HttpServer::HandleStatsRequest(HttpServerRequest& Request) +{ + Request.WriteResponse(HttpResponseCode::OK, CollectStats()); +} + ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() @@ -1294,6 +1370,8 @@ HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpP #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.httpserver"); + TEST_CASE("http.common") { using namespace std::literals; @@ -1310,7 +1388,11 @@ TEST_CASE("http.common") { TestHttpServerRequest(HttpService& Service, std::string_view Uri) : HttpServerRequest(Service) { m_Uri = Uri; } virtual IoBuffer ReadPayload() override { return IoBuffer(); } - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override + + virtual bool IsLocalMachineRequest() const override { return false; } + virtual std::string_view GetAuthorizationHeader() const override { return {}; } + + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override { ZEN_UNUSED(ResponseCode, ContentType, Blobs); } @@ -1395,20 +1477,33 @@ TEST_CASE("http.common") SUBCASE("router-matcher") { - bool HandledA = false; - bool HandledAA = false; - bool HandledAB = false; - bool HandledAandB = false; + bool HandledA = false; + bool HandledAA = false; + bool HandledAB = false; + bool HandledAandB = false; + bool HandledAandPath = false; std::vector<std::string> Captures; auto Reset = [&] { - HandledA = HandledAA = HandledAB = HandledAandB = false; + HandledA = HandledAA = HandledAB = HandledAandB = HandledAandPath = false; Captures.clear(); }; TestHttpService Service; HttpRequestRouter r; - r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; }); - r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; }); + + r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0 && In.find('/') == std::string_view::npos; }); + r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0 && In.find('/') == std::string_view::npos; }); + static constexpr AsciiSet ValidPathCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789/_.,;$~{}+-[]%()]ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + r.AddMatcher("path", [](std::string_view Str) -> bool { return !Str.empty() && AsciiSet::HasOnly(Str, ValidPathCharactersSet); }); + + r.RegisterRoute( + "path/{a}/{path}", + [&](auto& Req) { + HandledAandPath = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + r.RegisterRoute( "{a}", [&](auto& Req) { @@ -1437,7 +1532,6 @@ TEST_CASE("http.common") Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; }, HttpVerb::kGet); - { Reset(); TestHttpServerRequest req{Service, "ab"sv}; @@ -1445,6 +1539,7 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 1); CHECK_EQ(Captures[0], "ab"sv); @@ -1457,6 +1552,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(HandledAB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1470,6 +1566,7 @@ TEST_CASE("http.common") CHECK(!HandledAA); CHECK(!HandledAB); CHECK(HandledAandB); + CHECK(!HandledAandPath); REQUIRE_EQ(Captures.size(), 2); CHECK_EQ(Captures[0], "ab"sv); CHECK_EQ(Captures[1], "def"sv); @@ -1482,6 +1579,7 @@ TEST_CASE("http.common") CHECK(!HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); } { @@ -1491,6 +1589,35 @@ TEST_CASE("http.common") CHECK(HandledA); CHECK(!HandledAA); CHECK(!HandledAB); + CHECK(!HandledAandPath); + REQUIRE_EQ(Captures.size(), 1); + CHECK_EQ(Captures[0], "a123"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/simple_path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "simple_path.txt"sv); + } + + { + Reset(); + TestHttpServerRequest req{Service, "path/ab/directory/and/path.txt"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandPath); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "directory/and/path.txt"sv); } } @@ -1508,6 +1635,8 @@ TEST_CASE("http.common") } } +TEST_SUITE_END(); + void http_forcelink() { diff --git a/src/zenhttp/include/zenhttp/cprutils.h b/src/zenhttp/include/zenhttp/cprutils.h index a988346e0..c252a5d99 100644 --- a/src/zenhttp/include/zenhttp/cprutils.h +++ b/src/zenhttp/include/zenhttp/cprutils.h @@ -66,10 +66,10 @@ struct fmt::formatter<cpr::Response> Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Json); } else @@ -82,10 +82,10 @@ struct fmt::formatter<cpr::Response> Response.url.str(), Response.status_code, zen::ToString(zen::HttpResponseCode(Response.status_code)), + Response.reason, Response.uploaded_bytes, Response.downloaded_bytes, NiceResponseTime.c_str(), - Response.reason, Body.GetText()); } } diff --git a/src/zenhttp/include/zenhttp/formatters.h b/src/zenhttp/include/zenhttp/formatters.h index addb00cb8..57ab01158 100644 --- a/src/zenhttp/include/zenhttp/formatters.h +++ b/src/zenhttp/include/zenhttp/formatters.h @@ -73,7 +73,7 @@ struct fmt::formatter<zen::HttpClient::Response> if (Response.IsSuccess()) { return fmt::format_to(Ctx.out(), - "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", + "OK: Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}", ToString(Response.StatusCode), Response.UploadedBytes, Response.DownloadedBytes, diff --git a/src/zenhttp/include/zenhttp/httpapiservice.h b/src/zenhttp/include/zenhttp/httpapiservice.h index 0270973bf..2d384d1d8 100644 --- a/src/zenhttp/include/zenhttp/httpapiservice.h +++ b/src/zenhttp/include/zenhttp/httpapiservice.h @@ -1,4 +1,5 @@ // Copyright Epic Games, Inc. All Rights Reserved. +#pragma once #include <zenhttp/httpserver.h> diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h index 9a9b74d72..1bb36a298 100644 --- a/src/zenhttp/include/zenhttp/httpclient.h +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -13,6 +13,7 @@ #include <functional> #include <optional> #include <unordered_map> +#include <vector> namespace zen { @@ -58,6 +59,10 @@ struct HttpClientSettings Oid SessionId = Oid::Zero; bool Verbose = false; uint64_t MaximumInMemoryDownloadSize = 1024u * 1024u; + + /// HTTP status codes that are expected and should not be logged as warnings. + /// 404 is always treated as expected regardless of this list. + std::vector<HttpResponseCode> ExpectedErrorCodes; }; class HttpClientError : public std::runtime_error @@ -113,6 +118,15 @@ private: class HttpClientBase; +/** HTTP Client + * + * This is safe for use on multiple threads simultaneously, as each + * instance maintains an internal connection pool and will synchronize + * access to it as needed. + * + * Uses libcurl under the hood. We currently only use HTTP 1.1 features. + * + */ class HttpClient { public: @@ -123,8 +137,11 @@ public: struct ErrorContext { - int ErrorCode; + int ErrorCode = 0; std::string ErrorMessage; + + /** True when the error is a transport-level connection failure (connect timeout, refused, DNS) */ + bool IsConnectionError() const; }; struct KeyValueMap @@ -171,13 +188,29 @@ public: KeyValueMap Header; // The number of bytes sent as part of the request - int64_t UploadedBytes; + int64_t UploadedBytes = 0; // The number of bytes received as part of the response - int64_t DownloadedBytes; + int64_t DownloadedBytes = 0; // The elapsed time in seconds for the request to execute - double ElapsedSeconds; + double ElapsedSeconds = 0.0; + + struct MultipartBoundary + { + uint64_t OffsetInPayload = 0; + uint64_t RangeOffset = 0; + uint64_t RangeLength = 0; + HttpContentType ContentType; + }; + + // Ranges will map out all received ranges, both single and multi-range responses + // If no range was requested Ranges will be empty + std::vector<MultipartBoundary> Ranges; + + // Map the absolute OffsetAndLengthPairs into ResponsePayload from the ranges received (Ranges). + // If the response was not a partial response, an empty vector will be returned + std::vector<std::pair<uint64_t, uint64_t>> GetRanges(std::span<const std::pair<uint64_t, uint64_t>> OffsetAndLengthPairs) const; // This contains any errors from the HTTP stack. It won't contain information on // why the server responded with a non-success HTTP status, that may be gleaned @@ -260,6 +293,16 @@ private: const HttpClientSettings m_ConnectionSettings; }; -void httpclient_forcelink(); // internal +struct LatencyTestResult +{ + bool Success = false; + std::string FailureReason; + double LatencySeconds = -1.0; +}; + +LatencyTestResult MeasureLatency(HttpClient& Client, std::string_view Url); + +void httpclient_forcelink(); // internal +void httpclient_test_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h index bc18549c9..8fca35ac5 100644 --- a/src/zenhttp/include/zenhttp/httpcommon.h +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -184,6 +184,13 @@ IsHttpSuccessCode(HttpResponseCode HttpCode) noexcept return IsHttpSuccessCode(int(HttpCode)); } +[[nodiscard]] inline bool +IsHttpOk(HttpResponseCode HttpCode) noexcept +{ + return HttpCode == HttpResponseCode::OK || HttpCode == HttpResponseCode::Created || HttpCode == HttpResponseCode::Accepted || + HttpCode == HttpResponseCode::NoContent; +} + std::string_view ToString(HttpResponseCode HttpCode); } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 3438a1471..0e1714669 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -13,6 +13,8 @@ #include <zencore/uid.h> #include <zenhttp/httpcommon.h> +#include <zentelemetry/stats.h> + #include <functional> #include <gsl/gsl-lite.hpp> #include <list> @@ -30,16 +32,18 @@ class HttpService; */ class HttpServerRequest { -public: +protected: explicit HttpServerRequest(HttpService& Service); + +public: ~HttpServerRequest(); // Synchronous operations [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix - [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } + [[nodiscard]] inline std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } - [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix + [[nodiscard]] inline HttpService& Service() const { return m_Service; } struct QueryParams { @@ -79,6 +83,18 @@ public: inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + inline void SetLogRequest(bool ShouldLog) + { + if (ShouldLog) + { + m_Flags |= kLogRequest; + } + else + { + m_Flags &= ~kLogRequest; + } + } + inline bool ShouldLogRequest() const { return !!(m_Flags & kLogRequest); } /** Read POST/PUT payload for request body, which is always available without delay */ @@ -87,6 +103,10 @@ public: CbObject ReadPayloadObject(); CbPackage ReadPayloadPackage(); + virtual bool IsLocalMachineRequest() const = 0; + virtual std::string_view GetAuthorizationHeader() const = 0; + virtual std::string_view GetRemoteAddress() const { return {}; } + /** Respond with payload No data will have been sent when any of these functions return. Instead, the response will be transmitted @@ -115,15 +135,17 @@ protected: kSuppressBody = 1 << 1, kHaveRequestId = 1 << 2, kHaveSessionId = 1 << 3, + kLogRequest = 1 << 4, }; - mutable uint32_t m_Flags = 0; + mutable uint32_t m_Flags = 0; + + HttpService& m_Service; // Service handling this request HttpVerb m_Verb = HttpVerb::kGet; HttpContentType m_ContentType = HttpContentType::kBinary; HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; uint64_t m_ContentLength = ~0ull; - std::string_view m_BaseUri; // Base URI path of the service handling this request - std::string_view m_Uri; // URI without service prefix + std::string_view m_Uri; // URI without service prefix std::string_view m_UriWithExtension; std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); @@ -144,6 +166,19 @@ public: virtual void OnRequestComplete() = 0; }; +class IHttpRequestFilter +{ +public: + virtual ~IHttpRequestFilter() {} + enum class Result + { + Forbidden, + ResponseSent, + Accepted + }; + virtual Result FilterRequest(HttpServerRequest& Request) = 0; +}; + /** * Base class for implementing an HTTP "service" * @@ -170,30 +205,110 @@ private: int m_UriPrefixLength = 0; }; +struct IHttpStatsProvider +{ + /** Handle an HTTP stats request, writing the response directly. + * Implementations may inspect query parameters on the request + * to include optional detailed breakdowns. + */ + virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; + + /** Return the provider's current stats as a CbObject snapshot. + * Used by the WebSocket push thread to broadcast live updates + * without requiring an HttpServerRequest. Providers that do + * not override this will be skipped in WebSocket broadcasts. + */ + virtual CbObject CollectStats() { return {}; } +}; + +struct IHttpStatsService +{ + virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; + virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; +}; + /** HTTP server * * Implements the main event loop to service HTTP requests, and handles routing * requests to the appropriate handler as registered via RegisterService */ -class HttpServer : public RefCounted +class HttpServer : public RefCounted, public IHttpStatsProvider { public: void RegisterService(HttpService& Service); void EnumerateServices(std::function<void(HttpService&)>&& Callback); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); int Initialize(int BasePort, std::filesystem::path DataDir); void Run(bool IsInteractiveSession); void RequestExit(); void Close(); + /** Returns a canonical http:// URI for the given service, using the external + * IP and the port the server is actually listening on. Only valid + * after Initialize() has returned successfully. + */ + std::string GetServiceUri(const HttpService* Service) const; + + /** Returns the external host string (IP or hostname) determined during Initialize(). + * Only valid after Initialize() has returned successfully. + */ + std::string_view GetExternalHost() const { return m_ExternalHost; } + + /** Returns total bytes received and sent across all connections since server start. */ + virtual uint64_t GetTotalBytesReceived() const { return 0; } + virtual uint64_t GetTotalBytesSent() const { return 0; } + + /** Mark that a request has been handled. Called by server implementations. */ + void MarkRequest() { m_RequestMeter.Mark(); } + + /** Set a default redirect path for root requests */ + void SetDefaultRedirect(std::string_view Path) { m_DefaultRedirect = Path; } + + std::string_view GetDefaultRedirect() const { return m_DefaultRedirect; } + + /** Track active WebSocket connections — called by server implementations on upgrade/close. */ + void OnWebSocketConnectionOpened() { m_ActiveWebSocketConnections.fetch_add(1, std::memory_order_relaxed); } + void OnWebSocketConnectionClosed() { m_ActiveWebSocketConnections.fetch_sub(1, std::memory_order_relaxed); } + uint64_t GetActiveWebSocketConnectionCount() const { return m_ActiveWebSocketConnections.load(std::memory_order_relaxed); } + + /** Track WebSocket frame and byte counters — called by WS connection implementations per frame. */ + void OnWebSocketFrameReceived(uint64_t Bytes) + { + m_WsFramesReceived.fetch_add(1, std::memory_order_relaxed); + m_WsBytesReceived.fetch_add(Bytes, std::memory_order_relaxed); + } + void OnWebSocketFrameSent(uint64_t Bytes) + { + m_WsFramesSent.fetch_add(1, std::memory_order_relaxed); + m_WsBytesSent.fetch_add(Bytes, std::memory_order_relaxed); + } + + // IHttpStatsProvider + virtual CbObject CollectStats() override; + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + private: std::vector<HttpService*> m_KnownServices; + int m_EffectivePort = 0; + std::string m_ExternalHost; + metrics::Meter m_RequestMeter; + std::string m_DefaultRedirect; + std::atomic<uint64_t> m_ActiveWebSocketConnections{0}; + std::atomic<uint64_t> m_WsFramesReceived{0}; + std::atomic<uint64_t> m_WsFramesSent{0}; + std::atomic<uint64_t> m_WsBytesReceived{0}; + std::atomic<uint64_t> m_WsBytesSent{0}; virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0; virtual void OnRun(bool IsInteractiveSession) = 0; virtual void OnRequestExit() = 0; virtual void OnClose() = 0; + +protected: + virtual std::string OnGetExternalHost() const; }; struct HttpServerPluginConfig @@ -236,7 +351,7 @@ public: inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } private: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + explicit HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} ~HttpRouterRequest() = default; HttpRouterRequest(const HttpRouterRequest&) = delete; @@ -385,7 +500,7 @@ public: ~HttpRpcHandler(); HttpRpcHandler(const HttpRpcHandler&) = delete; - HttpRpcHandler operator=(const HttpRpcHandler&) = delete; + HttpRpcHandler& operator=(const HttpRpcHandler&) = delete; void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction); @@ -401,17 +516,7 @@ private: bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); -struct IHttpStatsProvider -{ - virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; -}; - -struct IHttpStatsService -{ - virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; - virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) = 0; -}; - -void http_forcelink(); // internal +void http_forcelink(); // internal +void websocket_forcelink(); // internal } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpstats.h b/src/zenhttp/include/zenhttp/httpstats.h index e6fea6765..460315faf 100644 --- a/src/zenhttp/include/zenhttp/httpstats.h +++ b/src/zenhttp/include/zenhttp/httpstats.h @@ -3,23 +3,50 @@ #pragma once #include <zencore/logging.h> +#include <zencore/thread.h> #include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <atomic> #include <map> +#include <memory> +#include <thread> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +#include <asio/steady_timer.hpp> +ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -class HttpStatsService : public HttpService, public IHttpStatsService +class HttpStatsService : public HttpService, public IHttpStatsService, public IWebSocketHandler { public: - HttpStatsService(); + /// Construct without an io_context — optionally uses a dedicated push thread + /// for WebSocket stats broadcasting. + explicit HttpStatsService(bool EnableWebSockets = false); + + /// Construct with an external io_context — uses an asio timer instead + /// of a dedicated thread for WebSocket stats broadcasting. + /// The caller must ensure the io_context outlives this service and that + /// its run loop is active. + HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets = true); + ~HttpStatsService(); + void Shutdown(); + virtual const char* BaseUri() const override; virtual void HandleRequest(HttpServerRequest& Request) override; virtual void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; virtual void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) override; + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override; + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override; + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) override; + private: LoggerRef m_Log; HttpRequestRouter m_Router; @@ -28,6 +55,22 @@ private: RwLock m_Lock; std::map<std::string, IHttpStatsProvider*> m_Providers; + + // WebSocket push + RwLock m_WsConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_WsConnections; + std::atomic<bool> m_PushEnabled{false}; + + void BroadcastStats(); + + // Thread-based push (when no io_context is provided) + std::thread m_PushThread; + Event m_PushEvent; + void PushThreadFunction(); + + // Timer-based push (when an io_context is provided) + std::unique_ptr<asio::steady_timer> m_PushTimer; + void EnqueuePushTimer(); }; } // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpwsclient.h b/src/zenhttp/include/zenhttp/httpwsclient.h new file mode 100644 index 000000000..926ec1e3d --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpwsclient.h @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zenhttp/httpclient.h> +#include <zenhttp/websocket.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio/io_context.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <chrono> +#include <cstdint> +#include <functional> +#include <memory> +#include <optional> +#include <span> +#include <string> +#include <string_view> + +namespace zen { + +/** + * Callback interface for WebSocket client events + * + * Separate from the server-side IWebSocketHandler because the caller + * already owns the HttpWsClient — no Ref<WebSocketConnection> needed. + */ +class IWsClientHandler +{ +public: + virtual ~IWsClientHandler() = default; + + virtual void OnWsOpen() = 0; + virtual void OnWsMessage(const WebSocketMessage& Msg) = 0; + virtual void OnWsClose(uint16_t Code, std::string_view Reason) = 0; +}; + +struct HttpWsClientSettings +{ + std::string LogCategory = "wsclient"; + std::chrono::milliseconds ConnectTimeout{5000}; + std::optional<std::function<HttpClientAccessToken()>> AccessTokenProvider; +}; + +/** + * WebSocket client over TCP (ws:// scheme) + * + * Uses ASIO for async I/O. Two construction modes: + * - Internal io_context + background thread (standalone use) + * - External io_context (shared event loop, no internal thread) + * + * Thread-safe for SendText/SendBinary/Close. + */ +class HttpWsClient +{ +public: + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings = {}); + HttpWsClient(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings = {}); + + ~HttpWsClient(); + + HttpWsClient(const HttpWsClient&) = delete; + HttpWsClient& operator=(const HttpWsClient&) = delete; + + void Connect(); + void SendText(std::string_view Text); + void SendBinary(std::span<const uint8_t> Data); + void Close(uint16_t Code = 1000, std::string_view Reason = {}); + bool IsOpen() const; + +private: + struct Impl; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/packageformat.h b/src/zenhttp/include/zenhttp/packageformat.h index c90b840da..1a5068580 100644 --- a/src/zenhttp/include/zenhttp/packageformat.h +++ b/src/zenhttp/include/zenhttp/packageformat.h @@ -68,7 +68,7 @@ struct CbAttachmentEntry struct CbAttachmentReferenceHeader { uint64_t PayloadByteOffset = 0; - uint64_t PayloadByteSize = ~0u; + uint64_t PayloadByteSize = ~uint64_t(0); uint16_t AbsolutePathLength = 0; // This header will be followed by UTF8 encoded absolute path to backing file diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurity.h b/src/zenhttp/include/zenhttp/security/passwordsecurity.h new file mode 100644 index 000000000..6b2b548a6 --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurity.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class PasswordSecurity +{ +public: + struct Configuration + { + std::string Password; + bool ProtectMachineLocalRequests = false; + std::vector<std::string> UnprotectedUris; + }; + + explicit PasswordSecurity(const Configuration& Config); + + [[nodiscard]] inline std::string_view Password() const { return m_Config.Password; } + [[nodiscard]] inline bool ProtectMachineLocalRequests() const { return m_Config.ProtectMachineLocalRequests; } + [[nodiscard]] bool IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const; + + bool IsAllowed(std::string_view Password, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest); + +private: + const Configuration m_Config; + tsl::robin_map<uint32_t, uint32_t> m_UnprotectedUriHashes; +}; + +void passwordsecurity_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h new file mode 100644 index 000000000..c098f05ad --- /dev/null +++ b/src/zenhttp/include/zenhttp/security/passwordsecurityfilter.h @@ -0,0 +1,51 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <zenhttp/security/passwordsecurity.h> + +namespace zen { + +class PasswordHttpFilter : public IHttpRequestFilter +{ +public: + static constexpr std::string_view TypeName = "password"; + + struct Configuration + { + PasswordSecurity::Configuration PasswordConfig; + std::string AuthenticationTypeString; + }; + + /** + * Expected format (Json) + * { + * "password": { # "Authorization: Basic <username:password base64 encoded>" style + * "username": "<username>", + * "password": "<password>" + * }, + * "protect-machine-local-requests": false, + * "unprotected-uris": [ + * "/health/", + * "/health/info", + * "/health/version" + * ] + * } + */ + static Configuration ReadConfiguration(CbObjectView Config); + + explicit PasswordHttpFilter(const PasswordHttpFilter::Configuration& Config) + : m_PasswordSecurity(Config.PasswordConfig) + , m_AuthenticationTypeString(Config.AuthenticationTypeString) + { + } + + virtual Result FilterRequest(HttpServerRequest& Request) override; + +private: + PasswordSecurity m_PasswordSecurity; + const std::string m_AuthenticationTypeString; +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..bc3293282 --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenbase/refcount.h> +#include <zencore/iobuffer.h> + +#include <cstdint> +#include <span> +#include <string_view> + +namespace zen { + +enum class WebSocketOpcode : uint8_t +{ + kText = 0x1, + kBinary = 0x2, + kClose = 0x8, + kPing = 0x9, + kPong = 0xA +}; + +struct WebSocketMessage +{ + WebSocketOpcode Opcode = WebSocketOpcode::kText; + IoBuffer Payload; + uint16_t CloseCode = 0; +}; + +/** + * Represents an active WebSocket connection + * + * Derived classes implement the actual transport (e.g. ASIO sockets). + * Instances are reference-counted so that both the service layer and + * the async read/write loop can share ownership. + */ +class WebSocketConnection : public RefCounted +{ +public: + virtual ~WebSocketConnection() = default; + + virtual void SendText(std::string_view Text) = 0; + virtual void SendBinary(std::span<const uint8_t> Data) = 0; + virtual void Close(uint16_t Code = 1000, std::string_view Reason = {}) = 0; + virtual bool IsOpen() const = 0; +}; + +/** + * Interface for services that accept WebSocket upgrades + * + * An HttpService may additionally implement this interface to indicate + * it supports WebSocket connections. The HTTP server checks for this + * via dynamic_cast when it sees an Upgrade: websocket request. + */ +class IWebSocketHandler +{ +public: + virtual ~IWebSocketHandler() = default; + + virtual void OnWebSocketOpen(Ref<WebSocketConnection> Connection) = 0; + virtual void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) = 0; + virtual void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, std::string_view Reason) = 0; +}; + +} // namespace zen diff --git a/src/zenhttp/monitoring/httpstats.cpp b/src/zenhttp/monitoring/httpstats.cpp index b097a0d3f..2370def0c 100644 --- a/src/zenhttp/monitoring/httpstats.cpp +++ b/src/zenhttp/monitoring/httpstats.cpp @@ -3,15 +3,57 @@ #include "zenhttp/httpstats.h" #include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/trace.h> namespace zen { -HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats")) +HttpStatsService::HttpStatsService(bool EnableWebSockets) : m_Log(logging::Get("stats")) { + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushThread = std::thread([this] { PushThreadFunction(); }); + } +} + +HttpStatsService::HttpStatsService(asio::io_context& IoContext, bool EnableWebSockets) : m_Log(logging::Get("stats")) +{ + if (EnableWebSockets) + { + m_PushEnabled.store(true); + m_PushTimer = std::make_unique<asio::steady_timer>(IoContext); + EnqueuePushTimer(); + } } HttpStatsService::~HttpStatsService() { + Shutdown(); +} + +void +HttpStatsService::Shutdown() +{ + if (!m_PushEnabled.exchange(false)) + { + return; + } + + if (m_PushTimer) + { + m_PushTimer->cancel(); + m_PushTimer.reset(); + } + + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + m_PushThread.join(); + } + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.clear(); }); } const char* @@ -39,6 +81,7 @@ HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Pro void HttpStatsService::HandleRequest(HttpServerRequest& Request) { + ZEN_TRACE_CPU("HttpStatsService::HandleRequest"); using namespace std::literals; std::string_view Key = Request.RelativeUri(); @@ -89,4 +132,154 @@ HttpStatsService::HandleRequest(HttpServerRequest& Request) } } +////////////////////////////////////////////////////////////////////////// +// +// IWebSocketHandler +// + +void +HttpStatsService::OnWebSocketOpen(Ref<WebSocketConnection> Connection) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketOpen"); + ZEN_INFO("Stats WebSocket client connected"); + + m_WsConnectionsLock.WithExclusiveLock([&] { m_WsConnections.push_back(std::move(Connection)); }); + + // Send initial state immediately + if (m_PushThread.joinable()) + { + m_PushEvent.Set(); + } +} + +void +HttpStatsService::OnWebSocketMessage(WebSocketConnection& /*Conn*/, const WebSocketMessage& /*Msg*/) +{ + // No client-to-server messages expected +} + +void +HttpStatsService::OnWebSocketClose(WebSocketConnection& Conn, [[maybe_unused]] uint16_t Code, [[maybe_unused]] std::string_view Reason) +{ + ZEN_TRACE_CPU("HttpStatsService::OnWebSocketClose"); + ZEN_INFO("Stats WebSocket client disconnected (code {})", Code); + + m_WsConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_WsConnections.begin(), m_WsConnections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_WsConnections.erase(It, m_WsConnections.end()); + }); +} + +////////////////////////////////////////////////////////////////////////// +// +// Stats broadcast +// + +void +HttpStatsService::BroadcastStats() +{ + ZEN_TRACE_CPU("HttpStatsService::BroadcastStats"); + std::vector<Ref<WebSocketConnection>> Connections; + m_WsConnectionsLock.WithSharedLock([&] { Connections = m_WsConnections; }); + + if (Connections.empty()) + { + return; + } + + // Collect stats from all providers + ExtendableStringBuilder<4096> JsonBuilder; + JsonBuilder.Append("{"); + + bool First = true; + { + RwLock::SharedLockScope _(m_Lock); + for (auto& [Id, Provider] : m_Providers) + { + CbObject Stats = Provider->CollectStats(); + if (!Stats) + { + continue; + } + + if (!First) + { + JsonBuilder.Append(","); + } + First = false; + + // Emit as "provider_id": { ... } + JsonBuilder.Append("\""); + JsonBuilder.Append(Id); + JsonBuilder.Append("\":"); + + ExtendableStringBuilder<2048> StatsJson; + Stats.ToJson(StatsJson); + JsonBuilder.Append(StatsJson.ToView()); + } + } + + JsonBuilder.Append("}"); + + std::string_view Json = JsonBuilder.ToView(); + for (auto& Conn : Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Json); + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Thread-based push (fallback when no io_context) +// + +void +HttpStatsService::PushThreadFunction() +{ + SetCurrentThreadName("stats_ws_push"); + + while (m_PushEnabled.load()) + { + m_PushEvent.Wait(5000); + m_PushEvent.Reset(); + + if (!m_PushEnabled.load()) + { + break; + } + + BroadcastStats(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Timer-based push (when io_context is provided) +// + +void +HttpStatsService::EnqueuePushTimer() +{ + if (!m_PushTimer) + { + return; + } + + m_PushTimer->expires_after(std::chrono::seconds(5)); + m_PushTimer->async_wait([this](const asio::error_code& Ec) { + if (Ec) + { + return; + } + + BroadcastStats(); + EnqueuePushTimer(); + }); +} + } // namespace zen diff --git a/src/zenhttp/packageformat.cpp b/src/zenhttp/packageformat.cpp index 708238224..cbfe4d889 100644 --- a/src/zenhttp/packageformat.cpp +++ b/src/zenhttp/packageformat.cpp @@ -581,7 +581,7 @@ ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); - Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); + Attachments.emplace_back(CbAttachment(SharedBuffer{AttachmentBufferCopy}, Entry.AttachmentHash)); } else { @@ -805,6 +805,8 @@ CbPackageReader::Finalize() #if ZEN_WITH_TESTS +TEST_SUITE_BEGIN("http.packageformat"); + TEST_CASE("CbPackage.Serialization") { // Make a test package @@ -926,6 +928,8 @@ TEST_CASE("CbPackage.LocalRef") Reader.Finalize(); } +TEST_SUITE_END(); + void forcelink_packageformat() { diff --git a/src/zenhttp/security/passwordsecurity.cpp b/src/zenhttp/security/passwordsecurity.cpp new file mode 100644 index 000000000..0e3a743c3 --- /dev/null +++ b/src/zenhttp/security/passwordsecurity.cpp @@ -0,0 +1,176 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurity.h" +#include <zencore/compactbinaryutil.h> +#include <zencore/fmtutils.h> +#include <zencore/string.h> + +#if ZEN_WITH_TESTS +# include <zencore/compactbinarybuilder.h> +# include <zencore/testing.h> +#endif // ZEN_WITH_TESTS + +namespace zen { +using namespace std::literals; + +PasswordSecurity::PasswordSecurity(const Configuration& Config) : m_Config(Config) +{ + m_UnprotectedUriHashes.reserve(m_Config.UnprotectedUris.size()); + for (uint32_t Index = 0; Index < m_Config.UnprotectedUris.size(); Index++) + { + const std::string& UnprotectedUri = m_Config.UnprotectedUris[Index]; + if (auto Result = m_UnprotectedUriHashes.insert({HashStringDjb2(UnprotectedUri), Index}); !Result.second) + { + throw std::runtime_error(fmt::format( + "password security unprotected uris does not generate unique hashes. Uri #{} ('{}') collides with uri #{} ('{}')", + Index + 1, + UnprotectedUri, + Result.first->second + 1, + m_Config.UnprotectedUris[Result.first->second])); + } + } +} + +bool +PasswordSecurity::IsUnprotectedUri(std::string_view BaseUri, std::string_view RelativeUri) const +{ + if (!m_Config.UnprotectedUris.empty()) + { + uint32_t UriHash = HashStringDjb2(std::array<const std::string_view, 2>{BaseUri, RelativeUri}); + if (auto It = m_UnprotectedUriHashes.find(UriHash); It != m_UnprotectedUriHashes.end()) + { + const std::string_view& UnprotectedUri = m_Config.UnprotectedUris[It->second]; + if (UnprotectedUri.length() == BaseUri.length() + RelativeUri.length()) + { + if (UnprotectedUri.substr(0, BaseUri.length()) == BaseUri && UnprotectedUri.substr(BaseUri.length()) == RelativeUri) + { + return true; + } + } + } + } + return false; +} + +bool +PasswordSecurity::IsAllowed(std::string_view InPassword, std::string_view BaseUri, std::string_view RelativeUri, bool IsMachineLocalRequest) +{ + if (IsUnprotectedUri(BaseUri, RelativeUri)) + { + return true; + } + if (!ProtectMachineLocalRequests() && IsMachineLocalRequest) + { + return true; + } + if (Password().empty()) + { + return true; + } + if (Password() == InPassword) + { + return true; + } + return false; +} + +#if ZEN_WITH_TESTS + +TEST_SUITE_BEGIN("http.passwordsecurity"); + +TEST_CASE("passwordsecurity.allowanything") +{ + PasswordSecurity Anything({}); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(Anything.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); +} + +TEST_CASE("passwordsecurity.allowalllocal") +{ + PasswordSecurity AllLocal({.Password = "123456"}); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); +} + +TEST_CASE("passwordsecurity.allowonlypassword") +{ + PasswordSecurity AllLocal({.Password = "123456", .ProtectMachineLocalRequests = true}); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.allowsomeexternaluris") +{ + PasswordSecurity AllLocal( + {.Password = "123456", .ProtectMachineLocalRequests = false, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})}); + CHECK(AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.allowsomelocaluris") +{ + PasswordSecurity AllLocal( + {.Password = "123456", .ProtectMachineLocalRequests = true, .UnprotectedUris = std::vector<std::string>({"/free/access", "/ok"})}); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed(""sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ true)); + CHECK(!AllLocal.IsAllowed("thewrongpassword"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ true)); + CHECK(AllLocal.IsAllowed(""sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed(""sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/free", "/access", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("thewrongpassword"sv, "/ok", "", /*IsMachineLocalRequest*/ false)); + CHECK(AllLocal.IsAllowed("123456"sv, "/supersecret", "/uri", /*IsMachineLocalRequest*/ false)); +} + +TEST_CASE("passwordsecurity.conflictingunprotecteduris") +{ + try + { + PasswordSecurity AllLocal({.Password = "123456", + .ProtectMachineLocalRequests = true, + .UnprotectedUris = std::vector<std::string>({"/free/access", "/free/access"})}); + CHECK(false); + } + catch (const std::runtime_error& Ex) + { + CHECK_EQ(Ex.what(), + std::string("password security unprotected uris does not generate unique hashes. Uri #2 ('/free/access') collides with " + "uri #1 ('/free/access')")); + } +} + +TEST_SUITE_END(); + +void +passwordsecurity_forcelink() +{ +} +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zenhttp/security/passwordsecurityfilter.cpp b/src/zenhttp/security/passwordsecurityfilter.cpp new file mode 100644 index 000000000..87d8cc275 --- /dev/null +++ b/src/zenhttp/security/passwordsecurityfilter.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenhttp/security/passwordsecurityfilter.h" + +#include <zencore/base64.h> +#include <zencore/compactbinaryutil.h> +#include <zencore/fmtutils.h> + +namespace zen { + +using namespace std::literals; + +PasswordHttpFilter::Configuration +PasswordHttpFilter::ReadConfiguration(CbObjectView Config) +{ + Configuration Result; + if (CbObjectView PasswordType = Config["basic"sv].AsObjectView(); PasswordType) + { + Result.AuthenticationTypeString = "Basic "; + std::string_view Username = PasswordType["username"sv].AsString(); + std::string_view Password = PasswordType["password"sv].AsString(); + std::string UsernamePassword = fmt::format("{}:{}", Username, Password); + Result.PasswordConfig.Password.resize(Base64::GetEncodedDataSize(uint32_t(UsernamePassword.length()))); + Base64::Encode(reinterpret_cast<const uint8_t*>(UsernamePassword.data()), + uint32_t(UsernamePassword.size()), + const_cast<char*>(Result.PasswordConfig.Password.data())); + } + Result.PasswordConfig.ProtectMachineLocalRequests = Config["protect-machine-local-requests"sv].AsBool(); + Result.PasswordConfig.UnprotectedUris = compactbinary_helpers::ReadArray<std::string>("unprotected-uris"sv, Config); + return Result; +} + +IHttpRequestFilter::Result +PasswordHttpFilter::FilterRequest(HttpServerRequest& Request) +{ + std::string_view Password; + std::string_view AuthorizationHeader = Request.GetAuthorizationHeader(); + size_t AuthorizationHeaderLength = AuthorizationHeader.length(); + if (AuthorizationHeaderLength > m_AuthenticationTypeString.length()) + { + if (StrCaseCompare(AuthorizationHeader.data(), m_AuthenticationTypeString.c_str(), m_AuthenticationTypeString.length()) == 0) + { + Password = AuthorizationHeader.substr(m_AuthenticationTypeString.length()); + } + } + + bool IsAllowed = + m_PasswordSecurity.IsAllowed(Password, Request.Service().BaseUri(), Request.RelativeUri(), Request.IsLocalMachineRequest()); + if (IsAllowed) + { + return Result::Accepted; + } + return Result::Forbidden; +} + +} // namespace zen diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 18a0f6a40..f5178ebe8 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -7,12 +7,15 @@ #include <zencore/fmtutils.h> #include <zencore/logging.h> #include <zencore/memory/llm.h> +#include <zencore/system.h> #include <zencore/thread.h> #include <zencore/trace.h> #include <zencore/windows.h> #include <zenhttp/httpserver.h> #include "httpparser.h" +#include "wsasio.h" +#include "wsframecodec.h" #include <EASTL/fixed_vector.h> @@ -89,15 +92,19 @@ IsIPv6AvailableSysctl(void) char buf[16]; if (fgets(buf, sizeof(buf), f)) { - fclose(f); // 0 means IPv6 enabled, 1 means disabled val = atoi(buf); } + fclose(f); } return val == 0; } +#endif // ZEN_PLATFORM_LINUX +namespace zen { + +#if ZEN_PLATFORM_LINUX bool IsIPv6Capable() { @@ -121,8 +128,6 @@ IsIPv6Capable() } #endif -namespace zen { - const FLLMTag& GetHttpasioTag() { @@ -145,7 +150,7 @@ inline LoggerRef InitLogger() { LoggerRef Logger = logging::Get("asio"); - // Logger.set_level(spdlog::level::trace); + // Logger.SetLogLevel(logging::Trace); return Logger; } @@ -496,16 +501,21 @@ public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); - void Initialize(std::filesystem::path DataDir); - int Start(uint16_t Port, const AsioConfig& Config); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); + void Initialize(std::filesystem::path DataDir); + int Start(uint16_t Port, const AsioConfig& Config); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + + bool IsLoopbackOnly() const; asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -518,6 +528,11 @@ public: RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; + + HttpServer* m_HttpServer = nullptr; }; /** @@ -527,12 +542,21 @@ public: class HttpAsioServerRequest : public HttpServerRequest { public: - HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, uint32_t RequestNumber); + HttpAsioServerRequest(HttpRequestParser& Request, + HttpService& Service, + IoBuffer PayloadBuffer, + uint32_t RequestNumber, + bool IsLocalMachineRequest, + std::string RemoteAddress); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; @@ -548,6 +572,8 @@ public: HttpRequestParser& m_Request; uint32_t m_RequestNumber = 0; // Note: different to request ID which is derived from headers IoBuffer m_PayloadBuffer; + bool m_IsLocalMachineRequest; + std::string m_RemoteAddress; std::unique_ptr<HttpResponse> m_Response; }; @@ -925,6 +951,7 @@ private: void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, uint32_t RequestNumber, HttpResponse* ResponseToPop); void CloseConnection(); + void SendInlineResponse(uint32_t RequestNumber, std::string_view StatusLine, std::string_view Headers = {}, std::string_view Body = {}); HttpAsioServerImpl& m_Server; asio::streambuf m_RequestBuffer; @@ -1025,6 +1052,8 @@ HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused] } } + m_Server.m_TotalBytesReceived.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data received, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed), @@ -1078,6 +1107,8 @@ HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, return; } + m_Server.m_TotalBytesSent.fetch_add(ByteCount, std::memory_order_relaxed); + ZEN_TRACE_VERBOSE("on data sent, connection: {}, request: {}, thread: {}, bytes: {}", m_ConnectionId, RequestNumber, @@ -1139,10 +1170,91 @@ HttpServerConnection::CloseConnection() } void +HttpServerConnection::SendInlineResponse(uint32_t RequestNumber, + std::string_view StatusLine, + std::string_view Headers, + std::string_view Body) +{ + ExtendableStringBuilder<256> ResponseBuilder; + ResponseBuilder << "HTTP/1.1 " << StatusLine << "\r\n"; + if (!Headers.empty()) + { + ResponseBuilder << Headers; + } + if (!m_RequestData.IsKeepAlive()) + { + ResponseBuilder << "Connection: close\r\n"; + } + ResponseBuilder << "\r\n"; + if (!Body.empty()) + { + ResponseBuilder << Body; + } + auto ResponseView = ResponseBuilder.ToView(); + IoBuffer ResponseData(IoBuffer::Clone, ResponseView.data(), ResponseView.size()); + auto Buffer = asio::buffer(ResponseData.GetData(), ResponseData.GetSize()); + asio::async_write( + *m_Socket.get(), + Buffer, + [Conn = AsSharedPtr(), RequestNumber, Response = std::move(ResponseData)](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); + }); +} + +void HttpServerConnection::HandleRequest() { ZEN_MEMSCOPE(GetHttpasioTag()); + // WebSocket upgrade detection must happen before the keep-alive check below, + // because Upgrade requests have "Connection: Upgrade" which the HTTP parser + // treats as non-keep-alive, causing a premature shutdown of the receive side. + if (m_RequestData.IsWebSocketUpgrade()) + { + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service); + if (WsHandler && !m_RequestData.SecWebSocketKey().empty()) + { + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(m_RequestData.SecWebSocketKey()); + + auto ResponseStr = std::make_shared<std::string>(); + ResponseStr->reserve(256); + ResponseStr->append( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: "); + ResponseStr->append(AcceptKey); + ResponseStr->append("\r\n\r\n"); + + // Send the 101 response on the current socket, then hand the socket off + // to a WsAsioConnection for the WebSocket protocol. + asio::async_write(*m_Socket, + asio::buffer(ResponseStr->data(), ResponseStr->size()), + [Conn = AsSharedPtr(), WsHandler, OwnedResponse = ResponseStr](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_WARN("WebSocket 101 send failed: {}", Ec.message()); + return; + } + + Conn->m_Server.m_HttpServer->OnWebSocketConnectionOpened(); + Ref<WsAsioConnection> WsConn( + new WsAsioConnection(std::move(Conn->m_Socket), *WsHandler, Conn->m_Server.m_HttpServer)); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + }); + + m_RequestState = RequestState::kDone; + return; + } + } + // Service doesn't support WebSocket or missing key — fall through to normal handling + } + if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; @@ -1166,14 +1278,24 @@ HttpServerConnection::HandleRequest() { ZEN_TRACE_CPU("asio::HandleRequest"); - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body(), RequestNumber); + m_Server.m_HttpServer->MarkRequest(); + + auto RemoteEndpoint = m_Socket->remote_endpoint(); + bool IsLocalConnection = m_Socket->local_endpoint().address() == RemoteEndpoint.address(); + + HttpAsioServerRequest Request(m_RequestData, + *Service, + m_RequestData.Body(), + RequestNumber, + IsLocalConnection, + RemoteEndpoint.address().to_string()); ZEN_TRACE_VERBOSE("handle request, connection: {}, request: {}'", m_ConnectionId, RequestNumber); const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server.m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server.m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server.m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -1188,56 +1310,73 @@ HttpServerConnection::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) { + if (Request.ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", ToString(RequestVerb), Uri, Response->ResponseCode(), NiceBytes(Response->ContentLength())); + } + // Transmit the response if (m_RequestData.RequestVerb() == HttpVerb::kHead) @@ -1278,51 +1417,24 @@ HttpServerConnection::HandleRequest() } } - if (m_RequestData.RequestVerb() == HttpVerb::kHead) + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = m_Server.m_HttpServer->GetDefaultRedirect(); + if (!DefaultRedirect.empty() && (m_RequestData.Url() == "/" || m_RequestData.Url().empty())) { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + ExtendableStringBuilder<128> Headers; + Headers << "Location: " << DefaultRedirect << "\r\nContent-Length: 0\r\n"; + SendInlineResponse(RequestNumber, "302 Found"sv, Headers.ToView()); + } + else if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + SendInlineResponse(RequestNumber, "404 NOT FOUND"sv); } else { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - - asio::async_write(*m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr(), RequestNumber](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, RequestNumber, /* ResponseToPop */ nullptr); - }); + SendInlineResponse(RequestNumber, + "404 NOT FOUND"sv, + "Content-Length: 23\r\nContent-Type: text/plain\r\n"sv, + "No suitable route found"sv); } } @@ -1348,8 +1460,11 @@ struct HttpAcceptor m_Acceptor.set_option(exclusive_address(true)); m_AlternateProtocolAcceptor.set_option(exclusive_address(true)); #else // ZEN_PLATFORM_WINDOWS - m_Acceptor.set_option(asio::socket_base::reuse_address(false)); - m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(false)); + // Allow binding to a port in TIME_WAIT so the server can restart immediately + // after a previous instance exits. On Linux this does not allow two processes + // to actively listen on the same port simultaneously. + m_Acceptor.set_option(asio::socket_base::reuse_address(true)); + m_AlternateProtocolAcceptor.set_option(asio::socket_base::reuse_address(true)); #endif // ZEN_PLATFORM_WINDOWS m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); @@ -1512,7 +1627,7 @@ struct HttpAcceptor { ZEN_WARN("Unable to initialize asio service, (bind returned '{}')", BindErrorCode.message()); - return 0; + return {}; } if (EffectivePort != BasePort) @@ -1569,7 +1684,8 @@ struct HttpAcceptor void StopAccepting() { m_IsStopped = true; } - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + int GetAcceptPort() const { return m_Acceptor.local_endpoint().port(); } + bool IsLoopbackOnly() const { return m_Acceptor.local_endpoint().address().is_loopback(); } bool IsValid() const { return m_IsValid; } @@ -1632,11 +1748,15 @@ private: HttpAsioServerRequest::HttpAsioServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer, - uint32_t RequestNumber) + uint32_t RequestNumber, + bool IsLocalMachineRequest, + std::string RemoteAddress) : HttpServerRequest(Service) , m_Request(Request) , m_RequestNumber(RequestNumber) , m_PayloadBuffer(std::move(PayloadBuffer)) +, m_IsLocalMachineRequest(IsLocalMachineRequest) +, m_RemoteAddress(std::move(RemoteAddress)) { const int PrefixLength = Service.UriPrefixLength(); @@ -1708,6 +1828,24 @@ HttpAsioServerRequest::ParseRequestId() const return m_Request.RequestId(); } +bool +HttpAsioServerRequest::IsLocalMachineRequest() const +{ + return m_IsLocalMachineRequest; +} + +std::string_view +HttpAsioServerRequest::GetRemoteAddress() const +{ + return m_RemoteAddress; +} + +std::string_view +HttpAsioServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + IoBuffer HttpAsioServerRequest::ReadPayload() { @@ -1904,6 +2042,37 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +void +HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_MEMSCOPE(GetHttpasioTag()); + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + + return RequestFilter->FilterRequest(Request); +} + +bool +HttpAsioServerImpl::IsLoopbackOnly() const +{ + return m_Acceptor && m_Acceptor->IsLoopbackOnly(); +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1916,11 +2085,15 @@ public: HttpAsioServer(const AsioConfig& Config); ~HttpAsioServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; private: Event m_ShutdownEvent; @@ -1934,6 +2107,7 @@ HttpAsioServer::HttpAsioServer(const AsioConfig& Config) : m_InitialConfig(Config) , m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) { + m_Impl->m_HttpServer = this; ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(HttpRequestParser), sizeof(HttpRequestParser)); } @@ -1965,6 +2139,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service) m_Impl->RegisterService(Service.BaseUri(), Service); } +void +HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + m_Impl->SetHttpRequestFilter(RequestFilter); +} + int HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { @@ -1989,10 +2169,46 @@ HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) return m_BasePort; } +std::string +HttpAsioServer::OnGetExternalHost() const +{ + if (m_Impl->IsLoopbackOnly()) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpAsioServer::GetTotalBytesReceived() const +{ + return m_Impl->m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpAsioServer::GetTotalBytesSent() const +{ + return m_Impl->m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpAsioServer::OnRun(bool IsInteractive) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractive) @@ -2008,12 +2224,13 @@ HttpAsioServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractive) { @@ -2022,8 +2239,8 @@ HttpAsioServer::OnRun(bool IsInteractive) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpasio.h b/src/zenhttp/servers/httpasio.h index c483dfc28..3ec1141a7 100644 --- a/src/zenhttp/servers/httpasio.h +++ b/src/zenhttp/servers/httpasio.h @@ -15,4 +15,6 @@ struct AsioConfig Ref<HttpServer> CreateHttpAsioServer(const AsioConfig& Config); +bool IsIPv6Capable(); + } // namespace zen diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 31cb04be5..584e06cbf 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -54,9 +54,19 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir) } void +HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + for (auto& Server : m_Servers) + { + Server->SetHttpRequestFilter(RequestFilter); + } +} + +void HttpMultiServer::OnRun(bool IsInteractiveSession) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractiveSession) @@ -72,12 +82,13 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractiveSession) { @@ -86,8 +97,8 @@ HttpMultiServer::OnRun(bool IsInteractiveSession) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } @@ -106,6 +117,16 @@ HttpMultiServer::OnClose() } } +std::string +HttpMultiServer::OnGetExternalHost() const +{ + if (!m_Servers.empty()) + { + return std::string(m_Servers.front()->GetExternalHost()); + } + return HttpServer::OnGetExternalHost(); +} + void HttpMultiServer::AddServer(Ref<HttpServer> Server) { diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index ae0ed74cf..97699828a 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -15,11 +15,13 @@ public: HttpMultiServer(); ~HttpMultiServer(); - virtual void OnRegisterService(HttpService& Service) override; - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool IsInteractiveSession) override; - virtual void OnRequestExit() override; - virtual void OnClose() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool IsInteractiveSession) override; + virtual void OnRequestExit() override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; void AddServer(Ref<HttpServer> Server); diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index 0ec1cb3c4..9bb7ef3bc 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service) ZEN_UNUSED(Service); } +void +HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_UNUSED(RequestFilter); +} + int HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { @@ -34,7 +40,8 @@ HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) void HttpNullServer::OnRun(bool IsInteractiveSession) { - const int WaitTimeout = 1000; + const int WaitTimeout = 1000; + bool ShutdownRequested = false; #if ZEN_PLATFORM_WINDOWS if (IsInteractiveSession) @@ -50,12 +57,13 @@ HttpNullServer::OnRun(bool IsInteractiveSession) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #else if (IsInteractiveSession) { @@ -64,8 +72,8 @@ HttpNullServer::OnRun(bool IsInteractiveSession) do { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); + } while (!ShutdownRequested); #endif } diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h index ce7230938..52838f012 100644 --- a/src/zenhttp/servers/httpnull.h +++ b/src/zenhttp/servers/httpnull.h @@ -18,6 +18,7 @@ public: ~HttpNullServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpparser.cpp b/src/zenhttp/servers/httpparser.cpp index 93094e21b..918b55dc6 100644 --- a/src/zenhttp/servers/httpparser.cpp +++ b/src/zenhttp/servers/httpparser.cpp @@ -12,13 +12,17 @@ namespace zen { using namespace std::literals; -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constexpr uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constexpr uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constexpr uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constexpr uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constexpr uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constexpr uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); +static constexpr uint32_t HashAuthorization = HashStringAsLowerDjb2("Authorization"sv); +static constexpr uint32_t HashUpgrade = HashStringAsLowerDjb2("Upgrade"sv); +static constexpr uint32_t HashSecWebSocketKey = HashStringAsLowerDjb2("Sec-WebSocket-Key"sv); +static constexpr uint32_t HashSecWebSocketVersion = HashStringAsLowerDjb2("Sec-WebSocket-Version"sv); ////////////////////////////////////////////////////////////////////////// // @@ -142,41 +146,62 @@ HttpRequestParser::ParseCurrentHeader() const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); const int8_t CurrentHeaderIndex = int8_t(CurrentHeaderCount - 1); - if (HeaderHash == HashContentLength) + switch (HeaderHash) { - m_ContentLengthHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = CurrentHeaderIndex; - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::TryFromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = CurrentHeaderIndex; + case HashContentLength: + m_ContentLengthHeaderIndex = CurrentHeaderIndex; + break; + + case HashAccept: + m_AcceptHeaderIndex = CurrentHeaderIndex; + break; + + case HashContentType: + m_ContentTypeHeaderIndex = CurrentHeaderIndex; + break; + + case HashAuthorization: + m_AuthorizationHeaderIndex = CurrentHeaderIndex; + break; + + case HashSession: + m_SessionId = Oid::TryFromHexString(HeaderValue); + break; + + case HashRequest: + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + break; + + case HashExpect: + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + break; + + case HashRange: + m_RangeHeaderIndex = CurrentHeaderIndex; + break; + + case HashUpgrade: + m_UpgradeHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketKey: + m_SecWebSocketKeyHeaderIndex = CurrentHeaderIndex; + break; + + case HashSecWebSocketVersion: + m_SecWebSocketVersionHeaderIndex = CurrentHeaderIndex; + break; + + default: + break; } } @@ -220,11 +245,6 @@ NormalizeUrlPath(std::string_view InUrl, std::string& NormalizedUrl) NormalizedUrl.reserve(UrlLength); NormalizedUrl.append(Url, UrlIndex); } - - if (!LastCharWasSeparator) - { - NormalizedUrl.push_back('/'); - } } else if (!NormalizedUrl.empty()) { @@ -305,6 +325,7 @@ HttpRequestParser::OnHeadersComplete() if (ContentLength) { + // TODO: should sanity-check content length here m_BodyBuffer = IoBuffer(ContentLength); } @@ -324,9 +345,9 @@ HttpRequestParser::OnHeadersComplete() int HttpRequestParser::OnBody(const char* Data, size_t Bytes) { - if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + if ((m_BodyPosition + Bytes) > m_BodyBuffer.Size()) { - ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more buffer bytes", (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); return 1; } @@ -337,7 +358,7 @@ HttpRequestParser::OnBody(const char* Data, size_t Bytes) { if (m_BodyPosition != m_BodyBuffer.Size()) { - ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + ZEN_WARN("Body size mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); return 1; } } @@ -353,13 +374,18 @@ HttpRequestParser::ResetState() m_HeaderEntries.clear(); - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_AuthorizationHeaderIndex = -1; + m_UpgradeHeaderIndex = -1; + m_SecWebSocketKeyHeaderIndex = -1; + m_SecWebSocketVersionHeaderIndex = -1; + m_RequestVerb = HttpVerb::kGet; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; m_HeaderData.clear(); m_NormalizedUrl.clear(); @@ -416,4 +442,21 @@ HttpRequestParser::OnMessageComplete() } } +bool +HttpRequestParser::IsWebSocketUpgrade() const +{ + std::string_view Upgrade = GetHeaderValue(m_UpgradeHeaderIndex); + if (Upgrade.empty()) + { + return false; + } + + // Case-insensitive check for "websocket" + if (Upgrade.size() != 9) + { + return false; + } + return StrCaseCompare(Upgrade.data(), "websocket", 9) == 0; +} + } // namespace zen diff --git a/src/zenhttp/servers/httpparser.h b/src/zenhttp/servers/httpparser.h index 0d2664ec5..23ad9d8fb 100644 --- a/src/zenhttp/servers/httpparser.h +++ b/src/zenhttp/servers/httpparser.h @@ -46,6 +46,12 @@ struct HttpRequestParser std::string_view RangeHeader() const { return GetHeaderValue(m_RangeHeaderIndex); } + std::string_view AuthorizationHeader() const { return GetHeaderValue(m_AuthorizationHeaderIndex); } + + std::string_view UpgradeHeader() const { return GetHeaderValue(m_UpgradeHeaderIndex); } + std::string_view SecWebSocketKey() const { return GetHeaderValue(m_SecWebSocketKeyHeaderIndex); } + bool IsWebSocketUpgrade() const; + private: struct HeaderRange { @@ -83,7 +89,11 @@ private: int8_t m_AcceptHeaderIndex; int8_t m_ContentTypeHeaderIndex; int8_t m_RangeHeaderIndex; - HttpVerb m_RequestVerb; + int8_t m_AuthorizationHeaderIndex; + int8_t m_UpgradeHeaderIndex; + int8_t m_SecWebSocketKeyHeaderIndex; + int8_t m_SecWebSocketVersionHeaderIndex; + HttpVerb m_RequestVerb = HttpVerb::kGet; std::atomic_bool m_KeepAlive{false}; bool m_Expect100Continue = false; int m_RequestId = -1; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index b9217ed87..4bf8c61bb 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer // HttpPluginServer virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; @@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer virtual void AddPlugin(Ref<TransportPlugin> Plugin) override; virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override; - HttpService* RouteRequest(std::string_view Url); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); struct ServiceEntry { @@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer HttpService* Service; }; - bool m_IsInitialized = false; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + bool m_IsInitialized = false; RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; std::vector<Ref<TransportPlugin>> m_Plugins; @@ -120,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer bool m_IsRequestLoggingEnabled = false; LoggerRef m_RequestLog; std::atomic_uint32_t m_ConnectionIdCounter{0}; - int m_BasePort; + int m_BasePort = 0; HttpServerTracer m_RequestTracer; @@ -143,8 +146,11 @@ public: HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; + // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection + virtual bool IsLocalMachineRequest() const /* override*/ { return false; } + virtual std::string_view GetAuthorizationHeader() const override; + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; @@ -288,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug ConnectionName = "anonymous"; } - ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName); + ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName); } uint32_t @@ -372,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest() { ZEN_TRACE_CPU("http_plugin::HandleRequest"); + m_Server->MarkRequest(); + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); const HttpVerb RequestVerb = Request.RequestVerb(); const std::string_view Uri = Request.RelativeUri(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { ZEN_LOG_TRACE(m_Server->m_RequestLog, "connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})", @@ -392,53 +400,65 @@ HttpPluginConnectionHandler::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) { @@ -462,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest() const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers(); - if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace)) + if (m_Server->m_RequestLog.ShouldLog(logging::Trace)) { m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber), ResponseBuffers); @@ -618,6 +638,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest() { } +std::string_view +HttpPluginServerRequest::GetAuthorizationHeader() const +{ + return m_Request.AuthorizationHeader(); +} + Oid HttpPluginServerRequest::ParseSessionId() const { @@ -750,6 +776,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir } void +HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +void HttpPluginServerImpl::OnClose() { if (!m_IsInitialized) @@ -806,6 +839,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } @@ -894,6 +928,22 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +IHttpRequestFilter::Result +HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + return RequestFilter->FilterRequest(Request); +} + ////////////////////////////////////////////////////////////////////////// struct HttpPluginServerImpl; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 54cc0c22d..dfe6bb6aa 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -12,6 +12,7 @@ #include <zencore/memory/llm.h> #include <zencore/scopeguard.h> #include <zencore/string.h> +#include <zencore/system.h> #include <zencore/timer.h> #include <zencore/trace.h> #include <zenhttp/packageformat.h> @@ -25,7 +26,9 @@ # include <zencore/workthreadpool.h> # include "iothreadpool.h" +# include <atomic> # include <http.h> +# include <asio.hpp> // for resolving addresses for GetExternalHost namespace zen { @@ -72,6 +75,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In OutString.Append("unknown"); } +class HttpSysServerRequest; + /** * @brief Windows implementation of HTTP server based on http.sys * @@ -83,6 +88,8 @@ GetAddressString(StringBuilderBase& OutString, const SOCKADDR* SockAddr, bool In class HttpSysServer : public HttpServer { friend class HttpSysTransaction; + friend class HttpMessageResponseRequest; + friend struct InitialRequestHandler; public: explicit HttpSysServer(const HttpSysConfig& Config); @@ -90,17 +97,23 @@ public: // HttpServer interface implementation - virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; - virtual void OnRun(bool TestMode) override; - virtual void OnRequestExit() override; - virtual void OnRegisterService(HttpService& Service) override; - virtual void OnClose() override; + virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnRun(bool TestMode) override; + virtual void OnRequestExit() override; + virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; + virtual void OnClose() override; + virtual std::string OnGetExternalHost() const override; + virtual uint64_t GetTotalBytesReceived() const override; + virtual uint64_t GetTotalBytesSent() const override; WorkerThreadPool& WorkPool(); inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + IHttpRequestFilter::Result FilterRequest(HttpSysServerRequest& Request); + private: int InitializeServer(int BasePort); void Cleanup(); @@ -124,8 +137,8 @@ private: std::unique_ptr<WinIoThreadPool> m_IoThreadPool; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; + RwLock m_AsyncWorkPoolInitLock; + std::atomic<WorkerThreadPool*> m_AsyncWorkPool = nullptr; std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; @@ -137,6 +150,12 @@ private: int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; + + RwLock m_RequestFilterLock; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + + std::atomic<uint64_t> m_TotalBytesReceived{0}; + std::atomic<uint64_t> m_TotalBytesSent{0}; }; } // namespace zen @@ -144,6 +163,10 @@ private: #if ZEN_WITH_HTTPSYS +# include "httpsys_iocontext.h" +# include "wshttpsys.h" +# include "wsframecodec.h" + # include <conio.h> # include <mstcpip.h> # pragma comment(lib, "httpapi.lib") @@ -313,6 +336,10 @@ public: virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; + virtual bool IsLocalMachineRequest() const override; + virtual std::string_view GetAuthorizationHeader() const override; + virtual std::string_view GetRemoteAddress() const override; + virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; @@ -320,16 +347,19 @@ public: virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; virtual bool TryGetRanges(HttpRanges& Ranges) override; + void LogRequest(HttpMessageResponseRequest* Response); + using HttpServerRequest::WriteResponse; HttpSysServerRequest(const HttpSysServerRequest&) = delete; HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; + mutable ExtendableStringBuilder<64> m_RemoteAddress; }; /** HTTP transaction @@ -363,7 +393,7 @@ public: PTP_IO Iocp(); HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline OVERLAPPED* Overlapped() { return &m_IoContext.Overlapped; } inline HttpSysServer& Server() { return m_HttpServer; } inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } @@ -380,8 +410,8 @@ public: }; private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; + HttpSysIoContext m_IoContext{}; + HttpSysServer& m_HttpServer; // Tracks which handler is due to handle the next I/O completion event HttpSysRequestHandler* m_CompletionHandler = nullptr; @@ -418,7 +448,10 @@ public: virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; void SuppressResponseBody(); // typically used for HEAD requests - inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + inline uint16_t GetResponseCode() const { return m_ResponseCode; } + inline int64_t GetResponseBodySize() const { return m_TotalDataSize; } + + void SetLocationHeader(std::string_view Location) { m_LocationHeader = Location; } private: eastl::fixed_vector<HTTP_DATA_CHUNK, 16> m_HttpDataChunks; @@ -429,6 +462,7 @@ private: bool m_IsInitialResponse = true; HttpContentType m_ContentType = HttpContentType::kBinary; eastl::fixed_vector<IoBuffer, 16> m_DataBuffers; + std::string m_LocationHeader; void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); }; @@ -569,7 +603,7 @@ HttpMessageResponseRequest::SuppressResponseBody() HttpSysRequestHandler* HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) { - ZEN_UNUSED(NumberOfBytesTransferred); + Transaction().Server().m_TotalBytesSent.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); if (IoResult != NO_ERROR) { @@ -684,6 +718,15 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) ContentTypeHeader->pRawValue = ContentTypeString.data(); ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + // Location header (for redirects) + + if (!m_LocationHeader.empty()) + { + PHTTP_KNOWN_HEADER LocationHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderLocation]; + LocationHeader->pRawValue = m_LocationHeader.data(); + LocationHeader->RawValueLength = (USHORT)m_LocationHeader.size(); + } + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); HttpResponse.StatusCode = m_ResponseCode; @@ -694,21 +737,22 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) HTTP_CACHE_POLICY CachePolicy; - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.Policy = HttpCachePolicyNocache; CachePolicy.SecondsToLive = 0; // Initial response API call - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags + &HttpResponse, // HttpResponse + &CachePolicy, // CachePolicy + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); m_IsInitialResponse = false; } @@ -716,9 +760,9 @@ HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) { // Subsequent response API calls - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), // RequestQueueHandle + HttpReq->RequestId, // RequestId + SendFlags, // Flags (USHORT)ThisRequestChunkCount, // EntityChunkCount &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks NULL, // BytesSent @@ -884,7 +928,10 @@ HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTr ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {} ({:#x}), NumberOfBytesTransferred: {}", + GetSystemErrorAsString(IoResult), + IoResult, + NumberOfBytesTransferred); return this; } @@ -1017,8 +1064,10 @@ HttpSysServer::~HttpSysServer() ZEN_ERROR("~HttpSysServer() called without calling Close() first"); } - delete m_AsyncWorkPool; + auto WorkPool = m_AsyncWorkPool.load(std::memory_order_relaxed); m_AsyncWorkPool = nullptr; + + delete WorkPool; } void @@ -1049,7 +1098,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create server session for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1058,7 +1110,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to create URL group for '{}': {} ({:#x})", WideToUtf8(WildcardUrlPath), GetSystemErrorAsString(Result), Result); return 0; } @@ -1082,7 +1134,9 @@ HttpSysServer::InitializeServer(int BasePort) if ((Result == ERROR_SHARING_VIOLATION)) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); @@ -1104,7 +1158,9 @@ HttpSysServer::InitializeServer(int BasePort) { for (uint32_t Retries = 0; (Result == ERROR_SHARING_VIOLATION) && (Retries < 3); Retries++) { - ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", EffectivePort, Result); + ZEN_INFO("Desired port {} is in use (HttpAddUrlToUrlGroup returned: {}), retrying", + EffectivePort, + GetSystemErrorAsString(Result)); Sleep(500); Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); } @@ -1128,25 +1184,29 @@ HttpSysServer::InitializeServer(int BasePort) // port for the current user. eg: // netsh http add urlacl url=http://*:8558/ user=<some_user> - ZEN_WARN( - "Unable to register handler using '{}' - falling back to local-only. " - "Please ensure the appropriate netsh URL reservation configuration " - "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", - WideToUtf8(WildcardUrlPath)); + if (!m_InitialConfig.ForceLoopback) + { + ZEN_WARN( + "Unable to register handler using '{}' - falling back to local-only. " + "Please ensure the appropriate netsh URL reservation configuration " + "is made to allow http.sys access (see https://github.com/EpicGames/zen/blob/main/README.md)", + WideToUtf8(WildcardUrlPath)); + } const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + bool ShouldRetryNextPort = true; + for (int PortOffset = 0; ShouldRetryNextPort && (PortOffset < 10); ++PortOffset) { - EffectivePort = BasePort + (PortOffset * 100); + EffectivePort = BasePort + (PortOffset * 100); + ShouldRetryNextPort = false; for (const std::u8string_view Host : Hosts) { WideStringBuilder<64> LocalUrlPath; LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + ULONG InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); if (InternalResult == NO_ERROR) { @@ -1154,11 +1214,25 @@ HttpSysServer::InitializeServer(int BasePort) m_BaseUris.push_back(LocalUrlPath.c_str()); } + else if (InternalResult == ERROR_SHARING_VIOLATION || InternalResult == ERROR_ACCESS_DENIED) + { + // Port may be owned by another process's wildcard registration (access denied) + // or actively in use (sharing violation) — retry on a different port + ShouldRetryNextPort = true; + } else { - break; + ZEN_WARN("Failed to register local handler '{}': {} ({:#x})", + WideToUtf8(LocalUrlPath), + GetSystemErrorAsString(InternalResult), + InternalResult); } } + + if (!m_BaseUris.empty()) + { + break; + } } } else @@ -1174,7 +1248,10 @@ HttpSysServer::InitializeServer(int BasePort) if (m_BaseUris.empty()) { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + ZEN_ERROR("Failed to add base URL to URL group for '{}': {} ({:#x})", + WideToUtf8(WildcardUrlPath), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1192,7 +1269,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to create request queue for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1204,7 +1284,10 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + ZEN_ERROR("Failed to set server binding property for '{}': {} ({:#x})", + WideToUtf8(m_BaseUris.front()), + GetSystemErrorAsString(Result), + Result); return 0; } @@ -1236,7 +1319,7 @@ HttpSysServer::InitializeServer(int BasePort) if (Result != NO_ERROR) { - ZEN_WARN("changing request queue length to {} failed: {}", QueueLength, Result); + ZEN_WARN("changing request queue length to {} failed: {} ({:#x})", QueueLength, GetSystemErrorAsString(Result), Result); } } @@ -1258,21 +1341,6 @@ HttpSysServer::InitializeServer(int BasePort) ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); } - // This is not available in all Windows SDK versions so for now we can't use recently - // released functionality. We should investigate how to get more recent SDK releases - // into the build - -# if 0 - if (HttpIsFeatureSupported(/* HttpFeatureHttp3 */ (HTTP_FEATURE_ID) 4)) - { - ZEN_DEBUG("HTTP3 is available"); - } - else - { - ZEN_DEBUG("HTTP3 is NOT available"); - } -# endif - return EffectivePort; } @@ -1305,17 +1373,17 @@ HttpSysServer::WorkPool() { ZEN_MEMSCOPE(GetHttpsysTag()); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_acquire)) { RwLock::ExclusiveLockScope _(m_AsyncWorkPoolInitLock); - if (!m_AsyncWorkPool) + if (!m_AsyncWorkPool.load(std::memory_order_relaxed)) { - m_AsyncWorkPool = new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"); + m_AsyncWorkPool.store(new WorkerThreadPool(m_InitialConfig.AsyncWorkThreadCount, "http_async"), std::memory_order_release); } } - return *m_AsyncWorkPool; + return *m_AsyncWorkPool.load(std::memory_order_relaxed); } void @@ -1337,9 +1405,9 @@ HttpSysServer::OnRun(bool IsInteractive) ZEN_CONSOLE("Zen Server running (http.sys). Press ESC or Q to quit"); } + bool ShutdownRequested = false; do { - // int WaitTimeout = -1; int WaitTimeout = 100; if (IsInteractive) @@ -1352,14 +1420,15 @@ HttpSysServer::OnRun(bool IsInteractive) if (c == 27 || c == 'Q' || c == 'q') { + m_ShutdownEvent.Set(); RequestApplicationExit(0); } } } - m_ShutdownEvent.Wait(WaitTimeout); + ShutdownRequested = m_ShutdownEvent.Wait(WaitTimeout); UpdateLofreqTimerValue(); - } while (!IsApplicationExitRequested()); + } while (!ShutdownRequested); } void @@ -1530,7 +1599,23 @@ HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, // than one thread at any given moment. This means we need to be careful about what // happens in here - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + HttpSysIoContext* IoContext = CONTAINING_RECORD(pOverlapped, HttpSysIoContext, Overlapped); + + switch (IoContext->ContextType) + { + case HttpSysIoContext::Type::kWebSocketRead: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnReadCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kWebSocketWrite: + static_cast<WsHttpSysConnection*>(IoContext->Owner)->OnWriteCompletion(IoResult, NumberOfBytesTransferred); + return; + + case HttpSysIoContext::Type::kTransaction: + break; + } + + HttpSysTransaction* Transaction = CONTAINING_RECORD(IoContext, HttpSysTransaction, m_IoContext); if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) { @@ -1641,6 +1726,8 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) { HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + m_HttpServer.MarkRequest(); + // Default request handling # if ZEN_WITH_OTEL @@ -1666,9 +1753,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); + if (FilterResult == IHttpRequestFilter::Result::Accepted) + { + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + ThisRequest.WriteResponse(HttpResponseCode::Forbidden); + } + else { - Service.HandleRequest(ThisRequest); + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; @@ -1810,6 +1909,52 @@ HttpSysServerRequest::ParseRequestId() const return 0; } +bool +HttpSysServerRequest::IsLocalMachineRequest() const +{ + const PSOCKADDR LocalAddress = m_HttpTx.HttpRequest()->Address.pLocalAddress; + const PSOCKADDR RemoteAddress = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + if (LocalAddress->sa_family != RemoteAddress->sa_family) + { + return false; + } + if (LocalAddress->sa_family == AF_INET) + { + const SOCKADDR_IN& LocalAddressv4 = (const SOCKADDR_IN&)(*LocalAddress); + const SOCKADDR_IN& RemoteAddressv4 = (const SOCKADDR_IN&)(*RemoteAddress); + return LocalAddressv4.sin_addr.S_un.S_addr == RemoteAddressv4.sin_addr.S_un.S_addr; + } + else if (LocalAddress->sa_family == AF_INET6) + { + const SOCKADDR_IN6& LocalAddressv6 = (const SOCKADDR_IN6&)(*LocalAddress); + const SOCKADDR_IN6& RemoteAddressv6 = (const SOCKADDR_IN6&)(*RemoteAddress); + return memcmp(&LocalAddressv6.sin6_addr, &RemoteAddressv6.sin6_addr, sizeof(in6_addr)) == 0; + } + else + { + return false; + } +} + +std::string_view +HttpSysServerRequest::GetRemoteAddress() const +{ + if (m_RemoteAddress.Size() == 0) + { + const SOCKADDR* SockAddr = m_HttpTx.HttpRequest()->Address.pRemoteAddress; + GetAddressString(m_RemoteAddress, SockAddr, /* IncludePort */ false); + } + return m_RemoteAddress.ToView(); +} + +std::string_view +HttpSysServerRequest::GetAuthorizationHeader() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& AuthorizationHeader = HttpRequestPtr->Headers.KnownHeaders[HttpHeaderAuthorization]; + return std::string_view(AuthorizationHeader.pRawValue, AuthorizationHeader.RawValueLength); +} + IoBuffer HttpSysServerRequest::ReadPayload() { @@ -1823,7 +1968,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); if (SuppressBody()) { @@ -1841,6 +1986,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) # endif SetIsHandled(); + LogRequest(Response); } void @@ -1850,7 +1996,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy ZEN_ASSERT(IsHandled() == false); - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + HttpMessageResponseRequest* Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); if (SuppressBody()) { @@ -1868,6 +2014,20 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); +} + +void +HttpSysServerRequest::LogRequest(HttpMessageResponseRequest* Response) +{ + if (ShouldLogRequest()) + { + ZEN_INFO("{} {} {} -> {}", + ToString(RequestVerb()), + m_UriUtf8.c_str(), + Response->GetResponseCode(), + NiceBytes(Response->GetResponseBodySize())); + } } void @@ -1896,6 +2056,7 @@ HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentTy # endif SetIsHandled(); + LogRequest(Response); } void @@ -2015,6 +2176,8 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT break; } + Transaction().Server().m_TotalBytesReceived.fetch_add(NumberOfBytesTransferred, std::memory_order_relaxed); + ZEN_TRACE_CPU("httpsys::HandleCompletion"); // Route request @@ -2023,64 +2186,122 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT { HTTP_REQUEST* HttpReq = HttpRequest(); -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) + // WebSocket upgrade detection + if (m_IsInitialRequest) { - case HttpRequestInfoTypeRequestTiming: + const HTTP_KNOWN_HEADER& UpgradeHeader = HttpReq->Headers.KnownHeaders[HttpHeaderUpgrade]; + if (UpgradeHeader.RawValueLength > 0 && + StrCaseCompare(UpgradeHeader.pRawValue, "websocket", UpgradeHeader.RawValueLength) == 0) + { + if (IWebSocketHandler* WsHandler = dynamic_cast<IWebSocketHandler*>(Service)) { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + // Extract Sec-WebSocket-Key from the unknown headers + // (http.sys has no known-header slot for it) + std::string_view SecWebSocketKey; + for (USHORT i = 0; i < HttpReq->Headers.UnknownHeaderCount; ++i) + { + const HTTP_UNKNOWN_HEADER& Hdr = HttpReq->Headers.pUnknownHeaders[i]; + if (Hdr.NameLength == 17 && _strnicmp(Hdr.pName, "Sec-WebSocket-Key", 17) == 0) + { + SecWebSocketKey = std::string_view(Hdr.pRawValue, Hdr.RawValueLength); + break; + } + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + if (SecWebSocketKey.empty()) + { + ZEN_WARN("WebSocket upgrade missing Sec-WebSocket-Key header"); + return nullptr; + } - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + const std::string AcceptKey = WsFrameCodec::ComputeAcceptKey(SecWebSocketKey); + + HANDLE RequestQueueHandle = Transaction().RequestQueueHandle(); + HTTP_REQUEST_ID RequestId = HttpReq->RequestId; + + // Build the 101 Switching Protocols response + HTTP_RESPONSE Response = {}; + Response.StatusCode = 101; + Response.pReason = "Switching Protocols"; + Response.ReasonLength = (USHORT)strlen(Response.pReason); + + Response.Headers.KnownHeaders[HttpHeaderUpgrade].pRawValue = "websocket"; + Response.Headers.KnownHeaders[HttpHeaderUpgrade].RawValueLength = 9; + + eastl::fixed_vector<HTTP_UNKNOWN_HEADER, 8> UnknownHeaders; - ZEN_INFO(""); + // IMPORTANT: Due to some quirk in HttpSendHttpResponse, this cannot use KnownHeaders + // despite there being an entry for it there (HttpHeaderConnection). If you try to do + // that you get an ERROR_INVALID_PARAMETERS error from HttpSendHttpResponse below + + UnknownHeaders.push_back({.NameLength = 10, .RawValueLength = 7, .pName = "Connection", .pRawValue = "Upgrade"}); + + UnknownHeaders.push_back({.NameLength = 20, + .RawValueLength = (USHORT)AcceptKey.size(), + .pName = "Sec-WebSocket-Accept", + .pRawValue = AcceptKey.c_str()}); + + Response.Headers.UnknownHeaderCount = (USHORT)UnknownHeaders.size(); + Response.Headers.pUnknownHeaders = UnknownHeaders.data(); + + const ULONG Flags = HTTP_SEND_RESPONSE_FLAG_OPAQUE | HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + + // Use an OVERLAPPED with an event so we can wait synchronously. + // The request queue is IOCP-associated, so passing NULL for pOverlapped + // may return ERROR_IO_PENDING. Setting the low-order bit of hEvent + // prevents IOCP delivery and lets us wait on the event directly. + OVERLAPPED SendOverlapped = {}; + HANDLE SendEvent = CreateEventW(nullptr, TRUE, FALSE, nullptr); + SendOverlapped.hEvent = (HANDLE)((uintptr_t)SendEvent | 1); + + ULONG SendResult = HttpSendHttpResponse(RequestQueueHandle, + RequestId, + Flags, + &Response, + nullptr, // CachePolicy + nullptr, // BytesSent + nullptr, // Reserved1 + 0, // Reserved2 + &SendOverlapped, + nullptr // LogData + ); + + if (SendResult == ERROR_IO_PENDING) + { + WaitForSingleObject(SendEvent, INFINITE); + SendResult = (SendOverlapped.Internal == 0) ? NO_ERROR : ERROR_IO_INCOMPLETE; + } + + CloseHandle(SendEvent); + + if (SendResult == NO_ERROR) + { + Transaction().Server().OnWebSocketConnectionOpened(); + Ref<WsHttpSysConnection> WsConn(new WsHttpSysConnection(RequestQueueHandle, + RequestId, + *WsHandler, + Transaction().Iocp(), + &Transaction().Server())); + Ref<WebSocketConnection> WsConnRef(WsConn.Get()); + + WsHandler->OnWebSocketOpen(std::move(WsConnRef)); + WsConn->Start(); + + return nullptr; + } + + ZEN_WARN("WebSocket 101 send failed: {} ({:#x})", GetSystemErrorAsString(SendResult), SendResult); + + // WebSocket upgrade failed — return nullptr since ServerRequest() + // was never populated (no InvokeRequestHandler call) + return nullptr; } - break; + // Service doesn't support WebSocket or missing key — fall through to normal handling + } } - } -# endif - if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) - { if (m_IsInitialRequest) { m_ContentLength = GetContentLength(HttpReq); @@ -2146,6 +2367,18 @@ InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesT return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); } } + else + { + // If a default redirect is configured and the request is for the root path, send a 302 + std::string_view DefaultRedirect = Transaction().Server().GetDefaultRedirect(); + std::string_view RawUrl(HttpReq->pRawUrl, HttpReq->RawUrlLength); + if (!DefaultRedirect.empty() && (RawUrl == "/" || RawUrl.empty())) + { + auto* Response = new HttpMessageResponseRequest(Transaction(), 302); + Response->SetLocationHeader(DefaultRedirect); + return Response; + } + } // Unable to route return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); @@ -2205,12 +2438,81 @@ HttpSysServer::OnRequestExit() m_ShutdownEvent.Set(); } +std::string +HttpSysServer::OnGetExternalHost() const +{ + // Check whether we registered a public wildcard URL (http://*:port/) or fell back to loopback + bool IsPublic = false; + for (const auto& Uri : m_BaseUris) + { + if (Uri.find(L'*') != std::wstring::npos) + { + IsPublic = true; + break; + } + } + + if (!IsPublic) + { + return "127.0.0.1"; + } + + // Use the UDP connect trick: connecting a UDP socket to an external address + // causes the OS to select the appropriate local interface without sending any data. + try + { + asio::io_service IoService; + asio::ip::udp::socket Sock(IoService, asio::ip::udp::v4()); + Sock.connect(asio::ip::udp::endpoint(asio::ip::address::from_string("8.8.8.8"), 80)); + return Sock.local_endpoint().address().to_string(); + } + catch (const std::exception&) + { + return GetMachineName(); + } +} + +uint64_t +HttpSysServer::GetTotalBytesReceived() const +{ + return m_TotalBytesReceived.load(std::memory_order_relaxed); +} + +uint64_t +HttpSysServer::GetTotalBytesSent() const +{ + return m_TotalBytesSent.load(std::memory_order_relaxed); +} + void HttpSysServer::OnRegisterService(HttpService& Service) { RegisterService(Service.BaseUri(), Service); } +void +HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_RequestFilterLock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpSysServer::FilterRequest(HttpSysServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_RequestFilterLock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + return RequestFilter->FilterRequest(Request); +} + Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config) { diff --git a/src/zenhttp/servers/httpsys_iocontext.h b/src/zenhttp/servers/httpsys_iocontext.h new file mode 100644 index 000000000..4f8a97012 --- /dev/null +++ b/src/zenhttp/servers/httpsys_iocontext.h @@ -0,0 +1,40 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> + +# include <cstdint> + +namespace zen { + +/** + * Tagged OVERLAPPED wrapper for http.sys IOCP dispatch + * + * Both HttpSysTransaction (for normal HTTP request I/O) and WsHttpSysConnection + * (for WebSocket read/write) embed this struct. The single IoCompletionCallback + * bound to the request queue uses the ContextType tag to dispatch to the correct + * handler. + * + * The Overlapped member must be first so that CONTAINING_RECORD works to recover + * the HttpSysIoContext from the OVERLAPPED pointer provided by the threadpool. + */ +struct HttpSysIoContext +{ + OVERLAPPED Overlapped{}; + + enum class Type : uint8_t + { + kTransaction, + kWebSocketRead, + kWebSocketWrite, + } ContextType = Type::kTransaction; + + void* Owner = nullptr; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/httptracer.h b/src/zenhttp/servers/httptracer.h index da72c79c9..a9a45f162 100644 --- a/src/zenhttp/servers/httptracer.h +++ b/src/zenhttp/servers/httptracer.h @@ -1,9 +1,9 @@ // Copyright Epic Games, Inc. All Rights Reserved. -#include <zenhttp/httpserver.h> - #pragma once +#include <zenhttp/httpserver.h> + namespace zen { /** Helper class for HTTP server implementations diff --git a/src/zenhttp/servers/wsasio.cpp b/src/zenhttp/servers/wsasio.cpp new file mode 100644 index 000000000..b2543277a --- /dev/null +++ b/src/zenhttp/servers/wsasio.cpp @@ -0,0 +1,311 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsasio.h" +#include "wsframecodec.h" + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +namespace zen::asio_http { + +static LoggerRef +WsLog() +{ + static LoggerRef g_Logger = logging::Get("ws"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsAsioConnection::WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server) +: m_Socket(std::move(Socket)) +, m_Handler(Handler) +, m_HttpServer(Server) +{ +} + +WsAsioConnection::~WsAsioConnection() +{ + m_IsOpen.store(false); + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } +} + +void +WsAsioConnection::Start() +{ + EnqueueRead(); +} + +bool +WsAsioConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Read loop +// + +void +WsAsioConnection::EnqueueRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + Ref<WsAsioConnection> Self(this); + + asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [Self](const asio::error_code& Ec, std::size_t ByteCount) { + Self->OnDataReceived(Ec, ByteCount); + }); +} + +void +WsAsioConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::eof && Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket read error: {}", Ec.message()); + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + return; + } + + ProcessReceivedData(); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + EnqueueRead(); + } +} + +void +WsAsioConnection::ProcessReceivedData() +{ + while (m_ReadBuffer.size() > 0) + { + const auto& InputBuffer = m_ReadBuffer.data(); + const auto* Data = static_cast<const uint8_t*>(InputBuffer.data()); + const auto Size = InputBuffer.size(); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Data, Size); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + m_ReadBuffer.consume(Frame.BytesConsumed); + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Shut down the socket + std::error_code ShutdownEc; + m_Socket->shutdown(asio::socket_base::shutdown_both, ShutdownEc); + m_Socket->close(ShutdownEc); + return; + } + + default: + ZEN_LOG_WARN(WsLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Write queue +// + +void +WsAsioConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::SendBinary(std::span<const uint8_t> Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsAsioConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsAsioConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + if (!m_CloseSent.exchange(true)) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); +} + +void +WsAsioConnection::EnqueueWrite(std::vector<uint8_t> Frame) +{ + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + + bool ShouldFlush = false; + + m_WriteLock.WithExclusiveLock([&] { + m_WriteQueue.push_back(std::move(Frame)); + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + }); + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsAsioConnection::FlushWriteQueue() +{ + std::vector<uint8_t> Frame; + + m_WriteLock.WithExclusiveLock([&] { + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + Frame = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + }); + + if (Frame.empty()) + { + return; + } + + Ref<WsAsioConnection> Self(this); + + // Move Frame into a shared_ptr so we can create the buffer and capture ownership + // in the same async_write call without evaluation order issues. + auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame)); + + asio::async_write(*m_Socket, + asio::buffer(OwnedFrame->data(), OwnedFrame->size()), + [Self, OwnedFrame](const asio::error_code& Ec, std::size_t ByteCount) { Self->OnWriteComplete(Ec, ByteCount); }); +} + +void +WsAsioConnection::OnWriteComplete(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_LOG_DEBUG(WsLog(), "WebSocket write error: {}", Ec.message()); + } + + m_WriteLock.WithExclusiveLock([&] { + m_IsWriting = false; + m_WriteQueue.clear(); + }); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + return; + } + + FlushWriteQueue(); +} + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsasio.h b/src/zenhttp/servers/wsasio.h new file mode 100644 index 000000000..e8bb3b1d2 --- /dev/null +++ b/src/zenhttp/servers/wsasio.h @@ -0,0 +1,77 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include <zencore/thread.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <deque> +#include <memory> +#include <vector> + +namespace zen { +class HttpServer; +} // namespace zen + +namespace zen::asio_http { + +/** + * WebSocket connection over an ASIO TCP socket + * + * Owns the TCP socket (moved from HttpServerConnection after the 101 handshake) + * and runs an async read/write loop to exchange WebSocket frames. + * + * Lifetime is managed solely through intrusive reference counting (RefCounted). + * The async read/write callbacks capture Ref<WsAsioConnection> to keep the + * connection alive for the duration of the async operation. The service layer + * also holds a Ref<WebSocketConnection>. + */ + +class WsAsioConnection : public WebSocketConnection +{ +public: + WsAsioConnection(std::unique_ptr<asio::ip::tcp::socket> Socket, IWebSocketHandler& Handler, HttpServer* Server); + ~WsAsioConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and the 101 response has been sent. + */ + void Start(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span<const uint8_t> Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void ProcessReceivedData(); + + void EnqueueWrite(std::vector<uint8_t> Frame); + void FlushWriteQueue(); + void OnWriteComplete(const asio::error_code& Ec, std::size_t ByteCount); + + void DoClose(uint16_t Code, std::string_view Reason); + + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + IWebSocketHandler& m_Handler; + zen::HttpServer* m_HttpServer; + asio::streambuf m_ReadBuffer; + + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + bool m_IsWriting = false; + + std::atomic<bool> m_IsOpen{true}; + std::atomic<bool> m_CloseSent{false}; +}; + +} // namespace zen::asio_http diff --git a/src/zenhttp/servers/wsframecodec.cpp b/src/zenhttp/servers/wsframecodec.cpp new file mode 100644 index 000000000..e452141fe --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.cpp @@ -0,0 +1,236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wsframecodec.h" + +#include <zencore/base64.h> +#include <zencore/sha1.h> + +#include <cstring> +#include <random> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +WsFrameParseResult +WsFrameCodec::TryParseFrame(const uint8_t* Data, size_t Size) +{ + // Minimum frame: 2 bytes header (unmasked server frames) or 6 bytes (masked client frames) + if (Size < 2) + { + return {}; + } + + const bool Fin = (Data[0] & 0x80) != 0; + const uint8_t OpcodeRaw = Data[0] & 0x0F; + const bool Masked = (Data[1] & 0x80) != 0; + uint64_t PayloadLen = Data[1] & 0x7F; + + size_t HeaderSize = 2; + + if (PayloadLen == 126) + { + if (Size < 4) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 8) | uint64_t(Data[3]); + HeaderSize = 4; + } + else if (PayloadLen == 127) + { + if (Size < 10) + { + return {}; + } + PayloadLen = (uint64_t(Data[2]) << 56) | (uint64_t(Data[3]) << 48) | (uint64_t(Data[4]) << 40) | (uint64_t(Data[5]) << 32) | + (uint64_t(Data[6]) << 24) | (uint64_t(Data[7]) << 16) | (uint64_t(Data[8]) << 8) | uint64_t(Data[9]); + HeaderSize = 10; + } + + // Reject frames with unreasonable payload sizes to prevent OOM + static constexpr uint64_t kMaxPayloadSize = 256 * 1024 * 1024; // 256 MB + if (PayloadLen > kMaxPayloadSize) + { + return {}; + } + + const size_t MaskSize = Masked ? 4 : 0; + const size_t TotalFrame = HeaderSize + MaskSize + PayloadLen; + + if (Size < TotalFrame) + { + return {}; + } + + const uint8_t* MaskKey = Masked ? (Data + HeaderSize) : nullptr; + const uint8_t* PayloadData = Data + HeaderSize + MaskSize; + + WsFrameParseResult Result; + Result.IsValid = true; + Result.BytesConsumed = TotalFrame; + Result.Opcode = static_cast<WebSocketOpcode>(OpcodeRaw); + Result.Fin = Fin; + + Result.Payload.resize(static_cast<size_t>(PayloadLen)); + if (PayloadLen > 0) + { + std::memcpy(Result.Payload.data(), PayloadData, static_cast<size_t>(PayloadLen)); + + if (Masked) + { + for (size_t i = 0; i < Result.Payload.size(); ++i) + { + Result.Payload[i] ^= MaskKey[i & 3]; + } + } + } + + return Result; +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (server-to-client, no masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length (no mask bit for server frames) + if (PayloadLen < 126) + { + Frame.push_back(static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + Frame.insert(Frame.end(), Payload.begin(), Payload.end()); + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame building (client-to-server, with masking) +// + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) +{ + std::vector<uint8_t> Frame; + + const size_t PayloadLen = Payload.size(); + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (PayloadLen < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(PayloadLen)); + } + else if (PayloadLen <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((PayloadLen >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(PayloadLen & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((PayloadLen >> (i * 8)) & 0xFF)); + } + } + + // Generate random 4-byte mask key + static thread_local std::mt19937 s_Rng(std::random_device{}()); + uint32_t MaskValue = s_Rng(); + uint8_t MaskKey[4]; + std::memcpy(MaskKey, &MaskValue, 4); + + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < PayloadLen; ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; +} + +std::vector<uint8_t> +WsFrameCodec::BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason) +{ + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + Payload.insert(Payload.end(), Reason.begin(), Reason.end()); + + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); +} + +////////////////////////////////////////////////////////////////////////// +// +// Sec-WebSocket-Accept key computation (RFC 6455 section 4.2.2) +// + +static constexpr std::string_view kWebSocketMagicGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +std::string +WsFrameCodec::ComputeAcceptKey(std::string_view ClientKey) +{ + // Concatenate client key with the magic GUID + std::string Combined; + Combined.reserve(ClientKey.size() + kWebSocketMagicGuid.size()); + Combined.append(ClientKey); + Combined.append(kWebSocketMagicGuid); + + // SHA1 hash + SHA1 Hash = SHA1::HashMemory(Combined.data(), Combined.size()); + + // Base64 encode the 20-byte hash + char Base64Buf[Base64::GetEncodedDataSize(20) + 1]; + uint32_t EncodedLen = Base64::Encode(Hash.Hash, 20, Base64Buf); + Base64Buf[EncodedLen] = '\0'; + + return std::string(Base64Buf, EncodedLen); +} + +} // namespace zen diff --git a/src/zenhttp/servers/wsframecodec.h b/src/zenhttp/servers/wsframecodec.h new file mode 100644 index 000000000..2d90b6fa1 --- /dev/null +++ b/src/zenhttp/servers/wsframecodec.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include <cstddef> +#include <cstdint> +#include <optional> +#include <span> +#include <string> +#include <string_view> +#include <vector> + +namespace zen { + +/** + * Result of attempting to parse a single WebSocket frame from a byte buffer + */ +struct WsFrameParseResult +{ + bool IsValid = false; // true if a complete frame was successfully parsed + size_t BytesConsumed = 0; // number of bytes consumed from the input buffer + WebSocketOpcode Opcode = WebSocketOpcode::kText; + bool Fin = false; + std::vector<uint8_t> Payload; +}; + +/** + * RFC 6455 WebSocket frame codec + * + * Provides static helpers for parsing client-to-server frames (which are + * always masked) and building server-to-client frames (which are never masked). + */ +struct WsFrameCodec +{ + /** + * Try to parse one complete frame from the front of the buffer. + * + * Returns a result with IsValid == false and BytesConsumed == 0 when + * there is not enough data yet. The caller should accumulate more data + * and retry. + */ + static WsFrameParseResult TryParseFrame(const uint8_t* Data, size_t Size); + + /** + * Build a server-to-client frame (no masking) + */ + static std::vector<uint8_t> BuildFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload); + + /** + * Build a close frame with a status code and optional reason string + */ + static std::vector<uint8_t> BuildCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Build a client-to-server frame (with masking per RFC 6455) + */ + static std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload); + + /** + * Build a masked close frame with status code and optional reason + */ + static std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code, std::string_view Reason = {}); + + /** + * Compute the Sec-WebSocket-Accept value per RFC 6455 section 4.2.2 + * + * accept = Base64(SHA1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + */ + static std::string ComputeAcceptKey(std::string_view ClientKey); +}; + +} // namespace zen diff --git a/src/zenhttp/servers/wshttpsys.cpp b/src/zenhttp/servers/wshttpsys.cpp new file mode 100644 index 000000000..af320172d --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.cpp @@ -0,0 +1,485 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "wshttpsys.h" + +#if ZEN_WITH_HTTPSYS + +# include "wsframecodec.h" + +# include <zencore/logging.h> +# include <zenhttp/httpserver.h> + +namespace zen { + +static LoggerRef +WsHttpSysLog() +{ + static LoggerRef g_Logger = logging::Get("ws_httpsys"); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +WsHttpSysConnection::WsHttpSysConnection(HANDLE RequestQueueHandle, + HTTP_REQUEST_ID RequestId, + IWebSocketHandler& Handler, + PTP_IO Iocp, + HttpServer* Server) +: m_RequestQueueHandle(RequestQueueHandle) +, m_RequestId(RequestId) +, m_Handler(Handler) +, m_Iocp(Iocp) +, m_HttpServer(Server) +, m_ReadBuffer(8192) +{ + m_ReadIoContext.ContextType = HttpSysIoContext::Type::kWebSocketRead; + m_ReadIoContext.Owner = this; + m_WriteIoContext.ContextType = HttpSysIoContext::Type::kWebSocketWrite; + m_WriteIoContext.Owner = this; +} + +WsHttpSysConnection::~WsHttpSysConnection() +{ + ZEN_ASSERT(m_OutstandingOps.load() == 0); + + if (m_IsOpen.exchange(false)) + { + Disconnect(); + } + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketConnectionClosed(); + } +} + +void +WsHttpSysConnection::Start() +{ + m_SelfRef = Ref<WsHttpSysConnection>(this); + IssueAsyncRead(); +} + +void +WsHttpSysConnection::Shutdown() +{ + m_ShutdownRequested.store(true, std::memory_order_relaxed); + + if (!m_IsOpen.exchange(false)) + { + return; + } + + // Cancel pending I/O — completions will fire with ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +bool +WsHttpSysConnection::IsOpen() const +{ + return m_IsOpen.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// +// +// Async read path +// + +void +WsHttpSysConnection::IssueAsyncRead() +{ + if (!m_IsOpen.load(std::memory_order_relaxed) || m_ShutdownRequested.load(std::memory_order_relaxed)) + { + MaybeReleaseSelfRef(); + return; + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_ReadIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpReceiveRequestEntityBody(m_RequestQueueHandle, + m_RequestId, + 0, // Flags + m_ReadBuffer.data(), + (ULONG)m_ReadBuffer.size(), + nullptr, // BytesRead (ignored for async) + &m_ReadIoContext.Overlapped); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "read issue failed"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // Hold a transient ref to prevent mid-callback destruction after MaybeReleaseSelfRef + Ref<WsHttpSysConnection> Guard(this); + + if (IoResult != NO_ERROR) + { + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.exchange(false)) + { + if (IoResult == ERROR_HANDLE_EOF) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection closed"); + } + else if (IoResult != ERROR_OPERATION_ABORTED) + { + m_Handler.OnWebSocketClose(*this, 1006, "connection lost"); + } + } + + MaybeReleaseSelfRef(); + return; + } + + if (NumberOfBytesTransferred > 0) + { + m_Accumulated.insert(m_Accumulated.end(), m_ReadBuffer.begin(), m_ReadBuffer.begin() + NumberOfBytesTransferred); + ProcessReceivedData(); + } + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + if (m_IsOpen.load(std::memory_order_relaxed)) + { + IssueAsyncRead(); + } + else + { + MaybeReleaseSelfRef(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Frame parsing +// + +void +WsHttpSysConnection::ProcessReceivedData() +{ + while (!m_Accumulated.empty()) + { + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(m_Accumulated.data(), m_Accumulated.size()); + if (!Frame.IsValid) + { + break; // not enough data yet + } + + // Remove consumed bytes + m_Accumulated.erase(m_Accumulated.begin(), m_Accumulated.begin() + Frame.BytesConsumed); + + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameReceived(Frame.BytesConsumed); + } + + switch (Frame.Opcode) + { + case WebSocketOpcode::kText: + case WebSocketOpcode::kBinary: + { + WebSocketMessage Msg; + Msg.Opcode = Frame.Opcode; + Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size()); + m_Handler.OnWebSocketMessage(*this, Msg); + break; + } + + case WebSocketOpcode::kPing: + { + // Auto-respond with pong carrying the same payload + std::vector<uint8_t> PongFrame = WsFrameCodec::BuildFrame(WebSocketOpcode::kPong, Frame.Payload); + EnqueueWrite(std::move(PongFrame)); + break; + } + + case WebSocketOpcode::kPong: + // Unsolicited pong — ignore per RFC 6455 + break; + + case WebSocketOpcode::kClose: + { + uint16_t Code = 1000; + std::string_view Reason; + + if (Frame.Payload.size() >= 2) + { + Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]); + if (Frame.Payload.size() > 2) + { + Reason = std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2); + } + } + + // Echo close frame back if we haven't sent one yet + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent.exchange(true)) + { + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_IsOpen.store(false); + m_Handler.OnWebSocketClose(*this, Code, Reason); + Disconnect(); + return; + } + + default: + ZEN_LOG_WARN(WsHttpSysLog(), "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode)); + break; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Async write path +// + +void +WsHttpSysConnection::EnqueueWrite(std::vector<uint8_t> Frame) +{ + if (m_HttpServer) + { + m_HttpServer->OnWebSocketFrameSent(Frame.size()); + } + + bool ShouldFlush = false; + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.push_back(std::move(Frame)); + + if (!m_IsWriting) + { + m_IsWriting = true; + ShouldFlush = true; + } + } + + if (ShouldFlush) + { + FlushWriteQueue(); + } +} + +void +WsHttpSysConnection::FlushWriteQueue() +{ + { + RwLock::ExclusiveLockScope _(m_WriteLock); + + if (m_WriteQueue.empty()) + { + m_IsWriting = false; + return; + } + + m_CurrentWriteBuffer = std::move(m_WriteQueue.front()); + m_WriteQueue.pop_front(); + } + + m_OutstandingOps.fetch_add(1, std::memory_order_relaxed); + + ZeroMemory(&m_WriteChunk, sizeof(m_WriteChunk)); + m_WriteChunk.DataChunkType = HttpDataChunkFromMemory; + m_WriteChunk.FromMemory.pBuffer = m_CurrentWriteBuffer.data(); + m_WriteChunk.FromMemory.BufferLength = (ULONG)m_CurrentWriteBuffer.size(); + + ZeroMemory(&m_WriteIoContext.Overlapped, sizeof(OVERLAPPED)); + + StartThreadpoolIo(m_Iocp); + + ULONG Result = HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA, + 1, + &m_WriteChunk, + nullptr, + nullptr, + 0, + &m_WriteIoContext.Overlapped, + nullptr); + + if (Result != NO_ERROR && Result != ERROR_IO_PENDING) + { + CancelThreadpoolIo(m_Iocp); + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket async write failed: {}", Result); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + m_CurrentWriteBuffer.clear(); + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + } +} + +void +WsHttpSysConnection::OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + // Hold a transient ref to prevent mid-callback destruction + Ref<WsHttpSysConnection> Guard(this); + + m_OutstandingOps.fetch_sub(1, std::memory_order_relaxed); + m_CurrentWriteBuffer.clear(); + + if (IoResult != NO_ERROR) + { + ZEN_LOG_DEBUG(WsHttpSysLog(), "WebSocket write completion error: {}", IoResult); + + { + RwLock::ExclusiveLockScope _(m_WriteLock); + m_WriteQueue.clear(); + m_IsWriting = false; + } + + if (m_IsOpen.exchange(false)) + { + m_Handler.OnWebSocketClose(*this, 1006, "write error"); + } + + MaybeReleaseSelfRef(); + return; + } + + FlushWriteQueue(); +} + +////////////////////////////////////////////////////////////////////////// +// +// Send interface +// + +void +WsHttpSysConnection::SendText(std::string_view Text) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::SendBinary(std::span<const uint8_t> Data) +{ + if (!m_IsOpen.load(std::memory_order_relaxed)) + { + return; + } + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Data); + EnqueueWrite(std::move(Frame)); +} + +void +WsHttpSysConnection::Close(uint16_t Code, std::string_view Reason) +{ + DoClose(Code, Reason); +} + +void +WsHttpSysConnection::DoClose(uint16_t Code, std::string_view Reason) +{ + if (!m_IsOpen.exchange(false)) + { + return; + } + + { + bool ShouldSendClose = false; + { + RwLock::ExclusiveLockScope _(m_WriteLock); + if (!m_CloseSent.exchange(true)) + { + ShouldSendClose = true; + } + } + if (ShouldSendClose) + { + std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildCloseFrame(Code, Reason); + EnqueueWrite(std::move(CloseFrame)); + } + } + + m_Handler.OnWebSocketClose(*this, Code, Reason); + + // Cancel pending read I/O — completions drain via ERROR_OPERATION_ABORTED + HttpCancelHttpRequest(m_RequestQueueHandle, m_RequestId, nullptr); +} + +////////////////////////////////////////////////////////////////////////// +// +// Lifetime management +// + +void +WsHttpSysConnection::MaybeReleaseSelfRef() +{ + if (m_OutstandingOps.load(std::memory_order_relaxed) == 0 && !m_IsOpen.load(std::memory_order_relaxed)) + { + m_SelfRef = nullptr; + } +} + +void +WsHttpSysConnection::Disconnect() +{ + // Send final empty body with DISCONNECT to tell http.sys the connection is done + HttpSendResponseEntityBody(m_RequestQueueHandle, + m_RequestId, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT, + 0, + nullptr, + nullptr, + nullptr, + 0, + nullptr, + nullptr); +} + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wshttpsys.h b/src/zenhttp/servers/wshttpsys.h new file mode 100644 index 000000000..6015e3873 --- /dev/null +++ b/src/zenhttp/servers/wshttpsys.h @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/websocket.h> + +#include "httpsys_iocontext.h" + +#include <zencore/thread.h> + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <http.h> + +# include <atomic> +# include <deque> +# include <vector> + +namespace zen { + +class HttpServer; + +/** + * WebSocket connection over an http.sys opaque-mode connection + * + * After the 101 Switching Protocols response is sent with + * HTTP_SEND_RESPONSE_FLAG_OPAQUE, http.sys stops parsing HTTP on the + * connection. Raw bytes are exchanged via HttpReceiveRequestEntityBody / + * HttpSendResponseEntityBody using the original RequestId. + * + * All I/O is performed asynchronously via the same IOCP threadpool used + * for normal http.sys traffic, eliminating per-connection threads. + * + * Lifetime is managed through intrusive reference counting (RefCounted). + * A self-reference (m_SelfRef) is held from Start() until all outstanding + * async operations have drained, preventing premature destruction. + */ +class WsHttpSysConnection : public WebSocketConnection +{ +public: + WsHttpSysConnection(HANDLE RequestQueueHandle, HTTP_REQUEST_ID RequestId, IWebSocketHandler& Handler, PTP_IO Iocp, HttpServer* Server); + ~WsHttpSysConnection() override; + + /** + * Start the async read loop. Must be called once after construction + * and after the 101 response has been sent. + */ + void Start(); + + /** + * Shut down the connection. Cancels pending I/O; IOCP completions + * will fire with ERROR_OPERATION_ABORTED and drain naturally. + */ + void Shutdown(); + + // WebSocketConnection interface + void SendText(std::string_view Text) override; + void SendBinary(std::span<const uint8_t> Data) override; + void Close(uint16_t Code, std::string_view Reason) override; + bool IsOpen() const override; + + // Called from IoCompletionCallback via tagged dispatch + void OnReadCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + void OnWriteCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + +private: + void IssueAsyncRead(); + void ProcessReceivedData(); + void EnqueueWrite(std::vector<uint8_t> Frame); + void FlushWriteQueue(); + void DoClose(uint16_t Code, std::string_view Reason); + void Disconnect(); + void MaybeReleaseSelfRef(); + + HANDLE m_RequestQueueHandle; + HTTP_REQUEST_ID m_RequestId; + IWebSocketHandler& m_Handler; + PTP_IO m_Iocp; + HttpServer* m_HttpServer; + + // Tagged OVERLAPPED contexts for concurrent read and write + HttpSysIoContext m_ReadIoContext{}; + HttpSysIoContext m_WriteIoContext{}; + + // Read state + std::vector<uint8_t> m_ReadBuffer; + std::vector<uint8_t> m_Accumulated; + + // Write state + RwLock m_WriteLock; + std::deque<std::vector<uint8_t>> m_WriteQueue; + std::vector<uint8_t> m_CurrentWriteBuffer; + HTTP_DATA_CHUNK m_WriteChunk{}; + bool m_IsWriting = false; + + // Lifetime management + std::atomic<int32_t> m_OutstandingOps{0}; + Ref<WsHttpSysConnection> m_SelfRef; + std::atomic<bool> m_ShutdownRequested{false}; + std::atomic<bool> m_IsOpen{true}; + std::atomic<bool> m_CloseSent{false}; +}; + +} // namespace zen + +#endif // ZEN_WITH_HTTPSYS diff --git a/src/zenhttp/servers/wstest.cpp b/src/zenhttp/servers/wstest.cpp new file mode 100644 index 000000000..2134e4ff1 --- /dev/null +++ b/src/zenhttp/servers/wstest.cpp @@ -0,0 +1,925 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#if ZEN_WITH_TESTS + +# include <zencore/scopeguard.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> + +# include <zenhttp/httpserver.h> +# include <zenhttp/httpwsclient.h> +# include <zenhttp/websocket.h> + +# include "httpasio.h" +# include "wsframecodec.h" + +ZEN_THIRD_PARTY_INCLUDES_START +# if ZEN_PLATFORM_WINDOWS +# include <winsock2.h> +# else +# include <poll.h> +# include <sys/socket.h> +# endif +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +# include <atomic> +# include <chrono> +# include <cstring> +# include <random> +# include <string> +# include <string_view> +# include <thread> +# include <vector> + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests: WsFrameCodec +// + +TEST_SUITE_BEGIN("http.wstest"); + +TEST_CASE("websocket.framecodec") +{ + SUBCASE("ComputeAcceptKey RFC 6455 test vector") + { + // RFC 6455 section 4.2.2 example + std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); + CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - text") + { + std::string_view Text = "Hello, WebSocket!"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); + + // Server frames are unmasked — TryParseFrame should handle them + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildFrame and TryParseFrame roundtrip - binary") + { + std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildFrame - medium payload (126-65535 bytes)") + { + std::vector<uint8_t> Payload(300, 0x42); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildFrame - large payload (>65535 bytes)") + { + std::vector<uint8_t> Payload(70000, 0xAB); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildCloseFrame roundtrip") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure"); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } + + SUBCASE("TryParseFrame - partial data returns invalid") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); + + // Pass only 1 byte — not enough for a frame header + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); + CHECK_FALSE(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, 0u); + } + + SUBCASE("TryParseFrame - empty payload") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span<const uint8_t>{}); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK(Result.Payload.empty()); + } + + SUBCASE("TryParseFrame - masked client frame") + { + // Build a masked frame manually as a client would send + // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello" + uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D}; + uint8_t MaskedPayload[5] = {}; + const char* Original = "Hello"; + for (int i = 0; i < 5; ++i) + { + MaskedPayload[i] = static_cast<uint8_t>(Original[i]) ^ MaskKey[i % 4]; + } + + std::vector<uint8_t> Frame; + Frame.push_back(0x81); // FIN + text + Frame.push_back(0x85); // MASK + len=5 + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), 5u); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), 5), "Hello"sv); + } + + SUBCASE("BuildMaskedFrame roundtrip - text") + { + std::string_view Text = "Hello, masked WebSocket!"; + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); + + // Verify mask bit is set + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.BytesConsumed, Frame.size()); + CHECK(Result.Fin); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); + CHECK_EQ(Result.Payload.size(), Text.size()); + CHECK_EQ(std::string_view(reinterpret_cast<const char*>(Result.Payload.data()), Result.Payload.size()), Text); + } + + SUBCASE("BuildMaskedFrame roundtrip - binary") + { + std::vector<uint8_t> BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); + CHECK_EQ(Result.Payload, BinaryData); + } + + SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)") + { + std::vector<uint8_t> Payload(300, 0x42); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 300u); + CHECK_EQ(Result.Payload, Payload); + } + + SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)") + { + std::vector<uint8_t> Payload(70000, 0xAB); + + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); + + CHECK((Frame[1] & 0x80) != 0); + CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Payload.size(), 70000u); + } + + SUBCASE("BuildMaskedCloseFrame roundtrip") + { + std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); + + CHECK((Frame[1] & 0x80) != 0); + + WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); + + CHECK(Result.IsValid); + CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); + REQUIRE(Result.Payload.size() >= 2); + + uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); + CHECK_EQ(Code, 1000); + + std::string_view Reason(reinterpret_cast<const char*>(Result.Payload.data() + 2), Result.Payload.size() - 2); + CHECK_EQ(Reason, "normal closure"); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: WebSocket over ASIO +// + +namespace { + + /** + * Helper: Build a masked client-to-server frame per RFC 6455 + */ + std::vector<uint8_t> BuildMaskedFrame(WebSocketOpcode Opcode, std::span<const uint8_t> Payload) + { + std::vector<uint8_t> Frame; + + // FIN + opcode + Frame.push_back(0x80 | static_cast<uint8_t>(Opcode)); + + // Payload length with mask bit set + if (Payload.size() < 126) + { + Frame.push_back(0x80 | static_cast<uint8_t>(Payload.size())); + } + else if (Payload.size() <= 0xFFFF) + { + Frame.push_back(0x80 | 126); + Frame.push_back(static_cast<uint8_t>((Payload.size() >> 8) & 0xFF)); + Frame.push_back(static_cast<uint8_t>(Payload.size() & 0xFF)); + } + else + { + Frame.push_back(0x80 | 127); + for (int i = 7; i >= 0; --i) + { + Frame.push_back(static_cast<uint8_t>((Payload.size() >> (i * 8)) & 0xFF)); + } + } + + // Mask key (use a fixed key for deterministic tests) + uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; + Frame.insert(Frame.end(), MaskKey, MaskKey + 4); + + // Masked payload + for (size_t i = 0; i < Payload.size(); ++i) + { + Frame.push_back(Payload[i] ^ MaskKey[i & 3]); + } + + return Frame; + } + + std::vector<uint8_t> BuildMaskedTextFrame(std::string_view Text) + { + std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size()); + return BuildMaskedFrame(WebSocketOpcode::kText, Payload); + } + + std::vector<uint8_t> BuildMaskedCloseFrame(uint16_t Code) + { + std::vector<uint8_t> Payload; + Payload.push_back(static_cast<uint8_t>((Code >> 8) & 0xFF)); + Payload.push_back(static_cast<uint8_t>(Code & 0xFF)); + return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); + } + + /** + * Test service that implements IWebSocketHandler + */ + struct WsTestService : public HttpService, public IWebSocketHandler + { + const char* BaseUri() const override { return "/wstest/"; } + + void HandleRequest(HttpServerRequest& Request) override + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest"); + } + + // IWebSocketHandler + void OnWebSocketOpen(Ref<WebSocketConnection> Connection) override + { + m_OpenCount.fetch_add(1); + + m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); + } + + void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override + { + m_MessageCount.fetch_add(1); + + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + + // Echo the message back + Conn.SendText(Text); + } + } + + void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + + m_ConnectionsLock.WithExclusiveLock([&] { + auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref<WebSocketConnection>& C) { + return C.Get() == &Conn; + }); + m_Connections.erase(It, m_Connections.end()); + }); + } + + void SendToAll(std::string_view Text) + { + RwLock::SharedLockScope _(m_ConnectionsLock); + for (auto& Conn : m_Connections) + { + if (Conn->IsOpen()) + { + Conn->SendText(Text); + } + } + } + + std::atomic<int> m_OpenCount{0}; + std::atomic<int> m_MessageCount{0}; + std::atomic<int> m_CloseCount{0}; + std::atomic<uint16_t> m_LastCloseCode{0}; + std::string m_LastMessage; + + RwLock m_ConnectionsLock; + std::vector<Ref<WebSocketConnection>> m_Connections; + }; + + /** + * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket + * + * Returns true on success (101 response), false otherwise. + */ + bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port) + { + // Send HTTP upgrade request + ExtendableStringBuilder<512> Request; + Request << "GET " << Path << " HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + // Read the response (look for "101") + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + return Response.find("101") != std::string::npos; + } + + /** + * Helper: Read a single server-to-client frame from a socket + * + * Uses a background thread with a synchronous ASIO read and a timeout. + */ + WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000) + { + std::vector<uint8_t> Buffer; + WsFrameParseResult Result; + std::atomic<bool> Done{false}; + + std::thread Reader([&] { + while (!Done.load()) + { + uint8_t Tmp[4096]; + asio::error_code Ec; + size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); + if (Ec || BytesRead == 0) + { + break; + } + + Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); + + WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); + if (Frame.IsValid) + { + Result = std::move(Frame); + Done.store(true); + return; + } + } + }); + + auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs); + while (!Done.load() && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + if (!Done.load()) + { + // Timeout — cancel the read + asio::error_code Ec; + Sock.cancel(Ec); + } + + if (Reader.joinable()) + { + Reader.join(); + } + + return Result; + } + +} // anonymous namespace + +TEST_CASE("websocket.integration") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + // Give server a moment to start accepting + Sleep(100); + + SUBCASE("handshake succeeds with 101") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + CHECK(Ok); + + Sleep(50); + CHECK_EQ(TestService.m_OpenCount.load(), 1); + + Sock.close(); + } + + SUBCASE("normal HTTP still works alongside WebSocket service") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + // Send a normal HTTP GET (not upgrade) + std::string HttpReq = fmt::format( + "GET /wstest/hello HTTP/1.1\r\n" + "Host: 127.0.0.1:{}\r\n" + "Connection: close\r\n" + "\r\n", + Port); + + asio::write(Sock, asio::buffer(HttpReq)); + + asio::streambuf ResponseBuf; + asio::error_code Ec; + asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + CHECK(Response.find("200") != std::string::npos); + } + + SUBCASE("echo message roundtrip") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a text message (masked, as client) + std::vector<uint8_t> Frame = BuildMaskedTextFrame("ping test"); + asio::write(Sock, asio::buffer(Frame)); + + // Read the echo reply + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, "ping test"sv); + CHECK_EQ(TestService.m_MessageCount.load(), 1); + CHECK_EQ(TestService.m_LastMessage, "ping test"); + + Sock.close(); + } + + SUBCASE("server push to client") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Server pushes a message + TestService.SendToAll("server says hello"); + + WsFrameParseResult Frame = ReadOneFrame(Sock); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "server says hello"sv); + + Sock.close(); + } + + SUBCASE("client close handshake") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send close frame + std::vector<uint8_t> CloseFrame = BuildMaskedCloseFrame(1000); + asio::write(Sock, asio::buffer(CloseFrame)); + + // Server should echo close back + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); + + Sleep(50); + CHECK_EQ(TestService.m_CloseCount.load(), 1); + CHECK_EQ(TestService.m_LastCloseCode.load(), 1000); + + Sock.close(); + } + + SUBCASE("multiple concurrent connections") + { + constexpr int NumClients = 5; + + asio::io_context IoCtx; + std::vector<asio::ip::tcp::socket> Sockets; + + for (int i = 0; i < NumClients; ++i) + { + Sockets.emplace_back(IoCtx); + Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port); + REQUIRE(Ok); + } + + Sleep(100); + CHECK_EQ(TestService.m_OpenCount.load(), NumClients); + + // Broadcast from server + TestService.SendToAll("broadcast"); + + // Each client should receive the message + for (int i = 0; i < NumClients; ++i) + { + WsFrameParseResult Frame = ReadOneFrame(Sockets[i]); + REQUIRE(Frame.IsValid); + CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); + std::string_view Text(reinterpret_cast<const char*>(Frame.Payload.data()), Frame.Payload.size()); + CHECK_EQ(Text, "broadcast"sv); + } + + // Close all + for (auto& S : Sockets) + { + S.close(); + } + } + + SUBCASE("service without IWebSocketHandler rejects upgrade") + { + // Register a plain HTTP service (no WebSocket) + struct PlainService : public HttpService + { + const char* BaseUri() const override { return "/plain/"; } + void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); } + }; + + PlainService Plain; + Server->RegisterService(Plain); + + Sleep(50); + + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + // Attempt WebSocket upgrade on the plain service + ExtendableStringBuilder<512> Request; + Request << "GET /plain/ws HTTP/1.1\r\n" + << "Host: 127.0.0.1:" << Port << "\r\n" + << "Upgrade: websocket\r\n" + << "Connection: Upgrade\r\n" + << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + << "Sec-WebSocket-Version: 13\r\n" + << "\r\n"; + + std::string_view ReqStr = Request.ToView(); + asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); + + asio::streambuf ResponseBuf; + asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); + + std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); + + // Should NOT get 101 — should fall through to normal request handling + CHECK(Response.find("101") == std::string::npos); + + Sock.close(); + } + + SUBCASE("ping/pong auto-response") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + // Send a ping frame with payload "test" + std::string_view PingPayload = "test"; + std::span<const uint8_t> PingData(reinterpret_cast<const uint8_t*>(PingPayload.data()), PingPayload.size()); + std::vector<uint8_t> PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData); + asio::write(Sock, asio::buffer(PingFrame)); + + // Should receive a pong with the same payload + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong); + CHECK_EQ(Reply.Payload.size(), 4u); + std::string_view PongText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(PongText, "test"sv); + + Sock.close(); + } + + SUBCASE("multiple messages in sequence") + { + asio::io_context IoCtx; + asio::ip::tcp::socket Sock(IoCtx); + Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast<uint16_t>(Port))); + + bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); + REQUIRE(Ok); + Sleep(50); + + for (int i = 0; i < 10; ++i) + { + std::string Msg = fmt::format("message {}", i); + std::vector<uint8_t> Frame = BuildMaskedTextFrame(Msg); + asio::write(Sock, asio::buffer(Frame)); + + WsFrameParseResult Reply = ReadOneFrame(Sock); + REQUIRE(Reply.IsValid); + CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); + std::string_view ReplyText(reinterpret_cast<const char*>(Reply.Payload.data()), Reply.Payload.size()); + CHECK_EQ(ReplyText, Msg); + } + + CHECK_EQ(TestService.m_MessageCount.load(), 10); + + Sock.close(); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Integration tests: HttpWsClient +// + +namespace { + + struct TestWsClientHandler : public IWsClientHandler + { + void OnWsOpen() override { m_OpenCount.fetch_add(1); } + + void OnWsMessage(const WebSocketMessage& Msg) override + { + if (Msg.Opcode == WebSocketOpcode::kText) + { + std::string_view Text(static_cast<const char*>(Msg.Payload.Data()), Msg.Payload.Size()); + m_LastMessage = std::string(Text); + } + m_MessageCount.fetch_add(1); + } + + void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override + { + m_CloseCount.fetch_add(1); + m_LastCloseCode = Code; + } + + std::atomic<int> m_OpenCount{0}; + std::atomic<int> m_MessageCount{0}; + std::atomic<int> m_CloseCount{0}; + std::atomic<uint16_t> m_LastCloseCode{0}; + std::string m_LastMessage; + }; + +} // anonymous namespace + +TEST_CASE("websocket.client") +{ + WsTestService TestService; + ScopedTemporaryDirectory TmpDir; + + Ref<HttpServer> Server = CreateHttpAsioServer(AsioConfig{}); + + int Port = Server->Initialize(7576, TmpDir.Path()); + REQUIRE(Port != 0); + + Server->RegisterService(TestService); + + std::thread ServerThread([&]() { Server->Run(false); }); + + auto ServerGuard = MakeGuard([&]() { + Server->RequestExit(); + if (ServerThread.joinable()) + { + ServerThread.join(); + } + Server->Close(); + }); + + Sleep(100); + + SUBCASE("connect, echo, close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + // Wait for OnWsOpen + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + CHECK(Client.IsOpen()); + + // Send text, expect echo + Client.SendText("hello from client"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + CHECK_EQ(Handler.m_MessageCount.load(), 1); + CHECK_EQ(Handler.m_LastMessage, "hello from client"); + + // Close + Client.Close(1000, "done"); + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + // The server echoes the close frame, which triggers OnWsClose on the client side + // with the server's close code. Allow the connection to settle. + Sleep(50); + CHECK_FALSE(Client.IsOpen()); + } + + SUBCASE("connect to bad port") + { + TestWsClientHandler Handler; + std::string Url = "ws://127.0.0.1:1/wstest/ws"; + + HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1006); + CHECK_EQ(Handler.m_OpenCount.load(), 0); + } + + SUBCASE("server-initiated close") + { + TestWsClientHandler Handler; + std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); + + HttpWsClient Client(Url, Handler); + Client.Connect(); + + auto Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + REQUIRE_EQ(Handler.m_OpenCount.load(), 1); + + // Copy connections then close them outside the lock to avoid deadlocking + // with OnWebSocketClose which acquires an exclusive lock + std::vector<Ref<WebSocketConnection>> Conns; + TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; }); + for (auto& Conn : Conns) + { + Conn->Close(1001, "going away"); + } + + Deadline = std::chrono::steady_clock::now() + 5s; + while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) + { + Sleep(10); + } + + CHECK_EQ(Handler.m_CloseCount.load(), 1); + CHECK_EQ(Handler.m_LastCloseCode.load(), 1001); + CHECK_FALSE(Client.IsOpen()); + } +} + +TEST_SUITE_END(); + +void +websocket_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_WITH_TESTS diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp index 9135d5425..489324aba 100644 --- a/src/zenhttp/transports/dlltransport.cpp +++ b/src/zenhttp/transports/dlltransport.cpp @@ -72,20 +72,36 @@ DllTransportLogger::DllTransportLogger(std::string_view PluginName) : m_PluginNa void DllTransportLogger::LogMessage(LogLevel PluginLogLevel, const char* Message) { - logging::level::LogLevel Level; - // clang-format off switch (PluginLogLevel) { - case LogLevel::Trace: Level = logging::level::Trace; break; - case LogLevel::Debug: Level = logging::level::Debug; break; - case LogLevel::Info: Level = logging::level::Info; break; - case LogLevel::Warn: Level = logging::level::Warn; break; - case LogLevel::Err: Level = logging::level::Err; break; - case LogLevel::Critical: Level = logging::level::Critical; break; - default: Level = logging::level::Off; break; + case LogLevel::Trace: + ZEN_TRACE("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Debug: + ZEN_DEBUG("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Info: + ZEN_INFO("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Warn: + ZEN_WARN("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Err: + ZEN_ERROR("[{}] {}", m_PluginName, Message); + return; + + case LogLevel::Critical: + ZEN_CRITICAL("[{}] {}", m_PluginName, Message); + return; + + default: + ZEN_UNUSED(Message); + break; } - // clang-format on - ZEN_LOG(Log(), Level, "[{}] {}", m_PluginName, Message) } uint32_t diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp index c06a50c95..0217ed44e 100644 --- a/src/zenhttp/transports/winsocktransport.cpp +++ b/src/zenhttp/transports/winsocktransport.cpp @@ -322,7 +322,7 @@ SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface) else { } - } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + } while (m_KeepRunning.test()); ZEN_INFO("HTTP plugin server accept thread exit"); }); diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 78876d21b..e8f87b668 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -6,6 +6,7 @@ target('zenhttp') add_headerfiles("**.h") add_files("**.cpp") add_files("servers/httpsys.cpp", {unity_ignored=true}) + add_files("servers/wshttpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) add_deps("zencore", "zentelemetry", "transport-sdk", "asio", "cpr") add_packages("http_parser", "json11") diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp index a2679f92e..3ac8eea8d 100644 --- a/src/zenhttp/zenhttp.cpp +++ b/src/zenhttp/zenhttp.cpp @@ -7,6 +7,7 @@ # include <zenhttp/httpclient.h> # include <zenhttp/httpserver.h> # include <zenhttp/packageformat.h> +# include <zenhttp/security/passwordsecurity.h> namespace zen { @@ -15,7 +16,10 @@ zenhttp_forcelinktests() { http_forcelink(); httpclient_forcelink(); + httpclient_test_forcelink(); forcelink_packageformat(); + passwordsecurity_forcelink(); + websocket_forcelink(); } } // namespace zen |