// Copyright Epic Games, Inc. All Rights Reserved. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "clients/httpclientcommon.h" #if ZEN_WITH_TESTS # include # include # include # include # include "servers/httpasio.h" # include "servers/httpsys.h" # include #endif // ZEN_WITH_TESTS namespace zen { extern HttpClientBase* CreateCprHttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function&& CheckIfAbortFunction); using namespace std::literals; ////////////////////////////////////////////////////////////////////////// HttpClientBase::HttpClientBase(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function&& CheckIfAbortFunction) : m_Log(zen::logging::Get(ConnectionSettings.LogCategory)) , m_BaseUri(BaseUri) , m_ConnectionSettings(ConnectionSettings) , m_CheckIfAbortFunction(std::move(CheckIfAbortFunction)) { if (ConnectionSettings.SessionId == Oid::Zero) { m_SessionId = GetSessionIdString(); } else { m_SessionId = ConnectionSettings.SessionId.ToString(); } } HttpClientBase::~HttpClientBase() { } bool HttpClientBase::Authenticate() { ZEN_TRACE_CPU("HttpClientBase::Authenticate"); std::optional Token = GetAccessToken(); if (!Token) { return false; } return Token->IsValid(); } const std::optional HttpClientBase::GetAccessToken() { ZEN_TRACE_CPU("HttpClientBase::GetAccessToken"); if (!m_ConnectionSettings.AccessTokenProvider.has_value()) { return {}; } { RwLock::SharedLockScope _(m_AccessTokenLock); if (m_CachedAccessToken.IsValid()) { return m_CachedAccessToken; } } RwLock::ExclusiveLockScope _(m_AccessTokenLock); if (m_CachedAccessToken.IsValid()) { return m_CachedAccessToken; } m_CachedAccessToken = m_ConnectionSettings.AccessTokenProvider.value()(); return m_CachedAccessToken; } ////////////////////////////////////////////////////////////////////////// CbObject HttpClient::Response::AsObject() const { if (ResponsePayload) { CbValidateError ValidationError = CbValidateError::None; if (CbObject ResponseObject = ValidateAndReadCompactBinaryObject(IoBuffer(ResponsePayload), ValidationError); ValidationError == CbValidateError::None) { return ResponseObject; } } return {}; } CbPackage HttpClient::Response::AsPackage() const { // TODO: sanity checks and error handling if (ResponsePayload) { return ParsePackageMessage(ResponsePayload); } return {}; } std::string_view HttpClient::Response::AsText() const { if (ResponsePayload) { return std::string_view(reinterpret_cast(ResponsePayload.GetData()), ResponsePayload.GetSize()); } return {}; } std::string HttpClient::Response::ToText() const { if (!ResponsePayload) return {}; switch (ResponsePayload.GetContentType()) { case ZenContentType::kCbObject: { zen::ExtendableStringBuilder<1024> ObjStr; zen::CbObject Object{SharedBuffer(ResponsePayload)}; zen::CompactBinaryToJson(Object, ObjStr); return ObjStr.ToString(); } break; case ZenContentType::kCSS: case ZenContentType::kHTML: case ZenContentType::kJavaScript: case ZenContentType::kJSON: case ZenContentType::kText: case ZenContentType::kYAML: return std::string{AsText()}; default: return ""; } } bool HttpClient::Response::IsSuccess() const noexcept { return !Error && IsHttpSuccessCode(StatusCode); } std::string HttpClient::Response::ErrorMessage(std::string_view Prefix) const { if (Error.has_value()) { return fmt::format("{}{}HTTP error ({}) '{}'", Prefix, Prefix.empty() ? ""sv : ": "sv, Error->ErrorCode, Error->ErrorMessage); } else if (StatusCode != HttpResponseCode::ImATeapot && (int)StatusCode) { std::string TextResponse = ToText(); return fmt::format("{}{}HTTP status ({}) '{}'{}", Prefix, Prefix.empty() ? ""sv : ": "sv, (int)StatusCode, zen::ToString(StatusCode), TextResponse.empty() ? ""sv : fmt::format(" ({})", TextResponse)); } else { return fmt::format("{}{}Unknown error", Prefix, Prefix.empty() ? ""sv : ": "sv); } } void HttpClient::Response::ThrowError(std::string_view ErrorPrefix) { if (!IsSuccess()) { throw HttpClientError(ErrorMessage(ErrorPrefix), Error.has_value() ? Error.value().ErrorCode : 0, StatusCode); } } ////////////////////////////////////////////////////////////////////////// HttpClient::HttpClient(std::string_view BaseUri, const HttpClientSettings& ConnectionSettings, std::function&& CheckIfAbortFunction) : m_BaseUri(BaseUri) , m_ConnectionSettings(ConnectionSettings) { m_SessionId = GetSessionIdString(); m_Inner = CreateCprHttpClient(BaseUri, ConnectionSettings, std::move(CheckIfAbortFunction)); } HttpClient::~HttpClient() { delete m_Inner; } void HttpClient::SetSessionId(const Oid& SessionId) { if (SessionId == Oid::Zero) { m_SessionId = GetSessionIdString(); } else { m_SessionId = SessionId.ToString(); } } HttpClient::Response HttpClient::Put(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Put(Url, Payload, AdditionalHeader); } HttpClient::Response HttpClient::Put(std::string_view Url, const HttpClient::KeyValueMap& Parameters) { return m_Inner->Put(Url, Parameters); } HttpClient::Response HttpClient::Get(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters) { return m_Inner->Get(Url, AdditionalHeader, Parameters); } HttpClient::Response HttpClient::Head(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Head(Url, AdditionalHeader); } HttpClient::Response HttpClient::Delete(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Delete(Url, AdditionalHeader); } HttpClient::Response HttpClient::Post(std::string_view Url, const HttpClient::KeyValueMap& AdditionalHeader, const HttpClient::KeyValueMap& Parameters) { return m_Inner->Post(Url, AdditionalHeader, Parameters); } HttpClient::Response HttpClient::Post(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Post(Url, Payload, AdditionalHeader); } HttpClient::Response HttpClient::Post(std::string_view Url, const IoBuffer& Payload, ZenContentType ContentType, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader); } HttpClient::Response HttpClient::Post(std::string_view Url, CbObject Payload, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Post(Url, Payload, AdditionalHeader); } HttpClient::Response HttpClient::Post(std::string_view Url, CbPackage Payload, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Post(Url, Payload, AdditionalHeader); } HttpClient::Response HttpClient::Post(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Post(Url, Payload, ContentType, AdditionalHeader); } HttpClient::Response HttpClient::Upload(std::string_view Url, const IoBuffer& Payload, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Upload(Url, Payload, AdditionalHeader); } HttpClient::Response HttpClient::Upload(std::string_view Url, const CompositeBuffer& Payload, ZenContentType ContentType, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Upload(Url, Payload, ContentType, AdditionalHeader); } HttpClient::Response HttpClient::Download(std::string_view Url, const std::filesystem::path& TempFolderPath, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->Download(Url, TempFolderPath, AdditionalHeader); } HttpClient::Response HttpClient::TransactPackage(std::string_view Url, CbPackage Package, const HttpClient::KeyValueMap& AdditionalHeader) { return m_Inner->TransactPackage(Url, Package, AdditionalHeader); } bool HttpClient::Authenticate() { return m_Inner->Authenticate(); } ////////////////////////////////////////////////////////////////////////// #if ZEN_WITH_TESTS TEST_CASE("responseformat") { using namespace std::literals; SUBCASE("identity") { BodyLogFormatter _{"abcd"}; CHECK_EQ(_.GetText(), "abcd"sv); } SUBCASE("very long") { std::string_view LongView = "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"; BodyLogFormatter _{LongView}; CHECK(_.GetText().size() < LongView.size()); CHECK(_.GetText().starts_with("[truncated"sv)); } SUBCASE("invalid text") { std::string_view BadText = "totobaba\xff\xfe"; BodyLogFormatter _{BadText}; CHECK_EQ(_.GetText(), "totobaba"); } } TEST_CASE("httpclient") { using namespace std::literals; struct TestHttpService : public HttpService { TestHttpService() = default; virtual const char* BaseUri() const override { return "/test/"; } virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override { if (HttpServiceRequest.RelativeUri() == "yo") { if (HttpServiceRequest.IsLocalMachineRequest()) { return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); } else { return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey stranger"); } } return HttpServiceRequest.WriteResponse(HttpResponseCode::OK); } }; TestHttpService TestService; ScopedTemporaryDirectory TmpDir; SUBCASE("asio") { Ref AsioServer = CreateHttpAsioServer(AsioConfig{}); int Port = AsioServer->Initialize(7575, TmpDir.Path()); REQUIRE(Port != -1); AsioServer->RegisterService(TestService); std::thread ServerThread([&]() { AsioServer->Run(false); }); { auto _ = MakeGuard([&]() { if (ServerThread.joinable()) { ServerThread.join(); } AsioServer->Close(); }); { HttpClient Client(fmt::format("127.0.0.1:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } if (IsIPv6Capable()) { HttpClient Client(fmt::format("[::1]:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } { HttpClient Client(fmt::format("localhost:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } # if 0 { HttpClient Client(fmt::format("10.24.101.77:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } Sleep(20000); # endif // 0 AsioServer->RequestExit(); } } # if ZEN_PLATFORM_WINDOWS SUBCASE("httpsys") { Ref HttpSysServer = CreateHttpSysServer(HttpSysConfig{.ForceLoopback = false}); int Port = HttpSysServer->Initialize(7575, TmpDir.Path()); REQUIRE(Port != -1); HttpSysServer->RegisterService(TestService); std::thread ServerThread([&]() { HttpSysServer->Run(false); }); { auto _ = MakeGuard([&]() { if (ServerThread.joinable()) { ServerThread.join(); } HttpSysServer->Close(); }); if (true) { HttpClient Client(fmt::format("127.0.0.1:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } if (IsIPv6Capable()) { HttpClient Client(fmt::format("[::1]:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } { HttpClient Client(fmt::format("localhost:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } # if 0 { HttpClient Client(fmt::format("10.24.101.77:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response TestResponse = Client.Get("/test/yo"); CHECK(TestResponse.IsSuccess()); CHECK_EQ(TestResponse.AsText(), "hey family"); } Sleep(20000); # endif // 0 HttpSysServer->RequestExit(); } } # endif // ZEN_PLATFORM_WINDOWS } TEST_CASE("httpclient.requestfilter") { using namespace std::literals; struct TestHttpService : public HttpService { TestHttpService() = default; virtual const char* BaseUri() const override { return "/test/"; } virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override { if (HttpServiceRequest.RelativeUri() == "yo") { return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); } { CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); } { CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); } } }; TestHttpService TestService; ScopedTemporaryDirectory TmpDir; class MyFilterImpl : public IHttpRequestFilter { public: virtual Result FilterRequest(HttpServerRequest& Request) { if (Request.RelativeUri() == "should_filter") { Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you"); return Result::ResponseSent; } else if (Request.RelativeUri() == "should_forbid") { return Result::Forbidden; } return Result::Accepted; } }; MyFilterImpl MyFilter; Ref AsioServer = CreateHttpAsioServer(AsioConfig{}); AsioServer->SetHttpRequestFilter(&MyFilter); int Port = AsioServer->Initialize(7575, TmpDir.Path()); REQUIRE(Port != -1); AsioServer->RegisterService(TestService); std::thread ServerThread([&]() { AsioServer->Run(false); }); { auto _ = MakeGuard([&]() { if (ServerThread.joinable()) { ServerThread.join(); } AsioServer->Close(); }); HttpClient Client(fmt::format("localhost:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response YoResponse = Client.Get("/test/yo"); CHECK(YoResponse.IsSuccess()); CHECK_EQ(YoResponse.AsText(), "hey family"); HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter"); CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed); CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you"); HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid"); CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden); AsioServer->RequestExit(); } } TEST_CASE("httpclient.password") { using namespace std::literals; struct TestHttpService : public HttpService { TestHttpService() = default; virtual const char* BaseUri() const override { return "/test/"; } virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override { if (HttpServiceRequest.RelativeUri() == "yo") { return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); } { CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); } { CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); } } }; TestHttpService TestService; ScopedTemporaryDirectory TmpDir; Ref AsioServer = CreateHttpAsioServer(AsioConfig{}); int Port = AsioServer->Initialize(7575, TmpDir.Path()); REQUIRE(Port != -1); AsioServer->RegisterService(TestService); std::thread ServerThread([&]() { AsioServer->Run(false); }); { auto _ = MakeGuard([&]() { if (ServerThread.joinable()) { ServerThread.join(); } AsioServer->Close(); }); SUBCASE("usernamepassword") { CbObjectWriter Writer; { Writer.BeginObject("basic"); { Writer << "username"sv << "me"; Writer << "password"sv << "456123789"; } Writer.EndObject(); Writer << "protect-machine-local-requests" << true; } PasswordHttpFilter::Configuration PasswordFilterOptions = PasswordHttpFilter::ReadConfiguration(Writer.Save()); PasswordHttpFilter MyFilter(PasswordFilterOptions); AsioServer->SetHttpRequestFilter(&MyFilter); HttpClient Client(fmt::format("localhost:{}", Port), HttpClientSettings{}, /*CheckIfAbortFunction*/ {}); ZEN_INFO("Request using {}", Client.GetBaseUri()); HttpClient::Response ForbiddenResponse = Client.Get("/test/yo"); CHECK(!ForbiddenResponse.IsSuccess()); CHECK_EQ(ForbiddenResponse.StatusCode, HttpResponseCode::Forbidden); HttpClient::Response WithBasicResponse = Client.Get("/test/yo", std::pair("Authorization", fmt::format("Basic {}", PasswordFilterOptions.PasswordConfig.Password))); CHECK(WithBasicResponse.IsSuccess()); AsioServer->SetHttpRequestFilter(nullptr); } AsioServer->RequestExit(); } } void httpclient_forcelink() { } #endif } // namespace zen