diff options
Diffstat (limited to 'zenhttp')
| -rw-r--r-- | zenhttp/httpasio.cpp | 1372 | ||||
| -rw-r--r-- | zenhttp/httpasio.h | 36 | ||||
| -rw-r--r-- | zenhttp/httpclient.cpp | 176 | ||||
| -rw-r--r-- | zenhttp/httpnull.cpp | 83 | ||||
| -rw-r--r-- | zenhttp/httpnull.h | 29 | ||||
| -rw-r--r-- | zenhttp/httpserver.cpp | 885 | ||||
| -rw-r--r-- | zenhttp/httpshared.cpp | 809 | ||||
| -rw-r--r-- | zenhttp/httpsys.cpp | 1674 | ||||
| -rw-r--r-- | zenhttp/httpsys.h | 90 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/httpclient.h | 47 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/httpcommon.h | 181 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/httpserver.h | 315 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/httpshared.h | 163 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/websocket.h | 256 | ||||
| -rw-r--r-- | zenhttp/include/zenhttp/zenhttp.h | 21 | ||||
| -rw-r--r-- | zenhttp/iothreadpool.cpp | 49 | ||||
| -rw-r--r-- | zenhttp/iothreadpool.h | 37 | ||||
| -rw-r--r-- | zenhttp/websocketasio.cpp | 1613 | ||||
| -rw-r--r-- | zenhttp/xmake.lua | 14 | ||||
| -rw-r--r-- | zenhttp/zenhttp.cpp | 22 |
20 files changed, 0 insertions, 7872 deletions
diff --git a/zenhttp/httpasio.cpp b/zenhttp/httpasio.cpp deleted file mode 100644 index 79b2c0a3d..000000000 --- a/zenhttp/httpasio.cpp +++ /dev/null @@ -1,1372 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpasio.h" - -#include <zencore/logging.h> -#include <zenhttp/httpserver.h> - -#include <deque> -#include <memory> -#include <string_view> -#include <vector> - -ZEN_THIRD_PARTY_INCLUDES_START -#if ZEN_PLATFORM_WINDOWS -# include <conio.h> -# include <mstcpip.h> -#endif -#include <http_parser.h> -#include <asio.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -#define ASIO_VERBOSE_TRACE 0 - -#if ASIO_VERBOSE_TRACE -# define ZEN_TRACE_VERBOSE ZEN_TRACE -#else -# define ZEN_TRACE_VERBOSE(fmtstr, ...) -#endif - -namespace zen::asio_http { - -using namespace std::literals; - -struct HttpAcceptor; -struct HttpRequest; -struct HttpResponse; -struct HttpServerConnection; - -static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); -static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); -static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); -static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); -static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); -static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); -static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); - -inline spdlog::logger& -InitLogger() -{ - spdlog::logger& Logger = logging::Get("asio"); - // Logger.set_level(spdlog::level::trace); - return Logger; -} - -inline spdlog::logger& -Log() -{ - static spdlog::logger& g_Logger = InitLogger(); - return g_Logger; -} - -////////////////////////////////////////////////////////////////////////// - -struct HttpAsioServerImpl -{ -public: - HttpAsioServerImpl(); - ~HttpAsioServerImpl(); - - int Start(uint16_t Port, int ThreadCount); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); - - asio::io_service m_IoService; - asio::io_service::work m_Work{m_IoService}; - std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; - std::vector<std::thread> m_ThreadPool; - - struct ServiceEntry - { - std::string ServiceUrlPath; - HttpService* Service; - }; - - RwLock m_Lock; - std::vector<ServiceEntry> m_UriHandlers; -}; - -/** - * This is the class which request handlers use to interact with the server instance - */ - -class HttpAsioServerRequest : public HttpServerRequest -{ -public: - HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer); - ~HttpAsioServerRequest(); - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; - - virtual IoBuffer ReadPayload() override; - virtual void WriteResponse(HttpResponseCode ResponseCode) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; - virtual bool TryGetRanges(HttpRanges& Ranges) override; - - using HttpServerRequest::WriteResponse; - - HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; - HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; - - asio_http::HttpRequest& m_Request; - IoBuffer m_PayloadBuffer; - std::unique_ptr<HttpResponse> m_Response; -}; - -struct HttpRequest -{ - explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {} - - void Initialize(); - size_t ConsumeData(const char* InputData, size_t DataSize); - void ResetState(); - - HttpVerb RequestVerb() const { return m_RequestVerb; } - bool IsKeepAlive() const { return m_KeepAlive; } - std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } - std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } - IoBuffer Body() { return m_BodyBuffer; } - - inline HttpContentType ContentType() - { - if (m_ContentTypeHeaderIndex < 0) - { - return HttpContentType::kUnknownContentType; - } - - return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); - } - - inline HttpContentType AcceptType() - { - if (m_AcceptHeaderIndex < 0) - { - return HttpContentType::kUnknownContentType; - } - - return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); - } - - Oid SessionId() const { return m_SessionId; } - int RequestId() const { return m_RequestId; } - - std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } - -private: - struct HeaderEntry - { - HeaderEntry() = default; - - HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} - - std::string_view Name; - std::string_view Value; - }; - - HttpServerConnection& m_Connection; - char* m_HeaderCursor = m_HeaderBuffer; - char* m_Url = nullptr; - size_t m_UrlLength = 0; - char* m_QueryString = nullptr; - size_t m_QueryLength = 0; - char* m_CurrentHeaderName = nullptr; // Used while parsing headers - size_t m_CurrentHeaderNameLength = 0; - char* m_CurrentHeaderValue = nullptr; // Used while parsing headers - size_t m_CurrentHeaderValueLength = 0; - std::vector<HeaderEntry> m_Headers; - int8_t m_ContentLengthHeaderIndex; - int8_t m_AcceptHeaderIndex; - int8_t m_ContentTypeHeaderIndex; - int8_t m_RangeHeaderIndex; - HttpVerb m_RequestVerb; - bool m_KeepAlive = false; - bool m_Expect100Continue = false; - int m_RequestId = -1; - Oid m_SessionId{}; - IoBuffer m_BodyBuffer; - uint64_t m_BodyPosition = 0; - http_parser m_Parser; - char m_HeaderBuffer[1024]; - std::string m_NormalizedUrl; - - void AppendCurrentHeader(); - - int OnMessageBegin(); - int OnUrl(const char* Data, size_t Bytes); - int OnHeader(const char* Data, size_t Bytes); - int OnHeaderValue(const char* Data, size_t Bytes); - int OnHeadersComplete(); - int OnBody(const char* Data, size_t Bytes); - int OnMessageComplete(); - - static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); } - static http_parser_settings s_ParserSettings; -}; - -struct HttpResponse -{ -public: - HttpResponse() = default; - explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} - - void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) - { - m_ResponseCode = ResponseCode; - const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); - - m_DataBuffers.reserve(ChunkCount); - - for (IoBuffer& Buffer : BlobList) - { -#if 1 - m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); -#else - IoBuffer TempBuffer = std::move(Buffer); - TempBuffer.MakeOwned(); - m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); -#endif - } - - uint64_t LocalDataSize = 0; - - m_AsioBuffers.push_back({}); // Placeholder for header - - for (IoBuffer& Buffer : m_DataBuffers) - { - uint64_t BufferDataSize = Buffer.Size(); - - ZEN_ASSERT(BufferDataSize); - - LocalDataSize += BufferDataSize; - - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - // TODO: Use direct file transfer, via TransmitFile/sendfile - // - // this looks like it requires some custom asio plumbing however - - m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); - } - else - { - // Send from memory - - m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); - } - } - m_ContentLength = LocalDataSize; - - auto Headers = GetHeaders(); - m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); - } - - uint16_t ResponseCode() const { return m_ResponseCode; } - uint64_t ContentLength() const { return m_ContentLength; } - - const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; } - - std::string_view GetHeaders() - { - m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" - << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" - << "Content-Length: " << ContentLength() << "\r\n"sv; - - if (!m_IsKeepAlive) - { - m_Headers << "Connection: close\r\n"sv; - } - - m_Headers << "\r\n"sv; - - return m_Headers; - } - - void SuppressPayload() { m_AsioBuffers.resize(1); } - -private: - uint16_t m_ResponseCode = 0; - bool m_IsKeepAlive = true; - HttpContentType m_ContentType = HttpContentType::kBinary; - uint64_t m_ContentLength = 0; - std::vector<IoBuffer> m_DataBuffers; - std::vector<asio::const_buffer> m_AsioBuffers; - ExtendableStringBuilder<160> m_Headers; -}; - -////////////////////////////////////////////////////////////////////////// - -struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection> -{ - HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); - ~HttpServerConnection(); - - void HandleNewRequest(); - void TerminateConnection(); - void HandleRequest(); - - std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } - -private: - enum class RequestState - { - kInitialState, - kInitialRead, - kReadingMore, - kWriting, - kWritingFinal, - kDone, - kTerminated - }; - - RequestState m_RequestState = RequestState::kInitialState; - HttpRequest m_RequestData{*this}; - - void EnqueueRead(); - void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); - void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); - void OnError(); - - HttpAsioServerImpl& m_Server; - asio::streambuf m_RequestBuffer; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::atomic<uint32_t> m_RequestCounter{0}; - uint32_t m_ConnectionId = 0; - Ref<IHttpPackageHandler> m_PackageHandler; - - RwLock m_ResponsesLock; - std::deque<std::unique_ptr<HttpResponse>> m_Responses; -}; - -std::atomic<uint32_t> g_ConnectionIdCounter{0}; - -HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) -: m_Server(Server) -, m_Socket(std::move(Socket)) -, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) -{ - ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); -} - -HttpServerConnection::~HttpServerConnection() -{ - ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); -} - -void -HttpServerConnection::HandleNewRequest() -{ - m_RequestData.Initialize(); - - EnqueueRead(); -} - -void -HttpServerConnection::TerminateConnection() -{ - m_RequestState = RequestState::kTerminated; - - std::error_code Ec; - m_Socket->close(Ec); -} - -void -HttpServerConnection::EnqueueRead() -{ - if (m_RequestState == RequestState::kInitialRead) - { - m_RequestState = RequestState::kReadingMore; - } - else - { - m_RequestState = RequestState::kInitialRead; - } - - m_RequestBuffer.prepare(64 * 1024); - - asio::async_read(*m_Socket.get(), - m_RequestBuffer, - asio::transfer_at_least(1), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); -} - -void -HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) -{ - if (Ec) - { - if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) - { - ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message()); - return; - } - else - { - ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); - return OnError(); - } - } - - ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'", - m_ConnectionId, - m_RequestCounter.load(std::memory_order_relaxed), - zen::GetCurrentThreadId(), - NiceBytes(ByteCount)); - - while (m_RequestBuffer.size()) - { - const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); - - size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); - if (Result == ~0ull) - { - return OnError(); - } - - m_RequestBuffer.consume(Result); - } - - switch (m_RequestState) - { - case RequestState::kDone: - case RequestState::kWritingFinal: - case RequestState::kTerminated: - break; - - default: - EnqueueRead(); - break; - } -} - -void -HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) -{ - if (Ec) - { - ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); - OnError(); - } - else - { - ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'", - m_ConnectionId, - m_RequestCounter.load(std::memory_order_relaxed), - zen::GetCurrentThreadId(), - NiceBytes(ByteCount)); - - if (!m_RequestData.IsKeepAlive()) - { - m_RequestState = RequestState::kDone; - - m_Socket->close(); - } - else - { - if (Pop) - { - RwLock::ExclusiveLockScope _(m_ResponsesLock); - m_Responses.pop_front(); - } - - m_RequestCounter.fetch_add(1); - } - } -} - -void -HttpServerConnection::OnError() -{ - m_Socket->close(); -} - -void -HttpServerConnection::HandleRequest() -{ - if (!m_RequestData.IsKeepAlive()) - { - m_RequestState = RequestState::kWritingFinal; - - std::error_code Ec; - m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); - - if (Ec) - { - ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); - } - } - else - { - m_RequestState = RequestState::kWriting; - } - - if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) - { - HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); - - ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); - - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) - { - try - { - Service->HandleRequest(Request); - } - catch (std::exception& ex) - { - ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); - - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); - } - } - - if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) - { - // Transmit the response - - if (m_RequestData.RequestVerb() == HttpVerb::kHead) - { - Response->SuppressPayload(); - } - - auto ResponseBuffers = Response->AsioBuffers(); - - uint64_t ResponseLength = 0; - - for (auto& Buffer : ResponseBuffers) - { - ResponseLength += Buffer.size(); - } - - { - RwLock::ExclusiveLockScope _(m_ResponsesLock); - m_Responses.push_back(std::move(Response)); - } - - // TODO: should cork/uncork for Linux? - - asio::async_write(*m_Socket.get(), - ResponseBuffers, - asio::transfer_exactly(ResponseLength), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { - Conn->OnResponseDataSent(Ec, ByteCount, true); - }); - - return; - } - } - - if (m_RequestData.RequestVerb() == HttpVerb::kHead) - { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "\r\n"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Connection: close\r\n" - "\r\n"sv; - } - - asio::async_write( - *m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); - } - else - { - std::string_view Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "No suitable route found"sv; - - if (!m_RequestData.IsKeepAlive()) - { - Response = - "HTTP/1.1 404 NOT FOUND\r\n" - "Content-Length: 23\r\n" - "Content-Type: text/plain\r\n" - "Connection: close\r\n" - "\r\n" - "No suitable route found"sv; - } - - asio::async_write( - *m_Socket.get(), - asio::buffer(Response), - [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); - } -} - -////////////////////////////////////////////////////////////////////////// -// -// HttpRequest -// - -http_parser_settings HttpRequest::s_ParserSettings{ - .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, - .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, - .on_status = - [](http_parser* p, const char* Data, size_t ByteCount) { - ZEN_UNUSED(p, Data, ByteCount); - return 0; - }, - .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, - .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, - .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, - .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, - .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, - .on_chunk_header{}, - .on_chunk_complete{}}; - -void -HttpRequest::Initialize() -{ - http_parser_init(&m_Parser, HTTP_REQUEST); - m_Parser.data = this; - - ResetState(); -} - -size_t -HttpRequest::ConsumeData(const char* InputData, size_t DataSize) -{ - const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); - - http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); - - if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) - { - ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); - return ~0ull; - } - - return ConsumedBytes; -} - -int -HttpRequest::OnUrl(const char* Data, size_t Bytes) -{ - if (!m_Url) - { - ZEN_ASSERT_SLOW(m_UrlLength == 0); - m_Url = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - - if (RemainingBufferSpace < Bytes) - { - ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; - } - - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_UrlLength += Bytes; - - return 0; -} - -int -HttpRequest::OnHeader(const char* Data, size_t Bytes) -{ - if (m_CurrentHeaderValueLength) - { - AppendCurrentHeader(); - - m_CurrentHeaderNameLength = 0; - m_CurrentHeaderValueLength = 0; - m_CurrentHeaderName = m_HeaderCursor; - } - else if (m_CurrentHeaderName == nullptr) - { - m_CurrentHeaderName = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - if (RemainingBufferSpace < Bytes) - { - ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; - } - - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_CurrentHeaderNameLength += Bytes; - - return 0; -} - -void -HttpRequest::AppendCurrentHeader() -{ - std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); - std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); - - const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); - - if (HeaderHash == HashContentLength) - { - m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); - } - else if (HeaderHash == HashAccept) - { - m_AcceptHeaderIndex = (int8_t)m_Headers.size(); - } - else if (HeaderHash == HashContentType) - { - m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); - } - else if (HeaderHash == HashSession) - { - m_SessionId = Oid::FromHexString(HeaderValue); - } - else if (HeaderHash == HashRequest) - { - std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); - } - else if (HeaderHash == HashExpect) - { - if (HeaderValue == "100-continue"sv) - { - // We don't currently do anything with this - m_Expect100Continue = true; - } - else - { - ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); - } - } - else if (HeaderHash == HashRange) - { - m_RangeHeaderIndex = (int8_t)m_Headers.size(); - } - - m_Headers.emplace_back(HeaderName, HeaderValue); -} - -int -HttpRequest::OnHeaderValue(const char* Data, size_t Bytes) -{ - if (m_CurrentHeaderValueLength == 0) - { - m_CurrentHeaderValue = m_HeaderCursor; - } - - const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; - if (RemainingBufferSpace < Bytes) - { - ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); - return 1; - } - - memcpy(m_HeaderCursor, Data, Bytes); - m_HeaderCursor += Bytes; - m_CurrentHeaderValueLength += Bytes; - - return 0; -} - -static void -NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) -{ - bool LastCharWasSeparator = false; - for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) - { - const char UrlChar = Url[UrlIndex]; - const bool IsSeparator = (UrlChar == '/'); - - if (IsSeparator && LastCharWasSeparator) - { - if (NormalizedUrl.empty()) - { - NormalizedUrl.reserve(UrlLength); - NormalizedUrl.append(Url, UrlIndex); - } - - if (!LastCharWasSeparator) - { - NormalizedUrl.push_back('/'); - } - } - else if (!NormalizedUrl.empty()) - { - NormalizedUrl.push_back(UrlChar); - } - - LastCharWasSeparator = IsSeparator; - } -} - -int -HttpRequest::OnHeadersComplete() -{ - if (m_CurrentHeaderValueLength) - { - AppendCurrentHeader(); - } - - if (m_ContentLengthHeaderIndex >= 0) - { - std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; - uint64_t ContentLength = 0; - std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); - - if (ContentLength) - { - m_BodyBuffer = IoBuffer(ContentLength); - } - - m_BodyBuffer.SetContentType(ContentType()); - - m_BodyPosition = 0; - } - - m_KeepAlive = !!http_should_keep_alive(&m_Parser); - - switch (m_Parser.method) - { - case HTTP_GET: - m_RequestVerb = HttpVerb::kGet; - break; - - case HTTP_POST: - m_RequestVerb = HttpVerb::kPost; - break; - - case HTTP_PUT: - m_RequestVerb = HttpVerb::kPut; - break; - - case HTTP_DELETE: - m_RequestVerb = HttpVerb::kDelete; - break; - - case HTTP_HEAD: - m_RequestVerb = HttpVerb::kHead; - break; - - case HTTP_COPY: - m_RequestVerb = HttpVerb::kCopy; - break; - - case HTTP_OPTIONS: - m_RequestVerb = HttpVerb::kOptions; - break; - - default: - ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); - break; - } - - std::string_view Url(m_Url, m_UrlLength); - - if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) - { - m_UrlLength = QuerySplit; - m_QueryString = m_Url + QuerySplit + 1; - m_QueryLength = Url.size() - QuerySplit - 1; - } - - NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); - - return 0; -} - -int -HttpRequest::OnBody(const char* Data, size_t Bytes) -{ - if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) - { - ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", - (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); - return 1; - } - memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); - m_BodyPosition += Bytes; - - if (http_body_is_final(&m_Parser)) - { - if (m_BodyPosition != m_BodyBuffer.Size()) - { - ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); - return 1; - } - } - - return 0; -} - -void -HttpRequest::ResetState() -{ - m_HeaderCursor = m_HeaderBuffer; - m_CurrentHeaderName = nullptr; - m_CurrentHeaderNameLength = 0; - m_CurrentHeaderValue = nullptr; - m_CurrentHeaderValueLength = 0; - m_CurrentHeaderName = nullptr; - m_Url = nullptr; - m_UrlLength = 0; - m_QueryString = nullptr; - m_QueryLength = 0; - m_ContentLengthHeaderIndex = -1; - m_AcceptHeaderIndex = -1; - m_ContentTypeHeaderIndex = -1; - m_RangeHeaderIndex = -1; - m_Expect100Continue = false; - m_BodyBuffer = {}; - m_BodyPosition = 0; - m_Headers.clear(); - m_NormalizedUrl.clear(); -} - -int -HttpRequest::OnMessageBegin() -{ - return 0; -} - -int -HttpRequest::OnMessageComplete() -{ - m_Connection.HandleRequest(); - - ResetState(); - - return 0; -} - -////////////////////////////////////////////////////////////////////////// - -struct HttpAcceptor -{ - HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) - : m_Server(Server) - , m_IoService(IoService) - , m_Acceptor(m_IoService, asio::ip::tcp::v6()) - { - m_Acceptor.set_option(asio::ip::v6_only(false)); -#if ZEN_PLATFORM_WINDOWS - // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms - typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address; - m_Acceptor.set_option(excluse_address(true)); -#else // ZEN_PLATFORM_WINDOWS - m_Acceptor.set_option(asio::socket_base::reuse_address(false)); -#endif // ZEN_PLATFORM_WINDOWS - - m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - uint16_t EffectivePort = BasePort; - - asio::error_code BindErrorCode; - m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); - // Sharing violation implies the port is being used by another process - for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) - { - EffectivePort = BasePort + (PortOffset * 100); - m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); - } - if (BindErrorCode == asio::error::access_denied) - { - EffectivePort = 0; - m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); - } - if (BindErrorCode) - { - ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); - } - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor.native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); -#endif - m_Acceptor.listen(); - - ZEN_INFO("Started asio server at port '{}'", EffectivePort); - } - - void Start() - { - m_Acceptor.listen(); - InitAccept(); - } - - void Stop() { m_IsStopped = true; } - - void InitAccept() - { - auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); - asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); - - m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { - if (Ec) - { - ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", - m_Acceptor.local_endpoint().address().to_string(), - m_Acceptor.local_endpoint().port(), - Ec.message()); - } - else - { - // New connection established, pass socket ownership into connection object - // and initiate request handling loop. The connection lifetime is - // managed by the async read/write loop by passing the shared - // reference to the callbacks. - - Socket->set_option(asio::ip::tcp::no_delay(true)); - Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); - Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); - - auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); - Conn->HandleNewRequest(); - } - - if (!m_IsStopped.load()) - { - InitAccept(); - } - else - { - m_Acceptor.close(); - } - }); - } - - int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } - -private: - HttpAsioServerImpl& m_Server; - asio::io_service& m_IoService; - asio::ip::tcp::acceptor m_Acceptor; - std::atomic<bool> m_IsStopped{false}; -}; - -////////////////////////////////////////////////////////////////////////// - -HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer) -: m_Request(Request) -, m_PayloadBuffer(std::move(PayloadBuffer)) -{ - const int PrefixLength = Service.UriPrefixLength(); - - std::string_view Uri = Request.Url(); - Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); - m_Uri = Uri; - m_UriWithExtension = Uri; - m_QueryString = Request.QueryString(); - - m_Verb = Request.RequestVerb(); - m_ContentLength = Request.Body().Size(); - m_ContentType = Request.ContentType(); - - HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; - - // Parse any extension, to allow requesting a particular response encoding via the URL - - { - std::string_view UriSuffix8{m_Uri}; - - const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); - - if (LastComponentIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastComponentIndex); - } - - const size_t LastDotIndex = UriSuffix8.find_last_of('.'); - - if (LastDotIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastDotIndex + 1); - - AcceptContentType = ParseContentType(UriSuffix8); - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); - } - } - } - - // It an explicit content type extension was specified then we'll use that over any - // Accept: header value that may be present - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_AcceptType = AcceptContentType; - } - else - { - m_AcceptType = Request.AcceptType(); - } -} - -HttpAsioServerRequest::~HttpAsioServerRequest() -{ -} - -Oid -HttpAsioServerRequest::ParseSessionId() const -{ - return m_Request.SessionId(); -} - -uint32_t -HttpAsioServerRequest::ParseRequestId() const -{ - return m_Request.RequestId(); -} - -IoBuffer -HttpAsioServerRequest::ReadPayload() -{ - return m_PayloadBuffer; -} - -void -HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) -{ - ZEN_ASSERT(!m_Response); - - m_Response.reset(new HttpResponse(HttpContentType::kBinary)); - std::array<IoBuffer, 0> Empty; - - m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); -} - -void -HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) -{ - ZEN_ASSERT(!m_Response); - - m_Response.reset(new HttpResponse(ContentType)); - m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); -} - -void -HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(!m_Response); - m_Response.reset(new HttpResponse(ContentType)); - - IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); - std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); - - m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); -} - -void -HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) -{ - ZEN_ASSERT(!m_Response); - - // Not one bit async, innit - ContinuationHandler(*this); -} - -bool -HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) -{ - return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); -} - -////////////////////////////////////////////////////////////////////////// - -HttpAsioServerImpl::HttpAsioServerImpl() -{ -} - -HttpAsioServerImpl::~HttpAsioServerImpl() -{ -} - -int -HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) -{ - ZEN_ASSERT(ThreadCount > 0); - - ZEN_INFO("starting asio http with {} service threads", ThreadCount); - - m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); - m_Acceptor->Start(); - - for (int i = 0; i < ThreadCount; ++i) - { - m_ThreadPool.emplace_back([this, Index = i + 1] { - SetCurrentThreadName(fmt::format("asio worker {}", Index)); - - try - { - m_IoService.run(); - } - catch (std::exception& e) - { - ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); - } - }); - } - - ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); - - return m_Acceptor->GetAcceptPort(); -} - -void -HttpAsioServerImpl::Stop() -{ - m_Acceptor->Stop(); - m_IoService.stop(); - for (auto& Thread : m_ThreadPool) - { - Thread.join(); - } -} - -void -HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) -{ - std::string_view UrlPath(InUrlPath); - Service.SetUriPrefixLength(UrlPath.size()); - if (!UrlPath.empty() && UrlPath.back() == '/') - { - UrlPath.remove_suffix(1); - } - - RwLock::ExclusiveLockScope _(m_Lock); - m_UriHandlers.push_back({std::string(UrlPath), &Service}); -} - -HttpService* -HttpAsioServerImpl::RouteRequest(std::string_view Url) -{ - RwLock::SharedLockScope _(m_Lock); - - HttpService* CandidateService = nullptr; - std::string::size_type CandidateMatchSize = 0; - for (const ServiceEntry& SvcEntry : m_UriHandlers) - { - const std::string& SvcUrl = SvcEntry.ServiceUrlPath; - const std::string::size_type SvcUrlSize = SvcUrl.size(); - if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && - ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) - { - CandidateMatchSize = SvcUrl.size(); - CandidateService = SvcEntry.Service; - } - } - - return CandidateService; -} - -} // namespace zen::asio_http - -////////////////////////////////////////////////////////////////////////// - -namespace zen { -HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) -{ - ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest)); -} - -HttpAsioServer::~HttpAsioServer() -{ - try - { - m_Impl->Stop(); - } - catch (std::exception& ex) - { - ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); - } -} - -void -HttpAsioServer::RegisterService(HttpService& Service) -{ - m_Impl->RegisterService(Service.BaseUri(), Service); -} - -int -HttpAsioServer::Initialize(int BasePort) -{ - m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u)); - return m_BasePort; -} - -void -HttpAsioServer::Run(bool IsInteractive) -{ - const bool TestMode = !IsInteractive; - - int WaitTimeout = -1; - if (!TestMode) - { - WaitTimeout = 1000; - } - -#if ZEN_PLATFORM_WINDOWS - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); - } - - do - { - if (!TestMode && _kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#else - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); - } - - do - { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#endif -} - -void -HttpAsioServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} - -} // namespace zen diff --git a/zenhttp/httpasio.h b/zenhttp/httpasio.h deleted file mode 100644 index 716145955..000000000 --- a/zenhttp/httpasio.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/thread.h> -#include <zenhttp/httpserver.h> - -#include <memory> - -namespace zen { - -namespace asio_http { - struct HttpServerConnection; - struct HttpAcceptor; - struct HttpAsioServerImpl; -} // namespace asio_http - -class HttpAsioServer : public HttpServer -{ -public: - HttpAsioServer(); - ~HttpAsioServer(); - - virtual void RegisterService(HttpService& Service) override; - virtual int Initialize(int BasePort) override; - virtual void Run(bool IsInteractiveSession) override; - virtual void RequestExit() override; - -private: - Event m_ShutdownEvent; - int m_BasePort = 0; - - std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl; -}; - -} // namespace zen diff --git a/zenhttp/httpclient.cpp b/zenhttp/httpclient.cpp deleted file mode 100644 index e6813d407..000000000 --- a/zenhttp/httpclient.cpp +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/httpclient.h> -#include <zenhttp/httpserver.h> - -#include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinarypackage.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/session.h> -#include <zencore/sharedbuffer.h> -#include <zencore/stream.h> -#include <zencore/testing.h> -#include <zenhttp/httpshared.h> - -static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; - -namespace zen { - -using namespace std::literals; - -HttpClient::Response -FromCprResponse(cpr::Response& InResponse) -{ - return {.StatusCode = int(InResponse.status_code)}; -} - -////////////////////////////////////////////////////////////////////////// - -HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri) -{ - StringBuilder<32> SessionId; - GetSessionId().ToString(SessionId); - m_SessionId = SessionId; -} - -HttpClient::~HttpClient() -{ -} - -HttpClient::Response -HttpClient::TransactPackage(std::string_view Url, CbPackage Package) -{ - cpr::Session Sess; - Sess.SetUrl(m_BaseUri + std::string(Url)); - - // First, list of offered chunks for filtering on the server end - - std::vector<IoHash> AttachmentsToSend; - std::span<const CbAttachment> Attachments = Package.GetAttachments(); - - const uint32_t RequestId = ++HttpClientRequestIdCounter; - auto RequestIdString = fmt::to_string(RequestId); - - if (Attachments.empty() == false) - { - CbObjectWriter Writer; - Writer.BeginArray("offer"); - - for (const CbAttachment& Attachment : Attachments) - { - IoHash Hash = Attachment.GetHash(); - - Writer.AddHash(Hash); - } - - Writer.EndArray(); - - BinaryWriter MemWriter; - Writer.Save(MemWriter); - - Sess.SetHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); - Sess.SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); - - cpr::Response FilterResponse = Sess.Post(); - - if (FilterResponse.status_code == 200) - { - IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); - CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer); - - for (auto& Entry : ResponseObject["need"]) - { - ZEN_ASSERT(Entry.IsHash()); - AttachmentsToSend.push_back(Entry.AsHash()); - } - } - } - - // Prepare package for send - - CbPackage SendPackage; - SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); - - for (const IoHash& AttachmentCid : AttachmentsToSend) - { - const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); - - if (Attachment) - { - SendPackage.AddAttachment(*Attachment); - } - else - { - // This should be an error -- server asked to have something we can't find - } - } - - // Transmit package payload - - CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); - SharedBuffer FlatMessage = Message.Flatten(); - - Sess.SetHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); - Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); - - cpr::Response FilterResponse = Sess.Post(); - - if (!IsHttpSuccessCode(FilterResponse.status_code)) - { - return FromCprResponse(FilterResponse); - } - - IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); - - if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) - { - HttpContentType ContentType = ParseContentType(It->second); - - ResponseBuffer.SetContentType(ContentType); - } - - return {.StatusCode = int(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; -} - -HttpClient::Response -HttpClient::Put(std::string_view Url, IoBuffer Payload) -{ - ZEN_UNUSED(Url); - ZEN_UNUSED(Payload); - return {}; -} - -HttpClient::Response -HttpClient::Get(std::string_view Url) -{ - ZEN_UNUSED(Url); - return {}; -} - -HttpClient::Response -HttpClient::Delete(std::string_view Url) -{ - ZEN_UNUSED(Url); - return {}; -} - -////////////////////////////////////////////////////////////////////////// - -#if ZEN_WITH_TESTS - -TEST_CASE("httpclient") -{ - using namespace std::literals; - - SUBCASE("client") {} -} - -void -httpclient_forcelink() -{ -} - -#endif - -} // namespace zen diff --git a/zenhttp/httpnull.cpp b/zenhttp/httpnull.cpp deleted file mode 100644 index a6e1d3567..000000000 --- a/zenhttp/httpnull.cpp +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpnull.h" - -#include <zencore/logging.h> - -#if ZEN_PLATFORM_WINDOWS -# include <conio.h> -#endif - -namespace zen { - -HttpNullServer::HttpNullServer() -{ -} - -HttpNullServer::~HttpNullServer() -{ -} - -void -HttpNullServer::RegisterService(HttpService& Service) -{ - ZEN_UNUSED(Service); -} - -int -HttpNullServer::Initialize(int BasePort) -{ - return BasePort; -} - -void -HttpNullServer::Run(bool IsInteractiveSession) -{ - const bool TestMode = !IsInteractiveSession; - - int WaitTimeout = -1; - if (!TestMode) - { - WaitTimeout = 1000; - } - -#if ZEN_PLATFORM_WINDOWS - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); - } - - do - { - if (!TestMode && _kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#else - if (TestMode == false) - { - zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit"); - } - - do - { - m_ShutdownEvent.Wait(WaitTimeout); - } while (!IsApplicationExitRequested()); -#endif -} - -void -HttpNullServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} - -} // namespace zen diff --git a/zenhttp/httpnull.h b/zenhttp/httpnull.h deleted file mode 100644 index 74f021f6b..000000000 --- a/zenhttp/httpnull.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/thread.h> -#include <zenhttp/httpserver.h> - -namespace zen { - -/** - * @brief Null implementation of "http" server. Does nothing - */ - -class HttpNullServer : public HttpServer -{ -public: - HttpNullServer(); - ~HttpNullServer(); - - virtual void RegisterService(HttpService& Service) override; - virtual int Initialize(int BasePort) override; - virtual void Run(bool IsInteractiveSession) override; - virtual void RequestExit() override; - -private: - Event m_ShutdownEvent; -}; - -} // namespace zen diff --git a/zenhttp/httpserver.cpp b/zenhttp/httpserver.cpp deleted file mode 100644 index 671cbd319..000000000 --- a/zenhttp/httpserver.cpp +++ /dev/null @@ -1,885 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/httpserver.h> - -#include "httpasio.h" -#include "httpnull.h" -#include "httpsys.h" - -#include <zencore/compactbinary.h> -#include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinarypackage.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/refcount.h> -#include <zencore/stream.h> -#include <zencore/string.h> -#include <zencore/testing.h> -#include <zencore/thread.h> -#include <zenhttp/httpshared.h> - -#include <charconv> -#include <mutex> -#include <span> -#include <string_view> - -namespace zen { - -using namespace std::literals; - -std::string_view -MapContentTypeToString(HttpContentType ContentType) -{ - switch (ContentType) - { - default: - case HttpContentType::kUnknownContentType: - case HttpContentType::kBinary: - return "application/octet-stream"sv; - - case HttpContentType::kText: - return "text/plain"sv; - - case HttpContentType::kJSON: - return "application/json"sv; - - case HttpContentType::kCbObject: - return "application/x-ue-cb"sv; - - case HttpContentType::kCbPackage: - return "application/x-ue-cbpkg"sv; - - case HttpContentType::kCbPackageOffer: - return "application/x-ue-offer"sv; - - case HttpContentType::kCompressedBinary: - return "application/x-ue-comp"sv; - - case HttpContentType::kYAML: - return "text/yaml"sv; - - case HttpContentType::kHTML: - return "text/html"sv; - - case HttpContentType::kJavaScript: - return "application/javascript"sv; - - case HttpContentType::kCSS: - return "text/css"sv; - - case HttpContentType::kPNG: - return "image/png"sv; - - case HttpContentType::kIcon: - return "image/x-icon"sv; - } -} - -////////////////////////////////////////////////////////////////////////// -// -// Note that in addition to MIME types we accept abbreviated versions, for -// use in suffix parsing as well as for convenience when using curl - -static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); -static constinit uint32_t HashJson = HashStringDjb2("json"sv); -static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv); -static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv); -static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv); -static constinit uint32_t HashText = HashStringDjb2("text/plain"sv); -static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv); -static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv); -static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); -static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv); -static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv); -static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv); -static constinit uint32_t HashHtml = HashStringDjb2("html"sv); -static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv); -static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv); -static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv); -static constinit uint32_t HashCss = HashStringDjb2("css"sv); -static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv); -static constinit uint32_t HashPng = HashStringDjb2("png"sv); -static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv); -static constinit uint32_t HashIcon = HashStringDjb2("ico"sv); -static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv); - -std::once_flag InitContentTypeLookup; - -struct HashedTypeEntry -{ - uint32_t Hash; - HttpContentType Type; -} TypeHashTable[] = { - // clang-format off - {HashBinary, HttpContentType::kBinary}, - {HashApplicationCompactBinary, HttpContentType::kCbObject}, - {HashCompactBinary, HttpContentType::kCbObject}, - {HashCompactBinaryPackage, HttpContentType::kCbPackage}, - {HashCompactBinaryPackageShort, HttpContentType::kCbPackage}, - {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer}, - {HashJson, HttpContentType::kJSON}, - {HashApplicationJson, HttpContentType::kJSON}, - {HashYaml, HttpContentType::kYAML}, - {HashTextYaml, HttpContentType::kYAML}, - {HashText, HttpContentType::kText}, - {HashCompressedBinary, HttpContentType::kCompressedBinary}, - {HashHtml, HttpContentType::kHTML}, - {HashTextHtml, HttpContentType::kHTML}, - {HashJavaScript, HttpContentType::kJavaScript}, - {HashApplicationJavaScript, HttpContentType::kJavaScript}, - {HashCss, HttpContentType::kCSS}, - {HashTextCss, HttpContentType::kCSS}, - {HashPng, HttpContentType::kPNG}, - {HashImagePng, HttpContentType::kPNG}, - {HashIcon, HttpContentType::kIcon}, - {HashImageIcon, HttpContentType::kIcon}, - // clang-format on -}; - -HttpContentType -ParseContentTypeImpl(const std::string_view& ContentTypeString) -{ - if (!ContentTypeString.empty()) - { - const uint32_t CtHash = HashStringDjb2(ContentTypeString); - - if (auto It = std::lower_bound(std::begin(TypeHashTable), - std::end(TypeHashTable), - CtHash, - [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; }); - It != std::end(TypeHashTable)) - { - if (It->Hash == CtHash) - { - return It->Type; - } - } - } - - return HttpContentType::kUnknownContentType; -} - -HttpContentType -ParseContentTypeInit(const std::string_view& ContentTypeString) -{ - std::call_once(InitContentTypeLookup, [] { - std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) { - return Lhs.Hash < Rhs.Hash; - }); - - // validate that there are no hash collisions - - uint32_t LastHash = 0; - - for (const auto& Item : TypeHashTable) - { - ZEN_ASSERT(LastHash != Item.Hash); - LastHash = Item.Hash; - } - }); - - ParseContentType = ParseContentTypeImpl; - - return ParseContentTypeImpl(ContentTypeString); -} - -HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit; - -bool -TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) -{ - if (RangeHeader.empty()) - { - return false; - } - - const size_t Count = Ranges.size(); - - std::size_t UnitDelim = RangeHeader.find_first_of('='); - if (UnitDelim == std::string_view::npos) - { - return false; - } - - // only bytes for now - std::string_view Unit = RangeHeader.substr(0, UnitDelim); - if (Unit != "bytes"sv) - { - return false; - } - - std::string_view Tokens = RangeHeader.substr(UnitDelim); - while (!Tokens.empty()) - { - // Skip =, - Tokens = Tokens.substr(1); - - size_t Delim = Tokens.find_first_of(','); - if (Delim == std::string_view::npos) - { - Delim = Tokens.length(); - } - - std::string_view Token = Tokens.substr(0, Delim); - Tokens = Tokens.substr(Delim); - - Delim = Token.find_first_of('-'); - if (Delim == std::string_view::npos) - { - return false; - } - - const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim)); - const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1)); - - if (Start.has_value() && End.has_value() && End.value() > Start.value()) - { - Ranges.push_back({.Start = Start.value(), .End = End.value()}); - } - else if (Start) - { - Ranges.push_back({.Start = Start.value()}); - } - else if (End) - { - Ranges.push_back({.End = End.value()}); - } - } - - return Count != Ranges.size(); -} - -////////////////////////////////////////////////////////////////////////// - -const std::string_view -ToString(HttpVerb Verb) -{ - switch (Verb) - { - case HttpVerb::kGet: - return "GET"sv; - case HttpVerb::kPut: - return "PUT"sv; - case HttpVerb::kPost: - return "POST"sv; - case HttpVerb::kDelete: - return "DELETE"sv; - case HttpVerb::kHead: - return "HEAD"sv; - case HttpVerb::kCopy: - return "COPY"sv; - case HttpVerb::kOptions: - return "OPTIONS"sv; - default: - return "???"sv; - } -} - -std::string_view -ReasonStringForHttpResultCode(int HttpCode) -{ - switch (HttpCode) - { - // 1xx Informational - - case 100: - return "Continue"sv; - case 101: - return "Switching Protocols"sv; - - // 2xx Success - - case 200: - return "OK"sv; - case 201: - return "Created"sv; - case 202: - return "Accepted"sv; - case 204: - return "No Content"sv; - case 205: - return "Reset Content"sv; - case 206: - return "Partial Content"sv; - - // 3xx Redirection - - case 300: - return "Multiple Choices"sv; - case 301: - return "Moved Permanently"sv; - case 302: - return "Found"sv; - case 303: - return "See Other"sv; - case 304: - return "Not Modified"sv; - case 305: - return "Use Proxy"sv; - case 306: - return "Switch Proxy"sv; - case 307: - return "Temporary Redirect"sv; - case 308: - return "Permanent Redirect"sv; - - // 4xx Client errors - - case 400: - return "Bad Request"sv; - case 401: - return "Unauthorized"sv; - case 402: - return "Payment Required"sv; - case 403: - return "Forbidden"sv; - case 404: - return "Not Found"sv; - case 405: - return "Method Not Allowed"sv; - case 406: - return "Not Acceptable"sv; - case 407: - return "Proxy Authentication Required"sv; - case 408: - return "Request Timeout"sv; - case 409: - return "Conflict"sv; - case 410: - return "Gone"sv; - case 411: - return "Length Required"sv; - case 412: - return "Precondition Failed"sv; - case 413: - return "Payload Too Large"sv; - case 414: - return "URI Too Long"sv; - case 415: - return "Unsupported Media Type"sv; - case 416: - return "Range Not Satisifiable"sv; - case 417: - return "Expectation Failed"sv; - case 418: - return "I'm a teapot"sv; - case 421: - return "Misdirected Request"sv; - case 422: - return "Unprocessable Entity"sv; - case 423: - return "Locked"sv; - case 424: - return "Failed Dependency"sv; - case 425: - return "Too Early"sv; - case 426: - return "Upgrade Required"sv; - case 428: - return "Precondition Required"sv; - case 429: - return "Too Many Requests"sv; - case 431: - return "Request Header Fields Too Large"sv; - - // 5xx Server errors - - case 500: - return "Internal Server Error"sv; - case 501: - return "Not Implemented"sv; - case 502: - return "Bad Gateway"sv; - case 503: - return "Service Unavailable"sv; - case 504: - return "Gateway Timeout"sv; - case 505: - return "HTTP Version Not Supported"sv; - case 506: - return "Variant Also Negotiates"sv; - case 507: - return "Insufficient Storage"sv; - case 508: - return "Loop Detected"sv; - case 510: - return "Not Extended"sv; - case 511: - return "Network Authentication Required"sv; - - default: - return "Unknown Result"sv; - } -} - -////////////////////////////////////////////////////////////////////////// - -Ref<IHttpPackageHandler> -HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) -{ - ZEN_UNUSED(HttpServiceRequest); - - return Ref<IHttpPackageHandler>(); -} - -////////////////////////////////////////////////////////////////////////// - -HttpServerRequest::HttpServerRequest() -{ -} - -HttpServerRequest::~HttpServerRequest() -{ -} - -void -HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data) -{ - std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data); - return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers); -} - -void -HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data) -{ - if (m_AcceptType == HttpContentType::kJSON) - { - ExtendableStringBuilder<1024> Sb; - WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView()); - } - else - { - SharedBuffer Buf = Data.GetBuffer(); - std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; - return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); - } -} - -void -HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array) -{ - if (m_AcceptType == HttpContentType::kJSON) - { - ExtendableStringBuilder<1024> Sb; - WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView()); - } - else - { - SharedBuffer Buf = Array.GetBuffer(); - std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; - return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); - } -} - -void -HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString) -{ - return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()}); -} - -void -HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob) -{ - std::array<IoBuffer, 1> Buffers{Blob}; - return WriteResponse(ResponseCode, ContentType, Buffers); -} - -void -HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) -{ - std::span<const SharedBuffer> Segments = Payload.GetSegments(); - - std::vector<IoBuffer> Buffers; - - for (auto& Segment : Segments) - { - Buffers.push_back(Segment.AsIoBuffer()); - } - - WriteResponse(ResponseCode, ContentType, Buffers); -} - -HttpServerRequest::QueryParams -HttpServerRequest::GetQueryParams() -{ - QueryParams Params; - - const std::string_view QStr = QueryString(); - - const char* QueryIt = QStr.data(); - const char* QueryEnd = QueryIt + QStr.size(); - - while (QueryIt != QueryEnd) - { - if (*QueryIt == '&') - { - ++QueryIt; - continue; - } - - size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt); - const std::string_view Query{QueryIt, QueryLen}; - - size_t DelimIndex = Query.find('&', 0); - - if (DelimIndex == std::string_view::npos) - { - DelimIndex = Query.size(); - } - - std::string_view ThisQuery{QueryIt, DelimIndex}; - - size_t EqIndex = ThisQuery.find('=', 0); - - if (EqIndex != std::string_view::npos) - { - std::string_view Param{ThisQuery.data(), EqIndex}; - ThisQuery.remove_prefix(EqIndex + 1); - - Params.KvPairs.emplace_back(Param, ThisQuery); - } - - QueryIt += DelimIndex; - } - - return Params; -} - -Oid -HttpServerRequest::SessionId() const -{ - if (m_Flags & kHaveSessionId) - { - return m_SessionId; - } - - m_SessionId = ParseSessionId(); - m_Flags |= kHaveSessionId; - return m_SessionId; -} - -uint32_t -HttpServerRequest::RequestId() const -{ - if (m_Flags & kHaveRequestId) - { - return m_RequestId; - } - - m_RequestId = ParseRequestId(); - m_Flags |= kHaveRequestId; - return m_RequestId; -} - -CbObject -HttpServerRequest::ReadPayloadObject() -{ - if (IoBuffer Payload = ReadPayload()) - { - return LoadCompactBinaryObject(std::move(Payload)); - } - - return {}; -} - -CbPackage -HttpServerRequest::ReadPayloadPackage() -{ - if (IoBuffer Payload = ReadPayload()) - { - return ParsePackageMessage(std::move(Payload)); - } - - return {}; -} - -////////////////////////////////////////////////////////////////////////// - -void -HttpRequestRouter::AddPattern(const char* Id, const char* Regex) -{ - ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); - - m_PatternMap.insert({Id, Regex}); -} - -void -HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) -{ - ExtendableStringBuilder<128> ExpandedRegex; - ProcessRegexSubstitutions(Regex, ExpandedRegex); - - m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); -} - -void -HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) -{ - size_t RegexLen = strlen(Regex); - - for (size_t i = 0; i < RegexLen;) - { - bool matched = false; - - if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) - { - // Might have a pattern reference - find closing brace - - for (size_t j = i + 1; j < RegexLen; ++j) - { - if (Regex[j] == '}') - { - std::string Pattern(&Regex[i + 1], j - i - 1); - - if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) - { - OutExpandedRegex.Append(it->second.c_str()); - } - else - { - // Default to anything goes (or should this just be an error?) - - OutExpandedRegex.Append("(.+?)"); - } - - // skip ahead - i = j + 1; - - matched = true; - - break; - } - } - } - - if (!matched) - { - OutExpandedRegex.Append(Regex[i++]); - } - } -} - -bool -HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) -{ - const HttpVerb Verb = Request.RequestVerb(); - - std::string_view Uri = Request.RelativeUri(); - HttpRouterRequest RouterRequest(Request); - - for (const auto& Handler : m_Handlers) - { - if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) - { - Handler.Handler(RouterRequest); - - return true; // Route matched - } - } - - return false; // No route matched -} - -////////////////////////////////////////////////////////////////////////// - -HttpRpcHandler::HttpRpcHandler() -{ -} - -HttpRpcHandler::~HttpRpcHandler() -{ -} - -void -HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction) -{ - ZEN_UNUSED(RpcId, HandlerFunction); -} - -////////////////////////////////////////////////////////////////////////// - -enum class HttpServerClass -{ - kHttpAsio, - kHttpSys, - kHttpNull -}; - -// Implemented in httpsys.cpp -Ref<HttpServer> CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads); - -Ref<HttpServer> -CreateHttpServer(std::string_view ServerClass) -{ - using namespace std::literals; - - HttpServerClass Class = HttpServerClass::kHttpNull; - -#if ZEN_WITH_HTTPSYS - Class = HttpServerClass::kHttpSys; -#elif 1 - Class = HttpServerClass::kHttpAsio; -#endif - - if (ServerClass == "asio"sv) - { - Class = HttpServerClass::kHttpAsio; - } - else if (ServerClass == "httpsys"sv) - { - Class = HttpServerClass::kHttpSys; - } - else if (ServerClass == "null"sv) - { - Class = HttpServerClass::kHttpNull; - } - - switch (Class) - { - default: - case HttpServerClass::kHttpAsio: - ZEN_INFO("using asio HTTP server implementation"); - return Ref<HttpServer>(new HttpAsioServer()); - -#if ZEN_WITH_HTTPSYS - case HttpServerClass::kHttpSys: - ZEN_INFO("using http.sys server implementation"); - return Ref<HttpServer>(new HttpSysServer(std::thread::hardware_concurrency(), /* background worker threads */ 16)); -#endif - - case HttpServerClass::kHttpNull: - ZEN_INFO("using null HTTP server implementation"); - return Ref<HttpServer>(new HttpNullServer); - } -} - -////////////////////////////////////////////////////////////////////////// - -bool -HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef) -{ - if (Request.RequestVerb() == HttpVerb::kPost) - { - if (Request.RequestContentType() == HttpContentType::kCbPackageOffer) - { - // The client is presenting us with a package attachments offer, we need - // to filter it down to the list of attachments we need them to send in - // the follow-up request - - PackageHandlerRef = Service.HandlePackageRequest(Request); - - if (PackageHandlerRef) - { - CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload()); - - std::vector<IoHash> OfferCids; - - for (auto& CidEntry : OfferMessage["offer"]) - { - if (!CidEntry.IsHash()) - { - // Should yield bad request response? - - ZEN_WARN("found invalid entry in offer"); - - continue; - } - - OfferCids.push_back(CidEntry.AsHash()); - } - - ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size()); - - PackageHandlerRef->FilterOffer(OfferCids); - - ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size()); - - CbObjectWriter ResponseWriter; - ResponseWriter.BeginArray("need"); - - for (const IoHash& Cid : OfferCids) - { - ResponseWriter.AddHash(Cid); - } - - ResponseWriter.EndArray(); - - // Emit filter response - Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); - return true; - } - } - else if (Request.RequestContentType() == HttpContentType::kCbPackage) - { - // Process chunks in package request - - PackageHandlerRef = Service.HandlePackageRequest(Request); - - // TODO: this should really be done in a streaming fashion, currently this emulates - // the intended flow from an API perspective - - if (PackageHandlerRef) - { - PackageHandlerRef->OnRequestBegin(); - - auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer { - return PackageHandlerRef->CreateTarget(Cid, Size); - }; - - CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); - - PackageHandlerRef->OnRequestComplete(); - } - } - } - return false; -} - -////////////////////////////////////////////////////////////////////////// - -#if ZEN_WITH_TESTS - -TEST_CASE("http.common") -{ - using namespace std::literals; - - SUBCASE("router") - { - HttpRequestRouter r; - r.AddPattern("a", "[[:alpha:]]+"); - r.RegisterRoute( - "{a}", - [&](auto) {}, - HttpVerb::kGet); - - // struct TestHttpServerRequest : public HttpServerRequest - //{ - // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {} - //}; - - // TestHttpServerRequest req{}; - // r.HandleRequest(req); - } - - SUBCASE("content-type") - { - for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i) - { - HttpContentType Ct{i}; - - if (Ct != HttpContentType::kUnknownContentType) - { - CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct))); - } - } - } -} - -void -http_forcelink() -{ -} - -#endif - -} // namespace zen diff --git a/zenhttp/httpshared.cpp b/zenhttp/httpshared.cpp deleted file mode 100644 index 7aade56d2..000000000 --- a/zenhttp/httpshared.cpp +++ /dev/null @@ -1,809 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/httpshared.h> - -#include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinarypackage.h> -#include <zencore/compositebuffer.h> -#include <zencore/filesystem.h> -#include <zencore/fmtutils.h> -#include <zencore/iobuffer.h> -#include <zencore/iohash.h> -#include <zencore/logging.h> -#include <zencore/scopeguard.h> -#include <zencore/stream.h> -#include <zencore/testing.h> -#include <zencore/testutils.h> - -#include <span> -#include <vector> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <tsl/robin_map.h> -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen { - -const std::string_view HandlePrefix(":?#:"); - -std::vector<IoBuffer> -FormatPackageMessage(const CbPackage& Data, int TargetProcessPid) -{ - return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessPid); -} -CompositeBuffer -FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid) -{ - return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessPid); -} - -CompositeBuffer -FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid) -{ - std::vector<IoBuffer> Message = FormatPackageMessage(Data, Flags, TargetProcessPid); - - std::vector<SharedBuffer> Buffers; - - for (IoBuffer& Buf : Message) - { - Buffers.push_back(SharedBuffer(Buf)); - } - - return CompositeBuffer(std::move(Buffers)); -} - -std::vector<IoBuffer> -FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid) -{ - void* TargetProcessHandle = nullptr; -#if ZEN_PLATFORM_WINDOWS - std::vector<HANDLE> DuplicatedHandles; - auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() { - if (TargetProcessHandle == nullptr) - { - return; - } - - for (HANDLE DuplicatedHandle : DuplicatedHandles) - { - HANDLE ClosingHandle; - if (::DuplicateHandle((HANDLE)TargetProcessHandle, - DuplicatedHandle, - GetCurrentProcess(), - &ClosingHandle, - 0, - FALSE, - DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE) - { - ::CloseHandle(ClosingHandle); - } - } - ::CloseHandle((HANDLE)TargetProcessHandle); - TargetProcessHandle = nullptr; - }); - - if (EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && TargetProcessPid != 0) - { - TargetProcessHandle = OpenProcess(PROCESS_DUP_HANDLE, FALSE, TargetProcessPid); - } -#else - ZEN_UNUSED(TargetProcessPid); - void* DuplicatedHandles = nullptr; -#endif // ZEN_PLATFORM_WINDOWS - - const std::span<const CbAttachment>& Attachments = Data.GetAttachments(); - std::vector<IoBuffer> ResponseBuffers; - - ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each - // attachment is likely to consist of several buffers - - // Fixed size header - - CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())}; - - ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr)); - - // Attachment metadata array - - IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)}; - CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData()); - - ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata - - // Root object - - IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer(); - ResponseBuffers.push_back(RootIoBuffer); // Root object - - *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()}; - - // Attachment payloads - - auto MarshalLocal = [&AttachmentInfo, &ResponseBuffers](const std::string& Path8, - CbAttachmentReferenceHeader& LocalRef, - const IoHash& AttachmentHash, - bool IsCompressed) { - IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size()); - - CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>(); - *RefHdr++ = LocalRef; - memcpy(RefHdr, Path8.data(), Path8.size()); - - *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(), - .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef, - .AttachmentHash = AttachmentHash}; - - ResponseBuffers.push_back(std::move(RefBuffer)); - }; - - tsl::robin_map<void*, std::string> FileNameMap; - - auto IsLocalRef = [&FileNameMap, &DuplicatedHandles](const CompositeBuffer& AttachmentBinary, - bool DenyPartialLocalReferences, - void* TargetProcessHandle, - CbAttachmentReferenceHeader& LocalRef, - std::string& Path8) -> bool { - const SharedBuffer& Segment = AttachmentBinary.GetSegments().front(); - IoBufferFileReference Ref; - const IoBuffer& SegmentBuffer = Segment.AsIoBuffer(); - - if (!SegmentBuffer.GetFileReference(Ref)) - { - return false; - } - - if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile()) - { - return false; - } - - if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end()) - { - Path8 = It->second; - } - else - { - bool UseFilePath = true; -#if ZEN_PLATFORM_WINDOWS - if (TargetProcessHandle != nullptr) - { - HANDLE TargetHandle = INVALID_HANDLE_VALUE; - BOOL OK = ::DuplicateHandle(GetCurrentProcess(), - Ref.FileHandle, - (HANDLE)TargetProcessHandle, - &TargetHandle, - FILE_GENERIC_READ, - FALSE, - 0); - if (OK) - { - DuplicatedHandles.push_back(TargetHandle); - Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle)); - UseFilePath = false; - } - } -#else // ZEN_PLATFORM_WINDOWS - ZEN_UNUSED(TargetProcessHandle); - // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to - // deal with acceess rights etc. -#endif // ZEN_PLATFORM_WINDOWS - if (UseFilePath) - { - ExtendablePathBuilder<256> LocalRefFile; - LocalRefFile.Append(std::filesystem::absolute(PathFromHandle(Ref.FileHandle))); - Path8 = LocalRefFile.ToUtf8(); - } - FileNameMap.insert_or_assign(Ref.FileHandle, Path8); - } - - LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size()); - LocalRef.PayloadByteOffset = Ref.FileChunkOffset; - LocalRef.PayloadByteSize = Ref.FileChunkSize; - - return true; - }; - - for (const CbAttachment& Attachment : Attachments) - { - if (Attachment.IsNull()) - { - ZEN_NOT_IMPLEMENTED("Null attachments are not supported"); - } - else if (CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary()) - { - CompositeBuffer Compressed = AttachmentBuffer.GetCompressed(); - IoHash AttachmentHash = Attachment.GetHash(); - - // If the data is either not backed by a file, or there are multiple - // fragments then we cannot marshal it by local reference. We might - // want/need to extend this in the future to allow multiple chunk - // segments to be marshaled at once - - bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1); - bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); - CbAttachmentReferenceHeader LocalRef; - std::string Path8; - - if (MarshalByLocalRef) - { - MarshalByLocalRef = IsLocalRef(Compressed, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8); - } - - if (MarshalByLocalRef) - { - const bool IsCompressed = true; - bool IsHandle = false; -#if ZEN_PLATFORM_WINDOWS - IsHandle = Path8.starts_with(HandlePrefix); -#endif - MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed); - ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize()); - } - else - { - *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(), - .Flags = CbAttachmentEntry::kIsCompressed, - .AttachmentHash = AttachmentHash}; - - for (const SharedBuffer& Segment : Compressed.GetSegments()) - { - ResponseBuffers.push_back(Segment.AsIoBuffer()); - } - } - } - else if (CbObject AttachmentObject = Attachment.AsObject()) - { - IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer(); - ResponseBuffers.push_back(ObjIoBuffer); - - *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(), - .Flags = CbAttachmentEntry::kIsObject, - .AttachmentHash = Attachment.GetHash()}; - } - else if (CompositeBuffer AttachmentBinary = Attachment.AsCompositeBinary()) - { - IoHash AttachmentHash = Attachment.GetHash(); - bool MarshalByLocalRef = - EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1); - bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); - - CbAttachmentReferenceHeader LocalRef; - std::string Path8; - - if (MarshalByLocalRef) - { - MarshalByLocalRef = IsLocalRef(AttachmentBinary, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8); - } - - if (MarshalByLocalRef) - { - const bool IsCompressed = false; - bool IsHandle = false; -#if ZEN_PLATFORM_WINDOWS - IsHandle = Path8.starts_with(HandlePrefix); -#endif - MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed); - ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize()); - } - else - { - *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()}; - - for (const SharedBuffer& Segment : AttachmentBinary.GetSegments()) - { - ResponseBuffers.push_back(Segment.AsIoBuffer()); - } - } - } - else - { - ZEN_NOT_IMPLEMENTED("Unknown attachment kind"); - } - } - FileNameMap.clear(); -#if ZEN_PLATFORM_WINDOWS - DuplicatedHandles.clear(); -#endif // ZEN_PLATFORM_WINDOWS - - return ResponseBuffers; -} - -bool -IsPackageMessage(IoBuffer Payload) -{ - if (!Payload) - { - return false; - } - - BinaryReader Reader(Payload); - - CbPackageHeader Hdr; - Reader.Read(&Hdr, sizeof Hdr); - - if (Hdr.HeaderMagic != kCbPkgMagic) - { - return false; - } - - return true; -} - -CbPackage -ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) -{ - if (!Payload) - { - return {}; - } - - BinaryReader Reader(Payload); - - CbPackageHeader Hdr; - Reader.Read(&Hdr, sizeof Hdr); - - if (Hdr.HeaderMagic != kCbPkgMagic) - { - throw std::runtime_error("invalid CbPackage header magic"); - } - - const uint32_t ChunkCount = Hdr.AttachmentCount + 1; - - std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]}; - - Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount); - - CbPackage Package; - - std::vector<CbAttachment> Attachments; - Attachments.reserve(ChunkCount); // Guessing here... - - tsl::robin_map<std::string, IoBuffer> PartialFileBuffers; - - // TODO: Throwing before this loop completes could result in leaking handles as we might not have picked up all the handles in the - // message - for (uint32_t i = 0; i < ChunkCount; ++i) - { - const CbAttachmentEntry& Entry = AttachmentEntries[i]; - const uint64_t AttachmentSize = Entry.PayloadSize; - - const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize); - Reader.Skip(AttachmentSize); - - if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) - { - // Marshal local reference - a "pointer" to the chunk backing file - - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); - - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); - - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); - std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); - - IoBuffer FullFileBuffer; - - std::filesystem::path Path(Utf8ToWide(PathView)); - if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end()) - { - FullFileBuffer = It->second; - } - else - { - if (PathView.starts_with(HandlePrefix)) - { -#if ZEN_PLATFORM_WINDOWS - std::string_view HandleString(PathView.substr(HandlePrefix.length())); - std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString)); - if (HandleNumber.has_value()) - { - HANDLE FileHandle = HANDLE(HandleNumber.value()); - ULARGE_INTEGER liFileSize; - liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart); - if (liFileSize.LowPart != INVALID_FILE_SIZE) - { - FullFileBuffer = IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart)); - PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer); - } - } -#else // ZEN_PLATFORM_WINDOWS - // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes - // and to deal with acceess rights etc. - ZEN_ASSERT(false); -#endif // ZEN_PLATFORM_WINDOWS - } - else - { - FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; - } - } - - if (!FullFileBuffer) - { - // Unable to open chunk reference - throw std::runtime_error(fmt::format("unable to resolve chunk #{} at '{}' (offset {}, size {})", - i, - Path, - AttachRefHdr->PayloadByteOffset, - AttachRefHdr->PayloadByteSize)); - } - - IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() - ? FullFileBuffer - : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); - - CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference))); - if (!CompBuf) - { - throw std::runtime_error(fmt::format("invalid format for chunk #{} at '{}' (offset {}, size {})", - i, - Path, - AttachRefHdr->PayloadByteOffset, - AttachRefHdr->PayloadByteSize)); - } - Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); - } - else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) - { - if (Entry.Flags & CbAttachmentEntry::kIsObject) - { - if (i == 0) - { - CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); - if (!CompBuf) - { - throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for CbObject", i)); - } - // First payload is always a compact binary object - Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf))); - } - else - { - ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported"); - } - } - else - { - CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); - if (!CompBuf) - { - throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for attachment", i)); - } - Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); - } - } - else /* not compressed */ - { - if (Entry.Flags & CbAttachmentEntry::kIsObject) - { - if (i == 0) - { - Package.SetObject(LoadCompactBinaryObject(AttachmentBuffer)); - } - else - { - ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported"); - } - } - else - { - // Make a copy of the buffer so we attachements don't reference the entire payload - IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); - ZEN_ASSERT(AttachmentBufferCopy); - ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); - AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); - - CbAttachment Attachment(SharedBuffer{AttachmentBufferCopy}); - Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); - } - } - } - PartialFileBuffers.clear(); - - Package.AddAttachments(Attachments); - - return Package; -} - -bool -ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage) -{ - if (IsPackageMessage(Response)) - { - OutPackage = ParsePackageMessage(Response); - return true; - } - return OutPackage.TryLoad(Response); -} - -CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) -{ -} - -CbPackageReader::~CbPackageReader() -{ -} - -void -CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer) -{ - m_CreateBuffer = CreateBuffer; -} - -uint64_t -CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) -{ - ZEN_ASSERT(m_CurrentState != State::kReadingBuffers); - - switch (m_CurrentState) - { - case State::kInitialState: - ZEN_ASSERT(Data == nullptr); - m_CurrentState = State::kReadingHeader; - return sizeof m_PackageHeader; - - case State::kReadingHeader: - ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); - memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); - ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); - m_CurrentState = State::kReadingAttachmentEntries; - m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); - return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); - - case State::kReadingAttachmentEntries: - ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); - memcpy(m_AttachmentEntries.data(), Data, DataBytes); - - for (CbAttachmentEntry& Entry : m_AttachmentEntries) - { - // This preallocates memory for payloads but note that for the local references - // the caller will need to handle the payload differently (i.e it's a - // CbAttachmentReferenceHeader not the actual payload) - - m_PayloadBuffers.push_back(IoBuffer{Entry.PayloadSize}); - } - - m_CurrentState = State::kReadingBuffers; - return 0; - - default: - ZEN_ASSERT(false); - return 0; - } -} - -IoBuffer -CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) -{ - // Marshal local reference - a "pointer" to the chunk backing file - - ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); - - const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); - const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); - - ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); - - std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; - - std::filesystem::path Path{PathView}; - - IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); - - if (!ChunkReference) - { - // Unable to open chunk reference - - throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})", - PathToUtf8(Path), - AttachRefHdr->PayloadByteOffset, - AttachRefHdr->PayloadByteSize)); - } - - return ChunkReference; -}; - -void -CbPackageReader::Finalize() -{ - if (m_AttachmentEntries.empty()) - { - return; - } - - m_Attachments.reserve(m_AttachmentEntries.size() - 1); - - int CurrentAttachmentIndex = 0; - for (CbAttachmentEntry& Entry : m_AttachmentEntries) - { - IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; - - if (CurrentAttachmentIndex == 0) - { - // Root object - if (Entry.Flags & CbAttachmentEntry::kIsObject) - { - if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) - { - m_RootObject = LoadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer)); - } - else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) - { - IoHash RawHash; - uint64_t RawSize; - CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize); - if (RawHash == Entry.AttachmentHash) - { - m_RootObject = LoadCompactBinaryObject(Compressed); - } - } - else - { - m_RootObject = LoadCompactBinaryObject(std::move(AttachmentBuffer)); - } - } - else - { - throw std::runtime_error("missing or invalid root object"); - } - } - else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) - { - IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer); - - if (Entry.Flags & CbAttachmentEntry::kIsCompressed) - { - IoHash RawHash; - uint64_t RawSize; - CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize); - if (RawHash == Entry.AttachmentHash) - { - m_Attachments.push_back(CbAttachment(Compressed, Entry.AttachmentHash)); - } - } - else - { - CompressedBuffer Compressed = - CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None); - m_Attachments.push_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash())); - } - } - - ++CurrentAttachmentIndex; - } -} - -/** - ______________________ _____________________________ - \__ ___/\_ _____// _____/\__ ___/ _____/ - | | | __)_ \_____ \ | | \_____ \ - | | | \/ \ | | / \ - |____| /_______ /_______ / |____| /_______ / - \/ \/ \/ - */ - -#if ZEN_WITH_TESTS - -TEST_CASE("CbPackage.Serialization") -{ - // Make a test package - - CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))}; - CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))}; - - CbObjectWriter Cbo; - Cbo.AddAttachment("abcd", Attach1); - Cbo.AddAttachment("efgh", Attach2); - - CbPackage Pkg; - Pkg.AddAttachment(Attach1); - Pkg.AddAttachment(Attach2); - Pkg.SetObject(Cbo.Save()); - - SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten(); - const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); - uint64_t RemainingBytes = Buffer.GetSize(); - - auto ConsumeBytes = [&](uint64_t ByteCount) { - ZEN_ASSERT(ByteCount <= RemainingBytes); - void* ReturnPtr = (void*)CursorPtr; - CursorPtr += ByteCount; - RemainingBytes -= ByteCount; - return ReturnPtr; - }; - - auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { - ZEN_ASSERT(ByteCount <= RemainingBytes); - memcpy(TargetBuffer, CursorPtr, ByteCount); - CursorPtr += ByteCount; - RemainingBytes -= ByteCount; - }; - - CbPackageReader Reader; - uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); - uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); - NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); - auto Buffers = Reader.GetPayloadBuffers(); - - for (auto& PayloadBuffer : Buffers) - { - CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); - } - - Reader.Finalize(); -} - -TEST_CASE("CbPackage.LocalRef") -{ - ScopedTemporaryDirectory TempDir; - - auto Path1 = TempDir.Path() / "abcd"; - auto Path2 = TempDir.Path() / "efgh"; - - { - IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd")); - IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh")); - - WriteFile(Path1, Buffer1); - WriteFile(Path2, Buffer2); - } - - // Make a test package - - IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1); - IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2); - - CbAttachment Attach1{SharedBuffer(FileBuffer1)}; - CbAttachment Attach2{SharedBuffer(FileBuffer2)}; - - CbObjectWriter Cbo; - Cbo.AddAttachment("abcd", Attach1); - Cbo.AddAttachment("efgh", Attach2); - - CbPackage Pkg; - Pkg.AddAttachment(Attach1); - Pkg.AddAttachment(Attach2); - Pkg.SetObject(Cbo.Save()); - - SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); - const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); - uint64_t RemainingBytes = Buffer.GetSize(); - - auto ConsumeBytes = [&](uint64_t ByteCount) { - ZEN_ASSERT(ByteCount <= RemainingBytes); - void* ReturnPtr = (void*)CursorPtr; - CursorPtr += ByteCount; - RemainingBytes -= ByteCount; - return ReturnPtr; - }; - - auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { - ZEN_ASSERT(ByteCount <= RemainingBytes); - memcpy(TargetBuffer, CursorPtr, ByteCount); - CursorPtr += ByteCount; - RemainingBytes -= ByteCount; - }; - - CbPackageReader Reader; - uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); - uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); - NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); - auto Buffers = Reader.GetPayloadBuffers(); - - for (auto& PayloadBuffer : Buffers) - { - CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); - } - - Reader.Finalize(); -} - -void -forcelink_httpshared() -{ -} - -#endif - -} // namespace zen diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp deleted file mode 100644 index c733d618d..000000000 --- a/zenhttp/httpsys.cpp +++ /dev/null @@ -1,1674 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "httpsys.h" - -#include <zencore/compactbinary.h> -#include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinarypackage.h> -#include <zencore/except.h> -#include <zencore/logging.h> -#include <zencore/scopeguard.h> -#include <zencore/string.h> -#include <zencore/timer.h> -#include <zenhttp/httpshared.h> - -#if ZEN_WITH_HTTPSYS - -# include <conio.h> -# include <mstcpip.h> -# pragma comment(lib, "httpapi.lib") - -std::wstring -UTF8_to_UTF16(const char* InPtr) -{ - std::wstring OutString; - unsigned int Codepoint; - - while (*InPtr != 0) - { - unsigned char InChar = static_cast<unsigned char>(*InPtr); - - if (InChar <= 0x7f) - Codepoint = InChar; - else if (InChar <= 0xbf) - Codepoint = (Codepoint << 6) | (InChar & 0x3f); - else if (InChar <= 0xdf) - Codepoint = InChar & 0x1f; - else if (InChar <= 0xef) - Codepoint = InChar & 0x0f; - else - Codepoint = InChar & 0x07; - - ++InPtr; - - if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) - { - if (Codepoint > 0xffff) - { - OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10))); - OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff))); - } - else if (Codepoint < 0xd800 || Codepoint >= 0xe000) - { - OutString.append(1, static_cast<wchar_t>(Codepoint)); - } - } - } - - return OutString; -} - -namespace zen { - -using namespace std::literals; - -class HttpSysServer; -class HttpSysTransaction; -class HttpMessageResponseRequest; - -////////////////////////////////////////////////////////////////////////// - -HttpVerb -TranslateHttpVerb(HTTP_VERB ReqVerb) -{ - switch (ReqVerb) - { - case HttpVerbOPTIONS: - return HttpVerb::kOptions; - - case HttpVerbGET: - return HttpVerb::kGet; - - case HttpVerbHEAD: - return HttpVerb::kHead; - - case HttpVerbPOST: - return HttpVerb::kPost; - - case HttpVerbPUT: - return HttpVerb::kPut; - - case HttpVerbDELETE: - return HttpVerb::kDelete; - - case HttpVerbCOPY: - return HttpVerb::kCopy; - - default: - // TODO: invalid request? - return (HttpVerb)0; - } -} - -uint64_t -GetContentLength(const HTTP_REQUEST* HttpRequest) -{ - const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; - std::string_view cl(clh.pRawValue, clh.RawValueLength); - uint64_t ContentLength = 0; - std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); - return ContentLength; -}; - -HttpContentType -GetContentType(const HTTP_REQUEST* HttpRequest) -{ - const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; - return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); -}; - -HttpContentType -GetAcceptType(const HTTP_REQUEST* HttpRequest) -{ - const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; - return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); -}; - -/** - * @brief Base class for any pending or active HTTP transactions - */ -class HttpSysRequestHandler -{ -public: - explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} - virtual ~HttpSysRequestHandler() = default; - - virtual void IssueRequest(std::error_code& ErrorCode) = 0; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; - HttpSysTransaction& Transaction() { return m_Transaction; } - - HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; - HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; - -private: - HttpSysTransaction& m_Transaction; -}; - -/** - * This is the handler for the initial HTTP I/O request which will receive the headers - * and however much of the remaining payload might fit in the embedded request buffer. - * - * It is also used to receive any entity body data relating to the request - * - */ -struct InitialRequestHandler : public HttpSysRequestHandler -{ - inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } - inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } - inline bool IsInitialRequest() const { return m_IsInitialRequest; } - - InitialRequestHandler(HttpSysTransaction& InRequest); - ~InitialRequestHandler(); - - virtual void IssueRequest(std::error_code& ErrorCode) override final; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - - bool m_IsInitialRequest = true; - uint64_t m_CurrentPayloadOffset = 0; - uint64_t m_ContentLength = ~uint64_t(0); - IoBuffer m_PayloadBuffer; - UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; -}; - -/** - * This is the class which request handlers use to interact with the server instance - */ - -class HttpSysServerRequest : public HttpServerRequest -{ -public: - HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); - ~HttpSysServerRequest() = default; - - virtual Oid ParseSessionId() const override; - virtual uint32_t ParseRequestId() const override; - - virtual IoBuffer ReadPayload() override; - virtual void WriteResponse(HttpResponseCode ResponseCode) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; - virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; - virtual bool TryGetRanges(HttpRanges& Ranges) override; - - using HttpServerRequest::WriteResponse; - - HttpSysServerRequest(const HttpSysServerRequest&) = delete; - HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; - - HttpSysTransaction& m_HttpTx; - HttpSysRequestHandler* m_NextCompletionHandler = nullptr; - IoBuffer m_PayloadBuffer; - ExtendableStringBuilder<128> m_UriUtf8; - ExtendableStringBuilder<128> m_QueryStringUtf8; -}; - -/** HTTP transaction - - There will be an instance of this per pending and in-flight HTTP transaction - - */ -class HttpSysTransaction final -{ -public: - HttpSysTransaction(HttpSysServer& Server); - virtual ~HttpSysTransaction(); - - enum class Status - { - kDone, - kRequestPending - }; - - Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); - - static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, - PVOID pContext /* HttpSysServer */, - PVOID pOverlapped, - ULONG IoResult, - ULONG_PTR NumberOfBytesTransferred, - PTP_IO Io); - - void IssueInitialRequest(std::error_code& ErrorCode); - bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); - - PTP_IO Iocp(); - HANDLE RequestQueueHandle(); - inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } - inline HttpSysServer& Server() { return m_HttpServer; } - inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } - - HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); - - HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } - -private: - OVERLAPPED m_HttpOverlapped{}; - HttpSysServer& m_HttpServer; - - // Tracks which handler is due to handle the next I/O completion event - HttpSysRequestHandler* m_CompletionHandler = nullptr; - RwLock m_CompletionMutex; - InitialRequestHandler m_InitialHttpHandler{*this}; - std::optional<HttpSysServerRequest> m_HandlerRequest; - Ref<IHttpPackageHandler> m_PackageHandler; -}; - -/** - * @brief HTTP request response I/O request handler - * - * Asynchronously streams out a response to an HTTP request via compound - * responses from memory or directly from file - */ - -class HttpMessageResponseRequest : public HttpSysRequestHandler -{ -public: - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - const void* Payload, - size_t PayloadSize); - HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - std::span<IoBuffer> Blobs); - ~HttpMessageResponseRequest(); - - virtual void IssueRequest(std::error_code& ErrorCode) override final; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - void SuppressResponseBody(); // typically used for HEAD requests - -private: - std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; - uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes - uint16_t m_ResponseCode = 0; - uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists - uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends - bool m_IsInitialResponse = true; - HttpContentType m_ContentType = HttpContentType::kBinary; - std::vector<IoBuffer> m_DataBuffers; - - void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); -}; - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) -: HttpSysRequestHandler(InRequest) -{ - std::array<IoBuffer, 0> EmptyBufferList; - - InitializeForPayload(ResponseCode, EmptyBufferList); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) -: HttpSysRequestHandler(InRequest) -, m_ContentType(HttpContentType::kText) -{ - IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); - std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); - - InitializeForPayload(ResponseCode, SingleBufferList); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - const void* Payload, - size_t PayloadSize) -: HttpSysRequestHandler(InRequest) -, m_ContentType(ContentType) -{ - IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); - std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); - - InitializeForPayload(ResponseCode, SingleBufferList); -} - -HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, - uint16_t ResponseCode, - HttpContentType ContentType, - std::span<IoBuffer> BlobList) -: HttpSysRequestHandler(InRequest) -, m_ContentType(ContentType) -{ - InitializeForPayload(ResponseCode, BlobList); -} - -HttpMessageResponseRequest::~HttpMessageResponseRequest() -{ -} - -void -HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) -{ - const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); - - m_HttpDataChunks.reserve(ChunkCount); - m_DataBuffers.reserve(ChunkCount); - - for (IoBuffer& Buffer : BlobList) - { - m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); - } - - // Initialize the full array up front - - uint64_t LocalDataSize = 0; - - for (IoBuffer& Buffer : m_DataBuffers) - { - uint64_t BufferDataSize = Buffer.Size(); - - ZEN_ASSERT(BufferDataSize); - - LocalDataSize += BufferDataSize; - - IoBufferFileReference FileRef; - if (Buffer.GetFileReference(/* out */ FileRef)) - { - // Use direct file transfer - - m_HttpDataChunks.push_back({}); - auto& Chunk = m_HttpDataChunks.back(); - - Chunk.DataChunkType = HttpDataChunkFromFileHandle; - Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; - Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; - Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; - } - else - { - // Send from memory, need to make sure we chunk the buffer up since - // the underlying data structure only accepts 32-bit chunk sizes for - // memory chunks. When this happens the vector will be reallocated, - // which is fine since this will be a pretty rare case and sending - // the data is going to take a lot longer than a memory allocation :) - - const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); - - while (BufferDataSize) - { - const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); - - m_HttpDataChunks.push_back({}); - auto& Chunk = m_HttpDataChunks.back(); - - Chunk.DataChunkType = HttpDataChunkFromMemory; - Chunk.FromMemory.pBuffer = (void*)WriteCursor; - Chunk.FromMemory.BufferLength = ThisChunkSize; - - BufferDataSize -= ThisChunkSize; - WriteCursor += ThisChunkSize; - } - } - } - - m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); - m_TotalDataSize = LocalDataSize; - - if (m_TotalDataSize == 0 && ResponseCode == 200) - { - // Some HTTP clients really don't like empty responses unless a 204 response is sent - m_ResponseCode = uint16_t(HttpResponseCode::NoContent); - } - else - { - m_ResponseCode = ResponseCode; - } -} - -void -HttpMessageResponseRequest::SuppressResponseBody() -{ - m_RemainingChunkCount = 0; - m_HttpDataChunks.clear(); - m_DataBuffers.clear(); -} - -HttpSysRequestHandler* -HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - ZEN_UNUSED(NumberOfBytesTransferred); - - if (IoResult != NO_ERROR) - { - ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); - - // if one transmit failed there's really no need to go on - return nullptr; - } - - if (m_RemainingChunkCount == 0) - { - return nullptr; // All done - } - - return this; -} - -void -HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) -{ - HttpSysTransaction& Tx = Transaction(); - HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); - PTP_IO const Iocp = Tx.Iocp(); - - StartThreadpoolIo(Iocp); - - // Split payload into batches to play well with the underlying API - - const int MaxChunksPerCall = 9999; - - const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); - const int ThisRequestChunkOffset = m_NextDataChunkOffset; - - m_RemainingChunkCount -= ThisRequestChunkCount; - m_NextDataChunkOffset += ThisRequestChunkCount; - - /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? - - From the docs: - - This flag enables buffering of data in the kernel on a per-response basis. It should - be used by an application doing synchronous I/O, or by a an application doing - asynchronous I/O with no more than one send outstanding at a time. - - Applications using asynchronous I/O which may have more than one send outstanding at - a time should not use this flag. - - When this flag is set, it should be used consistently in calls to the - HttpSendHttpResponse function as well. - */ - - ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; - - if (m_RemainingChunkCount) - { - // We need to make more calls to send the full amount of data - SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; - } - - ULONG SendResult = 0; - - if (m_IsInitialResponse) - { - // Populate response structure - - HTTP_RESPONSE HttpResponse = {}; - - HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); - HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; - - // Server header - // - // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header - // - // This is controlled via a registry key 'DisableServerHeader', at: - // - // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters - // - // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether - // (only the latter appears to do anything in my testing, on Windows 10). - // - // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) - // - - PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; - ServerHeader->pRawValue = "Zen"; - ServerHeader->RawValueLength = (USHORT)3; - - // Content-length header - - char ContentLengthString[32]; - _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); - - PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; - ContentLengthHeader->pRawValue = ContentLengthString; - ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); - - // Content-type header - - PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; - - std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); - - ContentTypeHeader->pRawValue = ContentTypeString.data(); - ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); - - std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); - - HttpResponse.StatusCode = m_ResponseCode; - HttpResponse.pReason = ReasonString.data(); - HttpResponse.ReasonLength = (USHORT)ReasonString.size(); - - // Cache policy - - HTTP_CACHE_POLICY CachePolicy; - - CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; - CachePolicy.SecondsToLive = 0; - - // Initial response API call - - SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - &HttpResponse, - &CachePolicy, - NULL, - NULL, - 0, - Tx.Overlapped(), - NULL); - - m_IsInitialResponse = false; - } - else - { - // Subsequent response API calls - - SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - SendFlags, - (USHORT)ThisRequestChunkCount, // EntityChunkCount - &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks - NULL, // BytesSent - NULL, // Reserved1 - 0, // Reserved2 - Tx.Overlapped(), // Overlapped - NULL // LogData - ); - } - - if (SendResult == NO_ERROR) - { - // Synchronous completion, but the completion event will still be posted to IOCP - - ErrorCode.clear(); - } - else if (SendResult == ERROR_IO_PENDING) - { - // Asynchronous completion, a completion notification will be posted to IOCP - - ErrorCode.clear(); - } - else - { - // An error occurred, no completion will be posted to IOCP - - CancelThreadpoolIo(Iocp); - - ZEN_WARN("failed to send HTTP response (error: '{}'), request URL: '{}', request id: {}", - GetSystemErrorAsString(SendResult), - HttpReq->pRawUrl, - HttpReq->RequestId); - - ErrorCode = MakeErrorCode(SendResult); - } -} - -/** HTTP completion handler for async work - - This is used to allow work to be taken off the request handler threads - and to support posting responses asynchronously. - */ - -class HttpAsyncWorkRequest : public HttpSysRequestHandler -{ -public: - HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response); - ~HttpAsyncWorkRequest(); - - virtual void IssueRequest(std::error_code& ErrorCode) override final; - virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; - -private: - struct AsyncWorkItem : public IWork - { - virtual void Execute() override; - - AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler) - : Tx(InTx) - , Handler(std::move(InHandler)) - { - } - - HttpSysTransaction& Tx; - std::function<void(HttpServerRequest&)> Handler; - }; - - Ref<AsyncWorkItem> m_WorkItem; -}; - -HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response) -: HttpSysRequestHandler(Tx) -{ - m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); -} - -HttpAsyncWorkRequest::~HttpAsyncWorkRequest() -{ -} - -void -HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) -{ - ErrorCode.clear(); - - Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); -} - -HttpSysRequestHandler* -HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - // This ought to not be called since there should be no outstanding I/O request - // when this completion handler is active - - ZEN_UNUSED(IoResult, NumberOfBytesTransferred); - - ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); - - return this; -} - -void -HttpAsyncWorkRequest::AsyncWorkItem::Execute() -{ - try - { - HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); - - ThisRequest.m_NextCompletionHandler = nullptr; - - Handler(ThisRequest); - - // TODO: should Handler be destroyed at this point to ensure there - // are no outstanding references into state which could be - // deleted asynchronously as a result of issuing the response? - - if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) - { - return (void)Tx.IssueNextRequest(NextHandler); - } - else if (!ThisRequest.IsHandled()) - { - return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); - } - else - { - // "Handled" but no request handler? Shouldn't ever happen - return (void)Tx.IssueNextRequest( - new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); - } - } - catch (std::exception& Ex) - { - return (void)Tx.IssueNextRequest( - new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); - } -} - -/** - _________ - / _____/ ______________ __ ___________ - \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ - / \ ___/| | \/\ /\ ___/| | \/ - /_______ /\___ >__| \_/ \___ >__| - \/ \/ \/ -*/ - -HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount) -: m_Log(logging::Get("http")) -, m_RequestLog(logging::Get("http_requests")) -, m_ThreadPool(ThreadCount) -, m_AsyncWorkPool(AsyncWorkThreadCount) -{ - ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); - - if (Result != NO_ERROR) - { - return; - } - - m_IsHttpInitialized = true; - m_IsOk = true; - - ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount); -} - -HttpSysServer::~HttpSysServer() -{ - if (m_IsHttpInitialized) - { - Cleanup(); - - HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); - } -} - -int -HttpSysServer::InitializeServer(int BasePort) -{ - using namespace std::literals; - - WideStringBuilder<64> WildcardUrlPath; - WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; - - m_IsOk = false; - - ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); - - return BasePort; - } - - Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); - - return BasePort; - } - - int EffectivePort = BasePort; - - Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - - // Sharing violation implies the port is being used by another process - for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) - { - EffectivePort = BasePort + (PortOffset * 100); - WildcardUrlPath.Reset(); - WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; - - Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - } - - m_BaseUris.clear(); - if (Result == NO_ERROR) - { - m_BaseUris.push_back(WildcardUrlPath.c_str()); - } - else if (Result == ERROR_ACCESS_DENIED) - { - // If we can't register the wildcard path, we fall back to local paths - // This local paths allow requests originating locally to function, but will not allow - // remote origin requests to function. This can be remedied by using netsh - // during an install process to grant permissions to route public access to the appropriate - // port for the current user. eg: - // netsh http add urlacl url=http://*:1337/ user=<some_user> - - ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); - - const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; - - ULONG InternalResult = ERROR_SHARING_VIOLATION; - for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) - { - EffectivePort = BasePort + (PortOffset * 100); - - for (const std::u8string_view Host : Hosts) - { - WideStringBuilder<64> LocalUrlPath; - LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; - - InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); - - if (InternalResult == NO_ERROR) - { - ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); - - m_BaseUris.push_back(LocalUrlPath.c_str()); - } - else - { - break; - } - } - } - } - - if (m_BaseUris.empty()) - { - ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); - - return BasePort; - } - - HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; - - Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, - /* Name */ nullptr, - /* SecurityAttributes */ nullptr, - /* Flags */ 0, - &m_RequestQueueHandle); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); - - return EffectivePort; - } - - HttpBindingInfo.Flags.Present = 1; - HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; - - Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); - - if (Result != NO_ERROR) - { - ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); - - return EffectivePort; - } - - // Create I/O completion port - - std::error_code ErrorCode; - m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); - - if (ErrorCode) - { - ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); - } - else - { - m_IsOk = true; - - ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); - } - - return EffectivePort; -} - -void -HttpSysServer::Cleanup() -{ - ++m_IsShuttingDown; - - if (m_RequestQueueHandle) - { - HttpCloseRequestQueue(m_RequestQueueHandle); - m_RequestQueueHandle = nullptr; - } - - if (m_HttpUrlGroupId) - { - HttpCloseUrlGroup(m_HttpUrlGroupId); - m_HttpUrlGroupId = 0; - } - - if (m_HttpSessionId) - { - HttpCloseServerSession(m_HttpSessionId); - m_HttpSessionId = 0; - } -} - -void -HttpSysServer::StartServer() -{ - const int InitialRequestCount = 32; - - for (int i = 0; i < InitialRequestCount; ++i) - { - IssueNewRequestMaybe(); - } -} - -void -HttpSysServer::Run(bool IsInteractive) -{ - if (IsInteractive) - { - zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); - } - - do - { - // int WaitTimeout = -1; - int WaitTimeout = 100; - - if (IsInteractive) - { - WaitTimeout = 1000; - - if (_kbhit() != 0) - { - char c = (char)_getch(); - - if (c == 27 || c == 'Q' || c == 'q') - { - RequestApplicationExit(0); - } - } - } - - m_ShutdownEvent.Wait(WaitTimeout); - UpdateLofreqTimerValue(); - } while (!IsApplicationExitRequested()); -} - -void -HttpSysServer::OnHandlingRequest() -{ - if (--m_PendingRequests > m_MinPendingRequests) - { - // We have more than the minimum number of requests pending, just let someone else - // enqueue new requests - return; - } - - IssueNewRequestMaybe(); -} - -void -HttpSysServer::IssueNewRequestMaybe() -{ - if (m_IsShuttingDown.load(std::memory_order::acquire)) - { - return; - } - - if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) - { - return; - } - - std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); - - std::error_code ErrorCode; - Request->IssueInitialRequest(ErrorCode); - - if (ErrorCode) - { - // No request was actually issued. What is the appropriate response? - - return; - } - - // This may end up exceeding the MaxPendingRequests limit, but it's not - // really a problem. I'm doing it this way mostly to avoid dealing with - // exceptions here - ++m_PendingRequests; - - Request.release(); -} - -void -HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) -{ - if (UrlPath[0] == '/') - { - ++UrlPath; - } - - const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); - - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); - - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); - - return; - } - } -} - -void -HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) -{ - ZEN_UNUSED(Service); - - if (UrlPath[0] == '/') - { - ++UrlPath; - } - - const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); - - // Convert to wide string - - for (const std::wstring& BaseUri : m_BaseUris) - { - std::wstring Url16 = BaseUri + PathUtf16; - - ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); - - if (Result != NO_ERROR) - { - ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); - } - } -} - -////////////////////////////////////////////////////////////////////////// - -HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) -{ -} - -HttpSysTransaction::~HttpSysTransaction() -{ -} - -PTP_IO -HttpSysTransaction::Iocp() -{ - return m_HttpServer.m_ThreadPool.Iocp(); -} - -HANDLE -HttpSysTransaction::RequestQueueHandle() -{ - return m_HttpServer.m_RequestQueueHandle; -} - -void -HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) -{ - m_InitialHttpHandler.IssueRequest(ErrorCode); -} - -void -HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, - PVOID pContext /* HttpSysServer */, - PVOID pOverlapped, - ULONG IoResult, - ULONG_PTR NumberOfBytesTransferred, - PTP_IO Io) -{ - UNREFERENCED_PARAMETER(Io); - UNREFERENCED_PARAMETER(Instance); - UNREFERENCED_PARAMETER(pContext); - - // Note that for a given transaction we may be in this completion function on more - // than one thread at any given moment. This means we need to be careful about what - // happens in here - - HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); - - if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) - { - delete Transaction; - } -} - -bool -HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) -{ - HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; - m_CompletionHandler = NewCompletionHandler; - - auto _ = MakeGuard([this, CurrentHandler] { - if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) - { - delete CurrentHandler; - } - }); - - if (NewCompletionHandler == nullptr) - { - return false; - } - - try - { - std::error_code ErrorCode; - m_CompletionHandler->IssueRequest(ErrorCode); - - if (!ErrorCode) - { - return true; - } - - ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); - } - catch (std::exception& Ex) - { - ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); - } - - // something went wrong, no request is pending - m_CompletionHandler = nullptr; - - return false; -} - -HttpSysTransaction::Status -HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - // We use this to ensure sequential execution of completion handlers - // for any given transaction. It also ensures all member variables are - // in a consistent state for the current thread - - RwLock::ExclusiveLockScope _(m_CompletionMutex); - - bool IsRequestPending = false; - - if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) - { - if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) - { - // Ensure we have a sufficient number of pending requests outstanding - m_HttpServer.OnHandlingRequest(); - } - - auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); - - IsRequestPending = IssueNextRequest(NewCompletionHandler); - } - - // Ensure new requests are enqueued as necessary - m_HttpServer.IssueNewRequestMaybe(); - - if (IsRequestPending) - { - // There is another request pending on this transaction, so it needs to remain valid - return Status::kRequestPending; - } - - if (m_HttpServer.m_IsRequestLoggingEnabled) - { - if (m_HandlerRequest.has_value()) - { - m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); - } - } - - // Transaction done, caller should clean up (delete) this instance - return Status::kDone; -} - -HttpSysServerRequest& -HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) -{ - HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); - - // Default request handling - - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) - { - Service.HandleRequest(ThisRequest); - } - - return ThisRequest; -} - -////////////////////////////////////////////////////////////////////////// - -HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) -: m_HttpTx(Tx) -, m_PayloadBuffer(std::move(PayloadBuffer)) -{ - const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); - - const int PrefixLength = Service.UriPrefixLength(); - const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); - - HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; - - if (AbsPathLength >= PrefixLength) - { - // We convert the URI immediately because most of the code involved prefers to deal - // with utf8. This is overhead which I'd prefer to avoid but for now we just have - // to live with it - - WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, - m_UriUtf8); - - std::string_view UriSuffix8{m_UriUtf8}; - - m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access - m_Uri = UriSuffix8; - - const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); - - if (LastComponentIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastComponentIndex); - } - - const size_t LastDotIndex = UriSuffix8.find_last_of('.'); - - if (LastDotIndex != std::string_view::npos) - { - UriSuffix8.remove_prefix(LastDotIndex + 1); - - AcceptContentType = ParseContentType(UriSuffix8); - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_Uri.remove_suffix(UriSuffix8.size() + 1); - } - } - } - else - { - m_UriUtf8.Reset(); - m_Uri = {}; - m_UriWithExtension = {}; - } - - if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) - { - --QueryStringLength; // We skip the leading question mark - - WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); - } - else - { - m_QueryStringUtf8.Reset(); - } - - m_QueryString = std::string_view(m_QueryStringUtf8); - m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); - m_ContentLength = GetContentLength(HttpRequestPtr); - m_ContentType = GetContentType(HttpRequestPtr); - - // It an explicit content type extension was specified then we'll use that over any - // Accept: header value that may be present - - if (AcceptContentType != HttpContentType::kUnknownContentType) - { - m_AcceptType = AcceptContentType; - } - else - { - m_AcceptType = GetAcceptType(HttpRequestPtr); - } - - if (m_Verb == HttpVerb::kHead) - { - SetSuppressResponseBody(); - } -} - -Oid -HttpSysServerRequest::ParseSessionId() const -{ - const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); - - for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) - { - HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; - std::string_view HeaderName{Header.pName, Header.NameLength}; - - if (HeaderName == "UE-Session"sv) - { - if (Header.RawValueLength == Oid::StringLength) - { - return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); - } - } - } - - return {}; -} - -uint32_t -HttpSysServerRequest::ParseRequestId() const -{ - const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); - - for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) - { - HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; - std::string_view HeaderName{Header.pName, Header.NameLength}; - - if (HeaderName == "UE-Request"sv) - { - std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; - uint32_t RequestId = 0; - std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); - return RequestId; - } - } - - return 0; -} - -IoBuffer -HttpSysServerRequest::ReadPayload() -{ - return m_PayloadBuffer; -} - -void -HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) -{ - ZEN_ASSERT(IsHandled() == false); - - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); - - if (SuppressBody()) - { - Response->SuppressResponseBody(); - } - - m_NextCompletionHandler = Response; - - SetIsHandled(); -} - -void -HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) -{ - ZEN_ASSERT(IsHandled() == false); - - auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); - - if (SuppressBody()) - { - Response->SuppressResponseBody(); - } - - m_NextCompletionHandler = Response; - - SetIsHandled(); -} - -void -HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) -{ - ZEN_ASSERT(IsHandled() == false); - - auto Response = - new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); - - if (SuppressBody()) - { - Response->SuppressResponseBody(); - } - - m_NextCompletionHandler = Response; - - SetIsHandled(); -} - -void -HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) -{ - if (m_HttpTx.Server().IsAsyncResponseEnabled()) - { - m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); - } - else - { - ContinuationHandler(m_HttpTx.ServerRequest()); - } -} - -bool -HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) -{ - HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); - const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; - - return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); -} - -////////////////////////////////////////////////////////////////////////// - -InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) -{ -} - -InitialRequestHandler::~InitialRequestHandler() -{ -} - -void -InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) -{ - HttpSysTransaction& Tx = Transaction(); - PTP_IO Iocp = Tx.Iocp(); - HTTP_REQUEST* HttpReq = Tx.HttpRequest(); - - StartThreadpoolIo(Iocp); - - ULONG HttpApiResult; - - if (IsInitialRequest()) - { - HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), - HTTP_NULL_ID, - HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, - HttpReq, - RequestBufferSize(), - NULL, - Tx.Overlapped()); - } - else - { - // The http.sys team recommends limiting the size to 128KB - static const uint64_t kMaxBytesPerApiCall = 128 * 1024; - - uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; - const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); - void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; - - HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), - HttpReq->RequestId, - 0, /* Flags */ - BufferWriteCursor, - gsl::narrow<ULONG>(BytesToReadThisCall), - nullptr, // BytesReturned - Tx.Overlapped()); - } - - if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) - { - CancelThreadpoolIo(Iocp); - - ErrorCode = MakeErrorCode(HttpApiResult); - - ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); - - return; - } - - ErrorCode.clear(); -} - -HttpSysRequestHandler* -InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) -{ - auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); - - switch (IoResult) - { - default: - case ERROR_OPERATION_ABORTED: - return nullptr; - - case ERROR_MORE_DATA: // Insufficient buffer space - case NO_ERROR: - break; - } - - // Route request - - try - { - HTTP_REQUEST* HttpReq = HttpRequest(); - -# if 0 - for (int i = 0; i < HttpReq->RequestInfoCount; ++i) - { - auto& ReqInfo = HttpReq->pRequestInfo[i]; - - switch (ReqInfo.InfoType) - { - case HttpRequestInfoTypeRequestTiming: - { - const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); - - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeAuth: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeChannelBind: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslProtocol: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBindingDraft: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeSslTokenBinding: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV0: - { - const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); - - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeRequestSizing: - { - const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); - ZEN_INFO(""); - } - break; - case HttpRequestInfoTypeQuicStats: - ZEN_INFO(""); - break; - case HttpRequestInfoTypeTcpInfoV1: - { - const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); - - ZEN_INFO(""); - } - break; - } - } -# endif - - if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) - { - if (m_IsInitialRequest) - { - m_ContentLength = GetContentLength(HttpReq); - const HttpContentType ContentType = GetContentType(HttpReq); - - if (m_ContentLength) - { - // Handle initial chunk read by copying any payload which has already been copied - // into our embedded request buffer - - m_PayloadBuffer = IoBuffer(m_ContentLength); - m_PayloadBuffer.SetContentType(ContentType); - - uint64_t BytesToRead = m_ContentLength; - uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); - uint8_t* BufferWriteCursor = BufferBase; - - const int EntityChunkCount = HttpReq->EntityChunkCount; - - for (int i = 0; i < EntityChunkCount; ++i) - { - HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; - - ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); - - const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; - - ZEN_ASSERT(BufferLength <= BytesToRead); - - memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); - - BufferWriteCursor += BufferLength; - BytesToRead -= BufferLength; - } - - m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; - } - } - else - { - m_CurrentPayloadOffset += NumberOfBytesTransferred; - } - - if (m_CurrentPayloadOffset != m_ContentLength) - { - // Body not complete, issue another read request to receive more body data - return this; - } - - // Request body received completely - - m_PayloadBuffer.MakeImmutable(); - - HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); - - if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) - { - return Response; - } - - if (!ThisRequest.IsHandled()) - { - return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); - } - } - - // Unable to route - return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); - } - catch (std::exception& ex) - { - ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); - - return new HttpMessageResponseRequest(Transaction(), 500, ex.what()); - } -} - -////////////////////////////////////////////////////////////////////////// -// -// HttpServer interface implementation -// - -int -HttpSysServer::Initialize(int BasePort) -{ - int EffectivePort = InitializeServer(BasePort); - StartServer(); - return EffectivePort; -} - -void -HttpSysServer::RequestExit() -{ - m_ShutdownEvent.Set(); -} -void -HttpSysServer::RegisterService(HttpService& Service) -{ - RegisterService(Service.BaseUri(), Service); -} - -Ref<HttpServer> -CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads) -{ - return Ref<HttpServer>(new HttpSysServer(Concurrency, BackgroundWorkerThreads)); -} - -} // namespace zen -#endif diff --git a/zenhttp/httpsys.h b/zenhttp/httpsys.h deleted file mode 100644 index d6bd34890..000000000 --- a/zenhttp/httpsys.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zenhttp/httpserver.h> - -#ifndef ZEN_WITH_HTTPSYS -# if ZEN_PLATFORM_WINDOWS -# define ZEN_WITH_HTTPSYS 1 -# else -# define ZEN_WITH_HTTPSYS 0 -# endif -#endif - -#if ZEN_WITH_HTTPSYS -# define _WINSOCKAPI_ -# include <zencore/windows.h> -# include <zencore/workthreadpool.h> -# include "iothreadpool.h" - -# include <http.h> - -namespace spdlog { -class logger; -} - -namespace zen { - -/** - * @brief Windows implementation of HTTP server based on http.sys - * - * This requires elevation to function - */ -class HttpSysServer : public HttpServer -{ - friend class HttpSysTransaction; - -public: - explicit HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount); - ~HttpSysServer(); - - // HttpServer interface implementation - - virtual int Initialize(int BasePort) override; - virtual void Run(bool TestMode) override; - virtual void RequestExit() override; - virtual void RegisterService(HttpService& Service) override; - - WorkerThreadPool& WorkPool() { return m_AsyncWorkPool; } - - inline bool IsOk() const { return m_IsOk; } - inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } - -private: - int InitializeServer(int BasePort); - void Cleanup(); - - void StartServer(); - void OnHandlingRequest(); - void IssueNewRequestMaybe(); - - void RegisterService(const char* Endpoint, HttpService& Service); - void UnregisterService(const char* Endpoint, HttpService& Service); - -private: - spdlog::logger& m_Log; - spdlog::logger& m_RequestLog; - spdlog::logger& Log() { return m_Log; } - - bool m_IsOk = false; - bool m_IsHttpInitialized = false; - bool m_IsRequestLoggingEnabled = false; - bool m_IsAsyncResponseEnabled = true; - - WinIoThreadPool m_ThreadPool; - WorkerThreadPool m_AsyncWorkPool; - - std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ - HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; - HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; - HANDLE m_RequestQueueHandle = 0; - std::atomic_int32_t m_PendingRequests{0}; - std::atomic_int32_t m_IsShuttingDown{0}; - int32_t m_MinPendingRequests = 16; - int32_t m_MaxPendingRequests = 128; - Event m_ShutdownEvent; -}; - -} // namespace zen -#endif diff --git a/zenhttp/include/zenhttp/httpclient.h b/zenhttp/include/zenhttp/httpclient.h deleted file mode 100644 index 8316a9b9f..000000000 --- a/zenhttp/include/zenhttp/httpclient.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "zenhttp.h" - -#include <zencore/iobuffer.h> -#include <zencore/uid.h> -#include <zenhttp/httpcommon.h> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <cpr/cpr.h> -ZEN_THIRD_PARTY_INCLUDES_END - -namespace zen { - -class CbPackage; - -/** HTTP client implementation for Zen use cases - - Currently simple and synchronous, should become lean and asynchronous - */ -class HttpClient -{ -public: - HttpClient(std::string_view BaseUri); - ~HttpClient(); - - struct Response - { - int StatusCode = 0; - IoBuffer ResponsePayload; // Note: this also includes the content type - }; - - [[nodiscard]] Response Put(std::string_view Url, IoBuffer Payload); - [[nodiscard]] Response Get(std::string_view Url); - [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package); - [[nodiscard]] Response Delete(std::string_view Url); - -private: - std::string m_BaseUri; - std::string m_SessionId; -}; - -} // namespace zen - -void httpclient_forcelink(); // internal diff --git a/zenhttp/include/zenhttp/httpcommon.h b/zenhttp/include/zenhttp/httpcommon.h deleted file mode 100644 index 19fda8db4..000000000 --- a/zenhttp/include/zenhttp/httpcommon.h +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/iobuffer.h> - -#include <string_view> - -#include <gsl/gsl-lite.hpp> - -namespace zen { - -using HttpContentType = ZenContentType; - -class IoBuffer; -class CbObject; -class CbPackage; -class StringBuilderBase; - -struct HttpRange -{ - uint32_t Start = ~uint32_t(0); - uint32_t End = ~uint32_t(0); -}; - -using HttpRanges = std::vector<HttpRange>; - -std::string_view MapContentTypeToString(HttpContentType ContentType); -extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString); -std::string_view ReasonStringForHttpResultCode(int HttpCode); -bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges); - -[[nodiscard]] inline bool -IsHttpSuccessCode(int HttpCode) -{ - return (HttpCode >= 200) && (HttpCode < 300); -} - -enum class HttpVerb : uint8_t -{ - kGet = 1 << 0, - kPut = 1 << 1, - kPost = 1 << 2, - kDelete = 1 << 3, - kHead = 1 << 4, - kCopy = 1 << 5, - kOptions = 1 << 6 -}; - -gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb); - -const std::string_view ToString(HttpVerb Verb); - -enum class HttpResponseCode -{ - // 1xx - Informational - - Continue = 100, //!< Indicates that the initial part of a request has been received and has not yet been rejected by the server. - SwitchingProtocols = 101, //!< Indicates that the server understands and is willing to comply with the client's request, via the - //!< Upgrade header field, for a change in the application protocol being used on this connection. - Processing = 102, //!< Is an interim response used to inform the client that the server has accepted the complete request, but has not - //!< yet completed it. - EarlyHints = 103, //!< Indicates to the client that the server is likely to send a final response with the header fields included in - //!< the informational response. - - // 2xx - Successful - - OK = 200, //!< Indicates that the request has succeeded. - Created = 201, //!< Indicates that the request has been fulfilled and has resulted in one or more new resources being created. - Accepted = 202, //!< Indicates that the request has been accepted for processing, but the processing has not been completed. - NonAuthoritativeInformation = 203, //!< Indicates that the request was successful but the enclosed payload has been modified from that - //!< of the origin server's 200 (OK) response by a transforming proxy. - NoContent = 204, //!< Indicates that the server has successfully fulfilled the request and that there is no additional content to send - //!< in the response payload body. - ResetContent = 205, //!< Indicates that the server has fulfilled the request and desires that the user agent reset the \"document - //!< view\", which caused the request to be sent, to its original state as received from the origin server. - PartialContent = 206, //!< Indicates that the server is successfully fulfilling a range request for the target resource by transferring - //!< one or more parts of the selected representation that correspond to the satisfiable ranges found in the - //!< requests's Range header field. - MultiStatus = 207, //!< Provides status for multiple independent operations. - AlreadyReported = 208, //!< Used inside a DAV:propstat response element to avoid enumerating the internal members of multiple bindings - //!< to the same collection repeatedly. [RFC 5842] - IMUsed = 226, //!< The server has fulfilled a GET request for the resource, and the response is a representation of the result of one - //!< or more instance-manipulations applied to the current instance. - - // 3xx - Redirection - - MultipleChoices = 300, //!< Indicates that the target resource has more than one representation, each with its own more specific - //!< identifier, and information about the alternatives is being provided so that the user (or user agent) can - //!< select a preferred representation by redirecting its request to one or more of those identifiers. - MovedPermanently = 301, //!< Indicates that the target resource has been assigned a new permanent URI and any future references to this - //!< resource ought to use one of the enclosed URIs. - Found = 302, //!< Indicates that the target resource resides temporarily under a different URI. - SeeOther = 303, //!< Indicates that the server is redirecting the user agent to a different resource, as indicated by a URI in the - //!< Location header field, that is intended to provide an indirect response to the original request. - NotModified = 304, //!< Indicates that a conditional GET request has been received and would have resulted in a 200 (OK) response if it - //!< were not for the fact that the condition has evaluated to false. - UseProxy = 305, //!< \deprecated \parblock Due to security concerns regarding in-band configuration of a proxy. \endparblock - //!< The requested resource MUST be accessed through the proxy given by the Location field. - TemporaryRedirect = 307, //!< Indicates that the target resource resides temporarily under a different URI and the user agent MUST NOT - //!< change the request method if it performs an automatic redirection to that URI. - PermanentRedirect = 308, //!< The target resource has been assigned a new permanent URI and any future references to this resource - //!< ought to use one of the enclosed URIs. [...] This status code is similar to 301 Moved Permanently - //!< (Section 7.3.2 of rfc7231), except that it does not allow rewriting the request method from POST to GET. - - // 4xx - Client Error - BadRequest = 400, //!< Indicates that the server cannot or will not process the request because the received syntax is invalid, - //!< nonsensical, or exceeds some limitation on what the server is willing to process. - Unauthorized = 401, //!< Indicates that the request has not been applied because it lacks valid authentication credentials for the - //!< target resource. - PaymentRequired = 402, //!< *Reserved* - Forbidden = 403, //!< Indicates that the server understood the request but refuses to authorize it. - NotFound = 404, //!< Indicates that the origin server did not find a current representation for the target resource or is not willing - //!< to disclose that one exists. - MethodNotAllowed = 405, //!< Indicates that the method specified in the request-line is known by the origin server but not supported by - //!< the target resource. - NotAcceptable = 406, //!< Indicates that the target resource does not have a current representation that would be acceptable to the - //!< user agent, according to the proactive negotiation header fields received in the request, and the server is - //!< unwilling to supply a default representation. - ProxyAuthenticationRequired = - 407, //!< Is similar to 401 (Unauthorized), but indicates that the client needs to authenticate itself in order to use a proxy. - RequestTimeout = - 408, //!< Indicates that the server did not receive a complete request message within the time that it was prepared to wait. - Conflict = 409, //!< Indicates that the request could not be completed due to a conflict with the current state of the resource. - Gone = 410, //!< Indicates that access to the target resource is no longer available at the origin server and that this condition is - //!< likely to be permanent. - LengthRequired = 411, //!< Indicates that the server refuses to accept the request without a defined Content-Length. - PreconditionFailed = - 412, //!< Indicates that one or more preconditions given in the request header fields evaluated to false when tested on the server. - PayloadTooLarge = 413, //!< Indicates that the server is refusing to process a request because the request payload is larger than the - //!< server is willing or able to process. - URITooLong = 414, //!< Indicates that the server is refusing to service the request because the request-target is longer than the - //!< server is willing to interpret. - UnsupportedMediaType = 415, //!< Indicates that the origin server is refusing to service the request because the payload is in a format - //!< not supported by the target resource for this method. - RangeNotSatisfiable = 416, //!< Indicates that none of the ranges in the request's Range header field overlap the current extent of the - //!< selected resource or that the set of ranges requested has been rejected due to invalid ranges or an - //!< excessive request of small or overlapping ranges. - ExpectationFailed = 417, //!< Indicates that the expectation given in the request's Expect header field could not be met by at least - //!< one of the inbound servers. - ImATeapot = 418, //!< Any attempt to brew coffee with a teapot should result in the error code 418 I'm a teapot. - UnprocessableEntity = 422, //!< Means the server understands the content type of the request entity (hence a 415(Unsupported Media - //!< Type) status code is inappropriate), and the syntax of the request entity is correct (thus a 400 (Bad - //!< Request) status code is inappropriate) but was unable to process the contained instructions. - Locked = 423, //!< Means the source or destination resource of a method is locked. - FailedDependency = 424, //!< Means that the method could not be performed on the resource because the requested action depended on - //!< another action and that action failed. - UpgradeRequired = 426, //!< Indicates that the server refuses to perform the request using the current protocol but might be willing to - //!< do so after the client upgrades to a different protocol. - PreconditionRequired = 428, //!< Indicates that the origin server requires the request to be conditional. - TooManyRequests = 429, //!< Indicates that the user has sent too many requests in a given amount of time (\"rate limiting\"). - RequestHeaderFieldsTooLarge = - 431, //!< Indicates that the server is unwilling to process the request because its header fields are too large. - UnavailableForLegalReasons = - 451, //!< This status code indicates that the server is denying access to the resource in response to a legal demand. - - // 5xx - Server Error - - InternalServerError = - 500, //!< Indicates that the server encountered an unexpected condition that prevented it from fulfilling the request. - NotImplemented = 501, //!< Indicates that the server does not support the functionality required to fulfill the request. - BadGateway = 502, //!< Indicates that the server, while acting as a gateway or proxy, received an invalid response from an inbound - //!< server it accessed while attempting to fulfill the request. - ServiceUnavailable = 503, //!< Indicates that the server is currently unable to handle the request due to a temporary overload or - //!< scheduled maintenance, which will likely be alleviated after some delay. - GatewayTimeout = 504, //!< Indicates that the server, while acting as a gateway or proxy, did not receive a timely response from an - //!< upstream server it needed to access in order to complete the request. - HTTPVersionNotSupported = 505, //!< Indicates that the server does not support, or refuses to support, the protocol version that was - //!< used in the request message. - VariantAlsoNegotiates = - 506, //!< Indicates that the server has an internal configuration error: the chosen variant resource is configured to engage in - //!< transparent content negotiation itself, and is therefore not a proper end point in the negotiation process. - InsufficientStorage = 507, //!< Means the method could not be performed on the resource because the server is unable to store the - //!< representation needed to successfully complete the request. - LoopDetected = 508, //!< Indicates that the server terminated an operation because it encountered an infinite loop while processing a - //!< request with "Depth: infinity". [RFC 5842] - NotExtended = 510, //!< The policy for accessing the resource has not been met in the request. [RFC 2774] - NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access. -}; - -} // namespace zen diff --git a/zenhttp/include/zenhttp/httpserver.h b/zenhttp/include/zenhttp/httpserver.h deleted file mode 100644 index 3b9fa50b4..000000000 --- a/zenhttp/include/zenhttp/httpserver.h +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include "zenhttp.h" - -#include <zencore/compactbinary.h> -#include <zencore/enumflags.h> -#include <zencore/iobuffer.h> -#include <zencore/iohash.h> -#include <zencore/refcount.h> -#include <zencore/string.h> -#include <zencore/uid.h> -#include <zenhttp/httpcommon.h> - -#include <functional> -#include <gsl/gsl-lite.hpp> -#include <list> -#include <map> -#include <regex> -#include <span> -#include <unordered_map> - -namespace zen { - -/** HTTP Server Request - */ -class HttpServerRequest -{ -public: - HttpServerRequest(); - ~HttpServerRequest(); - - // Synchronous operations - - [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix - [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } - [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } - - struct QueryParams - { - std::vector<std::pair<std::string_view, std::string_view>> KvPairs; - - std::string_view GetValue(std::string_view ParamName) const - { - for (const auto& Kv : KvPairs) - { - const std::string_view& Key = Kv.first; - - if (Key.size() == ParamName.size()) - { - if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size())) - { - return Kv.second; - } - } - } - - return std::string_view(); - } - }; - - virtual bool TryGetRanges(HttpRanges&) { return false; } - - QueryParams GetQueryParams(); - - inline HttpVerb RequestVerb() const { return m_Verb; } - inline HttpContentType RequestContentType() { return m_ContentType; } - inline HttpContentType AcceptContentType() { return m_AcceptType; } - - inline uint64_t ContentLength() const { return m_ContentLength; } - Oid SessionId() const; - uint32_t RequestId() const; - - inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } - inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } - inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } - - /** Read POST/PUT payload for request body, which is always available without delay - */ - virtual IoBuffer ReadPayload() = 0; - - ZENCORE_API CbObject ReadPayloadObject(); - ZENCORE_API CbPackage ReadPayloadPackage(); - - /** Respond with payload - - No data will have been sent when any of these functions return. Instead, the response will be transmitted - asynchronously, after returning from a request handler function. - - Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be - moved into our response handler array where they are kept alive, in order to reduce ref-counting storms - */ - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0; - virtual void WriteResponse(HttpResponseCode ResponseCode) = 0; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; - virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); - - void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); - void WriteResponse(HttpResponseCode ResponseCode, CbArray Array); - void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); - void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); - void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); - - virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0; - -protected: - enum - { - kIsHandled = 1 << 0, - kSuppressBody = 1 << 1, - kHaveRequestId = 1 << 2, - kHaveSessionId = 1 << 3, - }; - - mutable uint32_t m_Flags = 0; - HttpVerb m_Verb = HttpVerb::kGet; - HttpContentType m_ContentType = HttpContentType::kBinary; - HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; - uint64_t m_ContentLength = ~0ull; - std::string_view m_Uri; - std::string_view m_UriWithExtension; - std::string_view m_QueryString; - mutable uint32_t m_RequestId = ~uint32_t(0); - mutable Oid m_SessionId = Oid::Zero; - - inline void SetIsHandled() { m_Flags |= kIsHandled; } - - virtual Oid ParseSessionId() const = 0; - virtual uint32_t ParseRequestId() const = 0; -}; - -class IHttpPackageHandler : public RefCounted -{ -public: - virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0; - virtual void OnRequestBegin() = 0; - virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0; - virtual void OnRequestComplete() = 0; -}; - -/** - * Base class for implementing an HTTP "service" - * - * A service exposes one or more endpoints with a certain URI prefix - * - */ - -class HttpService -{ -public: - HttpService() = default; - virtual ~HttpService() = default; - - virtual const char* BaseUri() const = 0; - virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; - virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); - - // Internals - - inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; } - inline int UriPrefixLength() const { return m_UriPrefixLength; } - -private: - int m_UriPrefixLength = 0; -}; - -/** HTTP server - * - * Implements the main event loop to service HTTP requests, and handles routing - * requests to the appropriate handler as registered via RegisterService - */ -class HttpServer : public RefCounted -{ -public: - virtual void RegisterService(HttpService& Service) = 0; - virtual int Initialize(int BasePort) = 0; - virtual void Run(bool IsInteractiveSession) = 0; - virtual void RequestExit() = 0; -}; - -Ref<HttpServer> CreateHttpServer(std::string_view ServerClass); - -////////////////////////////////////////////////////////////////////////// - -class HttpRouterRequest -{ -public: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} - - ZENCORE_API std::string GetCapture(uint32_t Index) const; - inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } - -private: - using MatchResults_t = std::match_results<std::string_view::const_iterator>; - - HttpServerRequest& m_HttpRequest; - MatchResults_t m_Match; - - friend class HttpRequestRouter; -}; - -inline std::string -HttpRouterRequest::GetCapture(uint32_t Index) const -{ - ZEN_ASSERT(Index < m_Match.size()); - - return m_Match[Index]; -} - -/** HTTP request router helper - * - * This helper class allows a service implementer to register one or more - * endpoints using pattern matching (currently using regex matching) - * - * This is intended to be initialized once only, there is no thread - * safety so you can absolutely not add or remove endpoints once the handler - * goes live - */ - -class HttpRequestRouter -{ -public: - typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; - - /** - * @brief Add pattern which can be referenced by name, commonly used for URL components - * @param Id String used to identify patterns for replacement - * @param Regex String which will replace the Id string in any registered URL paths - */ - void AddPattern(const char* Id, const char* Regex); - - /** - * @brief Register a an endpoint handler for the given route - * @param Regex Regular expression used to match the handler to a request. This may - * contain pattern aliases registered via AddPattern - * @param HandlerFunc Handler function to call for any matching request - * @param SupportedVerbs Supported HTTP verbs for this handler - */ - void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); - - void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); - - /** - * @brief HTTP request handling function - this should be called to route the - * request to a registered handler - * @param Request Request to route to a handler - * @return Function returns true if the request was routed successfully - */ - bool HandleRequest(zen::HttpServerRequest& Request); - -private: - struct HandlerEntry - { - HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) - : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) - , Verbs(SupportedVerbs) - , Handler(std::move(Handler)) - , Pattern(Pattern) - { - } - - ~HandlerEntry() = default; - - std::regex RegEx; - HttpVerb Verbs; - HandlerFunc_t Handler; - const char* Pattern; - - private: - HandlerEntry& operator=(const HandlerEntry&) = delete; - HandlerEntry(const HandlerEntry&) = delete; - }; - - std::list<HandlerEntry> m_Handlers; - std::unordered_map<std::string, std::string> m_PatternMap; -}; - -/** HTTP RPC request helper - */ - -class RpcResult -{ - RpcResult(CbObject Result) : m_Result(std::move(Result)) {} - -private: - CbObject m_Result; -}; - -class HttpRpcHandler -{ -public: - HttpRpcHandler(); - ~HttpRpcHandler(); - - HttpRpcHandler(const HttpRpcHandler&) = delete; - HttpRpcHandler operator=(const HttpRpcHandler&) = delete; - - void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction); - -private: - struct RpcFunction - { - std::function<void(CbObject& RpcArgs)> Function; - std::string Identifier; - }; - - std::map<std::string, RpcFunction> m_Functions; -}; - -bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); - -void http_forcelink(); // internal - -} // namespace zen diff --git a/zenhttp/include/zenhttp/httpshared.h b/zenhttp/include/zenhttp/httpshared.h deleted file mode 100644 index d335572c5..000000000 --- a/zenhttp/include/zenhttp/httpshared.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/compactbinarypackage.h> -#include <zencore/iobuffer.h> -#include <zencore/iohash.h> - -#include <functional> -#include <gsl/gsl-lite.hpp> - -namespace zen { - -class IoBuffer; -class CbPackage; -class CompositeBuffer; - -/** _____ _ _____ _ - / ____| | | __ \ | | - | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___ - | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \ - | |____| |_) | | | (_| | (__| < (_| | (_| | __/ - \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___| - __/ | - |___/ - - Structures and code related to handling CbPackage transactions - - CbPackage instances are marshaled across the wire using a distinct message - format. We don't use the CbPackage serialization format provided by the - CbPackage implementation itself since that does not provide much flexibility - in how the attachment payloads are transmitted. The scheme below separates - metadata cleanly from payloads and this enables us to more efficiently - transmit them either via sendfile/TransmitFile like mechanisms, or by - reference/memory mapping in the local case. - */ - -struct CbPackageHeader -{ - uint32_t HeaderMagic; - uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document? - uint32_t Reserved1; - uint32_t Reserved2; -}; - -static_assert(sizeof(CbPackageHeader) == 16); - -enum : uint32_t -{ - kCbPkgMagic = 0xaa77aacc -}; - -struct CbAttachmentEntry -{ - uint64_t PayloadSize; // Size of the associated payload data in the message - uint32_t Flags; // See flags below - IoHash AttachmentHash; // Content Id for the attachment - - enum - { - kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format - kIsObject = (1u << 1), // Is compact binary object - kIsError = (1u << 2), // Is error (compact binary formatted) object - kIsLocalRef = (1u << 3), // Is "local reference" - }; -}; - -struct CbAttachmentReferenceHeader -{ - uint64_t PayloadByteOffset = 0; - uint64_t PayloadByteSize = ~0u; - uint16_t AbsolutePathLength = 0; - - // This header will be followed by UTF8 encoded absolute path to backing file -}; - -static_assert(sizeof(CbAttachmentEntry) == 32); - -enum class FormatFlags -{ - kDefault = 0, - kAllowLocalReferences = (1u << 0), - kDenyPartialLocalReferences = (1u << 1) -}; - -gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags); - -enum class RpcAcceptOptions : uint16_t -{ - kNone = 0, - kAllowLocalReferences = (1u << 0), - kAllowPartialLocalReferences = (1u << 1) -}; - -gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); - -std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0); -CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0); -CbPackage ParsePackageMessage( - IoBuffer Payload, - std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { - return IoBuffer{Size}; - }); -bool IsPackageMessage(IoBuffer Payload); - -bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); - -std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, int TargetProcessPid = 0); -CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid = 0); - -/** Streaming reader for compact binary packages - - The goal is to ultimately support zero-copy I/O, but for now there'll be some - copying involved on some platforms at least. - - This approach to deserializing CbPackage data is more efficient than - `ParsePackageMessage` since it does not require the entire message to - be resident in a memory buffer - - */ -class CbPackageReader -{ -public: - CbPackageReader(); - ~CbPackageReader(); - - void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); - - /** Process compact binary package data stream - - The data stream must be in the serialization format produced by FormatPackageMessage - - \return How many bytes must be fed to this function in the next call - */ - uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes); - - void Finalize(); - const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; } - CbObject GetRootObject() { return m_RootObject; } - std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; } - -private: - enum class State - { - kInitialState, - kReadingHeader, - kReadingAttachmentEntries, - kReadingBuffers - } m_CurrentState = State::kInitialState; - - std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; - std::vector<IoBuffer> m_PayloadBuffers; - std::vector<CbAttachmentEntry> m_AttachmentEntries; - std::vector<CbAttachment> m_Attachments; - CbObject m_RootObject; - CbPackageHeader m_PackageHeader; - - IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer); -}; - -void forcelink_httpshared(); - -} // namespace zen diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h deleted file mode 100644 index adca7e988..000000000 --- a/zenhttp/include/zenhttp/websocket.h +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zencore/compactbinarypackage.h> -#include <zencore/memory.h> - -#include <compare> -#include <functional> -#include <future> -#include <memory> -#include <optional> - -#pragma once - -namespace asio { -class io_context; -} - -namespace zen { - -class BinaryWriter; - -/** - * A unique socket ID. - */ -class WebSocketId -{ - static std::atomic_uint32_t NextId; - -public: - WebSocketId() = default; - - uint32_t Value() const { return m_Value; } - - auto operator<=>(const WebSocketId&) const = default; - - static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); } - -private: - WebSocketId(uint32_t Value) : m_Value(Value) {} - - uint32_t m_Value{}; -}; - -/** - * Type of web socket message. - */ -enum class WebSocketMessageType : uint8_t -{ - kInvalid, - kNotification, - kRequest, - kStreamRequest, - kResponse, - kStreamResponse, - kStreamCompleteResponse, - kCount -}; - -inline std::string_view -ToString(WebSocketMessageType Type) -{ - switch (Type) - { - case WebSocketMessageType::kInvalid: - return std::string_view("Invalid"); - case WebSocketMessageType::kNotification: - return std::string_view("Notification"); - case WebSocketMessageType::kRequest: - return std::string_view("Request"); - case WebSocketMessageType::kStreamRequest: - return std::string_view("StreamRequest"); - case WebSocketMessageType::kResponse: - return std::string_view("Response"); - case WebSocketMessageType::kStreamResponse: - return std::string_view("StreamResponse"); - case WebSocketMessageType::kStreamCompleteResponse: - return std::string_view("StreamCompleteResponse"); - default: - return std::string_view("Unknown"); - }; -} - -/** - * Web socket message. - */ -class WebSocketMessage -{ - struct Header - { - static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh - - uint64_t MessageSize{}; - uint32_t Magic{ExpectedMagic}; - uint32_t CorrelationId{}; - uint32_t StatusCode{200u}; - WebSocketMessageType MessageType{}; - uint8_t Reserved[3] = {0}; - - bool IsValid() const; - }; - - static_assert(sizeof(Header) == 24); - - static std::atomic_uint32_t NextCorrelationId; - -public: - static constexpr size_t HeaderSize = sizeof(Header); - - WebSocketMessage() = default; - - WebSocketId SocketId() const { return m_SocketId; } - void SetSocketId(WebSocketId Id) { m_SocketId = Id; } - uint64_t MessageSize() const { return m_Header.MessageSize; } - void SetMessageType(WebSocketMessageType MessageType); - void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } - uint32_t CorrelationId() const { return m_Header.CorrelationId; } - uint32_t StatusCode() const { return m_Header.StatusCode; } - void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; } - WebSocketMessageType MessageType() const { return m_Header.MessageType; } - - const CbPackage& Body() const { return m_Body.value(); } - void SetBody(CbPackage&& Body); - void SetBody(CbObject&& Body); - bool HasBody() const { return m_Body.has_value(); } - - void Save(BinaryWriter& Writer); - bool TryLoadHeader(MemoryView Memory); - - bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; } - -private: - Header m_Header{}; - WebSocketId m_SocketId{}; - std::optional<CbPackage> m_Body; -}; - -class WebSocketServer; - -/** - * Base class for handling web socket requests and notifications from connected client(s). - */ -class WebSocketService -{ -public: - virtual ~WebSocketService() = default; - - void Configure(WebSocketServer& Server); - - virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); } - virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); } - -protected: - WebSocketService() = default; - - virtual void RegisterHandlers(WebSocketServer& Server) = 0; - void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete); - void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete); - - WebSocketServer& SocketServer() - { - ZEN_ASSERT(m_SocketServer); - return *m_SocketServer; - } - -private: - WebSocketServer* m_SocketServer{}; -}; - -/** - * Server options. - */ -struct WebSocketServerOptions -{ - uint16_t Port = 2337; - uint32_t ThreadCount = 1; -}; - -/** - * The web socket server manages client connections and routing of requests and notifications. - */ -class WebSocketServer -{ -public: - virtual ~WebSocketServer() = default; - - virtual bool Run() = 0; - virtual void Shutdown() = 0; - - virtual void RegisterService(WebSocketService& Service) = 0; - virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0; - virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0; - - virtual void SendNotification(WebSocketMessage&& Notification) = 0; - virtual void SendResponse(WebSocketMessage&& Response) = 0; - - static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options); -}; - -/** - * The state of the web socket. - */ -enum class WebSocketState : uint32_t -{ - kNone, - kHandshaking, - kConnected, - kDisconnected, - kError -}; - -/** - * Type of web socket client event. - */ -enum class WebSocketEvent : uint32_t -{ - kConnected, - kDisconnected, - kError -}; - -/** - * Web socket client connection info. - */ -struct WebSocketConnectInfo -{ - std::string Host; - int16_t Port{8848}; - std::string Endpoint; - std::vector<std::string> Protocols; - uint16_t Version{13}; -}; - -/** - * A connection to a web socket server for sending requests and listening for notifications. - */ -class WebSocketClient -{ -public: - using EventCallback = std::function<void()>; - using NotificationCallback = std::function<void(WebSocketMessage&&)>; - - virtual ~WebSocketClient() = default; - - virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0; - virtual void Disconnect() = 0; - virtual bool IsConnected() const = 0; - virtual WebSocketState State() const = 0; - - virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0; - virtual void OnNotification(NotificationCallback&& Cb) = 0; - virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0; - - static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx); -}; - -} // namespace zen diff --git a/zenhttp/include/zenhttp/zenhttp.h b/zenhttp/include/zenhttp/zenhttp.h deleted file mode 100644 index 59c64b31f..000000000 --- a/zenhttp/include/zenhttp/zenhttp.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#ifndef ZEN_WITH_HTTPSYS -# if ZEN_PLATFORM_WINDOWS -# define ZEN_WITH_HTTPSYS 1 -# else -# define ZEN_WITH_HTTPSYS 0 -# endif -#endif - -#define ZENHTTP_API // Placeholder to allow DLL configs in the future - -namespace zen { - -ZENHTTP_API void zenhttp_forcelinktests(); - -} diff --git a/zenhttp/iothreadpool.cpp b/zenhttp/iothreadpool.cpp deleted file mode 100644 index 6087e69ec..000000000 --- a/zenhttp/iothreadpool.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "iothreadpool.h" - -#include <zencore/except.h> - -#if ZEN_PLATFORM_WINDOWS - -namespace zen { - -WinIoThreadPool::WinIoThreadPool(int InThreadCount) -{ - // Thread pool setup - - m_ThreadPool = CreateThreadpool(NULL); - - SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); - SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2); - - InitializeThreadpoolEnvironment(&m_CallbackEnvironment); - - m_CleanupGroup = CreateThreadpoolCleanupGroup(); - - SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); - - SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); -} - -WinIoThreadPool::~WinIoThreadPool() -{ - CloseThreadpool(m_ThreadPool); -} - -void -WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) -{ - ZEN_ASSERT(!m_ThreadPoolIo); - - m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); - - if (!m_ThreadPoolIo) - { - ErrorCode = MakeErrorCodeFromLastError(); - } -} - -} // namespace zen - -#endif diff --git a/zenhttp/iothreadpool.h b/zenhttp/iothreadpool.h deleted file mode 100644 index 8333964c3..000000000 --- a/zenhttp/iothreadpool.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include <zencore/zencore.h> - -#if ZEN_PLATFORM_WINDOWS -# include <zencore/windows.h> - -# include <system_error> - -namespace zen { - -////////////////////////////////////////////////////////////////////////// -// -// Thread pool. Implemented in terms of Windows thread pool right now, will -// need a cross-platform implementation eventually -// - -class WinIoThreadPool -{ -public: - WinIoThreadPool(int InThreadCount); - ~WinIoThreadPool(); - - void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); - inline PTP_IO Iocp() const { return m_ThreadPoolIo; } - -private: - PTP_POOL m_ThreadPool = nullptr; - PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; - PTP_IO m_ThreadPoolIo = nullptr; - TP_CALLBACK_ENVIRON m_CallbackEnvironment; -}; - -} // namespace zen -#endif diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp deleted file mode 100644 index bbe7e1ad8..000000000 --- a/zenhttp/websocketasio.cpp +++ /dev/null @@ -1,1613 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/websocket.h> - -#include <zencore/base64.h> -#include <zencore/compactbinarybuilder.h> -#include <zencore/compactbinaryvalidation.h> -#include <zencore/intmath.h> -#include <zencore/iobuffer.h> -#include <zencore/logging.h> -#include <zencore/memory.h> -#include <zencore/sha1.h> -#include <zencore/stream.h> -#include <zencore/string.h> -#include <zencore/trace.h> - -#include <chrono> -#include <optional> -#include <shared_mutex> -#include <span> -#include <system_error> -#include <thread> - -ZEN_THIRD_PARTY_INCLUDES_START -#include <fmt/format.h> -#include <http_parser.h> -#include <asio.hpp> -ZEN_THIRD_PARTY_INCLUDES_END - -#if ZEN_PLATFORM_WINDOWS -# include <mstcpip.h> -#endif - -namespace zen::websocket { - -using namespace std::literals; - -ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); - -ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); - -using Clock = std::chrono::steady_clock; -using TimePoint = Clock::time_point; - -/////////////////////////////////////////////////////////////////////////////// -namespace http_header { - static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; - static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; - static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; - static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; - static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; - static constexpr std::string_view Upgrade = "Upgrade"sv; -} // namespace http_header - -/////////////////////////////////////////////////////////////////////////////// -enum class ParseMessageStatus : uint32_t -{ - kError, - kContinue, - kDone, -}; - -struct ParseMessageResult -{ - ParseMessageStatus Status{}; - size_t ByteCount{}; - std::optional<std::string> Reason; -}; - -class MessageParser -{ -public: - virtual ~MessageParser() = default; - - ParseMessageResult ParseMessage(MemoryView Msg); - void Reset(); - -protected: - MessageParser() = default; - - virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; - virtual void OnReset() = 0; - - BinaryWriter m_Stream; -}; - -ParseMessageResult -MessageParser::ParseMessage(MemoryView Msg) -{ - return OnParseMessage(Msg); -} - -void -MessageParser::Reset() -{ - OnReset(); - - m_Stream.Reset(); -} - -/////////////////////////////////////////////////////////////////////////////// -enum class HttpMessageParserType -{ - kRequest, - kResponse, - kBoth -}; - -class HttpMessageParser final : public MessageParser -{ -public: - using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; - - HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } - - virtual ~HttpMessageParser() = default; - - int32_t StatusCode() const { return m_Parser.status_code; } - bool IsUpgrade() const { return m_Parser.upgrade != 0; } - HttpHeaders& Headers() { return m_Headers; } - MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } - - std::string_view StatusText() const - { - return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); - } - - bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); - -private: - void Initialize(); - virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; - virtual void OnReset() override; - int OnMessageBegin(); - int OnUrl(MemoryView Url); - int OnStatus(MemoryView Status); - int OnHeaderField(MemoryView HeaderField); - int OnHeaderValue(MemoryView HeaderValue); - int OnHeadersComplete(); - int OnBody(MemoryView Body); - int OnMessageComplete(); - - struct StreamEntry - { - uint64_t Offset{}; - uint64_t Size{}; - }; - - struct HeaderStreamEntry - { - StreamEntry Field{}; - StreamEntry Value{}; - }; - - HttpMessageParserType m_Type; - http_parser m_Parser; - StreamEntry m_UrlEntry; - StreamEntry m_StatusEntry; - StreamEntry m_BodyEntry; - HeaderStreamEntry m_CurrentHeader; - std::vector<HeaderStreamEntry> m_HeaderEntries; - HttpHeaders m_Headers; - bool m_IsMsgComplete{false}; - - static http_parser_settings ParserSettings; -}; - -http_parser_settings HttpMessageParser::ParserSettings = { - .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, - - .on_url = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, - - .on_status = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, - - .on_header_field = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, - - .on_header_value = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, - - .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, - - .on_body = [](http_parser* P, - const char* Data, - size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, - - .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; - -void -HttpMessageParser::Initialize() -{ - http_parser_init(&m_Parser, - m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST - : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE - : HTTP_BOTH); - m_Parser.data = this; - - m_UrlEntry = {}; - m_StatusEntry = {}; - m_CurrentHeader = {}; - m_BodyEntry = {}; - - m_IsMsgComplete = false; - - m_HeaderEntries.clear(); -} - -ParseMessageResult -HttpMessageParser::OnParseMessage(MemoryView Msg) -{ - const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); - - auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; - - if (m_Parser.http_errno != 0) - { - Status = ParseMessageStatus::kError; - } - - return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; -} - -void -HttpMessageParser::OnReset() -{ - Initialize(); -} - -int -HttpMessageParser::OnMessageBegin() -{ - ZEN_ASSERT(m_IsMsgComplete == false); - ZEN_ASSERT(m_HeaderEntries.empty()); - ZEN_ASSERT(m_Headers.empty()); - - return 0; -} - -int -HttpMessageParser::OnStatus(MemoryView Status) -{ - m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; - - m_Stream.Write(Status); - - return 0; -} - -int -HttpMessageParser::OnUrl(MemoryView Url) -{ - m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; - - m_Stream.Write(Url); - - return 0; -} - -int -HttpMessageParser::OnHeaderField(MemoryView HeaderField) -{ - if (m_CurrentHeader.Value.Size > 0) - { - m_HeaderEntries.push_back(m_CurrentHeader); - m_CurrentHeader = {}; - } - - if (m_CurrentHeader.Field.Size == 0) - { - m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); - } - - m_CurrentHeader.Field.Size += HeaderField.GetSize(); - - m_Stream.Write(HeaderField); - - return 0; -} - -int -HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) -{ - if (m_CurrentHeader.Value.Size == 0) - { - m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); - } - - m_CurrentHeader.Value.Size += HeaderValue.GetSize(); - - m_Stream.Write(HeaderValue); - - return 0; -} - -int -HttpMessageParser::OnHeadersComplete() -{ - if (m_CurrentHeader.Value.Size > 0) - { - m_HeaderEntries.push_back(m_CurrentHeader); - m_CurrentHeader = {}; - } - - m_Headers.clear(); - m_Headers.reserve(m_HeaderEntries.size()); - - const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); - - for (const auto& Entry : m_HeaderEntries) - { - auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); - auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); - - m_Headers.try_emplace(std::move(Field), std::move(Value)); - } - - return 0; -} - -int -HttpMessageParser::OnBody(MemoryView Body) -{ - m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; - - m_Stream.Write(Body); - - return 0; -} - -int -HttpMessageParser::OnMessageComplete() -{ - m_IsMsgComplete = true; - - return 0; -} - -bool -HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) -{ - static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; - - OutAcceptHash = std::string(); - - if (m_Headers.contains(http_header::SecWebSocketKey) == false) - { - OutReason = "Missing header Sec-WebSocket-Key"; - return false; - } - - if (m_Headers.contains(http_header::Upgrade) == false) - { - OutReason = "Missing header Upgrade"; - return false; - } - - ExtendableStringBuilder<128> Sb; - Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; - - SHA1Stream HashStream; - HashStream.Append(Sb.Data(), Sb.Size()); - - SHA1 Hash = HashStream.GetHash(); - - OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); - Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); - - return true; -} - -/////////////////////////////////////////////////////////////////////////////// -class WebSocketMessageParser final : public MessageParser -{ -public: - WebSocketMessageParser() : MessageParser() {} - - WebSocketMessage ConsumeMessage(); - -private: - virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; - virtual void OnReset() override; - - WebSocketMessage m_Message; -}; - -ParseMessageResult -WebSocketMessageParser::OnParseMessage(MemoryView Msg) -{ - ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); - - const uint64_t PrevOffset = m_Stream.CurrentOffset(); - - if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) - { - const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); - - m_Stream.Write(Msg.Left(RemaingHeaderSize)); - Msg += RemaingHeaderSize; - - if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) - { - return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; - } - - const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); - - if (IsValidHeader == false) - { - OnReset(); - - return {.Status = ParseMessageStatus::kError, - .ByteCount = m_Stream.CurrentOffset() - PrevOffset, - .Reason = std::string("Invalid websocket message header")}; - } - - if (m_Message.MessageSize() == 0) - { - return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; - } - } - - ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); - - if (Msg.IsEmpty() == false) - { - const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); - m_Stream.Write(Msg.Left(RemaingMessageSize)); - } - - auto Status = ParseMessageStatus::kContinue; - - if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) - { - Status = ParseMessageStatus::kDone; - - BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); - - CbPackage Pkg; - if (Pkg.TryLoad(Reader) == false) - { - return {.Status = ParseMessageStatus::kError, - .ByteCount = m_Stream.CurrentOffset() - PrevOffset, - .Reason = std::string("Invalid websocket message")}; - } - - m_Message.SetBody(std::move(Pkg)); - } - - return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; -} - -void -WebSocketMessageParser::OnReset() -{ - m_Message = WebSocketMessage(); -} - -WebSocketMessage -WebSocketMessageParser::ConsumeMessage() -{ - WebSocketMessage Msg = std::move(m_Message); - m_Message = WebSocketMessage(); - - return Msg; -} - -/////////////////////////////////////////////////////////////////////////////// -class WsConnection : public std::enable_shared_from_this<WsConnection> -{ -public: - WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) - : m_Id(Id) - , m_Socket(std::move(Socket)) - , m_StartTime(Clock::now()) - , m_State() - { - } - - ~WsConnection() = default; - - std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } - - WebSocketId Id() const { return m_Id; } - asio::ip::tcp::socket& Socket() { return *m_Socket; } - TimePoint StartTime() const { return m_StartTime; } - WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } - std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } - asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } - WebSocketState Close(); - MessageParser* Parser() { return m_MsgParser.get(); } - void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } - std::mutex& WriteMutex() { return m_WriteMutex; } - -private: - WebSocketId m_Id; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - TimePoint m_StartTime; - std::atomic_uint32_t m_State; - std::unique_ptr<MessageParser> m_MsgParser; - asio::streambuf m_ReadBuffer; - std::mutex m_WriteMutex; -}; - -WebSocketState -WsConnection::Close() -{ - const auto PrevState = SetState(WebSocketState::kDisconnected); - - if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open()) - { - m_Socket->close(); - } - - return PrevState; -} - -/////////////////////////////////////////////////////////////////////////////// -class WsThreadPool -{ -public: - WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} - void Start(uint32_t ThreadCount); - void Stop(); - -private: - asio::io_service& m_IoSvc; - std::vector<std::thread> m_Threads; - std::atomic_bool m_Running{false}; -}; - -void -WsThreadPool::Start(uint32_t ThreadCount) -{ - ZEN_ASSERT(m_Threads.empty()); - - ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); - - m_Running = true; - - for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) - { - m_Threads.emplace_back([this, ThreadId = Idx + 1] { - for (;;) - { - if (m_Running == false) - { - break; - } - - try - { - m_IoSvc.run(); - } - catch (std::exception& Err) - { - ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); - } - } - - ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId); - }); - } -} - -void -WsThreadPool::Stop() -{ - if (m_Running) - { - m_Running = false; - - for (std::thread& Thread : m_Threads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } - - m_Threads.clear(); - } -} - -/////////////////////////////////////////////////////////////////////////////// -class WsServer final : public WebSocketServer -{ -public: - WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} - virtual ~WsServer() { Shutdown(); } - - virtual bool Run() override; - virtual void Shutdown() override; - - virtual void RegisterService(WebSocketService& Service) override; - virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; - virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; - - virtual void SendNotification(WebSocketMessage&& Notification) override; - virtual void SendResponse(WebSocketMessage&& Response) override; - -private: - friend class WsConnection; - - void AcceptConnection(); - void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); - - void ReadMessage(std::shared_ptr<WsConnection> Connection); - void RouteMessage(WebSocketMessage&& Msg); - void SendMessage(WebSocketMessage&& Msg); - - struct IdHasher - { - size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } - }; - - using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; - using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; - using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; - - WebSocketServerOptions m_Options; - asio::io_service m_IoSvc; - std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; - std::unique_ptr<WsThreadPool> m_ThreadPool; - ConnectionMap m_Connections; - std::shared_mutex m_ConnMutex; - std::vector<WebSocketService*> m_Services; - RequestHandlerMap m_RequestHandlers; - NotificationHandlerMap m_NotificationHandlers; - std::atomic_bool m_Running{}; -}; - -void -WsServer::RegisterService(WebSocketService& Service) -{ - m_Services.push_back(&Service); - - Service.Configure(*this); -} - -bool -WsServer::Run() -{ - static constexpr size_t ReceiveBufferSize = 256 << 10; - static constexpr size_t SendBufferSize = 256 << 10; - - m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); - - m_Acceptor->set_option(asio::ip::v6_only(false)); - m_Acceptor->set_option(asio::socket_base::reuse_address(true)); - m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); - m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); - m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); - -#if ZEN_PLATFORM_WINDOWS - // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. - // This must be used by both the client and server side, and is only effective in the absence of - // Windows Filtering Platform (WFP) callouts which can be installed by security software. - // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path - SOCKET NativeSocket = m_Acceptor->native_handle(); - int LoopbackOptionValue = 1; - DWORD OptionNumberOfBytesReturned = 0; - WSAIoctl(NativeSocket, - SIO_LOOPBACK_FAST_PATH, - &LoopbackOptionValue, - sizeof(LoopbackOptionValue), - NULL, - 0, - &OptionNumberOfBytesReturned, - 0, - 0); -#endif - - asio::error_code Ec; - m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); - - if (Ec) - { - ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); - - return false; - } - - m_Acceptor->listen(); - m_Running = true; - - ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); - - AcceptConnection(); - - m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); - m_ThreadPool->Start(m_Options.ThreadCount); - - return true; -} - -void -WsServer::Shutdown() -{ - if (m_Running) - { - ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); - - m_Running = false; - - m_Acceptor->close(); - m_Acceptor.reset(); - m_IoSvc.stop(); - - m_ThreadPool->Stop(); - } -} - -void -WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) -{ - auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); - Result.first->second.push_back(&Service); -} - -void -WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) -{ - m_RequestHandlers[Key] = &Service; -} - -void -WsServer::SendNotification(WebSocketMessage&& Notification) -{ - ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); - - SendMessage(std::move(Notification)); -} -void -WsServer::SendResponse(WebSocketMessage&& Response) -{ - ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || - Response.MessageType() == WebSocketMessageType::kStreamResponse || - Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); - - ZEN_ASSERT(Response.CorrelationId() != 0); - - SendMessage(std::move(Response)); -} - -void -WsServer::AcceptConnection() -{ - auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); - asio::ip::tcp::socket& SocketRef = *Socket.get(); - - m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { - if (m_Running) - { - if (Ec) - { - ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); - } - else - { - auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); - - ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); - - { - std::unique_lock _(m_ConnMutex); - m_Connections[Connection->Id()] = Connection; - } - - Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); - Connection->SetState(WebSocketState::kHandshaking); - - ReadMessage(Connection); - } - - AcceptConnection(); - } - }); -} - -void -WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) -{ - if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) - { - if (Ec) - { - ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value()); - } - else - { - ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); - } - } - - const WebSocketId Id = Connection->Id(); - - { - std::unique_lock _(m_ConnMutex); - if (m_Connections.contains(Id)) - { - m_Connections.erase(Id); - } - } -} - -void -WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) -{ - Connection->ReadBuffer().prepare(64 << 10); - - asio::async_read( - Connection->Socket(), - Connection->ReadBuffer(), - asio::transfer_at_least(1), - [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { - if (ReadEc) - { - return CloseConnection(Connection, ReadEc); - } - - switch (Connection->State()) - { - case WebSocketState::kHandshaking: - { - HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); - asio::const_buffer Buffer = Connection->ReadBuffer().data(); - - ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); - - Connection->ReadBuffer().consume(Result.ByteCount); - - if (Result.Status == ParseMessageStatus::kContinue) - { - return ReadMessage(Connection); - } - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWebSocket, - "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", - Connection->Id().Value(), - Connection->RemoteAddr()); - - return CloseConnection(Connection, std::error_code()); - } - - if (Parser.IsUpgrade() == false) - { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", - Connection->Id().Value(), - Connection->RemoteAddr()); - - constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; - - return async_write(Connection->Socket(), - asio::buffer(UpgradeRequiredResponse), - [this, Connection](const asio::error_code& WriteEc, std::size_t) { - if (WriteEc) - { - return CloseConnection(Connection, WriteEc); - } - - Connection->Parser()->Reset(); - Connection->SetState(WebSocketState::kHandshaking); - - ReadMessage(Connection); - }); - } - - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - - std::string AcceptHash; - std::string Reason; - const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); - - if (ValidHandshake == false) - { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '{}' FAILED, reason '{}'", - Connection->Id().Value(), - Reason); - - constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; - - return async_write(Connection->Socket(), - asio::buffer(UpgradeRequiredResponse), - [this, &Connection](const asio::error_code& WriteEc, std::size_t) { - if (WriteEc) - { - return CloseConnection(Connection, WriteEc); - } - - Connection->Parser()->Reset(); - Connection->SetState(WebSocketState::kHandshaking); - - ReadMessage(Connection); - }); - } - - ExtendableStringBuilder<128> Sb; - - Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; - Sb << "Upgrade: websocket\r\n"sv; - Sb << "Connection: Upgrade\r\n"sv; - - // TODO: Verify protocol - if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) - { - Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] - << "\r\n"; - } - - Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; - Sb << "\r\n"sv; - - ZEN_LOG_DEBUG(LogWebSocket, - "accepting handshake from connection '#{} {}'", - Connection->Id().Value(), - Connection->RemoteAddr()); - - std::string Response = Sb.ToString(); - Buffer = asio::buffer(Response); - - async_write(Connection->Socket(), - Buffer, - [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { - if (WriteEc) - { - ZEN_LOG_DEBUG(LogWebSocket, - "handshake with connection '{}' FAILED, reason '{}'", - Connection->Id().Value(), - WriteEc.message()); - - return CloseConnection(Connection, WriteEc); - } - - ZEN_LOG_DEBUG(LogWebSocket, - "handshake ({}B) with connection '#{} {}' OK", - ByteCount, - Connection->Id().Value(), - Connection->RemoteAddr()); - - Connection->SetParser(std::make_unique<WebSocketMessageParser>()); - Connection->SetState(WebSocketState::kConnected); - - ReadMessage(Connection); - }); - } - break; - - case WebSocketState::kConnected: - { - WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); - - uint64_t RemainingBytes = Connection->ReadBuffer().size(); - - while (RemainingBytes > 0) - { - MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - Connection->ReadBuffer().consume(Result.ByteCount); - RemainingBytes = Connection->ReadBuffer().size(); - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - - return CloseConnection(Connection, std::error_code()); - } - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(RemainingBytes == 0); - continue; - } - - WebSocketMessage Message = Parser.ConsumeMessage(); - Parser.Reset(); - - Message.SetSocketId(Connection->Id()); - - RouteMessage(std::move(Message)); - } - - ReadMessage(Connection); - } - break; - - default: - break; - }; - }); -} - -void -WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) -{ - switch (RoutedMessage.MessageType()) - { - case WebSocketMessageType::kRequest: - case WebSocketMessageType::kStreamRequest: - { - CbObjectView Request = RoutedMessage.Body().GetObject(); - std::string_view Method = Request["Method"].AsString(); - bool Handled = false; - bool Error = false; - std::exception Exception; - - if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) - { - WebSocketService* Service = It->second; - ZEN_ASSERT(Service); - - try - { - Handled = Service->HandleRequest(std::move(RoutedMessage)); - } - catch (std::exception& Err) - { - Exception = std::move(Err); - Error = true; - } - } - - if (Error || Handled == false) - { - std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); - - ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); - - CbObjectWriter Response; - Response << "Error"sv << ErrorText; - - WebSocketMessage ResponseMsg; - ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); - ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); - ResponseMsg.SetSocketId(RoutedMessage.SocketId()); - ResponseMsg.SetBody(Response.Save()); - - SendResponse(std::move(ResponseMsg)); - } - } - break; - - case WebSocketMessageType::kNotification: - { - CbObjectView Notification = RoutedMessage.Body().GetObject(); - std::string_view Message = Notification["Message"].AsString(); - - if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) - { - std::vector<WebSocketService*>& Handlers = It->second; - - for (WebSocketService* Handler : Handlers) - { - Handler->HandleNotification(RoutedMessage); - } - } - else - { - ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); - } - } - break; - - default: - break; - }; -} - -void -WsServer::SendMessage(WebSocketMessage&& Msg) -{ - std::shared_ptr<WsConnection> Connection; - - { - std::unique_lock _(m_ConnMutex); - - if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) - { - Connection = It->second; - } - } - - if (Connection.get() == nullptr) - { - ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); - return; - } - - if (Connection.get() != nullptr) - { - BinaryWriter Writer; - Msg.Save(Writer); - - ZEN_LOG_TRACE(LogWebSocket, - "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", - ToString(Msg.MessageType()), - Connection->Id().Value(), - Msg.MessageSize(), - Msg.CorrelationId(), - NiceBytes(Writer.Size())); - - { - ZEN_TRACE_CPU("WS::SendMessage"); - std::unique_lock _(Connection->WriteMutex()); - ZEN_TRACE_CPU("WS::WriteSocketData"); - asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); - } - } -} - -/////////////////////////////////////////////////////////////////////////////// -class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> -{ -public: - WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} - - virtual ~WsClient() { Disconnect(); } - - std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } - - virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; - virtual void Disconnect() override; - virtual bool IsConnected() const override { return false; } - virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); } - - virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; - virtual void OnNotification(NotificationCallback&& Cb) override; - virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; - -private: - WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } - MessageParser* Parser() { return m_MsgParser.get(); } - void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } - asio::streambuf& ReadBuffer() { return m_ReadBuffer; } - void TriggerEvent(WebSocketEvent Evt); - void ReadMessage(); - void RouteMessage(WebSocketMessage&& RoutedMessage); - - using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; - - asio::io_context& m_IoCtx; - WebSocketId m_Id; - std::unique_ptr<asio::ip::tcp::socket> m_Socket; - std::unique_ptr<MessageParser> m_MsgParser; - asio::streambuf m_ReadBuffer; - EventCallback m_EventCallbacks[3]; - NotificationCallback m_NotificationCallback; - PendingRequestMap m_PendingRequests; - std::mutex m_RequestMutex; - std::promise<bool> m_ConnectPromise; - std::atomic_uint32_t m_State; - std::string m_Host; - int16_t m_Port{}; -}; - -std::future<bool> -WsClient::Connect(const WebSocketConnectInfo& Info) -{ - if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) - { - return m_ConnectPromise.get_future(); - } - - SetState(WebSocketState::kHandshaking); - - try - { - asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); - m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); - - m_Socket->connect(Endpoint); - - m_Host = m_Socket->remote_endpoint().address().to_string(); - m_Port = Info.Port; - - ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); - } - catch (std::exception& Err) - { - ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); - - SetState(WebSocketState::kError); - m_Socket.reset(); - - TriggerEvent(WebSocketEvent::kDisconnected); - - m_ConnectPromise.set_value(false); - - return m_ConnectPromise.get_future(); - } - - ExtendableStringBuilder<128> Sb; - Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; - Sb << "Host: " << Info.Host << "\r\n"sv; - Sb << "Upgrade: websocket\r\n"sv; - Sb << "Connection: upgrade\r\n"sv; - Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; - - if (Info.Protocols.empty() == false) - { - Sb << "Sec-WebSocket-Protocol: "sv; - for (size_t Idx = 0; const auto& Protocol : Info.Protocols) - { - if (Idx++) - { - Sb << ", "; - } - Sb << Protocol; - } - } - - Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; - Sb << "\r\n"; - - std::string HandshakeRequest = Sb.ToString(); - asio::const_buffer Buffer = asio::buffer(HandshakeRequest); - - ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); - - m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); - m_MsgParser->Reset(); - - async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { - if (Ec) - { - ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); - - Self->Disconnect(); - } - else - { - Self->ReadMessage(); - } - }); - - return m_ConnectPromise.get_future(); -} - -void -WsClient::Disconnect() -{ - if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) - { - ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); - - if (m_Socket && m_Socket->is_open()) - { - m_Socket->close(); - m_Socket.reset(); - } - - TriggerEvent(WebSocketEvent::kDisconnected); - - { - std::unique_lock _(m_RequestMutex); - - for (auto& Kv : m_PendingRequests) - { - Kv.second.set_value(WebSocketMessage()); - } - - m_PendingRequests.clear(); - } - } -} - -std::future<WebSocketMessage> -WsClient::SendRequest(WebSocketMessage&& Request) -{ - ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); - - BinaryWriter Writer; - Request.Save(Writer); - - std::future<WebSocketMessage> FutureResponse; - - { - std::unique_lock _(m_RequestMutex); - - auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); - ZEN_ASSERT(Result.second); - - auto It = Result.first; - FutureResponse = It->second.get_future(); - } - - IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); - - async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { - if (Ec) - { - ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); - - Self->Disconnect(); - } - }); - - return FutureResponse; -} - -void -WsClient::OnNotification(NotificationCallback&& Cb) -{ - m_NotificationCallback = std::move(Cb); -} - -void -WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) -{ - m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); -} - -void -WsClient::TriggerEvent(WebSocketEvent Evt) -{ - const uint32_t Index = static_cast<uint32_t>(Evt); - - if (m_EventCallbacks[Index]) - { - m_EventCallbacks[Index](); - } -} - -void -WsClient::ReadMessage() -{ - m_ReadBuffer.prepare(64 << 10); - - async_read(*m_Socket, - m_ReadBuffer, - asio::transfer_at_least(1), - [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { - const WebSocketState State = Self->State(); - - if (State == WebSocketState::kDisconnected) - { - return; - } - - if (Ec) - { - ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); - - return Self->Disconnect(); - } - - switch (State) - { - case WebSocketState::kHandshaking: - { - HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); - - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); - - ParseMessageResult Result = Parser.ParseMessage(MessageData); - - Self->ReadBuffer().consume(size_t(Result.ByteCount)); - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); - - Self->m_ConnectPromise.set_value(false); - - return Self->Disconnect(); - } - - if (Result.Status == ParseMessageStatus::kContinue) - { - return Self->ReadMessage(); - } - - ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); - - if (Parser.StatusCode() != 101) - { - ZEN_LOG_WARN(LogWsClient, - "handshake FAILED, status '{}', status code '{}'", - Parser.StatusText(), - Parser.StatusCode()); - - Self->m_ConnectPromise.set_value(false); - - return Self->Disconnect(); - } - - ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); - - Self->SetParser(std::make_unique<WebSocketMessageParser>()); - Self->SetState(WebSocketState::kConnected); - Self->ReadMessage(); - Self->TriggerEvent(WebSocketEvent::kConnected); - - Self->m_ConnectPromise.set_value(true); - } - break; - - case WebSocketState::kConnected: - { - WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); - - uint64_t RemainingBytes = Self->ReadBuffer().size(); - - while (RemainingBytes > 0) - { - MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); - const ParseMessageResult Result = Parser.ParseMessage(MessageData); - - Self->ReadBuffer().consume(Result.ByteCount); - RemainingBytes = Self->ReadBuffer().size(); - - if (Result.Status == ParseMessageStatus::kError) - { - ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); - - Parser.Reset(); - continue; - } - - if (Result.Status == ParseMessageStatus::kContinue) - { - ZEN_ASSERT(RemainingBytes == 0); - continue; - } - - WebSocketMessage Message = Parser.ConsumeMessage(); - Parser.Reset(); - - Self->RouteMessage(std::move(Message)); - } - - Self->ReadMessage(); - } - break; - - default: - break; - } - }); -} - -void -WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) -{ - switch (RoutedMessage.MessageType()) - { - case WebSocketMessageType::kResponse: - { - std::unique_lock _(m_RequestMutex); - - if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) - { - It->second.set_value(std::move(RoutedMessage)); - m_PendingRequests.erase(It); - } - else - { - ZEN_LOG_WARN(LogWsClient, - "route request message FAILED, reason 'unknown correlation ID ({})'", - RoutedMessage.CorrelationId()); - } - } - break; - - case WebSocketMessageType::kNotification: - { - std::unique_lock _(m_RequestMutex); - - if (m_NotificationCallback) - { - m_NotificationCallback(std::move(RoutedMessage)); - } - } - break; - - default: - ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType())); - break; - }; -} - -} // namespace zen::websocket - -namespace zen { - -std::atomic_uint32_t WebSocketId::NextId{1}; - -bool -WebSocketMessage::Header::IsValid() const -{ - return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && - uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); -} - -std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; - -void -WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) -{ - m_Header.MessageType = MessageType; -} - -void -WebSocketMessage::SetBody(CbPackage&& Body) -{ - m_Body = std::move(Body); -} -void -WebSocketMessage::SetBody(CbObject&& Body) -{ - CbPackage Pkg; - Pkg.SetObject(Body); - - SetBody(std::move(Pkg)); -} - -void -WebSocketMessage::Save(BinaryWriter& Writer) -{ - Writer.Write(&m_Header, HeaderSize); - - if (m_Body.has_value()) - { - const CbObject& Obj = m_Body.value().GetObject(); - MemoryView View = Obj.GetBuffer().GetView(); - - const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); - ZEN_ASSERT(ValidationResult == CbValidateError::None); - - m_Body.value().Save(Writer); - } - - if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) - { - m_Header.CorrelationId = NextCorrelationId.fetch_add(1); - } - - m_Header.MessageSize = Writer.Size() - HeaderSize; - - Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); -} - -bool -WebSocketMessage::TryLoadHeader(MemoryView Memory) -{ - if (Memory.GetSize() < HeaderSize) - { - return false; - } - - MutableMemoryView HeaderView(&m_Header, HeaderSize); - - HeaderView.CopyFrom(Memory); - - return m_Header.IsValid(); -} - -void -WebSocketService::Configure(WebSocketServer& Server) -{ - ZEN_ASSERT(m_SocketServer == nullptr); - - m_SocketServer = &Server; - - RegisterHandlers(Server); -} - -void -WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) -{ - WebSocketMessage Message; - - Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); - Message.SetCorrelationId(CorrelationId); - Message.SetSocketId(SocketId); - Message.SetBody(std::move(StreamResponse)); - - SocketServer().SendResponse(std::move(Message)); -} - -void -WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) -{ - CbPackage Response; - Response.SetObject(std::move(StreamResponse)); - - SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); -} - -std::unique_ptr<WebSocketServer> -WebSocketServer::Create(const WebSocketServerOptions& Options) -{ - return std::make_unique<websocket::WsServer>(Options); -} - -std::shared_ptr<WebSocketClient> -WebSocketClient::Create(asio::io_context& IoCtx) -{ - return std::make_shared<websocket::WsClient>(IoCtx); -} - -} // namespace zen diff --git a/zenhttp/xmake.lua b/zenhttp/xmake.lua deleted file mode 100644 index b0dbdbc79..000000000 --- a/zenhttp/xmake.lua +++ /dev/null @@ -1,14 +0,0 @@ --- Copyright Epic Games, Inc. All Rights Reserved. - -target('zenhttp') - set_kind("static") - add_headerfiles("**.h") - add_files("**.cpp") - add_files("httpsys.cpp", {unity_ignored=true}) - add_includedirs("include", {public=true}) - add_deps("zencore") - add_packages( - "vcpkg::gsl-lite", - "vcpkg::http-parser" - ) - add_options("httpsys") diff --git a/zenhttp/zenhttp.cpp b/zenhttp/zenhttp.cpp deleted file mode 100644 index 4bd6a5697..000000000 --- a/zenhttp/zenhttp.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include <zenhttp/zenhttp.h> - -#if ZEN_WITH_TESTS - -# include <zenhttp/httpclient.h> -# include <zenhttp/httpserver.h> -# include <zenhttp/httpshared.h> - -namespace zen { - -void -zenhttp_forcelinktests() -{ - http_forcelink(); - forcelink_httpshared(); -} - -} // namespace zen - -#endif |