diff options
Diffstat (limited to 'src/zenhttp/httpclient.cpp')
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 104 |
1 files changed, 98 insertions, 6 deletions
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index c77be8624..16729ce38 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -427,13 +427,13 @@ TEST_CASE("httpclient") AsioServer->RegisterService(TestService); - std::thread SeverThread([&]() { AsioServer->Run(false); }); + std::thread ServerThread([&]() { AsioServer->Run(false); }); { auto _ = MakeGuard([&]() { - if (SeverThread.joinable()) + if (ServerThread.joinable()) { - SeverThread.join(); + ServerThread.join(); } AsioServer->Close(); }); @@ -502,13 +502,13 @@ TEST_CASE("httpclient") HttpSysServer->RegisterService(TestService); - std::thread SeverThread([&]() { HttpSysServer->Run(false); }); + std::thread ServerThread([&]() { HttpSysServer->Run(false); }); { auto _ = MakeGuard([&]() { - if (SeverThread.joinable()) + if (ServerThread.joinable()) { - SeverThread.join(); + ServerThread.join(); } HttpSysServer->Close(); }); @@ -570,6 +570,98 @@ TEST_CASE("httpclient") # endif // ZEN_PLATFORM_WINDOWS } +TEST_CASE("httpclient.requestfilter") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + class MyFilterImpl : public IHttpRequestFilter + { + public: + virtual Result FilterRequest(HttpServerRequest& Request) + { + if (Request.RelativeUri() == "should_filter") + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you"); + return Result::ResponseSent; + } + else if (Request.RelativeUri() == "should_forbid") + { + return Result::Forbidden; + } + return Result::Accepted; + } + }; + + MyFilterImpl MyFilter; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response YoResponse = Client.Get("/test/yo"); + CHECK(YoResponse.IsSuccess()); + CHECK_EQ(YoResponse.AsText(), "hey family"); + + HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter"); + CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed); + CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you"); + + HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid"); + CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden); + + AsioServer->RequestExit(); + } +} + void httpclient_forcelink() { |