aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp')
-rw-r--r--src/zenhttp/httpasio.cpp1372
-rw-r--r--src/zenhttp/httpasio.h36
-rw-r--r--src/zenhttp/httpclient.cpp176
-rw-r--r--src/zenhttp/httpnull.cpp83
-rw-r--r--src/zenhttp/httpnull.h29
-rw-r--r--src/zenhttp/httpserver.cpp885
-rw-r--r--src/zenhttp/httpshared.cpp809
-rw-r--r--src/zenhttp/httpsys.cpp1674
-rw-r--r--src/zenhttp/httpsys.h90
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h47
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h181
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h315
-rw-r--r--src/zenhttp/include/zenhttp/httpshared.h163
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h256
-rw-r--r--src/zenhttp/include/zenhttp/zenhttp.h21
-rw-r--r--src/zenhttp/iothreadpool.cpp49
-rw-r--r--src/zenhttp/iothreadpool.h37
-rw-r--r--src/zenhttp/websocketasio.cpp1613
-rw-r--r--src/zenhttp/xmake.lua14
-rw-r--r--src/zenhttp/zenhttp.cpp22
20 files changed, 7872 insertions, 0 deletions
diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp
new file mode 100644
index 000000000..79b2c0a3d
--- /dev/null
+++ b/src/zenhttp/httpasio.cpp
@@ -0,0 +1,1372 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpasio.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <deque>
+#include <memory>
+#include <string_view>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+# include <mstcpip.h>
+#endif
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#define ASIO_VERBOSE_TRACE 0
+
+#if ASIO_VERBOSE_TRACE
+# define ZEN_TRACE_VERBOSE ZEN_TRACE
+#else
+# define ZEN_TRACE_VERBOSE(fmtstr, ...)
+#endif
+
+namespace zen::asio_http {
+
+using namespace std::literals;
+
+struct HttpAcceptor;
+struct HttpRequest;
+struct HttpResponse;
+struct HttpServerConnection;
+
+static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+
+inline spdlog::logger&
+InitLogger()
+{
+ spdlog::logger& Logger = logging::Get("asio");
+ // Logger.set_level(spdlog::level::trace);
+ return Logger;
+}
+
+inline spdlog::logger&
+Log()
+{
+ static spdlog::logger& g_Logger = InitLogger();
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAsioServerImpl
+{
+public:
+ HttpAsioServerImpl();
+ ~HttpAsioServerImpl();
+
+ int Start(uint16_t Port, int ThreadCount);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ HttpService* RouteRequest(std::string_view Url);
+
+ asio::io_service m_IoService;
+ asio::io_service::work m_Work{m_IoService};
+ std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
+ std::vector<std::thread> m_ThreadPool;
+
+ struct ServiceEntry
+ {
+ std::string ServiceUrlPath;
+ HttpService* Service;
+ };
+
+ RwLock m_Lock;
+ std::vector<ServiceEntry> m_UriHandlers;
+};
+
+/**
+ * This is the class which request handlers use to interact with the server instance
+ */
+
+class HttpAsioServerRequest : public HttpServerRequest
+{
+public:
+ HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer);
+ ~HttpAsioServerRequest();
+
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
+
+ virtual IoBuffer ReadPayload() override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
+ virtual bool TryGetRanges(HttpRanges& Ranges) override;
+
+ using HttpServerRequest::WriteResponse;
+
+ HttpAsioServerRequest(const HttpAsioServerRequest&) = delete;
+ HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete;
+
+ asio_http::HttpRequest& m_Request;
+ IoBuffer m_PayloadBuffer;
+ std::unique_ptr<HttpResponse> m_Response;
+};
+
+struct HttpRequest
+{
+ explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {}
+
+ void Initialize();
+ size_t ConsumeData(const char* InputData, size_t DataSize);
+ void ResetState();
+
+ HttpVerb RequestVerb() const { return m_RequestVerb; }
+ bool IsKeepAlive() const { return m_KeepAlive; }
+ std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; }
+ std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); }
+ IoBuffer Body() { return m_BodyBuffer; }
+
+ inline HttpContentType ContentType()
+ {
+ if (m_ContentTypeHeaderIndex < 0)
+ {
+ return HttpContentType::kUnknownContentType;
+ }
+
+ return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value);
+ }
+
+ inline HttpContentType AcceptType()
+ {
+ if (m_AcceptHeaderIndex < 0)
+ {
+ return HttpContentType::kUnknownContentType;
+ }
+
+ return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value);
+ }
+
+ Oid SessionId() const { return m_SessionId; }
+ int RequestId() const { return m_RequestId; }
+
+ std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); }
+
+private:
+ struct HeaderEntry
+ {
+ HeaderEntry() = default;
+
+ HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {}
+
+ std::string_view Name;
+ std::string_view Value;
+ };
+
+ HttpServerConnection& m_Connection;
+ char* m_HeaderCursor = m_HeaderBuffer;
+ char* m_Url = nullptr;
+ size_t m_UrlLength = 0;
+ char* m_QueryString = nullptr;
+ size_t m_QueryLength = 0;
+ char* m_CurrentHeaderName = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderNameLength = 0;
+ char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderValueLength = 0;
+ std::vector<HeaderEntry> m_Headers;
+ int8_t m_ContentLengthHeaderIndex;
+ int8_t m_AcceptHeaderIndex;
+ int8_t m_ContentTypeHeaderIndex;
+ int8_t m_RangeHeaderIndex;
+ HttpVerb m_RequestVerb;
+ bool m_KeepAlive = false;
+ bool m_Expect100Continue = false;
+ int m_RequestId = -1;
+ Oid m_SessionId{};
+ IoBuffer m_BodyBuffer;
+ uint64_t m_BodyPosition = 0;
+ http_parser m_Parser;
+ char m_HeaderBuffer[1024];
+ std::string m_NormalizedUrl;
+
+ void AppendCurrentHeader();
+
+ int OnMessageBegin();
+ int OnUrl(const char* Data, size_t Bytes);
+ int OnHeader(const char* Data, size_t Bytes);
+ int OnHeaderValue(const char* Data, size_t Bytes);
+ int OnHeadersComplete();
+ int OnBody(const char* Data, size_t Bytes);
+ int OnMessageComplete();
+
+ static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); }
+ static http_parser_settings s_ParserSettings;
+};
+
+struct HttpResponse
+{
+public:
+ HttpResponse() = default;
+ explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {}
+
+ void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
+ {
+ m_ResponseCode = ResponseCode;
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
+
+ m_DataBuffers.reserve(ChunkCount);
+
+ for (IoBuffer& Buffer : BlobList)
+ {
+#if 1
+ m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+#else
+ IoBuffer TempBuffer = std::move(Buffer);
+ TempBuffer.MakeOwned();
+ m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer));
+#endif
+ }
+
+ uint64_t LocalDataSize = 0;
+
+ m_AsioBuffers.push_back({}); // Placeholder for header
+
+ for (IoBuffer& Buffer : m_DataBuffers)
+ {
+ uint64_t BufferDataSize = Buffer.Size();
+
+ ZEN_ASSERT(BufferDataSize);
+
+ LocalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
+ {
+ // TODO: Use direct file transfer, via TransmitFile/sendfile
+ //
+ // this looks like it requires some custom asio plumbing however
+
+ m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
+ }
+ else
+ {
+ // Send from memory
+
+ m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
+ }
+ }
+ m_ContentLength = LocalDataSize;
+
+ auto Headers = GetHeaders();
+ m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size());
+ }
+
+ uint16_t ResponseCode() const { return m_ResponseCode; }
+ uint64_t ContentLength() const { return m_ContentLength; }
+
+ const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; }
+
+ std::string_view GetHeaders()
+ {
+ m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
+ << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Length: " << ContentLength() << "\r\n"sv;
+
+ if (!m_IsKeepAlive)
+ {
+ m_Headers << "Connection: close\r\n"sv;
+ }
+
+ m_Headers << "\r\n"sv;
+
+ return m_Headers;
+ }
+
+ void SuppressPayload() { m_AsioBuffers.resize(1); }
+
+private:
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
+ std::vector<IoBuffer> m_DataBuffers;
+ std::vector<asio::const_buffer> m_AsioBuffers;
+ ExtendableStringBuilder<160> m_Headers;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection>
+{
+ HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
+ ~HttpServerConnection();
+
+ void HandleNewRequest();
+ void TerminateConnection();
+ void HandleRequest();
+
+ std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+
+private:
+ enum class RequestState
+ {
+ kInitialState,
+ kInitialRead,
+ kReadingMore,
+ kWriting,
+ kWritingFinal,
+ kDone,
+ kTerminated
+ };
+
+ RequestState m_RequestState = RequestState::kInitialState;
+ HttpRequest m_RequestData{*this};
+
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false);
+ void OnError();
+
+ HttpAsioServerImpl& m_Server;
+ asio::streambuf m_RequestBuffer;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::atomic<uint32_t> m_RequestCounter{0};
+ uint32_t m_ConnectionId = 0;
+ Ref<IHttpPackageHandler> m_PackageHandler;
+
+ RwLock m_ResponsesLock;
+ std::deque<std::unique_ptr<HttpResponse>> m_Responses;
+};
+
+std::atomic<uint32_t> g_ConnectionIdCounter{0};
+
+HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+: m_Server(Server)
+, m_Socket(std::move(Socket))
+, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
+{
+ ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
+}
+
+HttpServerConnection::~HttpServerConnection()
+{
+ ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
+}
+
+void
+HttpServerConnection::HandleNewRequest()
+{
+ m_RequestData.Initialize();
+
+ EnqueueRead();
+}
+
+void
+HttpServerConnection::TerminateConnection()
+{
+ m_RequestState = RequestState::kTerminated;
+
+ std::error_code Ec;
+ m_Socket->close(Ec);
+}
+
+void
+HttpServerConnection::EnqueueRead()
+{
+ if (m_RequestState == RequestState::kInitialRead)
+ {
+ m_RequestState = RequestState::kReadingMore;
+ }
+ else
+ {
+ m_RequestState = RequestState::kInitialRead;
+ }
+
+ m_RequestBuffer.prepare(64 * 1024);
+
+ asio::async_read(*m_Socket.get(),
+ m_RequestBuffer,
+ asio::transfer_at_least(1),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
+}
+
+void
+HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead)
+ {
+ ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ return;
+ }
+ else
+ {
+ ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ return OnError();
+ }
+ }
+
+ ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'",
+ m_ConnectionId,
+ m_RequestCounter.load(std::memory_order_relaxed),
+ zen::GetCurrentThreadId(),
+ NiceBytes(ByteCount));
+
+ while (m_RequestBuffer.size())
+ {
+ const asio::const_buffer& InputBuffer = m_RequestBuffer.data();
+
+ size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size());
+ if (Result == ~0ull)
+ {
+ return OnError();
+ }
+
+ m_RequestBuffer.consume(Result);
+ }
+
+ switch (m_RequestState)
+ {
+ case RequestState::kDone:
+ case RequestState::kWritingFinal:
+ case RequestState::kTerminated:
+ break;
+
+ default:
+ EnqueueRead();
+ break;
+ }
+}
+
+void
+HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop)
+{
+ if (Ec)
+ {
+ ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ OnError();
+ }
+ else
+ {
+ ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'",
+ m_ConnectionId,
+ m_RequestCounter.load(std::memory_order_relaxed),
+ zen::GetCurrentThreadId(),
+ NiceBytes(ByteCount));
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ m_RequestState = RequestState::kDone;
+
+ m_Socket->close();
+ }
+ else
+ {
+ if (Pop)
+ {
+ RwLock::ExclusiveLockScope _(m_ResponsesLock);
+ m_Responses.pop_front();
+ }
+
+ m_RequestCounter.fetch_add(1);
+ }
+ }
+}
+
+void
+HttpServerConnection::OnError()
+{
+ m_Socket->close();
+}
+
+void
+HttpServerConnection::HandleRequest()
+{
+ if (!m_RequestData.IsKeepAlive())
+ {
+ m_RequestState = RequestState::kWritingFinal;
+
+ std::error_code Ec;
+ m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ m_RequestState = RequestState::kWriting;
+ }
+
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body());
+
+ ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed));
+
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ {
+ try
+ {
+ Service->HandleRequest(Request);
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
+
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
+ }
+
+ if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
+ {
+ // Transmit the response
+
+ if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ Response->SuppressPayload();
+ }
+
+ auto ResponseBuffers = Response->AsioBuffers();
+
+ uint64_t ResponseLength = 0;
+
+ for (auto& Buffer : ResponseBuffers)
+ {
+ ResponseLength += Buffer.size();
+ }
+
+ {
+ RwLock::ExclusiveLockScope _(m_ResponsesLock);
+ m_Responses.push_back(std::move(Response));
+ }
+
+ // TODO: should cork/uncork for Linux?
+
+ asio::async_write(*m_Socket.get(),
+ ResponseBuffers,
+ asio::transfer_exactly(ResponseLength),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, true);
+ });
+
+ return;
+ }
+ }
+
+ if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ std::string_view Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "\r\n"sv;
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Connection: close\r\n"
+ "\r\n"sv;
+ }
+
+ asio::async_write(
+ *m_Socket.get(),
+ asio::buffer(Response),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
+ }
+ else
+ {
+ std::string_view Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "Connection: close\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+ }
+
+ asio::async_write(
+ *m_Socket.get(),
+ asio::buffer(Response),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// HttpRequest
+//
+
+http_parser_settings HttpRequest::s_ParserSettings{
+ .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
+ .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
+ .on_status =
+ [](http_parser* p, const char* Data, size_t ByteCount) {
+ ZEN_UNUSED(p, Data, ByteCount);
+ return 0;
+ },
+ .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
+ .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
+ .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
+ .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
+ .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
+ .on_chunk_header{},
+ .on_chunk_complete{}};
+
+void
+HttpRequest::Initialize()
+{
+ http_parser_init(&m_Parser, HTTP_REQUEST);
+ m_Parser.data = this;
+
+ ResetState();
+}
+
+size_t
+HttpRequest::ConsumeData(const char* InputData, size_t DataSize)
+{
+ const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
+
+ http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
+
+ if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
+ {
+ ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
+ return ~0ull;
+ }
+
+ return ConsumedBytes;
+}
+
+int
+HttpRequest::OnUrl(const char* Data, size_t Bytes)
+{
+ if (!m_Url)
+ {
+ ZEN_ASSERT_SLOW(m_UrlLength == 0);
+ m_Url = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_UrlLength += Bytes;
+
+ return 0;
+}
+
+int
+HttpRequest::OnHeader(const char* Data, size_t Bytes)
+{
+ if (m_CurrentHeaderValueLength)
+ {
+ AppendCurrentHeader();
+
+ m_CurrentHeaderNameLength = 0;
+ m_CurrentHeaderValueLength = 0;
+ m_CurrentHeaderName = m_HeaderCursor;
+ }
+ else if (m_CurrentHeaderName == nullptr)
+ {
+ m_CurrentHeaderName = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_CurrentHeaderNameLength += Bytes;
+
+ return 0;
+}
+
+void
+HttpRequest::AppendCurrentHeader()
+{
+ std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength);
+ std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength);
+
+ const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
+
+ if (HeaderHash == HashContentLength)
+ {
+ m_ContentLengthHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashAccept)
+ {
+ m_AcceptHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashContentType)
+ {
+ m_ContentTypeHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashSession)
+ {
+ m_SessionId = Oid::FromHexString(HeaderValue);
+ }
+ else if (HeaderHash == HashRequest)
+ {
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ }
+ else if (HeaderHash == HashExpect)
+ {
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ }
+ else if (HeaderHash == HashRange)
+ {
+ m_RangeHeaderIndex = (int8_t)m_Headers.size();
+ }
+
+ m_Headers.emplace_back(HeaderName, HeaderValue);
+}
+
+int
+HttpRequest::OnHeaderValue(const char* Data, size_t Bytes)
+{
+ if (m_CurrentHeaderValueLength == 0)
+ {
+ m_CurrentHeaderValue = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_CurrentHeaderValueLength += Bytes;
+
+ return 0;
+}
+
+static void
+NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl)
+{
+ bool LastCharWasSeparator = false;
+ for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex)
+ {
+ const char UrlChar = Url[UrlIndex];
+ const bool IsSeparator = (UrlChar == '/');
+
+ if (IsSeparator && LastCharWasSeparator)
+ {
+ if (NormalizedUrl.empty())
+ {
+ NormalizedUrl.reserve(UrlLength);
+ NormalizedUrl.append(Url, UrlIndex);
+ }
+
+ if (!LastCharWasSeparator)
+ {
+ NormalizedUrl.push_back('/');
+ }
+ }
+ else if (!NormalizedUrl.empty())
+ {
+ NormalizedUrl.push_back(UrlChar);
+ }
+
+ LastCharWasSeparator = IsSeparator;
+ }
+}
+
+int
+HttpRequest::OnHeadersComplete()
+{
+ if (m_CurrentHeaderValueLength)
+ {
+ AppendCurrentHeader();
+ }
+
+ if (m_ContentLengthHeaderIndex >= 0)
+ {
+ std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value;
+ uint64_t ContentLength = 0;
+ std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength);
+
+ if (ContentLength)
+ {
+ m_BodyBuffer = IoBuffer(ContentLength);
+ }
+
+ m_BodyBuffer.SetContentType(ContentType());
+
+ m_BodyPosition = 0;
+ }
+
+ m_KeepAlive = !!http_should_keep_alive(&m_Parser);
+
+ switch (m_Parser.method)
+ {
+ case HTTP_GET:
+ m_RequestVerb = HttpVerb::kGet;
+ break;
+
+ case HTTP_POST:
+ m_RequestVerb = HttpVerb::kPost;
+ break;
+
+ case HTTP_PUT:
+ m_RequestVerb = HttpVerb::kPut;
+ break;
+
+ case HTTP_DELETE:
+ m_RequestVerb = HttpVerb::kDelete;
+ break;
+
+ case HTTP_HEAD:
+ m_RequestVerb = HttpVerb::kHead;
+ break;
+
+ case HTTP_COPY:
+ m_RequestVerb = HttpVerb::kCopy;
+ break;
+
+ case HTTP_OPTIONS:
+ m_RequestVerb = HttpVerb::kOptions;
+ break;
+
+ default:
+ ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
+ break;
+ }
+
+ std::string_view Url(m_Url, m_UrlLength);
+
+ if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos)
+ {
+ m_UrlLength = QuerySplit;
+ m_QueryString = m_Url + QuerySplit + 1;
+ m_QueryLength = Url.size() - QuerySplit - 1;
+ }
+
+ NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl);
+
+ return 0;
+}
+
+int
+HttpRequest::OnBody(const char* Data, size_t Bytes)
+{
+ if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
+ {
+ ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
+ (m_BodyPosition + Bytes) - m_BodyBuffer.Size());
+ return 1;
+ }
+ memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
+ m_BodyPosition += Bytes;
+
+ if (http_body_is_final(&m_Parser))
+ {
+ if (m_BodyPosition != m_BodyBuffer.Size())
+ {
+ ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+void
+HttpRequest::ResetState()
+{
+ m_HeaderCursor = m_HeaderBuffer;
+ m_CurrentHeaderName = nullptr;
+ m_CurrentHeaderNameLength = 0;
+ m_CurrentHeaderValue = nullptr;
+ m_CurrentHeaderValueLength = 0;
+ m_CurrentHeaderName = nullptr;
+ m_Url = nullptr;
+ m_UrlLength = 0;
+ m_QueryString = nullptr;
+ m_QueryLength = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
+ m_Headers.clear();
+ m_NormalizedUrl.clear();
+}
+
+int
+HttpRequest::OnMessageBegin()
+{
+ return 0;
+}
+
+int
+HttpRequest::OnMessageComplete()
+{
+ m_Connection.HandleRequest();
+
+ ResetState();
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService, asio::ip::tcp::v6())
+ {
+ m_Acceptor.set_option(asio::ip::v6_only(false));
+#if ZEN_PLATFORM_WINDOWS
+ // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
+ typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address;
+ m_Acceptor.set_option(excluse_address(true));
+#else // ZEN_PLATFORM_WINDOWS
+ m_Acceptor.set_option(asio::socket_base::reuse_address(false));
+#endif // ZEN_PLATFORM_WINDOWS
+
+ m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ uint16_t EffectivePort = BasePort;
+
+ asio::error_code BindErrorCode;
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ // Sharing violation implies the port is being used by another process
+ for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode == asio::error::access_denied)
+ {
+ EffectivePort = 0;
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode)
+ {
+ ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message());
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
+ m_Acceptor.listen();
+
+ ZEN_INFO("Started asio server at port '{}'", EffectivePort);
+ }
+
+ void Start()
+ {
+ m_Acceptor.listen();
+ InitAccept();
+ }
+
+ void Stop() { m_IsStopped = true; }
+
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
+ asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
+ m_Acceptor.local_endpoint().address().to_string(),
+ m_Acceptor.local_endpoint().port(),
+ Ec.message());
+ }
+ else
+ {
+ // New connection established, pass socket ownership into connection object
+ // and initiate request handling loop. The connection lifetime is
+ // managed by the async read/write loop by passing the shared
+ // reference to the callbacks.
+
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ m_Acceptor.close();
+ }
+ });
+ }
+
+ int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
+
+private:
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+ asio::ip::tcp::acceptor m_Acceptor;
+ std::atomic<bool> m_IsStopped{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer)
+: m_Request(Request)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const int PrefixLength = Service.UriPrefixLength();
+
+ std::string_view Uri = Request.Url();
+ Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size())));
+ m_Uri = Uri;
+ m_UriWithExtension = Uri;
+ m_QueryString = Request.QueryString();
+
+ m_Verb = Request.RequestVerb();
+ m_ContentLength = Request.Body().Size();
+ m_ContentType = Request.ContentType();
+
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
+ // Parse any extension, to allow requesting a particular response encoding via the URL
+
+ {
+ std::string_view UriSuffix8{m_Uri};
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1));
+ }
+ }
+ }
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = Request.AcceptType();
+ }
+}
+
+HttpAsioServerRequest::~HttpAsioServerRequest()
+{
+}
+
+Oid
+HttpAsioServerRequest::ParseSessionId() const
+{
+ return m_Request.SessionId();
+}
+
+uint32_t
+HttpAsioServerRequest::ParseRequestId() const
+{
+ return m_Request.RequestId();
+}
+
+IoBuffer
+HttpAsioServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpResponse(HttpContentType::kBinary));
+ std::array<IoBuffer, 0> Empty;
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpResponse(ContentType));
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(!m_Response);
+ m_Response.reset(new HttpResponse(ContentType));
+
+ IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList);
+}
+
+void
+HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ ZEN_ASSERT(!m_Response);
+
+ // Not one bit async, innit
+ ContinuationHandler(*this);
+}
+
+bool
+HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges)
+{
+ return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpAsioServerImpl::HttpAsioServerImpl()
+{
+}
+
+HttpAsioServerImpl::~HttpAsioServerImpl()
+{
+}
+
+int
+HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount)
+{
+ ZEN_ASSERT(ThreadCount > 0);
+
+ ZEN_INFO("starting asio http with {} service threads", ThreadCount);
+
+ m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port));
+ m_Acceptor->Start();
+
+ for (int i = 0; i < ThreadCount; ++i)
+ {
+ m_ThreadPool.emplace_back([this, Index = i + 1] {
+ SetCurrentThreadName(fmt::format("asio worker {}", Index));
+
+ try
+ {
+ m_IoService.run();
+ }
+ catch (std::exception& e)
+ {
+ ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what());
+ }
+ });
+ }
+
+ ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort());
+
+ return m_Acceptor->GetAcceptPort();
+}
+
+void
+HttpAsioServerImpl::Stop()
+{
+ m_Acceptor->Stop();
+ m_IoService.stop();
+ for (auto& Thread : m_ThreadPool)
+ {
+ Thread.join();
+ }
+}
+
+void
+HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
+{
+ std::string_view UrlPath(InUrlPath);
+ Service.SetUriPrefixLength(UrlPath.size());
+ if (!UrlPath.empty() && UrlPath.back() == '/')
+ {
+ UrlPath.remove_suffix(1);
+ }
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_UriHandlers.push_back({std::string(UrlPath), &Service});
+}
+
+HttpService*
+HttpAsioServerImpl::RouteRequest(std::string_view Url)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ HttpService* CandidateService = nullptr;
+ std::string::size_type CandidateMatchSize = 0;
+ for (const ServiceEntry& SvcEntry : m_UriHandlers)
+ {
+ const std::string& SvcUrl = SvcEntry.ServiceUrlPath;
+ const std::string::size_type SvcUrlSize = SvcUrl.size();
+ if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 &&
+ ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/')))
+ {
+ CandidateMatchSize = SvcUrl.size();
+ CandidateService = SvcEntry.Service;
+ }
+ }
+
+ return CandidateService;
+}
+
+} // namespace zen::asio_http
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
+{
+ ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest));
+}
+
+HttpAsioServer::~HttpAsioServer()
+{
+ try
+ {
+ m_Impl->Stop();
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("Caught exception stopping http asio server: {}", ex.what());
+ }
+}
+
+void
+HttpAsioServer::RegisterService(HttpService& Service)
+{
+ m_Impl->RegisterService(Service.BaseUri(), Service);
+}
+
+int
+HttpAsioServer::Initialize(int BasePort)
+{
+ m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u));
+ return m_BasePort;
+}
+
+void
+HttpAsioServer::Run(bool IsInteractive)
+{
+ const bool TestMode = !IsInteractive;
+
+ int WaitTimeout = -1;
+ if (!TestMode)
+ {
+ WaitTimeout = 1000;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit");
+ }
+
+ do
+ {
+ if (!TestMode && _kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#else
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit");
+ }
+
+ do
+ {
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#endif
+}
+
+void
+HttpAsioServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h
new file mode 100644
index 000000000..716145955
--- /dev/null
+++ b/src/zenhttp/httpasio.h
@@ -0,0 +1,36 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+#include <memory>
+
+namespace zen {
+
+namespace asio_http {
+ struct HttpServerConnection;
+ struct HttpAcceptor;
+ struct HttpAsioServerImpl;
+} // namespace asio_http
+
+class HttpAsioServer : public HttpServer
+{
+public:
+ HttpAsioServer();
+ ~HttpAsioServer();
+
+ virtual void RegisterService(HttpService& Service) override;
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool IsInteractiveSession) override;
+ virtual void RequestExit() override;
+
+private:
+ Event m_ShutdownEvent;
+ int m_BasePort = 0;
+
+ std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
new file mode 100644
index 000000000..e6813d407
--- /dev/null
+++ b/src/zenhttp/httpclient.cpp
@@ -0,0 +1,176 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpserver.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+#include <zencore/sharedbuffer.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+#include <zenhttp/httpshared.h>
+
+static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpClient::Response
+FromCprResponse(cpr::Response& InResponse)
+{
+ return {.StatusCode = int(InResponse.status_code)};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri)
+{
+ StringBuilder<32> SessionId;
+ GetSessionId().ToString(SessionId);
+ m_SessionId = SessionId;
+}
+
+HttpClient::~HttpClient()
+{
+}
+
+HttpClient::Response
+HttpClient::TransactPackage(std::string_view Url, CbPackage Package)
+{
+ cpr::Session Sess;
+ Sess.SetUrl(m_BaseUri + std::string(Url));
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++HttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (Attachments.empty() == false)
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ IoHash Hash = Attachment.GetHash();
+
+ Writer.AddHash(Hash);
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ Sess.SetHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
+ Sess.SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (FilterResponse.status_code == 200)
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size());
+ CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer);
+
+ for (auto& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ else
+ {
+ // This should be an error -- server asked to have something we can't find
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ Sess.SetHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
+ Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (!IsHttpSuccessCode(FilterResponse.status_code))
+ {
+ return FromCprResponse(FilterResponse);
+ }
+
+ IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size());
+
+ if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end())
+ {
+ HttpContentType ContentType = ParseContentType(It->second);
+
+ ResponseBuffer.SetContentType(ContentType);
+ }
+
+ return {.StatusCode = int(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+}
+
+HttpClient::Response
+HttpClient::Put(std::string_view Url, IoBuffer Payload)
+{
+ ZEN_UNUSED(Url);
+ ZEN_UNUSED(Payload);
+ return {};
+}
+
+HttpClient::Response
+HttpClient::Get(std::string_view Url)
+{
+ ZEN_UNUSED(Url);
+ return {};
+}
+
+HttpClient::Response
+HttpClient::Delete(std::string_view Url)
+{
+ ZEN_UNUSED(Url);
+ return {};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("httpclient")
+{
+ using namespace std::literals;
+
+ SUBCASE("client") {}
+}
+
+void
+httpclient_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpnull.cpp b/src/zenhttp/httpnull.cpp
new file mode 100644
index 000000000..a6e1d3567
--- /dev/null
+++ b/src/zenhttp/httpnull.cpp
@@ -0,0 +1,83 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpnull.h"
+
+#include <zencore/logging.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+#endif
+
+namespace zen {
+
+HttpNullServer::HttpNullServer()
+{
+}
+
+HttpNullServer::~HttpNullServer()
+{
+}
+
+void
+HttpNullServer::RegisterService(HttpService& Service)
+{
+ ZEN_UNUSED(Service);
+}
+
+int
+HttpNullServer::Initialize(int BasePort)
+{
+ return BasePort;
+}
+
+void
+HttpNullServer::Run(bool IsInteractiveSession)
+{
+ const bool TestMode = !IsInteractiveSession;
+
+ int WaitTimeout = -1;
+ if (!TestMode)
+ {
+ WaitTimeout = 1000;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit");
+ }
+
+ do
+ {
+ if (!TestMode && _kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#else
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit");
+ }
+
+ do
+ {
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#endif
+}
+
+void
+HttpNullServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpnull.h b/src/zenhttp/httpnull.h
new file mode 100644
index 000000000..74f021f6b
--- /dev/null
+++ b/src/zenhttp/httpnull.h
@@ -0,0 +1,29 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+/**
+ * @brief Null implementation of "http" server. Does nothing
+ */
+
+class HttpNullServer : public HttpServer
+{
+public:
+ HttpNullServer();
+ ~HttpNullServer();
+
+ virtual void RegisterService(HttpService& Service) override;
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool IsInteractiveSession) override;
+ virtual void RequestExit() override;
+
+private:
+ Event m_ShutdownEvent;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
new file mode 100644
index 000000000..671cbd319
--- /dev/null
+++ b/src/zenhttp/httpserver.cpp
@@ -0,0 +1,885 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpserver.h>
+
+#include "httpasio.h"
+#include "httpnull.h"
+#include "httpsys.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/thread.h>
+#include <zenhttp/httpshared.h>
+
+#include <charconv>
+#include <mutex>
+#include <span>
+#include <string_view>
+
+namespace zen {
+
+using namespace std::literals;
+
+std::string_view
+MapContentTypeToString(HttpContentType ContentType)
+{
+ switch (ContentType)
+ {
+ default:
+ case HttpContentType::kUnknownContentType:
+ case HttpContentType::kBinary:
+ return "application/octet-stream"sv;
+
+ case HttpContentType::kText:
+ return "text/plain"sv;
+
+ case HttpContentType::kJSON:
+ return "application/json"sv;
+
+ case HttpContentType::kCbObject:
+ return "application/x-ue-cb"sv;
+
+ case HttpContentType::kCbPackage:
+ return "application/x-ue-cbpkg"sv;
+
+ case HttpContentType::kCbPackageOffer:
+ return "application/x-ue-offer"sv;
+
+ case HttpContentType::kCompressedBinary:
+ return "application/x-ue-comp"sv;
+
+ case HttpContentType::kYAML:
+ return "text/yaml"sv;
+
+ case HttpContentType::kHTML:
+ return "text/html"sv;
+
+ case HttpContentType::kJavaScript:
+ return "application/javascript"sv;
+
+ case HttpContentType::kCSS:
+ return "text/css"sv;
+
+ case HttpContentType::kPNG:
+ return "image/png"sv;
+
+ case HttpContentType::kIcon:
+ return "image/x-icon"sv;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Note that in addition to MIME types we accept abbreviated versions, for
+// use in suffix parsing as well as for convenience when using curl
+
+static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv);
+static constinit uint32_t HashJson = HashStringDjb2("json"sv);
+static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv);
+static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv);
+static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv);
+static constinit uint32_t HashText = HashStringDjb2("text/plain"sv);
+static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv);
+static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv);
+static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv);
+static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv);
+static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv);
+static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv);
+static constinit uint32_t HashHtml = HashStringDjb2("html"sv);
+static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv);
+static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv);
+static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv);
+static constinit uint32_t HashCss = HashStringDjb2("css"sv);
+static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv);
+static constinit uint32_t HashPng = HashStringDjb2("png"sv);
+static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv);
+static constinit uint32_t HashIcon = HashStringDjb2("ico"sv);
+static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv);
+
+std::once_flag InitContentTypeLookup;
+
+struct HashedTypeEntry
+{
+ uint32_t Hash;
+ HttpContentType Type;
+} TypeHashTable[] = {
+ // clang-format off
+ {HashBinary, HttpContentType::kBinary},
+ {HashApplicationCompactBinary, HttpContentType::kCbObject},
+ {HashCompactBinary, HttpContentType::kCbObject},
+ {HashCompactBinaryPackage, HttpContentType::kCbPackage},
+ {HashCompactBinaryPackageShort, HttpContentType::kCbPackage},
+ {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer},
+ {HashJson, HttpContentType::kJSON},
+ {HashApplicationJson, HttpContentType::kJSON},
+ {HashYaml, HttpContentType::kYAML},
+ {HashTextYaml, HttpContentType::kYAML},
+ {HashText, HttpContentType::kText},
+ {HashCompressedBinary, HttpContentType::kCompressedBinary},
+ {HashHtml, HttpContentType::kHTML},
+ {HashTextHtml, HttpContentType::kHTML},
+ {HashJavaScript, HttpContentType::kJavaScript},
+ {HashApplicationJavaScript, HttpContentType::kJavaScript},
+ {HashCss, HttpContentType::kCSS},
+ {HashTextCss, HttpContentType::kCSS},
+ {HashPng, HttpContentType::kPNG},
+ {HashImagePng, HttpContentType::kPNG},
+ {HashIcon, HttpContentType::kIcon},
+ {HashImageIcon, HttpContentType::kIcon},
+ // clang-format on
+};
+
+HttpContentType
+ParseContentTypeImpl(const std::string_view& ContentTypeString)
+{
+ if (!ContentTypeString.empty())
+ {
+ const uint32_t CtHash = HashStringDjb2(ContentTypeString);
+
+ if (auto It = std::lower_bound(std::begin(TypeHashTable),
+ std::end(TypeHashTable),
+ CtHash,
+ [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; });
+ It != std::end(TypeHashTable))
+ {
+ if (It->Hash == CtHash)
+ {
+ return It->Type;
+ }
+ }
+ }
+
+ return HttpContentType::kUnknownContentType;
+}
+
+HttpContentType
+ParseContentTypeInit(const std::string_view& ContentTypeString)
+{
+ std::call_once(InitContentTypeLookup, [] {
+ std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) {
+ return Lhs.Hash < Rhs.Hash;
+ });
+
+ // validate that there are no hash collisions
+
+ uint32_t LastHash = 0;
+
+ for (const auto& Item : TypeHashTable)
+ {
+ ZEN_ASSERT(LastHash != Item.Hash);
+ LastHash = Item.Hash;
+ }
+ });
+
+ ParseContentType = ParseContentTypeImpl;
+
+ return ParseContentTypeImpl(ContentTypeString);
+}
+
+HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit;
+
+bool
+TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
+{
+ if (RangeHeader.empty())
+ {
+ return false;
+ }
+
+ const size_t Count = Ranges.size();
+
+ std::size_t UnitDelim = RangeHeader.find_first_of('=');
+ if (UnitDelim == std::string_view::npos)
+ {
+ return false;
+ }
+
+ // only bytes for now
+ std::string_view Unit = RangeHeader.substr(0, UnitDelim);
+ if (Unit != "bytes"sv)
+ {
+ return false;
+ }
+
+ std::string_view Tokens = RangeHeader.substr(UnitDelim);
+ while (!Tokens.empty())
+ {
+ // Skip =,
+ Tokens = Tokens.substr(1);
+
+ size_t Delim = Tokens.find_first_of(',');
+ if (Delim == std::string_view::npos)
+ {
+ Delim = Tokens.length();
+ }
+
+ std::string_view Token = Tokens.substr(0, Delim);
+ Tokens = Tokens.substr(Delim);
+
+ Delim = Token.find_first_of('-');
+ if (Delim == std::string_view::npos)
+ {
+ return false;
+ }
+
+ const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim));
+ const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1));
+
+ if (Start.has_value() && End.has_value() && End.value() > Start.value())
+ {
+ Ranges.push_back({.Start = Start.value(), .End = End.value()});
+ }
+ else if (Start)
+ {
+ Ranges.push_back({.Start = Start.value()});
+ }
+ else if (End)
+ {
+ Ranges.push_back({.End = End.value()});
+ }
+ }
+
+ return Count != Ranges.size();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+const std::string_view
+ToString(HttpVerb Verb)
+{
+ switch (Verb)
+ {
+ case HttpVerb::kGet:
+ return "GET"sv;
+ case HttpVerb::kPut:
+ return "PUT"sv;
+ case HttpVerb::kPost:
+ return "POST"sv;
+ case HttpVerb::kDelete:
+ return "DELETE"sv;
+ case HttpVerb::kHead:
+ return "HEAD"sv;
+ case HttpVerb::kCopy:
+ return "COPY"sv;
+ case HttpVerb::kOptions:
+ return "OPTIONS"sv;
+ default:
+ return "???"sv;
+ }
+}
+
+std::string_view
+ReasonStringForHttpResultCode(int HttpCode)
+{
+ switch (HttpCode)
+ {
+ // 1xx Informational
+
+ case 100:
+ return "Continue"sv;
+ case 101:
+ return "Switching Protocols"sv;
+
+ // 2xx Success
+
+ case 200:
+ return "OK"sv;
+ case 201:
+ return "Created"sv;
+ case 202:
+ return "Accepted"sv;
+ case 204:
+ return "No Content"sv;
+ case 205:
+ return "Reset Content"sv;
+ case 206:
+ return "Partial Content"sv;
+
+ // 3xx Redirection
+
+ case 300:
+ return "Multiple Choices"sv;
+ case 301:
+ return "Moved Permanently"sv;
+ case 302:
+ return "Found"sv;
+ case 303:
+ return "See Other"sv;
+ case 304:
+ return "Not Modified"sv;
+ case 305:
+ return "Use Proxy"sv;
+ case 306:
+ return "Switch Proxy"sv;
+ case 307:
+ return "Temporary Redirect"sv;
+ case 308:
+ return "Permanent Redirect"sv;
+
+ // 4xx Client errors
+
+ case 400:
+ return "Bad Request"sv;
+ case 401:
+ return "Unauthorized"sv;
+ case 402:
+ return "Payment Required"sv;
+ case 403:
+ return "Forbidden"sv;
+ case 404:
+ return "Not Found"sv;
+ case 405:
+ return "Method Not Allowed"sv;
+ case 406:
+ return "Not Acceptable"sv;
+ case 407:
+ return "Proxy Authentication Required"sv;
+ case 408:
+ return "Request Timeout"sv;
+ case 409:
+ return "Conflict"sv;
+ case 410:
+ return "Gone"sv;
+ case 411:
+ return "Length Required"sv;
+ case 412:
+ return "Precondition Failed"sv;
+ case 413:
+ return "Payload Too Large"sv;
+ case 414:
+ return "URI Too Long"sv;
+ case 415:
+ return "Unsupported Media Type"sv;
+ case 416:
+ return "Range Not Satisifiable"sv;
+ case 417:
+ return "Expectation Failed"sv;
+ case 418:
+ return "I'm a teapot"sv;
+ case 421:
+ return "Misdirected Request"sv;
+ case 422:
+ return "Unprocessable Entity"sv;
+ case 423:
+ return "Locked"sv;
+ case 424:
+ return "Failed Dependency"sv;
+ case 425:
+ return "Too Early"sv;
+ case 426:
+ return "Upgrade Required"sv;
+ case 428:
+ return "Precondition Required"sv;
+ case 429:
+ return "Too Many Requests"sv;
+ case 431:
+ return "Request Header Fields Too Large"sv;
+
+ // 5xx Server errors
+
+ case 500:
+ return "Internal Server Error"sv;
+ case 501:
+ return "Not Implemented"sv;
+ case 502:
+ return "Bad Gateway"sv;
+ case 503:
+ return "Service Unavailable"sv;
+ case 504:
+ return "Gateway Timeout"sv;
+ case 505:
+ return "HTTP Version Not Supported"sv;
+ case 506:
+ return "Variant Also Negotiates"sv;
+ case 507:
+ return "Insufficient Storage"sv;
+ case 508:
+ return "Loop Detected"sv;
+ case 510:
+ return "Not Extended"sv;
+ case 511:
+ return "Network Authentication Required"sv;
+
+ default:
+ return "Unknown Result"sv;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+Ref<IHttpPackageHandler>
+HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
+{
+ ZEN_UNUSED(HttpServiceRequest);
+
+ return Ref<IHttpPackageHandler>();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpServerRequest::HttpServerRequest()
+{
+}
+
+HttpServerRequest::~HttpServerRequest()
+{
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data)
+{
+ std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data);
+ return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers);
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data)
+{
+ if (m_AcceptType == HttpContentType::kJSON)
+ {
+ ExtendableStringBuilder<1024> Sb;
+ WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView());
+ }
+ else
+ {
+ SharedBuffer Buf = Data.GetBuffer();
+ std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())};
+ return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers);
+ }
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array)
+{
+ if (m_AcceptType == HttpContentType::kJSON)
+ {
+ ExtendableStringBuilder<1024> Sb;
+ WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView());
+ }
+ else
+ {
+ SharedBuffer Buf = Array.GetBuffer();
+ std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())};
+ return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers);
+ }
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString)
+{
+ return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()});
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob)
+{
+ std::array<IoBuffer, 1> Buffers{Blob};
+ return WriteResponse(ResponseCode, ContentType, Buffers);
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload)
+{
+ std::span<const SharedBuffer> Segments = Payload.GetSegments();
+
+ std::vector<IoBuffer> Buffers;
+
+ for (auto& Segment : Segments)
+ {
+ Buffers.push_back(Segment.AsIoBuffer());
+ }
+
+ WriteResponse(ResponseCode, ContentType, Buffers);
+}
+
+HttpServerRequest::QueryParams
+HttpServerRequest::GetQueryParams()
+{
+ QueryParams Params;
+
+ const std::string_view QStr = QueryString();
+
+ const char* QueryIt = QStr.data();
+ const char* QueryEnd = QueryIt + QStr.size();
+
+ while (QueryIt != QueryEnd)
+ {
+ if (*QueryIt == '&')
+ {
+ ++QueryIt;
+ continue;
+ }
+
+ size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt);
+ const std::string_view Query{QueryIt, QueryLen};
+
+ size_t DelimIndex = Query.find('&', 0);
+
+ if (DelimIndex == std::string_view::npos)
+ {
+ DelimIndex = Query.size();
+ }
+
+ std::string_view ThisQuery{QueryIt, DelimIndex};
+
+ size_t EqIndex = ThisQuery.find('=', 0);
+
+ if (EqIndex != std::string_view::npos)
+ {
+ std::string_view Param{ThisQuery.data(), EqIndex};
+ ThisQuery.remove_prefix(EqIndex + 1);
+
+ Params.KvPairs.emplace_back(Param, ThisQuery);
+ }
+
+ QueryIt += DelimIndex;
+ }
+
+ return Params;
+}
+
+Oid
+HttpServerRequest::SessionId() const
+{
+ if (m_Flags & kHaveSessionId)
+ {
+ return m_SessionId;
+ }
+
+ m_SessionId = ParseSessionId();
+ m_Flags |= kHaveSessionId;
+ return m_SessionId;
+}
+
+uint32_t
+HttpServerRequest::RequestId() const
+{
+ if (m_Flags & kHaveRequestId)
+ {
+ return m_RequestId;
+ }
+
+ m_RequestId = ParseRequestId();
+ m_Flags |= kHaveRequestId;
+ return m_RequestId;
+}
+
+CbObject
+HttpServerRequest::ReadPayloadObject()
+{
+ if (IoBuffer Payload = ReadPayload())
+ {
+ return LoadCompactBinaryObject(std::move(Payload));
+ }
+
+ return {};
+}
+
+CbPackage
+HttpServerRequest::ReadPayloadPackage()
+{
+ if (IoBuffer Payload = ReadPayload())
+ {
+ return ParsePackageMessage(std::move(Payload));
+ }
+
+ return {};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+HttpRequestRouter::AddPattern(const char* Id, const char* Regex)
+{
+ ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end());
+
+ m_PatternMap.insert({Id, Regex});
+}
+
+void
+HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs)
+{
+ ExtendableStringBuilder<128> ExpandedRegex;
+ ProcessRegexSubstitutions(Regex, ExpandedRegex);
+
+ m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex);
+}
+
+void
+HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex)
+{
+ size_t RegexLen = strlen(Regex);
+
+ for (size_t i = 0; i < RegexLen;)
+ {
+ bool matched = false;
+
+ if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\')))
+ {
+ // Might have a pattern reference - find closing brace
+
+ for (size_t j = i + 1; j < RegexLen; ++j)
+ {
+ if (Regex[j] == '}')
+ {
+ std::string Pattern(&Regex[i + 1], j - i - 1);
+
+ if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end())
+ {
+ OutExpandedRegex.Append(it->second.c_str());
+ }
+ else
+ {
+ // Default to anything goes (or should this just be an error?)
+
+ OutExpandedRegex.Append("(.+?)");
+ }
+
+ // skip ahead
+ i = j + 1;
+
+ matched = true;
+
+ break;
+ }
+ }
+ }
+
+ if (!matched)
+ {
+ OutExpandedRegex.Append(Regex[i++]);
+ }
+ }
+}
+
+bool
+HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
+{
+ const HttpVerb Verb = Request.RequestVerb();
+
+ std::string_view Uri = Request.RelativeUri();
+ HttpRouterRequest RouterRequest(Request);
+
+ for (const auto& Handler : m_Handlers)
+ {
+ if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx))
+ {
+ Handler.Handler(RouterRequest);
+
+ return true; // Route matched
+ }
+ }
+
+ return false; // No route matched
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpRpcHandler::HttpRpcHandler()
+{
+}
+
+HttpRpcHandler::~HttpRpcHandler()
+{
+}
+
+void
+HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction)
+{
+ ZEN_UNUSED(RpcId, HandlerFunction);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+enum class HttpServerClass
+{
+ kHttpAsio,
+ kHttpSys,
+ kHttpNull
+};
+
+// Implemented in httpsys.cpp
+Ref<HttpServer> CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads);
+
+Ref<HttpServer>
+CreateHttpServer(std::string_view ServerClass)
+{
+ using namespace std::literals;
+
+ HttpServerClass Class = HttpServerClass::kHttpNull;
+
+#if ZEN_WITH_HTTPSYS
+ Class = HttpServerClass::kHttpSys;
+#elif 1
+ Class = HttpServerClass::kHttpAsio;
+#endif
+
+ if (ServerClass == "asio"sv)
+ {
+ Class = HttpServerClass::kHttpAsio;
+ }
+ else if (ServerClass == "httpsys"sv)
+ {
+ Class = HttpServerClass::kHttpSys;
+ }
+ else if (ServerClass == "null"sv)
+ {
+ Class = HttpServerClass::kHttpNull;
+ }
+
+ switch (Class)
+ {
+ default:
+ case HttpServerClass::kHttpAsio:
+ ZEN_INFO("using asio HTTP server implementation");
+ return Ref<HttpServer>(new HttpAsioServer());
+
+#if ZEN_WITH_HTTPSYS
+ case HttpServerClass::kHttpSys:
+ ZEN_INFO("using http.sys server implementation");
+ return Ref<HttpServer>(new HttpSysServer(std::thread::hardware_concurrency(), /* background worker threads */ 16));
+#endif
+
+ case HttpServerClass::kHttpNull:
+ ZEN_INFO("using null HTTP server implementation");
+ return Ref<HttpServer>(new HttpNullServer);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef)
+{
+ if (Request.RequestVerb() == HttpVerb::kPost)
+ {
+ if (Request.RequestContentType() == HttpContentType::kCbPackageOffer)
+ {
+ // The client is presenting us with a package attachments offer, we need
+ // to filter it down to the list of attachments we need them to send in
+ // the follow-up request
+
+ PackageHandlerRef = Service.HandlePackageRequest(Request);
+
+ if (PackageHandlerRef)
+ {
+ CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload());
+
+ std::vector<IoHash> OfferCids;
+
+ for (auto& CidEntry : OfferMessage["offer"])
+ {
+ if (!CidEntry.IsHash())
+ {
+ // Should yield bad request response?
+
+ ZEN_WARN("found invalid entry in offer");
+
+ continue;
+ }
+
+ OfferCids.push_back(CidEntry.AsHash());
+ }
+
+ ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size());
+
+ PackageHandlerRef->FilterOffer(OfferCids);
+
+ ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size());
+
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("need");
+
+ for (const IoHash& Cid : OfferCids)
+ {
+ ResponseWriter.AddHash(Cid);
+ }
+
+ ResponseWriter.EndArray();
+
+ // Emit filter response
+ Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ return true;
+ }
+ }
+ else if (Request.RequestContentType() == HttpContentType::kCbPackage)
+ {
+ // Process chunks in package request
+
+ PackageHandlerRef = Service.HandlePackageRequest(Request);
+
+ // TODO: this should really be done in a streaming fashion, currently this emulates
+ // the intended flow from an API perspective
+
+ if (PackageHandlerRef)
+ {
+ PackageHandlerRef->OnRequestBegin();
+
+ auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer {
+ return PackageHandlerRef->CreateTarget(Cid, Size);
+ };
+
+ CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer);
+
+ PackageHandlerRef->OnRequestComplete();
+ }
+ }
+ }
+ return false;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("http.common")
+{
+ using namespace std::literals;
+
+ SUBCASE("router")
+ {
+ HttpRequestRouter r;
+ r.AddPattern("a", "[[:alpha:]]+");
+ r.RegisterRoute(
+ "{a}",
+ [&](auto) {},
+ HttpVerb::kGet);
+
+ // struct TestHttpServerRequest : public HttpServerRequest
+ //{
+ // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {}
+ //};
+
+ // TestHttpServerRequest req{};
+ // r.HandleRequest(req);
+ }
+
+ SUBCASE("content-type")
+ {
+ for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i)
+ {
+ HttpContentType Ct{i};
+
+ if (Ct != HttpContentType::kUnknownContentType)
+ {
+ CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct)));
+ }
+ }
+ }
+}
+
+void
+http_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpshared.cpp b/src/zenhttp/httpshared.cpp
new file mode 100644
index 000000000..7aade56d2
--- /dev/null
+++ b/src/zenhttp/httpshared.cpp
@@ -0,0 +1,809 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpshared.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+
+#include <span>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+const std::string_view HandlePrefix(":?#:");
+
+std::vector<IoBuffer>
+FormatPackageMessage(const CbPackage& Data, int TargetProcessPid)
+{
+ return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessPid);
+}
+CompositeBuffer
+FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid)
+{
+ return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessPid);
+}
+
+CompositeBuffer
+FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid)
+{
+ std::vector<IoBuffer> Message = FormatPackageMessage(Data, Flags, TargetProcessPid);
+
+ std::vector<SharedBuffer> Buffers;
+
+ for (IoBuffer& Buf : Message)
+ {
+ Buffers.push_back(SharedBuffer(Buf));
+ }
+
+ return CompositeBuffer(std::move(Buffers));
+}
+
+std::vector<IoBuffer>
+FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid)
+{
+ void* TargetProcessHandle = nullptr;
+#if ZEN_PLATFORM_WINDOWS
+ std::vector<HANDLE> DuplicatedHandles;
+ auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() {
+ if (TargetProcessHandle == nullptr)
+ {
+ return;
+ }
+
+ for (HANDLE DuplicatedHandle : DuplicatedHandles)
+ {
+ HANDLE ClosingHandle;
+ if (::DuplicateHandle((HANDLE)TargetProcessHandle,
+ DuplicatedHandle,
+ GetCurrentProcess(),
+ &ClosingHandle,
+ 0,
+ FALSE,
+ DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE)
+ {
+ ::CloseHandle(ClosingHandle);
+ }
+ }
+ ::CloseHandle((HANDLE)TargetProcessHandle);
+ TargetProcessHandle = nullptr;
+ });
+
+ if (EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && TargetProcessPid != 0)
+ {
+ TargetProcessHandle = OpenProcess(PROCESS_DUP_HANDLE, FALSE, TargetProcessPid);
+ }
+#else
+ ZEN_UNUSED(TargetProcessPid);
+ void* DuplicatedHandles = nullptr;
+#endif // ZEN_PLATFORM_WINDOWS
+
+ const std::span<const CbAttachment>& Attachments = Data.GetAttachments();
+ std::vector<IoBuffer> ResponseBuffers;
+
+ ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each
+ // attachment is likely to consist of several buffers
+
+ // Fixed size header
+
+ CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())};
+
+ ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr));
+
+ // Attachment metadata array
+
+ IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)};
+ CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData());
+
+ ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata
+
+ // Root object
+
+ IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer();
+ ResponseBuffers.push_back(RootIoBuffer); // Root object
+
+ *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()};
+
+ // Attachment payloads
+
+ auto MarshalLocal = [&AttachmentInfo, &ResponseBuffers](const std::string& Path8,
+ CbAttachmentReferenceHeader& LocalRef,
+ const IoHash& AttachmentHash,
+ bool IsCompressed) {
+ IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size());
+
+ CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>();
+ *RefHdr++ = LocalRef;
+ memcpy(RefHdr, Path8.data(), Path8.size());
+
+ *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(),
+ .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef,
+ .AttachmentHash = AttachmentHash};
+
+ ResponseBuffers.push_back(std::move(RefBuffer));
+ };
+
+ tsl::robin_map<void*, std::string> FileNameMap;
+
+ auto IsLocalRef = [&FileNameMap, &DuplicatedHandles](const CompositeBuffer& AttachmentBinary,
+ bool DenyPartialLocalReferences,
+ void* TargetProcessHandle,
+ CbAttachmentReferenceHeader& LocalRef,
+ std::string& Path8) -> bool {
+ const SharedBuffer& Segment = AttachmentBinary.GetSegments().front();
+ IoBufferFileReference Ref;
+ const IoBuffer& SegmentBuffer = Segment.AsIoBuffer();
+
+ if (!SegmentBuffer.GetFileReference(Ref))
+ {
+ return false;
+ }
+
+ if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile())
+ {
+ return false;
+ }
+
+ if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end())
+ {
+ Path8 = It->second;
+ }
+ else
+ {
+ bool UseFilePath = true;
+#if ZEN_PLATFORM_WINDOWS
+ if (TargetProcessHandle != nullptr)
+ {
+ HANDLE TargetHandle = INVALID_HANDLE_VALUE;
+ BOOL OK = ::DuplicateHandle(GetCurrentProcess(),
+ Ref.FileHandle,
+ (HANDLE)TargetProcessHandle,
+ &TargetHandle,
+ FILE_GENERIC_READ,
+ FALSE,
+ 0);
+ if (OK)
+ {
+ DuplicatedHandles.push_back(TargetHandle);
+ Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle));
+ UseFilePath = false;
+ }
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ ZEN_UNUSED(TargetProcessHandle);
+ // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to
+ // deal with acceess rights etc.
+#endif // ZEN_PLATFORM_WINDOWS
+ if (UseFilePath)
+ {
+ ExtendablePathBuilder<256> LocalRefFile;
+ LocalRefFile.Append(std::filesystem::absolute(PathFromHandle(Ref.FileHandle)));
+ Path8 = LocalRefFile.ToUtf8();
+ }
+ FileNameMap.insert_or_assign(Ref.FileHandle, Path8);
+ }
+
+ LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size());
+ LocalRef.PayloadByteOffset = Ref.FileChunkOffset;
+ LocalRef.PayloadByteSize = Ref.FileChunkSize;
+
+ return true;
+ };
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ if (Attachment.IsNull())
+ {
+ ZEN_NOT_IMPLEMENTED("Null attachments are not supported");
+ }
+ else if (CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary())
+ {
+ CompositeBuffer Compressed = AttachmentBuffer.GetCompressed();
+ IoHash AttachmentHash = Attachment.GetHash();
+
+ // If the data is either not backed by a file, or there are multiple
+ // fragments then we cannot marshal it by local reference. We might
+ // want/need to extend this in the future to allow multiple chunk
+ // segments to be marshaled at once
+
+ bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1);
+ bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
+ CbAttachmentReferenceHeader LocalRef;
+ std::string Path8;
+
+ if (MarshalByLocalRef)
+ {
+ MarshalByLocalRef = IsLocalRef(Compressed, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8);
+ }
+
+ if (MarshalByLocalRef)
+ {
+ const bool IsCompressed = true;
+ bool IsHandle = false;
+#if ZEN_PLATFORM_WINDOWS
+ IsHandle = Path8.starts_with(HandlePrefix);
+#endif
+ MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed);
+ ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize());
+ }
+ else
+ {
+ *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(),
+ .Flags = CbAttachmentEntry::kIsCompressed,
+ .AttachmentHash = AttachmentHash};
+
+ for (const SharedBuffer& Segment : Compressed.GetSegments())
+ {
+ ResponseBuffers.push_back(Segment.AsIoBuffer());
+ }
+ }
+ }
+ else if (CbObject AttachmentObject = Attachment.AsObject())
+ {
+ IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer();
+ ResponseBuffers.push_back(ObjIoBuffer);
+
+ *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(),
+ .Flags = CbAttachmentEntry::kIsObject,
+ .AttachmentHash = Attachment.GetHash()};
+ }
+ else if (CompositeBuffer AttachmentBinary = Attachment.AsCompositeBinary())
+ {
+ IoHash AttachmentHash = Attachment.GetHash();
+ bool MarshalByLocalRef =
+ EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1);
+ bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
+
+ CbAttachmentReferenceHeader LocalRef;
+ std::string Path8;
+
+ if (MarshalByLocalRef)
+ {
+ MarshalByLocalRef = IsLocalRef(AttachmentBinary, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8);
+ }
+
+ if (MarshalByLocalRef)
+ {
+ const bool IsCompressed = false;
+ bool IsHandle = false;
+#if ZEN_PLATFORM_WINDOWS
+ IsHandle = Path8.starts_with(HandlePrefix);
+#endif
+ MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed);
+ ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize());
+ }
+ else
+ {
+ *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()};
+
+ for (const SharedBuffer& Segment : AttachmentBinary.GetSegments())
+ {
+ ResponseBuffers.push_back(Segment.AsIoBuffer());
+ }
+ }
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Unknown attachment kind");
+ }
+ }
+ FileNameMap.clear();
+#if ZEN_PLATFORM_WINDOWS
+ DuplicatedHandles.clear();
+#endif // ZEN_PLATFORM_WINDOWS
+
+ return ResponseBuffers;
+}
+
+bool
+IsPackageMessage(IoBuffer Payload)
+{
+ if (!Payload)
+ {
+ return false;
+ }
+
+ BinaryReader Reader(Payload);
+
+ CbPackageHeader Hdr;
+ Reader.Read(&Hdr, sizeof Hdr);
+
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ return false;
+ }
+
+ return true;
+}
+
+CbPackage
+ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
+{
+ if (!Payload)
+ {
+ return {};
+ }
+
+ BinaryReader Reader(Payload);
+
+ CbPackageHeader Hdr;
+ Reader.Read(&Hdr, sizeof Hdr);
+
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ throw std::runtime_error("invalid CbPackage header magic");
+ }
+
+ const uint32_t ChunkCount = Hdr.AttachmentCount + 1;
+
+ std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]};
+
+ Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount);
+
+ CbPackage Package;
+
+ std::vector<CbAttachment> Attachments;
+ Attachments.reserve(ChunkCount); // Guessing here...
+
+ tsl::robin_map<std::string, IoBuffer> PartialFileBuffers;
+
+ // TODO: Throwing before this loop completes could result in leaking handles as we might not have picked up all the handles in the
+ // message
+ for (uint32_t i = 0; i < ChunkCount; ++i)
+ {
+ const CbAttachmentEntry& Entry = AttachmentEntries[i];
+ const uint64_t AttachmentSize = Entry.PayloadSize;
+
+ const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize);
+ Reader.Skip(AttachmentSize);
+
+ if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ // Marshal local reference - a "pointer" to the chunk backing file
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+ const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+ std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
+
+ IoBuffer FullFileBuffer;
+
+ std::filesystem::path Path(Utf8ToWide(PathView));
+ if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end())
+ {
+ FullFileBuffer = It->second;
+ }
+ else
+ {
+ if (PathView.starts_with(HandlePrefix))
+ {
+#if ZEN_PLATFORM_WINDOWS
+ std::string_view HandleString(PathView.substr(HandlePrefix.length()));
+ std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString));
+ if (HandleNumber.has_value())
+ {
+ HANDLE FileHandle = HANDLE(HandleNumber.value());
+ ULARGE_INTEGER liFileSize;
+ liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart);
+ if (liFileSize.LowPart != INVALID_FILE_SIZE)
+ {
+ FullFileBuffer = IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart));
+ PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer);
+ }
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes
+ // and to deal with acceess rights etc.
+ ZEN_ASSERT(false);
+#endif // ZEN_PLATFORM_WINDOWS
+ }
+ else
+ {
+ FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
+ }
+ }
+
+ if (!FullFileBuffer)
+ {
+ // Unable to open chunk reference
+ throw std::runtime_error(fmt::format("unable to resolve chunk #{} at '{}' (offset {}, size {})",
+ i,
+ Path,
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+
+ IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
+ ? FullFileBuffer
+ : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
+
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference)));
+ if (!CompBuf)
+ {
+ throw std::runtime_error(fmt::format("invalid format for chunk #{} at '{}' (offset {}, size {})",
+ i,
+ Path,
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+ Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (i == 0)
+ {
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
+ if (!CompBuf)
+ {
+ throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for CbObject", i));
+ }
+ // First payload is always a compact binary object
+ Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf)));
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported");
+ }
+ }
+ else
+ {
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
+ if (!CompBuf)
+ {
+ throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for attachment", i));
+ }
+ Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
+ }
+ }
+ else /* not compressed */
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (i == 0)
+ {
+ Package.SetObject(LoadCompactBinaryObject(AttachmentBuffer));
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported");
+ }
+ }
+ else
+ {
+ // Make a copy of the buffer so we attachements don't reference the entire payload
+ IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize);
+ ZEN_ASSERT(AttachmentBufferCopy);
+ ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
+ AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
+
+ CbAttachment Attachment(SharedBuffer{AttachmentBufferCopy});
+ Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
+ }
+ }
+ }
+ PartialFileBuffers.clear();
+
+ Package.AddAttachments(Attachments);
+
+ return Package;
+}
+
+bool
+ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage)
+{
+ if (IsPackageMessage(Response))
+ {
+ OutPackage = ParsePackageMessage(Response);
+ return true;
+ }
+ return OutPackage.TryLoad(Response);
+}
+
+CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
+{
+}
+
+CbPackageReader::~CbPackageReader()
+{
+}
+
+void
+CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer)
+{
+ m_CreateBuffer = CreateBuffer;
+}
+
+uint64_t
+CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
+{
+ ZEN_ASSERT(m_CurrentState != State::kReadingBuffers);
+
+ switch (m_CurrentState)
+ {
+ case State::kInitialState:
+ ZEN_ASSERT(Data == nullptr);
+ m_CurrentState = State::kReadingHeader;
+ return sizeof m_PackageHeader;
+
+ case State::kReadingHeader:
+ ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
+ memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
+ ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
+ m_CurrentState = State::kReadingAttachmentEntries;
+ m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
+ return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
+
+ case State::kReadingAttachmentEntries:
+ ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
+ memcpy(m_AttachmentEntries.data(), Data, DataBytes);
+
+ for (CbAttachmentEntry& Entry : m_AttachmentEntries)
+ {
+ // This preallocates memory for payloads but note that for the local references
+ // the caller will need to handle the payload differently (i.e it's a
+ // CbAttachmentReferenceHeader not the actual payload)
+
+ m_PayloadBuffers.push_back(IoBuffer{Entry.PayloadSize});
+ }
+
+ m_CurrentState = State::kReadingBuffers;
+ return 0;
+
+ default:
+ ZEN_ASSERT(false);
+ return 0;
+ }
+}
+
+IoBuffer
+CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
+{
+ // Marshal local reference - a "pointer" to the chunk backing file
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+ const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+
+ std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
+
+ std::filesystem::path Path{PathView};
+
+ IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
+
+ if (!ChunkReference)
+ {
+ // Unable to open chunk reference
+
+ throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})",
+ PathToUtf8(Path),
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+
+ return ChunkReference;
+};
+
+void
+CbPackageReader::Finalize()
+{
+ if (m_AttachmentEntries.empty())
+ {
+ return;
+ }
+
+ m_Attachments.reserve(m_AttachmentEntries.size() - 1);
+
+ int CurrentAttachmentIndex = 0;
+ for (CbAttachmentEntry& Entry : m_AttachmentEntries)
+ {
+ IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
+
+ if (CurrentAttachmentIndex == 0)
+ {
+ // Root object
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ m_RootObject = LoadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer));
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize);
+ if (RawHash == Entry.AttachmentHash)
+ {
+ m_RootObject = LoadCompactBinaryObject(Compressed);
+ }
+ }
+ else
+ {
+ m_RootObject = LoadCompactBinaryObject(std::move(AttachmentBuffer));
+ }
+ }
+ else
+ {
+ throw std::runtime_error("missing or invalid root object");
+ }
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer);
+
+ if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize);
+ if (RawHash == Entry.AttachmentHash)
+ {
+ m_Attachments.push_back(CbAttachment(Compressed, Entry.AttachmentHash));
+ }
+ }
+ else
+ {
+ CompressedBuffer Compressed =
+ CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ m_Attachments.push_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash()));
+ }
+ }
+
+ ++CurrentAttachmentIndex;
+ }
+}
+
+/**
+ ______________________ _____________________________
+ \__ ___/\_ _____// _____/\__ ___/ _____/
+ | | | __)_ \_____ \ | | \_____ \
+ | | | \/ \ | | / \
+ |____| /_______ /_______ / |____| /_______ /
+ \/ \/ \/
+ */
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("CbPackage.Serialization")
+{
+ // Make a test package
+
+ CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))};
+ CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+TEST_CASE("CbPackage.LocalRef")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ auto Path1 = TempDir.Path() / "abcd";
+ auto Path2 = TempDir.Path() / "efgh";
+
+ {
+ IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd"));
+ IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh"));
+
+ WriteFile(Path1, Buffer1);
+ WriteFile(Path2, Buffer2);
+ }
+
+ // Make a test package
+
+ IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1);
+ IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2);
+
+ CbAttachment Attach1{SharedBuffer(FileBuffer1)};
+ CbAttachment Attach2{SharedBuffer(FileBuffer2)};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+void
+forcelink_httpshared()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpsys.cpp b/src/zenhttp/httpsys.cpp
new file mode 100644
index 000000000..c733d618d
--- /dev/null
+++ b/src/zenhttp/httpsys.cpp
@@ -0,0 +1,1674 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpsys.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/except.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/timer.h>
+#include <zenhttp/httpshared.h>
+
+#if ZEN_WITH_HTTPSYS
+
+# include <conio.h>
+# include <mstcpip.h>
+# pragma comment(lib, "httpapi.lib")
+
+std::wstring
+UTF8_to_UTF16(const char* InPtr)
+{
+ std::wstring OutString;
+ unsigned int Codepoint;
+
+ while (*InPtr != 0)
+ {
+ unsigned char InChar = static_cast<unsigned char>(*InPtr);
+
+ if (InChar <= 0x7f)
+ Codepoint = InChar;
+ else if (InChar <= 0xbf)
+ Codepoint = (Codepoint << 6) | (InChar & 0x3f);
+ else if (InChar <= 0xdf)
+ Codepoint = InChar & 0x1f;
+ else if (InChar <= 0xef)
+ Codepoint = InChar & 0x0f;
+ else
+ Codepoint = InChar & 0x07;
+
+ ++InPtr;
+
+ if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff))
+ {
+ if (Codepoint > 0xffff)
+ {
+ OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10)));
+ OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff)));
+ }
+ else if (Codepoint < 0xd800 || Codepoint >= 0xe000)
+ {
+ OutString.append(1, static_cast<wchar_t>(Codepoint));
+ }
+ }
+ }
+
+ return OutString;
+}
+
+namespace zen {
+
+using namespace std::literals;
+
+class HttpSysServer;
+class HttpSysTransaction;
+class HttpMessageResponseRequest;
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpVerb
+TranslateHttpVerb(HTTP_VERB ReqVerb)
+{
+ switch (ReqVerb)
+ {
+ case HttpVerbOPTIONS:
+ return HttpVerb::kOptions;
+
+ case HttpVerbGET:
+ return HttpVerb::kGet;
+
+ case HttpVerbHEAD:
+ return HttpVerb::kHead;
+
+ case HttpVerbPOST:
+ return HttpVerb::kPost;
+
+ case HttpVerbPUT:
+ return HttpVerb::kPut;
+
+ case HttpVerbDELETE:
+ return HttpVerb::kDelete;
+
+ case HttpVerbCOPY:
+ return HttpVerb::kCopy;
+
+ default:
+ // TODO: invalid request?
+ return (HttpVerb)0;
+ }
+}
+
+uint64_t
+GetContentLength(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength];
+ std::string_view cl(clh.pRawValue, clh.RawValueLength);
+ uint64_t ContentLength = 0;
+ std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength);
+ return ContentLength;
+};
+
+HttpContentType
+GetContentType(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType];
+ return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
+};
+
+HttpContentType
+GetAcceptType(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept];
+ return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
+};
+
+/**
+ * @brief Base class for any pending or active HTTP transactions
+ */
+class HttpSysRequestHandler
+{
+public:
+ explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {}
+ virtual ~HttpSysRequestHandler() = default;
+
+ virtual void IssueRequest(std::error_code& ErrorCode) = 0;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0;
+ HttpSysTransaction& Transaction() { return m_Transaction; }
+
+ HttpSysRequestHandler(const HttpSysRequestHandler&) = delete;
+ HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete;
+
+private:
+ HttpSysTransaction& m_Transaction;
+};
+
+/**
+ * This is the handler for the initial HTTP I/O request which will receive the headers
+ * and however much of the remaining payload might fit in the embedded request buffer.
+ *
+ * It is also used to receive any entity body data relating to the request
+ *
+ */
+struct InitialRequestHandler : public HttpSysRequestHandler
+{
+ inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; }
+ inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; }
+ inline bool IsInitialRequest() const { return m_IsInitialRequest; }
+
+ InitialRequestHandler(HttpSysTransaction& InRequest);
+ ~InitialRequestHandler();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+
+ bool m_IsInitialRequest = true;
+ uint64_t m_CurrentPayloadOffset = 0;
+ uint64_t m_ContentLength = ~uint64_t(0);
+ IoBuffer m_PayloadBuffer;
+ UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)];
+};
+
+/**
+ * This is the class which request handlers use to interact with the server instance
+ */
+
+class HttpSysServerRequest : public HttpServerRequest
+{
+public:
+ HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer);
+ ~HttpSysServerRequest() = default;
+
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
+
+ virtual IoBuffer ReadPayload() override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
+ virtual bool TryGetRanges(HttpRanges& Ranges) override;
+
+ using HttpServerRequest::WriteResponse;
+
+ HttpSysServerRequest(const HttpSysServerRequest&) = delete;
+ HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
+
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+};
+
+/** HTTP transaction
+
+ There will be an instance of this per pending and in-flight HTTP transaction
+
+ */
+class HttpSysTransaction final
+{
+public:
+ HttpSysTransaction(HttpSysServer& Server);
+ virtual ~HttpSysTransaction();
+
+ enum class Status
+ {
+ kDone,
+ kRequestPending
+ };
+
+ Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+ static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
+ PVOID pContext /* HttpSysServer */,
+ PVOID pOverlapped,
+ ULONG IoResult,
+ ULONG_PTR NumberOfBytesTransferred,
+ PTP_IO Io);
+
+ void IssueInitialRequest(std::error_code& ErrorCode);
+ bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler);
+
+ PTP_IO Iocp();
+ HANDLE RequestQueueHandle();
+ inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline HttpSysServer& Server() { return m_HttpServer; }
+ inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
+
+ HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload);
+
+ HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); }
+
+private:
+ OVERLAPPED m_HttpOverlapped{};
+ HttpSysServer& m_HttpServer;
+
+ // Tracks which handler is due to handle the next I/O completion event
+ HttpSysRequestHandler* m_CompletionHandler = nullptr;
+ RwLock m_CompletionMutex;
+ InitialRequestHandler m_InitialHttpHandler{*this};
+ std::optional<HttpSysServerRequest> m_HandlerRequest;
+ Ref<IHttpPackageHandler> m_PackageHandler;
+};
+
+/**
+ * @brief HTTP request response I/O request handler
+ *
+ * Asynchronously streams out a response to an HTTP request via compound
+ * responses from memory or directly from file
+ */
+
+class HttpMessageResponseRequest : public HttpSysRequestHandler
+{
+public:
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ const void* Payload,
+ size_t PayloadSize);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ std::span<IoBuffer> Blobs);
+ ~HttpMessageResponseRequest();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+ void SuppressResponseBody(); // typically used for HEAD requests
+
+private:
+ std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks;
+ uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
+ uint16_t m_ResponseCode = 0;
+ uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists
+ uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
+ bool m_IsInitialResponse = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::vector<IoBuffer> m_DataBuffers;
+
+ void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
+};
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode)
+: HttpSysRequestHandler(InRequest)
+{
+ std::array<IoBuffer, 0> EmptyBufferList;
+
+ InitializeForPayload(ResponseCode, EmptyBufferList);
+}
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message)
+: HttpSysRequestHandler(InRequest)
+, m_ContentType(HttpContentType::kText)
+{
+ IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ InitializeForPayload(ResponseCode, SingleBufferList);
+}
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ const void* Payload,
+ size_t PayloadSize)
+: HttpSysRequestHandler(InRequest)
+, m_ContentType(ContentType)
+{
+ IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize);
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ InitializeForPayload(ResponseCode, SingleBufferList);
+}
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ std::span<IoBuffer> BlobList)
+: HttpSysRequestHandler(InRequest)
+, m_ContentType(ContentType)
+{
+ InitializeForPayload(ResponseCode, BlobList);
+}
+
+HttpMessageResponseRequest::~HttpMessageResponseRequest()
+{
+}
+
+void
+HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
+{
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
+
+ m_HttpDataChunks.reserve(ChunkCount);
+ m_DataBuffers.reserve(ChunkCount);
+
+ for (IoBuffer& Buffer : BlobList)
+ {
+ m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+ }
+
+ // Initialize the full array up front
+
+ uint64_t LocalDataSize = 0;
+
+ for (IoBuffer& Buffer : m_DataBuffers)
+ {
+ uint64_t BufferDataSize = Buffer.Size();
+
+ ZEN_ASSERT(BufferDataSize);
+
+ LocalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
+ {
+ // Use direct file transfer
+
+ m_HttpDataChunks.push_back({});
+ auto& Chunk = m_HttpDataChunks.back();
+
+ Chunk.DataChunkType = HttpDataChunkFromFileHandle;
+ Chunk.FromFileHandle.FileHandle = FileRef.FileHandle;
+ Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset;
+ Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize;
+ }
+ else
+ {
+ // Send from memory, need to make sure we chunk the buffer up since
+ // the underlying data structure only accepts 32-bit chunk sizes for
+ // memory chunks. When this happens the vector will be reallocated,
+ // which is fine since this will be a pretty rare case and sending
+ // the data is going to take a lot longer than a memory allocation :)
+
+ const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data());
+
+ while (BufferDataSize)
+ {
+ const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize));
+
+ m_HttpDataChunks.push_back({});
+ auto& Chunk = m_HttpDataChunks.back();
+
+ Chunk.DataChunkType = HttpDataChunkFromMemory;
+ Chunk.FromMemory.pBuffer = (void*)WriteCursor;
+ Chunk.FromMemory.BufferLength = ThisChunkSize;
+
+ BufferDataSize -= ThisChunkSize;
+ WriteCursor += ThisChunkSize;
+ }
+ }
+ }
+
+ m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size());
+ m_TotalDataSize = LocalDataSize;
+
+ if (m_TotalDataSize == 0 && ResponseCode == 200)
+ {
+ // Some HTTP clients really don't like empty responses unless a 204 response is sent
+ m_ResponseCode = uint16_t(HttpResponseCode::NoContent);
+ }
+ else
+ {
+ m_ResponseCode = ResponseCode;
+ }
+}
+
+void
+HttpMessageResponseRequest::SuppressResponseBody()
+{
+ m_RemainingChunkCount = 0;
+ m_HttpDataChunks.clear();
+ m_DataBuffers.clear();
+}
+
+HttpSysRequestHandler*
+HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult));
+
+ // if one transmit failed there's really no need to go on
+ return nullptr;
+ }
+
+ if (m_RemainingChunkCount == 0)
+ {
+ return nullptr; // All done
+ }
+
+ return this;
+}
+
+void
+HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
+{
+ HttpSysTransaction& Tx = Transaction();
+ HTTP_REQUEST* const HttpReq = Tx.HttpRequest();
+ PTP_IO const Iocp = Tx.Iocp();
+
+ StartThreadpoolIo(Iocp);
+
+ // Split payload into batches to play well with the underlying API
+
+ const int MaxChunksPerCall = 9999;
+
+ const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall);
+ const int ThisRequestChunkOffset = m_NextDataChunkOffset;
+
+ m_RemainingChunkCount -= ThisRequestChunkCount;
+ m_NextDataChunkOffset += ThisRequestChunkCount;
+
+ /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA?
+
+ From the docs:
+
+ This flag enables buffering of data in the kernel on a per-response basis. It should
+ be used by an application doing synchronous I/O, or by a an application doing
+ asynchronous I/O with no more than one send outstanding at a time.
+
+ Applications using asynchronous I/O which may have more than one send outstanding at
+ a time should not use this flag.
+
+ When this flag is set, it should be used consistently in calls to the
+ HttpSendHttpResponse function as well.
+ */
+
+ ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA;
+
+ if (m_RemainingChunkCount)
+ {
+ // We need to make more calls to send the full amount of data
+ SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+ }
+
+ ULONG SendResult = 0;
+
+ if (m_IsInitialResponse)
+ {
+ // Populate response structure
+
+ HTTP_RESPONSE HttpResponse = {};
+
+ HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount);
+ HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset;
+
+ // Server header
+ //
+ // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header
+ //
+ // This is controlled via a registry key 'DisableServerHeader', at:
+ //
+ // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters
+ //
+ // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether
+ // (only the latter appears to do anything in my testing, on Windows 10).
+ //
+ // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header)
+ //
+
+ PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer];
+ ServerHeader->pRawValue = "Zen";
+ ServerHeader->RawValueLength = (USHORT)3;
+
+ // Content-length header
+
+ char ContentLengthString[32];
+ _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10);
+
+ PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength];
+ ContentLengthHeader->pRawValue = ContentLengthString;
+ ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString);
+
+ // Content-type header
+
+ PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
+
+ std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
+
+ ContentTypeHeader->pRawValue = ContentTypeString.data();
+ ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+
+ std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
+
+ HttpResponse.StatusCode = m_ResponseCode;
+ HttpResponse.pReason = ReasonString.data();
+ HttpResponse.ReasonLength = (USHORT)ReasonString.size();
+
+ // Cache policy
+
+ HTTP_CACHE_POLICY CachePolicy;
+
+ CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
+ CachePolicy.SecondsToLive = 0;
+
+ // Initial response API call
+
+ SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ SendFlags,
+ &HttpResponse,
+ &CachePolicy,
+ NULL,
+ NULL,
+ 0,
+ Tx.Overlapped(),
+ NULL);
+
+ m_IsInitialResponse = false;
+ }
+ else
+ {
+ // Subsequent response API calls
+
+ SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ SendFlags,
+ (USHORT)ThisRequestChunkCount, // EntityChunkCount
+ &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
+ NULL, // BytesSent
+ NULL, // Reserved1
+ 0, // Reserved2
+ Tx.Overlapped(), // Overlapped
+ NULL // LogData
+ );
+ }
+
+ if (SendResult == NO_ERROR)
+ {
+ // Synchronous completion, but the completion event will still be posted to IOCP
+
+ ErrorCode.clear();
+ }
+ else if (SendResult == ERROR_IO_PENDING)
+ {
+ // Asynchronous completion, a completion notification will be posted to IOCP
+
+ ErrorCode.clear();
+ }
+ else
+ {
+ // An error occurred, no completion will be posted to IOCP
+
+ CancelThreadpoolIo(Iocp);
+
+ ZEN_WARN("failed to send HTTP response (error: '{}'), request URL: '{}', request id: {}",
+ GetSystemErrorAsString(SendResult),
+ HttpReq->pRawUrl,
+ HttpReq->RequestId);
+
+ ErrorCode = MakeErrorCode(SendResult);
+ }
+}
+
+/** HTTP completion handler for async work
+
+ This is used to allow work to be taken off the request handler threads
+ and to support posting responses asynchronously.
+ */
+
+class HttpAsyncWorkRequest : public HttpSysRequestHandler
+{
+public:
+ HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response);
+ ~HttpAsyncWorkRequest();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+
+private:
+ struct AsyncWorkItem : public IWork
+ {
+ virtual void Execute() override;
+
+ AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler)
+ : Tx(InTx)
+ , Handler(std::move(InHandler))
+ {
+ }
+
+ HttpSysTransaction& Tx;
+ std::function<void(HttpServerRequest&)> Handler;
+ };
+
+ Ref<AsyncWorkItem> m_WorkItem;
+};
+
+HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response)
+: HttpSysRequestHandler(Tx)
+{
+ m_WorkItem = new AsyncWorkItem(Tx, std::move(Response));
+}
+
+HttpAsyncWorkRequest::~HttpAsyncWorkRequest()
+{
+}
+
+void
+HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode)
+{
+ ErrorCode.clear();
+
+ Transaction().Server().WorkPool().ScheduleWork(m_WorkItem);
+}
+
+HttpSysRequestHandler*
+HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // This ought to not be called since there should be no outstanding I/O request
+ // when this completion handler is active
+
+ ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
+
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+
+ return this;
+}
+
+void
+HttpAsyncWorkRequest::AsyncWorkItem::Execute()
+{
+ try
+ {
+ HttpSysServerRequest& ThisRequest = Tx.ServerRequest();
+
+ ThisRequest.m_NextCompletionHandler = nullptr;
+
+ Handler(ThisRequest);
+
+ // TODO: should Handler be destroyed at this point to ensure there
+ // are no outstanding references into state which could be
+ // deleted asynchronously as a result of issuing the response?
+
+ if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler)
+ {
+ return (void)Tx.IssueNextRequest(NextHandler);
+ }
+ else if (!ThisRequest.IsHandled())
+ {
+ return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv));
+ }
+ else
+ {
+ // "Handled" but no request handler? Shouldn't ever happen
+ return (void)Tx.IssueNextRequest(
+ new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv));
+ }
+ }
+ catch (std::exception& Ex)
+ {
+ return (void)Tx.IssueNextRequest(
+ new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what())));
+ }
+}
+
+/**
+ _________
+ / _____/ ______________ __ ___________
+ \_____ \_/ __ \_ __ \ \/ // __ \_ __ \
+ / \ ___/| | \/\ /\ ___/| | \/
+ /_______ /\___ >__| \_/ \___ >__|
+ \/ \/ \/
+*/
+
+HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount)
+: m_Log(logging::Get("http"))
+, m_RequestLog(logging::Get("http_requests"))
+, m_ThreadPool(ThreadCount)
+, m_AsyncWorkPool(AsyncWorkThreadCount)
+{
+ ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ return;
+ }
+
+ m_IsHttpInitialized = true;
+ m_IsOk = true;
+
+ ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount);
+}
+
+HttpSysServer::~HttpSysServer()
+{
+ if (m_IsHttpInitialized)
+ {
+ Cleanup();
+
+ HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr);
+ }
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ using namespace std::literals;
+
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ m_IsOk = false;
+
+ ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+
+ return BasePort;
+ }
+
+ Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+
+ return BasePort;
+ }
+
+ int EffectivePort = BasePort;
+
+ Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ // Sharing violation implies the port is being used by another process
+ for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+ WildcardUrlPath.Reset();
+ WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv;
+
+ Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ m_BaseUris.clear();
+ if (Result == NO_ERROR)
+ {
+ m_BaseUris.push_back(WildcardUrlPath.c_str());
+ }
+ else if (Result == ERROR_ACCESS_DENIED)
+ {
+ // If we can't register the wildcard path, we fall back to local paths
+ // This local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
+ // during an install process to grant permissions to route public access to the appropriate
+ // port for the current user. eg:
+ // netsh http add urlacl url=http://*:1337/ user=<some_user>
+
+ ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath));
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ ULONG InternalResult = ERROR_SHARING_VIOLATION;
+ for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrlPath;
+ LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
+
+ InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (InternalResult == NO_ERROR)
+ {
+ ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath));
+
+ m_BaseUris.push_back(LocalUrlPath.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+ }
+ }
+
+ if (m_BaseUris.empty())
+ {
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+
+ return BasePort;
+ }
+
+ HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
+
+ Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ nullptr,
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+
+ return EffectivePort;
+ }
+
+ HttpBindingInfo.Flags.Present = 1;
+ HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle;
+
+ Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo));
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+
+ return EffectivePort;
+ }
+
+ // Create I/O completion port
+
+ std::error_code ErrorCode;
+ m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
+
+ if (ErrorCode)
+ {
+ ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ }
+ else
+ {
+ m_IsOk = true;
+
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+
+ return EffectivePort;
+}
+
+void
+HttpSysServer::Cleanup()
+{
+ ++m_IsShuttingDown;
+
+ if (m_RequestQueueHandle)
+ {
+ HttpCloseRequestQueue(m_RequestQueueHandle);
+ m_RequestQueueHandle = nullptr;
+ }
+
+ if (m_HttpUrlGroupId)
+ {
+ HttpCloseUrlGroup(m_HttpUrlGroupId);
+ m_HttpUrlGroupId = 0;
+ }
+
+ if (m_HttpSessionId)
+ {
+ HttpCloseServerSession(m_HttpSessionId);
+ m_HttpSessionId = 0;
+ }
+}
+
+void
+HttpSysServer::StartServer()
+{
+ const int InitialRequestCount = 32;
+
+ for (int i = 0; i < InitialRequestCount; ++i)
+ {
+ IssueNewRequestMaybe();
+ }
+}
+
+void
+HttpSysServer::Run(bool IsInteractive)
+{
+ if (IsInteractive)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit");
+ }
+
+ do
+ {
+ // int WaitTimeout = -1;
+ int WaitTimeout = 100;
+
+ if (IsInteractive)
+ {
+ WaitTimeout = 1000;
+
+ if (_kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ UpdateLofreqTimerValue();
+ } while (!IsApplicationExitRequested());
+}
+
+void
+HttpSysServer::OnHandlingRequest()
+{
+ if (--m_PendingRequests > m_MinPendingRequests)
+ {
+ // We have more than the minimum number of requests pending, just let someone else
+ // enqueue new requests
+ return;
+ }
+
+ IssueNewRequestMaybe();
+}
+
+void
+HttpSysServer::IssueNewRequestMaybe()
+{
+ if (m_IsShuttingDown.load(std::memory_order::acquire))
+ {
+ return;
+ }
+
+ if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests)
+ {
+ return;
+ }
+
+ std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this);
+
+ std::error_code ErrorCode;
+ Request->IssueInitialRequest(ErrorCode);
+
+ if (ErrorCode)
+ {
+ // No request was actually issued. What is the appropriate response?
+
+ return;
+ }
+
+ // This may end up exceeding the MaxPendingRequests limit, but it's not
+ // really a problem. I'm doing it this way mostly to avoid dealing with
+ // exceptions here
+ ++m_PendingRequests;
+
+ Request.release();
+}
+
+void
+HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
+{
+ if (UrlPath[0] == '/')
+ {
+ ++UrlPath;
+ }
+
+ const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
+ Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */);
+
+ // Convert to wide string
+
+ for (const std::wstring& BaseUri : m_BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+
+ return;
+ }
+ }
+}
+
+void
+HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
+{
+ ZEN_UNUSED(Service);
+
+ if (UrlPath[0] == '/')
+ {
+ ++UrlPath;
+ }
+
+ const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
+
+ // Convert to wide string
+
+ for (const std::wstring& BaseUri : m_BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
+
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler)
+{
+}
+
+HttpSysTransaction::~HttpSysTransaction()
+{
+}
+
+PTP_IO
+HttpSysTransaction::Iocp()
+{
+ return m_HttpServer.m_ThreadPool.Iocp();
+}
+
+HANDLE
+HttpSysTransaction::RequestQueueHandle()
+{
+ return m_HttpServer.m_RequestQueueHandle;
+}
+
+void
+HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode)
+{
+ m_InitialHttpHandler.IssueRequest(ErrorCode);
+}
+
+void
+HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
+ PVOID pContext /* HttpSysServer */,
+ PVOID pOverlapped,
+ ULONG IoResult,
+ ULONG_PTR NumberOfBytesTransferred,
+ PTP_IO Io)
+{
+ UNREFERENCED_PARAMETER(Io);
+ UNREFERENCED_PARAMETER(Instance);
+ UNREFERENCED_PARAMETER(pContext);
+
+ // Note that for a given transaction we may be in this completion function on more
+ // than one thread at any given moment. This means we need to be careful about what
+ // happens in here
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+
+ if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
+ {
+ delete Transaction;
+ }
+}
+
+bool
+HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler)
+{
+ HttpSysRequestHandler* CurrentHandler = m_CompletionHandler;
+ m_CompletionHandler = NewCompletionHandler;
+
+ auto _ = MakeGuard([this, CurrentHandler] {
+ if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler))
+ {
+ delete CurrentHandler;
+ }
+ });
+
+ if (NewCompletionHandler == nullptr)
+ {
+ return false;
+ }
+
+ try
+ {
+ std::error_code ErrorCode;
+ m_CompletionHandler->IssueRequest(ErrorCode);
+
+ if (!ErrorCode)
+ {
+ return true;
+ }
+
+ ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message());
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what());
+ }
+
+ // something went wrong, no request is pending
+ m_CompletionHandler = nullptr;
+
+ return false;
+}
+
+HttpSysTransaction::Status
+HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // We use this to ensure sequential execution of completion handlers
+ // for any given transaction. It also ensures all member variables are
+ // in a consistent state for the current thread
+
+ RwLock::ExclusiveLockScope _(m_CompletionMutex);
+
+ bool IsRequestPending = false;
+
+ if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler)
+ {
+ if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest())
+ {
+ // Ensure we have a sufficient number of pending requests outstanding
+ m_HttpServer.OnHandlingRequest();
+ }
+
+ auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred);
+
+ IsRequestPending = IssueNextRequest(NewCompletionHandler);
+ }
+
+ // Ensure new requests are enqueued as necessary
+ m_HttpServer.IssueNewRequestMaybe();
+
+ if (IsRequestPending)
+ {
+ // There is another request pending on this transaction, so it needs to remain valid
+ return Status::kRequestPending;
+ }
+
+ if (m_HttpServer.m_IsRequestLoggingEnabled)
+ {
+ if (m_HandlerRequest.has_value())
+ {
+ m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri());
+ }
+ }
+
+ // Transaction done, caller should clean up (delete) this instance
+ return Status::kDone;
+}
+
+HttpSysServerRequest&
+HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
+{
+ HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+
+ // Default request handling
+
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+
+ return ThisRequest;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer)
+: m_HttpTx(Tx)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest();
+
+ const int PrefixLength = Service.UriPrefixLength();
+ const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t);
+
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
+ if (AbsPathLength >= PrefixLength)
+ {
+ // We convert the URI immediately because most of the code involved prefers to deal
+ // with utf8. This is overhead which I'd prefer to avoid but for now we just have
+ // to live with it
+
+ WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)},
+ m_UriUtf8);
+
+ std::string_view UriSuffix8{m_UriUtf8};
+
+ m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access
+ m_Uri = UriSuffix8;
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_Uri.remove_suffix(UriSuffix8.size() + 1);
+ }
+ }
+ }
+ else
+ {
+ m_UriUtf8.Reset();
+ m_Uri = {};
+ m_UriWithExtension = {};
+ }
+
+ if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
+ {
+ --QueryStringLength; // We skip the leading question mark
+
+ WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8);
+ }
+ else
+ {
+ m_QueryStringUtf8.Reset();
+ }
+
+ m_QueryString = std::string_view(m_QueryStringUtf8);
+ m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb);
+ m_ContentLength = GetContentLength(HttpRequestPtr);
+ m_ContentType = GetContentType(HttpRequestPtr);
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = GetAcceptType(HttpRequestPtr);
+ }
+
+ if (m_Verb == HttpVerb::kHead)
+ {
+ SetSuppressResponseBody();
+ }
+}
+
+Oid
+HttpSysServerRequest::ParseSessionId() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+
+ for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i)
+ {
+ HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i];
+ std::string_view HeaderName{Header.pName, Header.NameLength};
+
+ if (HeaderName == "UE-Session"sv)
+ {
+ if (Header.RawValueLength == Oid::StringLength)
+ {
+ return Oid::FromHexString({Header.pRawValue, Header.RawValueLength});
+ }
+ }
+ }
+
+ return {};
+}
+
+uint32_t
+HttpSysServerRequest::ParseRequestId() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+
+ for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i)
+ {
+ HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i];
+ std::string_view HeaderName{Header.pName, Header.NameLength};
+
+ if (HeaderName == "UE-Request"sv)
+ {
+ std::string_view RequestValue{Header.pRawValue, Header.RawValueLength};
+ uint32_t RequestId = 0;
+ std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId);
+ return RequestId;
+ }
+ }
+
+ return 0;
+}
+
+IoBuffer
+HttpSysServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
+{
+ ZEN_ASSERT(IsHandled() == false);
+
+ auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+
+ if (SuppressBody())
+ {
+ Response->SuppressResponseBody();
+ }
+
+ m_NextCompletionHandler = Response;
+
+ SetIsHandled();
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(IsHandled() == false);
+
+ auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+
+ if (SuppressBody())
+ {
+ Response->SuppressResponseBody();
+ }
+
+ m_NextCompletionHandler = Response;
+
+ SetIsHandled();
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(IsHandled() == false);
+
+ auto Response =
+ new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size());
+
+ if (SuppressBody())
+ {
+ Response->SuppressResponseBody();
+ }
+
+ m_NextCompletionHandler = Response;
+
+ SetIsHandled();
+}
+
+void
+HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ if (m_HttpTx.Server().IsAsyncResponseEnabled())
+ {
+ m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler));
+ }
+ else
+ {
+ ContinuationHandler(m_HttpTx.ServerRequest());
+ }
+}
+
+bool
+HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges)
+{
+ HTTP_REQUEST* Req = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange];
+
+ return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest)
+{
+}
+
+InitialRequestHandler::~InitialRequestHandler()
+{
+}
+
+void
+InitialRequestHandler::IssueRequest(std::error_code& ErrorCode)
+{
+ HttpSysTransaction& Tx = Transaction();
+ PTP_IO Iocp = Tx.Iocp();
+ HTTP_REQUEST* HttpReq = Tx.HttpRequest();
+
+ StartThreadpoolIo(Iocp);
+
+ ULONG HttpApiResult;
+
+ if (IsInitialRequest())
+ {
+ HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(),
+ HTTP_NULL_ID,
+ HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
+ HttpReq,
+ RequestBufferSize(),
+ NULL,
+ Tx.Overlapped());
+ }
+ else
+ {
+ // The http.sys team recommends limiting the size to 128KB
+ static const uint64_t kMaxBytesPerApiCall = 128 * 1024;
+
+ uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset;
+ const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall);
+ void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset;
+
+ HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ 0, /* Flags */
+ BufferWriteCursor,
+ gsl::narrow<ULONG>(BytesToReadThisCall),
+ nullptr, // BytesReturned
+ Tx.Overlapped());
+ }
+
+ if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR)
+ {
+ CancelThreadpoolIo(Iocp);
+
+ ErrorCode = MakeErrorCode(HttpApiResult);
+
+ ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message());
+
+ return;
+ }
+
+ ErrorCode.clear();
+}
+
+HttpSysRequestHandler*
+InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ auto _ = MakeGuard([&] { m_IsInitialRequest = false; });
+
+ switch (IoResult)
+ {
+ default:
+ case ERROR_OPERATION_ABORTED:
+ return nullptr;
+
+ case ERROR_MORE_DATA: // Insufficient buffer space
+ case NO_ERROR:
+ break;
+ }
+
+ // Route request
+
+ try
+ {
+ HTTP_REQUEST* HttpReq = HttpRequest();
+
+# if 0
+ for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ {
+ auto& ReqInfo = HttpReq->pRequestInfo[i];
+
+ switch (ReqInfo.InfoType)
+ {
+ case HttpRequestInfoTypeRequestTiming:
+ {
+ const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+
+ ZEN_INFO("");
+ }
+ break;
+ case HttpRequestInfoTypeAuth:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeChannelBind:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeSslProtocol:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeSslTokenBindingDraft:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeSslTokenBinding:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeTcpInfoV0:
+ {
+ const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+
+ ZEN_INFO("");
+ }
+ break;
+ case HttpRequestInfoTypeRequestSizing:
+ {
+ const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
+ ZEN_INFO("");
+ }
+ break;
+ case HttpRequestInfoTypeQuicStats:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeTcpInfoV1:
+ {
+ const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+
+ ZEN_INFO("");
+ }
+ break;
+ }
+ }
+# endif
+
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
+ {
+ if (m_IsInitialRequest)
+ {
+ m_ContentLength = GetContentLength(HttpReq);
+ const HttpContentType ContentType = GetContentType(HttpReq);
+
+ if (m_ContentLength)
+ {
+ // Handle initial chunk read by copying any payload which has already been copied
+ // into our embedded request buffer
+
+ m_PayloadBuffer = IoBuffer(m_ContentLength);
+ m_PayloadBuffer.SetContentType(ContentType);
+
+ uint64_t BytesToRead = m_ContentLength;
+ uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData());
+ uint8_t* BufferWriteCursor = BufferBase;
+
+ const int EntityChunkCount = HttpReq->EntityChunkCount;
+
+ for (int i = 0; i < EntityChunkCount; ++i)
+ {
+ HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i];
+
+ ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory);
+
+ const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength;
+
+ ZEN_ASSERT(BufferLength <= BytesToRead);
+
+ memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength);
+
+ BufferWriteCursor += BufferLength;
+ BytesToRead -= BufferLength;
+ }
+
+ m_CurrentPayloadOffset = BufferWriteCursor - BufferBase;
+ }
+ }
+ else
+ {
+ m_CurrentPayloadOffset += NumberOfBytesTransferred;
+ }
+
+ if (m_CurrentPayloadOffset != m_ContentLength)
+ {
+ // Body not complete, issue another read request to receive more body data
+ return this;
+ }
+
+ // Request body received completely
+
+ m_PayloadBuffer.MakeImmutable();
+
+ HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer);
+
+ if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler)
+ {
+ return Response;
+ }
+
+ if (!ThisRequest.IsHandled())
+ {
+ return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
+ }
+ }
+
+ // Unable to route
+ return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
+
+ return new HttpMessageResponseRequest(Transaction(), 500, ex.what());
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// HttpServer interface implementation
+//
+
+int
+HttpSysServer::Initialize(int BasePort)
+{
+ int EffectivePort = InitializeServer(BasePort);
+ StartServer();
+ return EffectivePort;
+}
+
+void
+HttpSysServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+void
+HttpSysServer::RegisterService(HttpService& Service)
+{
+ RegisterService(Service.BaseUri(), Service);
+}
+
+Ref<HttpServer>
+CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads)
+{
+ return Ref<HttpServer>(new HttpSysServer(Concurrency, BackgroundWorkerThreads));
+}
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/httpsys.h b/src/zenhttp/httpsys.h
new file mode 100644
index 000000000..d6bd34890
--- /dev/null
+++ b/src/zenhttp/httpsys.h
@@ -0,0 +1,90 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+#ifndef ZEN_WITH_HTTPSYS
+# if ZEN_PLATFORM_WINDOWS
+# define ZEN_WITH_HTTPSYS 1
+# else
+# define ZEN_WITH_HTTPSYS 0
+# endif
+#endif
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <zencore/workthreadpool.h>
+# include "iothreadpool.h"
+
+# include <http.h>
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+/**
+ * @brief Windows implementation of HTTP server based on http.sys
+ *
+ * This requires elevation to function
+ */
+class HttpSysServer : public HttpServer
+{
+ friend class HttpSysTransaction;
+
+public:
+ explicit HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount);
+ ~HttpSysServer();
+
+ // HttpServer interface implementation
+
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool TestMode) override;
+ virtual void RequestExit() override;
+ virtual void RegisterService(HttpService& Service) override;
+
+ WorkerThreadPool& WorkPool() { return m_AsyncWorkPool; }
+
+ inline bool IsOk() const { return m_IsOk; }
+ inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+
+private:
+ int InitializeServer(int BasePort);
+ void Cleanup();
+
+ void StartServer();
+ void OnHandlingRequest();
+ void IssueNewRequestMaybe();
+
+ void RegisterService(const char* Endpoint, HttpService& Service);
+ void UnregisterService(const char* Endpoint, HttpService& Service);
+
+private:
+ spdlog::logger& m_Log;
+ spdlog::logger& m_RequestLog;
+ spdlog::logger& Log() { return m_Log; }
+
+ bool m_IsOk = false;
+ bool m_IsHttpInitialized = false;
+ bool m_IsRequestLoggingEnabled = false;
+ bool m_IsAsyncResponseEnabled = true;
+
+ WinIoThreadPool m_ThreadPool;
+ WorkerThreadPool m_AsyncWorkPool;
+
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
+ HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
+ HANDLE m_RequestQueueHandle = 0;
+ std::atomic_int32_t m_PendingRequests{0};
+ std::atomic_int32_t m_IsShuttingDown{0};
+ int32_t m_MinPendingRequests = 16;
+ int32_t m_MaxPendingRequests = 128;
+ Event m_ShutdownEvent;
+};
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
new file mode 100644
index 000000000..8316a9b9f
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -0,0 +1,47 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zencore/iobuffer.h>
+#include <zencore/uid.h>
+#include <zenhttp/httpcommon.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CbPackage;
+
+/** HTTP client implementation for Zen use cases
+
+ Currently simple and synchronous, should become lean and asynchronous
+ */
+class HttpClient
+{
+public:
+ HttpClient(std::string_view BaseUri);
+ ~HttpClient();
+
+ struct Response
+ {
+ int StatusCode = 0;
+ IoBuffer ResponsePayload; // Note: this also includes the content type
+ };
+
+ [[nodiscard]] Response Put(std::string_view Url, IoBuffer Payload);
+ [[nodiscard]] Response Get(std::string_view Url);
+ [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package);
+ [[nodiscard]] Response Delete(std::string_view Url);
+
+private:
+ std::string m_BaseUri;
+ std::string m_SessionId;
+};
+
+} // namespace zen
+
+void httpclient_forcelink(); // internal
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
new file mode 100644
index 000000000..19fda8db4
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -0,0 +1,181 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+
+#include <string_view>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+using HttpContentType = ZenContentType;
+
+class IoBuffer;
+class CbObject;
+class CbPackage;
+class StringBuilderBase;
+
+struct HttpRange
+{
+ uint32_t Start = ~uint32_t(0);
+ uint32_t End = ~uint32_t(0);
+};
+
+using HttpRanges = std::vector<HttpRange>;
+
+std::string_view MapContentTypeToString(HttpContentType ContentType);
+extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString);
+std::string_view ReasonStringForHttpResultCode(int HttpCode);
+bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges);
+
+[[nodiscard]] inline bool
+IsHttpSuccessCode(int HttpCode)
+{
+ return (HttpCode >= 200) && (HttpCode < 300);
+}
+
+enum class HttpVerb : uint8_t
+{
+ kGet = 1 << 0,
+ kPut = 1 << 1,
+ kPost = 1 << 2,
+ kDelete = 1 << 3,
+ kHead = 1 << 4,
+ kCopy = 1 << 5,
+ kOptions = 1 << 6
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb);
+
+const std::string_view ToString(HttpVerb Verb);
+
+enum class HttpResponseCode
+{
+ // 1xx - Informational
+
+ Continue = 100, //!< Indicates that the initial part of a request has been received and has not yet been rejected by the server.
+ SwitchingProtocols = 101, //!< Indicates that the server understands and is willing to comply with the client's request, via the
+ //!< Upgrade header field, for a change in the application protocol being used on this connection.
+ Processing = 102, //!< Is an interim response used to inform the client that the server has accepted the complete request, but has not
+ //!< yet completed it.
+ EarlyHints = 103, //!< Indicates to the client that the server is likely to send a final response with the header fields included in
+ //!< the informational response.
+
+ // 2xx - Successful
+
+ OK = 200, //!< Indicates that the request has succeeded.
+ Created = 201, //!< Indicates that the request has been fulfilled and has resulted in one or more new resources being created.
+ Accepted = 202, //!< Indicates that the request has been accepted for processing, but the processing has not been completed.
+ NonAuthoritativeInformation = 203, //!< Indicates that the request was successful but the enclosed payload has been modified from that
+ //!< of the origin server's 200 (OK) response by a transforming proxy.
+ NoContent = 204, //!< Indicates that the server has successfully fulfilled the request and that there is no additional content to send
+ //!< in the response payload body.
+ ResetContent = 205, //!< Indicates that the server has fulfilled the request and desires that the user agent reset the \"document
+ //!< view\", which caused the request to be sent, to its original state as received from the origin server.
+ PartialContent = 206, //!< Indicates that the server is successfully fulfilling a range request for the target resource by transferring
+ //!< one or more parts of the selected representation that correspond to the satisfiable ranges found in the
+ //!< requests's Range header field.
+ MultiStatus = 207, //!< Provides status for multiple independent operations.
+ AlreadyReported = 208, //!< Used inside a DAV:propstat response element to avoid enumerating the internal members of multiple bindings
+ //!< to the same collection repeatedly. [RFC 5842]
+ IMUsed = 226, //!< The server has fulfilled a GET request for the resource, and the response is a representation of the result of one
+ //!< or more instance-manipulations applied to the current instance.
+
+ // 3xx - Redirection
+
+ MultipleChoices = 300, //!< Indicates that the target resource has more than one representation, each with its own more specific
+ //!< identifier, and information about the alternatives is being provided so that the user (or user agent) can
+ //!< select a preferred representation by redirecting its request to one or more of those identifiers.
+ MovedPermanently = 301, //!< Indicates that the target resource has been assigned a new permanent URI and any future references to this
+ //!< resource ought to use one of the enclosed URIs.
+ Found = 302, //!< Indicates that the target resource resides temporarily under a different URI.
+ SeeOther = 303, //!< Indicates that the server is redirecting the user agent to a different resource, as indicated by a URI in the
+ //!< Location header field, that is intended to provide an indirect response to the original request.
+ NotModified = 304, //!< Indicates that a conditional GET request has been received and would have resulted in a 200 (OK) response if it
+ //!< were not for the fact that the condition has evaluated to false.
+ UseProxy = 305, //!< \deprecated \parblock Due to security concerns regarding in-band configuration of a proxy. \endparblock
+ //!< The requested resource MUST be accessed through the proxy given by the Location field.
+ TemporaryRedirect = 307, //!< Indicates that the target resource resides temporarily under a different URI and the user agent MUST NOT
+ //!< change the request method if it performs an automatic redirection to that URI.
+ PermanentRedirect = 308, //!< The target resource has been assigned a new permanent URI and any future references to this resource
+ //!< ought to use one of the enclosed URIs. [...] This status code is similar to 301 Moved Permanently
+ //!< (Section 7.3.2 of rfc7231), except that it does not allow rewriting the request method from POST to GET.
+
+ // 4xx - Client Error
+ BadRequest = 400, //!< Indicates that the server cannot or will not process the request because the received syntax is invalid,
+ //!< nonsensical, or exceeds some limitation on what the server is willing to process.
+ Unauthorized = 401, //!< Indicates that the request has not been applied because it lacks valid authentication credentials for the
+ //!< target resource.
+ PaymentRequired = 402, //!< *Reserved*
+ Forbidden = 403, //!< Indicates that the server understood the request but refuses to authorize it.
+ NotFound = 404, //!< Indicates that the origin server did not find a current representation for the target resource or is not willing
+ //!< to disclose that one exists.
+ MethodNotAllowed = 405, //!< Indicates that the method specified in the request-line is known by the origin server but not supported by
+ //!< the target resource.
+ NotAcceptable = 406, //!< Indicates that the target resource does not have a current representation that would be acceptable to the
+ //!< user agent, according to the proactive negotiation header fields received in the request, and the server is
+ //!< unwilling to supply a default representation.
+ ProxyAuthenticationRequired =
+ 407, //!< Is similar to 401 (Unauthorized), but indicates that the client needs to authenticate itself in order to use a proxy.
+ RequestTimeout =
+ 408, //!< Indicates that the server did not receive a complete request message within the time that it was prepared to wait.
+ Conflict = 409, //!< Indicates that the request could not be completed due to a conflict with the current state of the resource.
+ Gone = 410, //!< Indicates that access to the target resource is no longer available at the origin server and that this condition is
+ //!< likely to be permanent.
+ LengthRequired = 411, //!< Indicates that the server refuses to accept the request without a defined Content-Length.
+ PreconditionFailed =
+ 412, //!< Indicates that one or more preconditions given in the request header fields evaluated to false when tested on the server.
+ PayloadTooLarge = 413, //!< Indicates that the server is refusing to process a request because the request payload is larger than the
+ //!< server is willing or able to process.
+ URITooLong = 414, //!< Indicates that the server is refusing to service the request because the request-target is longer than the
+ //!< server is willing to interpret.
+ UnsupportedMediaType = 415, //!< Indicates that the origin server is refusing to service the request because the payload is in a format
+ //!< not supported by the target resource for this method.
+ RangeNotSatisfiable = 416, //!< Indicates that none of the ranges in the request's Range header field overlap the current extent of the
+ //!< selected resource or that the set of ranges requested has been rejected due to invalid ranges or an
+ //!< excessive request of small or overlapping ranges.
+ ExpectationFailed = 417, //!< Indicates that the expectation given in the request's Expect header field could not be met by at least
+ //!< one of the inbound servers.
+ ImATeapot = 418, //!< Any attempt to brew coffee with a teapot should result in the error code 418 I'm a teapot.
+ UnprocessableEntity = 422, //!< Means the server understands the content type of the request entity (hence a 415(Unsupported Media
+ //!< Type) status code is inappropriate), and the syntax of the request entity is correct (thus a 400 (Bad
+ //!< Request) status code is inappropriate) but was unable to process the contained instructions.
+ Locked = 423, //!< Means the source or destination resource of a method is locked.
+ FailedDependency = 424, //!< Means that the method could not be performed on the resource because the requested action depended on
+ //!< another action and that action failed.
+ UpgradeRequired = 426, //!< Indicates that the server refuses to perform the request using the current protocol but might be willing to
+ //!< do so after the client upgrades to a different protocol.
+ PreconditionRequired = 428, //!< Indicates that the origin server requires the request to be conditional.
+ TooManyRequests = 429, //!< Indicates that the user has sent too many requests in a given amount of time (\"rate limiting\").
+ RequestHeaderFieldsTooLarge =
+ 431, //!< Indicates that the server is unwilling to process the request because its header fields are too large.
+ UnavailableForLegalReasons =
+ 451, //!< This status code indicates that the server is denying access to the resource in response to a legal demand.
+
+ // 5xx - Server Error
+
+ InternalServerError =
+ 500, //!< Indicates that the server encountered an unexpected condition that prevented it from fulfilling the request.
+ NotImplemented = 501, //!< Indicates that the server does not support the functionality required to fulfill the request.
+ BadGateway = 502, //!< Indicates that the server, while acting as a gateway or proxy, received an invalid response from an inbound
+ //!< server it accessed while attempting to fulfill the request.
+ ServiceUnavailable = 503, //!< Indicates that the server is currently unable to handle the request due to a temporary overload or
+ //!< scheduled maintenance, which will likely be alleviated after some delay.
+ GatewayTimeout = 504, //!< Indicates that the server, while acting as a gateway or proxy, did not receive a timely response from an
+ //!< upstream server it needed to access in order to complete the request.
+ HTTPVersionNotSupported = 505, //!< Indicates that the server does not support, or refuses to support, the protocol version that was
+ //!< used in the request message.
+ VariantAlsoNegotiates =
+ 506, //!< Indicates that the server has an internal configuration error: the chosen variant resource is configured to engage in
+ //!< transparent content negotiation itself, and is therefore not a proper end point in the negotiation process.
+ InsufficientStorage = 507, //!< Means the method could not be performed on the resource because the server is unable to store the
+ //!< representation needed to successfully complete the request.
+ LoopDetected = 508, //!< Indicates that the server terminated an operation because it encountered an infinite loop while processing a
+ //!< request with "Depth: infinity". [RFC 5842]
+ NotExtended = 510, //!< The policy for accessing the resource has not been met in the request. [RFC 2774]
+ NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access.
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
new file mode 100644
index 000000000..3b9fa50b4
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -0,0 +1,315 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/enumflags.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/refcount.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+#include <zenhttp/httpcommon.h>
+
+#include <functional>
+#include <gsl/gsl-lite.hpp>
+#include <list>
+#include <map>
+#include <regex>
+#include <span>
+#include <unordered_map>
+
+namespace zen {
+
+/** HTTP Server Request
+ */
+class HttpServerRequest
+{
+public:
+ HttpServerRequest();
+ ~HttpServerRequest();
+
+ // Synchronous operations
+
+ [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix
+ [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
+ [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; }
+
+ struct QueryParams
+ {
+ std::vector<std::pair<std::string_view, std::string_view>> KvPairs;
+
+ std::string_view GetValue(std::string_view ParamName) const
+ {
+ for (const auto& Kv : KvPairs)
+ {
+ const std::string_view& Key = Kv.first;
+
+ if (Key.size() == ParamName.size())
+ {
+ if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size()))
+ {
+ return Kv.second;
+ }
+ }
+ }
+
+ return std::string_view();
+ }
+ };
+
+ virtual bool TryGetRanges(HttpRanges&) { return false; }
+
+ QueryParams GetQueryParams();
+
+ inline HttpVerb RequestVerb() const { return m_Verb; }
+ inline HttpContentType RequestContentType() { return m_ContentType; }
+ inline HttpContentType AcceptContentType() { return m_AcceptType; }
+
+ inline uint64_t ContentLength() const { return m_ContentLength; }
+ Oid SessionId() const;
+ uint32_t RequestId() const;
+
+ inline bool IsHandled() const { return !!(m_Flags & kIsHandled); }
+ inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); }
+ inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; }
+
+ /** Read POST/PUT payload for request body, which is always available without delay
+ */
+ virtual IoBuffer ReadPayload() = 0;
+
+ ZENCORE_API CbObject ReadPayloadObject();
+ ZENCORE_API CbPackage ReadPayloadPackage();
+
+ /** Respond with payload
+
+ No data will have been sent when any of these functions return. Instead, the response will be transmitted
+ asynchronously, after returning from a request handler function.
+
+ Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be
+ moved into our response handler array where they are kept alive, in order to reduce ref-counting storms
+ */
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) = 0;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload);
+
+ void WriteResponse(HttpResponseCode ResponseCode, CbObject Data);
+ void WriteResponse(HttpResponseCode ResponseCode, CbArray Array);
+ void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package);
+ void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString);
+ void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob);
+
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0;
+
+protected:
+ enum
+ {
+ kIsHandled = 1 << 0,
+ kSuppressBody = 1 << 1,
+ kHaveRequestId = 1 << 2,
+ kHaveSessionId = 1 << 3,
+ };
+
+ mutable uint32_t m_Flags = 0;
+ HttpVerb m_Verb = HttpVerb::kGet;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ HttpContentType m_AcceptType = HttpContentType::kUnknownContentType;
+ uint64_t m_ContentLength = ~0ull;
+ std::string_view m_Uri;
+ std::string_view m_UriWithExtension;
+ std::string_view m_QueryString;
+ mutable uint32_t m_RequestId = ~uint32_t(0);
+ mutable Oid m_SessionId = Oid::Zero;
+
+ inline void SetIsHandled() { m_Flags |= kIsHandled; }
+
+ virtual Oid ParseSessionId() const = 0;
+ virtual uint32_t ParseRequestId() const = 0;
+};
+
+class IHttpPackageHandler : public RefCounted
+{
+public:
+ virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0;
+ virtual void OnRequestBegin() = 0;
+ virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0;
+ virtual void OnRequestComplete() = 0;
+};
+
+/**
+ * Base class for implementing an HTTP "service"
+ *
+ * A service exposes one or more endpoints with a certain URI prefix
+ *
+ */
+
+class HttpService
+{
+public:
+ HttpService() = default;
+ virtual ~HttpService() = default;
+
+ virtual const char* BaseUri() const = 0;
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+
+ // Internals
+
+ inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; }
+ inline int UriPrefixLength() const { return m_UriPrefixLength; }
+
+private:
+ int m_UriPrefixLength = 0;
+};
+
+/** HTTP server
+ *
+ * Implements the main event loop to service HTTP requests, and handles routing
+ * requests to the appropriate handler as registered via RegisterService
+ */
+class HttpServer : public RefCounted
+{
+public:
+ virtual void RegisterService(HttpService& Service) = 0;
+ virtual int Initialize(int BasePort) = 0;
+ virtual void Run(bool IsInteractiveSession) = 0;
+ virtual void RequestExit() = 0;
+};
+
+Ref<HttpServer> CreateHttpServer(std::string_view ServerClass);
+
+//////////////////////////////////////////////////////////////////////////
+
+class HttpRouterRequest
+{
+public:
+ HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
+
+ ZENCORE_API std::string GetCapture(uint32_t Index) const;
+ inline HttpServerRequest& ServerRequest() { return m_HttpRequest; }
+
+private:
+ using MatchResults_t = std::match_results<std::string_view::const_iterator>;
+
+ HttpServerRequest& m_HttpRequest;
+ MatchResults_t m_Match;
+
+ friend class HttpRequestRouter;
+};
+
+inline std::string
+HttpRouterRequest::GetCapture(uint32_t Index) const
+{
+ ZEN_ASSERT(Index < m_Match.size());
+
+ return m_Match[Index];
+}
+
+/** HTTP request router helper
+ *
+ * This helper class allows a service implementer to register one or more
+ * endpoints using pattern matching (currently using regex matching)
+ *
+ * This is intended to be initialized once only, there is no thread
+ * safety so you can absolutely not add or remove endpoints once the handler
+ * goes live
+ */
+
+class HttpRequestRouter
+{
+public:
+ typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t;
+
+ /**
+ * @brief Add pattern which can be referenced by name, commonly used for URL components
+ * @param Id String used to identify patterns for replacement
+ * @param Regex String which will replace the Id string in any registered URL paths
+ */
+ void AddPattern(const char* Id, const char* Regex);
+
+ /**
+ * @brief Register a an endpoint handler for the given route
+ * @param Regex Regular expression used to match the handler to a request. This may
+ * contain pattern aliases registered via AddPattern
+ * @param HandlerFunc Handler function to call for any matching request
+ * @param SupportedVerbs Supported HTTP verbs for this handler
+ */
+ void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs);
+
+ void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex);
+
+ /**
+ * @brief HTTP request handling function - this should be called to route the
+ * request to a registered handler
+ * @param Request Request to route to a handler
+ * @return Function returns true if the request was routed successfully
+ */
+ bool HandleRequest(zen::HttpServerRequest& Request);
+
+private:
+ struct HandlerEntry
+ {
+ HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern)
+ : RegEx(Regex, std::regex::icase | std::regex::ECMAScript)
+ , Verbs(SupportedVerbs)
+ , Handler(std::move(Handler))
+ , Pattern(Pattern)
+ {
+ }
+
+ ~HandlerEntry() = default;
+
+ std::regex RegEx;
+ HttpVerb Verbs;
+ HandlerFunc_t Handler;
+ const char* Pattern;
+
+ private:
+ HandlerEntry& operator=(const HandlerEntry&) = delete;
+ HandlerEntry(const HandlerEntry&) = delete;
+ };
+
+ std::list<HandlerEntry> m_Handlers;
+ std::unordered_map<std::string, std::string> m_PatternMap;
+};
+
+/** HTTP RPC request helper
+ */
+
+class RpcResult
+{
+ RpcResult(CbObject Result) : m_Result(std::move(Result)) {}
+
+private:
+ CbObject m_Result;
+};
+
+class HttpRpcHandler
+{
+public:
+ HttpRpcHandler();
+ ~HttpRpcHandler();
+
+ HttpRpcHandler(const HttpRpcHandler&) = delete;
+ HttpRpcHandler operator=(const HttpRpcHandler&) = delete;
+
+ void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction);
+
+private:
+ struct RpcFunction
+ {
+ std::function<void(CbObject& RpcArgs)> Function;
+ std::string Identifier;
+ };
+
+ std::map<std::string, RpcFunction> m_Functions;
+};
+
+bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
+
+void http_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpshared.h b/src/zenhttp/include/zenhttp/httpshared.h
new file mode 100644
index 000000000..d335572c5
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpshared.h
@@ -0,0 +1,163 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+
+#include <functional>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+class IoBuffer;
+class CbPackage;
+class CompositeBuffer;
+
+/** _____ _ _____ _
+ / ____| | | __ \ | |
+ | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___
+ | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \
+ | |____| |_) | | | (_| | (__| < (_| | (_| | __/
+ \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___|
+ __/ |
+ |___/
+
+ Structures and code related to handling CbPackage transactions
+
+ CbPackage instances are marshaled across the wire using a distinct message
+ format. We don't use the CbPackage serialization format provided by the
+ CbPackage implementation itself since that does not provide much flexibility
+ in how the attachment payloads are transmitted. The scheme below separates
+ metadata cleanly from payloads and this enables us to more efficiently
+ transmit them either via sendfile/TransmitFile like mechanisms, or by
+ reference/memory mapping in the local case.
+ */
+
+struct CbPackageHeader
+{
+ uint32_t HeaderMagic;
+ uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document?
+ uint32_t Reserved1;
+ uint32_t Reserved2;
+};
+
+static_assert(sizeof(CbPackageHeader) == 16);
+
+enum : uint32_t
+{
+ kCbPkgMagic = 0xaa77aacc
+};
+
+struct CbAttachmentEntry
+{
+ uint64_t PayloadSize; // Size of the associated payload data in the message
+ uint32_t Flags; // See flags below
+ IoHash AttachmentHash; // Content Id for the attachment
+
+ enum
+ {
+ kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format
+ kIsObject = (1u << 1), // Is compact binary object
+ kIsError = (1u << 2), // Is error (compact binary formatted) object
+ kIsLocalRef = (1u << 3), // Is "local reference"
+ };
+};
+
+struct CbAttachmentReferenceHeader
+{
+ uint64_t PayloadByteOffset = 0;
+ uint64_t PayloadByteSize = ~0u;
+ uint16_t AbsolutePathLength = 0;
+
+ // This header will be followed by UTF8 encoded absolute path to backing file
+};
+
+static_assert(sizeof(CbAttachmentEntry) == 32);
+
+enum class FormatFlags
+{
+ kDefault = 0,
+ kAllowLocalReferences = (1u << 0),
+ kDenyPartialLocalReferences = (1u << 1)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags);
+
+enum class RpcAcceptOptions : uint16_t
+{
+ kNone = 0,
+ kAllowLocalReferences = (1u << 0),
+ kAllowPartialLocalReferences = (1u << 1)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
+
+std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0);
+CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0);
+CbPackage ParsePackageMessage(
+ IoBuffer Payload,
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
+ return IoBuffer{Size};
+ });
+bool IsPackageMessage(IoBuffer Payload);
+
+bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
+
+std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, int TargetProcessPid = 0);
+CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid = 0);
+
+/** Streaming reader for compact binary packages
+
+ The goal is to ultimately support zero-copy I/O, but for now there'll be some
+ copying involved on some platforms at least.
+
+ This approach to deserializing CbPackage data is more efficient than
+ `ParsePackageMessage` since it does not require the entire message to
+ be resident in a memory buffer
+
+ */
+class CbPackageReader
+{
+public:
+ CbPackageReader();
+ ~CbPackageReader();
+
+ void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
+
+ /** Process compact binary package data stream
+
+ The data stream must be in the serialization format produced by FormatPackageMessage
+
+ \return How many bytes must be fed to this function in the next call
+ */
+ uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes);
+
+ void Finalize();
+ const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; }
+ CbObject GetRootObject() { return m_RootObject; }
+ std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; }
+
+private:
+ enum class State
+ {
+ kInitialState,
+ kReadingHeader,
+ kReadingAttachmentEntries,
+ kReadingBuffers
+ } m_CurrentState = State::kInitialState;
+
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
+ std::vector<IoBuffer> m_PayloadBuffers;
+ std::vector<CbAttachmentEntry> m_AttachmentEntries;
+ std::vector<CbAttachment> m_Attachments;
+ CbObject m_RootObject;
+ CbPackageHeader m_PackageHeader;
+
+ IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer);
+};
+
+void forcelink_httpshared();
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
new file mode 100644
index 000000000..adca7e988
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -0,0 +1,256 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/memory.h>
+
+#include <compare>
+#include <functional>
+#include <future>
+#include <memory>
+#include <optional>
+
+#pragma once
+
+namespace asio {
+class io_context;
+}
+
+namespace zen {
+
+class BinaryWriter;
+
+/**
+ * A unique socket ID.
+ */
+class WebSocketId
+{
+ static std::atomic_uint32_t NextId;
+
+public:
+ WebSocketId() = default;
+
+ uint32_t Value() const { return m_Value; }
+
+ auto operator<=>(const WebSocketId&) const = default;
+
+ static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); }
+
+private:
+ WebSocketId(uint32_t Value) : m_Value(Value) {}
+
+ uint32_t m_Value{};
+};
+
+/**
+ * Type of web socket message.
+ */
+enum class WebSocketMessageType : uint8_t
+{
+ kInvalid,
+ kNotification,
+ kRequest,
+ kStreamRequest,
+ kResponse,
+ kStreamResponse,
+ kStreamCompleteResponse,
+ kCount
+};
+
+inline std::string_view
+ToString(WebSocketMessageType Type)
+{
+ switch (Type)
+ {
+ case WebSocketMessageType::kInvalid:
+ return std::string_view("Invalid");
+ case WebSocketMessageType::kNotification:
+ return std::string_view("Notification");
+ case WebSocketMessageType::kRequest:
+ return std::string_view("Request");
+ case WebSocketMessageType::kStreamRequest:
+ return std::string_view("StreamRequest");
+ case WebSocketMessageType::kResponse:
+ return std::string_view("Response");
+ case WebSocketMessageType::kStreamResponse:
+ return std::string_view("StreamResponse");
+ case WebSocketMessageType::kStreamCompleteResponse:
+ return std::string_view("StreamCompleteResponse");
+ default:
+ return std::string_view("Unknown");
+ };
+}
+
+/**
+ * Web socket message.
+ */
+class WebSocketMessage
+{
+ struct Header
+ {
+ static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh
+
+ uint64_t MessageSize{};
+ uint32_t Magic{ExpectedMagic};
+ uint32_t CorrelationId{};
+ uint32_t StatusCode{200u};
+ WebSocketMessageType MessageType{};
+ uint8_t Reserved[3] = {0};
+
+ bool IsValid() const;
+ };
+
+ static_assert(sizeof(Header) == 24);
+
+ static std::atomic_uint32_t NextCorrelationId;
+
+public:
+ static constexpr size_t HeaderSize = sizeof(Header);
+
+ WebSocketMessage() = default;
+
+ WebSocketId SocketId() const { return m_SocketId; }
+ void SetSocketId(WebSocketId Id) { m_SocketId = Id; }
+ uint64_t MessageSize() const { return m_Header.MessageSize; }
+ void SetMessageType(WebSocketMessageType MessageType);
+ void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; }
+ uint32_t CorrelationId() const { return m_Header.CorrelationId; }
+ uint32_t StatusCode() const { return m_Header.StatusCode; }
+ void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; }
+ WebSocketMessageType MessageType() const { return m_Header.MessageType; }
+
+ const CbPackage& Body() const { return m_Body.value(); }
+ void SetBody(CbPackage&& Body);
+ void SetBody(CbObject&& Body);
+ bool HasBody() const { return m_Body.has_value(); }
+
+ void Save(BinaryWriter& Writer);
+ bool TryLoadHeader(MemoryView Memory);
+
+ bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; }
+
+private:
+ Header m_Header{};
+ WebSocketId m_SocketId{};
+ std::optional<CbPackage> m_Body;
+};
+
+class WebSocketServer;
+
+/**
+ * Base class for handling web socket requests and notifications from connected client(s).
+ */
+class WebSocketService
+{
+public:
+ virtual ~WebSocketService() = default;
+
+ void Configure(WebSocketServer& Server);
+
+ virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); }
+ virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); }
+
+protected:
+ WebSocketService() = default;
+
+ virtual void RegisterHandlers(WebSocketServer& Server) = 0;
+ void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete);
+ void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete);
+
+ WebSocketServer& SocketServer()
+ {
+ ZEN_ASSERT(m_SocketServer);
+ return *m_SocketServer;
+ }
+
+private:
+ WebSocketServer* m_SocketServer{};
+};
+
+/**
+ * Server options.
+ */
+struct WebSocketServerOptions
+{
+ uint16_t Port = 2337;
+ uint32_t ThreadCount = 1;
+};
+
+/**
+ * The web socket server manages client connections and routing of requests and notifications.
+ */
+class WebSocketServer
+{
+public:
+ virtual ~WebSocketServer() = default;
+
+ virtual bool Run() = 0;
+ virtual void Shutdown() = 0;
+
+ virtual void RegisterService(WebSocketService& Service) = 0;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) = 0;
+ virtual void SendResponse(WebSocketMessage&& Response) = 0;
+
+ static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options);
+};
+
+/**
+ * The state of the web socket.
+ */
+enum class WebSocketState : uint32_t
+{
+ kNone,
+ kHandshaking,
+ kConnected,
+ kDisconnected,
+ kError
+};
+
+/**
+ * Type of web socket client event.
+ */
+enum class WebSocketEvent : uint32_t
+{
+ kConnected,
+ kDisconnected,
+ kError
+};
+
+/**
+ * Web socket client connection info.
+ */
+struct WebSocketConnectInfo
+{
+ std::string Host;
+ int16_t Port{8848};
+ std::string Endpoint;
+ std::vector<std::string> Protocols;
+ uint16_t Version{13};
+};
+
+/**
+ * A connection to a web socket server for sending requests and listening for notifications.
+ */
+class WebSocketClient
+{
+public:
+ using EventCallback = std::function<void()>;
+ using NotificationCallback = std::function<void(WebSocketMessage&&)>;
+
+ virtual ~WebSocketClient() = default;
+
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0;
+ virtual void Disconnect() = 0;
+ virtual bool IsConnected() const = 0;
+ virtual WebSocketState State() const = 0;
+
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0;
+ virtual void OnNotification(NotificationCallback&& Cb) = 0;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0;
+
+ static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/zenhttp.h b/src/zenhttp/include/zenhttp/zenhttp.h
new file mode 100644
index 000000000..59c64b31f
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/zenhttp.h
@@ -0,0 +1,21 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#ifndef ZEN_WITH_HTTPSYS
+# if ZEN_PLATFORM_WINDOWS
+# define ZEN_WITH_HTTPSYS 1
+# else
+# define ZEN_WITH_HTTPSYS 0
+# endif
+#endif
+
+#define ZENHTTP_API // Placeholder to allow DLL configs in the future
+
+namespace zen {
+
+ZENHTTP_API void zenhttp_forcelinktests();
+
+}
diff --git a/src/zenhttp/iothreadpool.cpp b/src/zenhttp/iothreadpool.cpp
new file mode 100644
index 000000000..6087e69ec
--- /dev/null
+++ b/src/zenhttp/iothreadpool.cpp
@@ -0,0 +1,49 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "iothreadpool.h"
+
+#include <zencore/except.h>
+
+#if ZEN_PLATFORM_WINDOWS
+
+namespace zen {
+
+WinIoThreadPool::WinIoThreadPool(int InThreadCount)
+{
+ // Thread pool setup
+
+ m_ThreadPool = CreateThreadpool(NULL);
+
+ SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount);
+ SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2);
+
+ InitializeThreadpoolEnvironment(&m_CallbackEnvironment);
+
+ m_CleanupGroup = CreateThreadpoolCleanupGroup();
+
+ SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool);
+
+ SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL);
+}
+
+WinIoThreadPool::~WinIoThreadPool()
+{
+ CloseThreadpool(m_ThreadPool);
+}
+
+void
+WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode)
+{
+ ZEN_ASSERT(!m_ThreadPoolIo);
+
+ m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment);
+
+ if (!m_ThreadPoolIo)
+ {
+ ErrorCode = MakeErrorCodeFromLastError();
+ }
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/iothreadpool.h b/src/zenhttp/iothreadpool.h
new file mode 100644
index 000000000..8333964c3
--- /dev/null
+++ b/src/zenhttp/iothreadpool.h
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+
+# include <system_error>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Thread pool. Implemented in terms of Windows thread pool right now, will
+// need a cross-platform implementation eventually
+//
+
+class WinIoThreadPool
+{
+public:
+ WinIoThreadPool(int InThreadCount);
+ ~WinIoThreadPool();
+
+ void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode);
+ inline PTP_IO Iocp() const { return m_ThreadPoolIo; }
+
+private:
+ PTP_POOL m_ThreadPool = nullptr;
+ PTP_CLEANUP_GROUP m_CleanupGroup = nullptr;
+ PTP_IO m_ThreadPoolIo = nullptr;
+ TP_CALLBACK_ENVIRON m_CallbackEnvironment;
+};
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp
new file mode 100644
index 000000000..bbe7e1ad8
--- /dev/null
+++ b/src/zenhttp/websocketasio.cpp
@@ -0,0 +1,1613 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/base64.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/intmath.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/sha1.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/trace.h>
+
+#include <chrono>
+#include <optional>
+#include <shared_mutex>
+#include <span>
+#include <system_error>
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# include <mstcpip.h>
+#endif
+
+namespace zen::websocket {
+
+using namespace std::literals;
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv);
+
+using Clock = std::chrono::steady_clock;
+using TimePoint = Clock::time_point;
+
+///////////////////////////////////////////////////////////////////////////////
+namespace http_header {
+ static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv;
+ static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv;
+ static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv;
+ static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv;
+ static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv;
+ static constexpr std::string_view Upgrade = "Upgrade"sv;
+} // namespace http_header
+
+///////////////////////////////////////////////////////////////////////////////
+enum class ParseMessageStatus : uint32_t
+{
+ kError,
+ kContinue,
+ kDone,
+};
+
+struct ParseMessageResult
+{
+ ParseMessageStatus Status{};
+ size_t ByteCount{};
+ std::optional<std::string> Reason;
+};
+
+class MessageParser
+{
+public:
+ virtual ~MessageParser() = default;
+
+ ParseMessageResult ParseMessage(MemoryView Msg);
+ void Reset();
+
+protected:
+ MessageParser() = default;
+
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
+ virtual void OnReset() = 0;
+
+ BinaryWriter m_Stream;
+};
+
+ParseMessageResult
+MessageParser::ParseMessage(MemoryView Msg)
+{
+ return OnParseMessage(Msg);
+}
+
+void
+MessageParser::Reset()
+{
+ OnReset();
+
+ m_Stream.Reset();
+}
+
+///////////////////////////////////////////////////////////////////////////////
+enum class HttpMessageParserType
+{
+ kRequest,
+ kResponse,
+ kBoth
+};
+
+class HttpMessageParser final : public MessageParser
+{
+public:
+ using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
+
+ HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); }
+
+ virtual ~HttpMessageParser() = default;
+
+ int32_t StatusCode() const { return m_Parser.status_code; }
+ bool IsUpgrade() const { return m_Parser.upgrade != 0; }
+ HttpHeaders& Headers() { return m_Headers; }
+ MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); }
+
+ std::string_view StatusText() const
+ {
+ return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size);
+ }
+
+ bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
+
+private:
+ void Initialize();
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+ int OnMessageBegin();
+ int OnUrl(MemoryView Url);
+ int OnStatus(MemoryView Status);
+ int OnHeaderField(MemoryView HeaderField);
+ int OnHeaderValue(MemoryView HeaderValue);
+ int OnHeadersComplete();
+ int OnBody(MemoryView Body);
+ int OnMessageComplete();
+
+ struct StreamEntry
+ {
+ uint64_t Offset{};
+ uint64_t Size{};
+ };
+
+ struct HeaderStreamEntry
+ {
+ StreamEntry Field{};
+ StreamEntry Value{};
+ };
+
+ HttpMessageParserType m_Type;
+ http_parser m_Parser;
+ StreamEntry m_UrlEntry;
+ StreamEntry m_StatusEntry;
+ StreamEntry m_BodyEntry;
+ HeaderStreamEntry m_CurrentHeader;
+ std::vector<HeaderStreamEntry> m_HeaderEntries;
+ HttpHeaders m_Headers;
+ bool m_IsMsgComplete{false};
+
+ static http_parser_settings ParserSettings;
+};
+
+http_parser_settings HttpMessageParser::ParserSettings = {
+ .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); },
+
+ .on_url = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); },
+
+ .on_status = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); },
+
+ .on_header_field = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); },
+
+ .on_header_value = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); },
+
+ .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); },
+
+ .on_body = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); },
+
+ .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
+
+void
+HttpMessageParser::Initialize()
+{
+ http_parser_init(&m_Parser,
+ m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
+ : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
+ : HTTP_BOTH);
+ m_Parser.data = this;
+
+ m_UrlEntry = {};
+ m_StatusEntry = {};
+ m_CurrentHeader = {};
+ m_BodyEntry = {};
+
+ m_IsMsgComplete = false;
+
+ m_HeaderEntries.clear();
+}
+
+ParseMessageResult
+HttpMessageParser::OnParseMessage(MemoryView Msg)
+{
+ const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize());
+
+ auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
+
+ if (m_Parser.http_errno != 0)
+ {
+ Status = ParseMessageStatus::kError;
+ }
+
+ return {.Status = Status, .ByteCount = uint64_t(ByteCount)};
+}
+
+void
+HttpMessageParser::OnReset()
+{
+ Initialize();
+}
+
+int
+HttpMessageParser::OnMessageBegin()
+{
+ ZEN_ASSERT(m_IsMsgComplete == false);
+ ZEN_ASSERT(m_HeaderEntries.empty());
+ ZEN_ASSERT(m_Headers.empty());
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnStatus(MemoryView Status)
+{
+ m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()};
+
+ m_Stream.Write(Status);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnUrl(MemoryView Url)
+{
+ m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()};
+
+ m_Stream.Write(Url);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderField(MemoryView HeaderField)
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
+
+ if (m_CurrentHeader.Field.Size == 0)
+ {
+ m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Field.Size += HeaderField.GetSize();
+
+ m_Stream.Write(HeaderField);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderValue(MemoryView HeaderValue)
+{
+ if (m_CurrentHeader.Value.Size == 0)
+ {
+ m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Value.Size += HeaderValue.GetSize();
+
+ m_Stream.Write(HeaderValue);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeadersComplete()
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
+
+ m_Headers.clear();
+ m_Headers.reserve(m_HeaderEntries.size());
+
+ const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data());
+
+ for (const auto& Entry : m_HeaderEntries)
+ {
+ auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size);
+ auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size);
+
+ m_Headers.try_emplace(std::move(Field), std::move(Value));
+ }
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnBody(MemoryView Body)
+{
+ m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()};
+
+ m_Stream.Write(Body);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnMessageComplete()
+{
+ m_IsMsgComplete = true;
+
+ return 0;
+}
+
+bool
+HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason)
+{
+ static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+
+ OutAcceptHash = std::string();
+
+ if (m_Headers.contains(http_header::SecWebSocketKey) == false)
+ {
+ OutReason = "Missing header Sec-WebSocket-Key";
+ return false;
+ }
+
+ if (m_Headers.contains(http_header::Upgrade) == false)
+ {
+ OutReason = "Missing header Upgrade";
+ return false;
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid;
+
+ SHA1Stream HashStream;
+ HashStream.Append(Sb.Data(), Sb.Size());
+
+ SHA1 Hash = HashStream.GetHash();
+
+ OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
+ Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data());
+
+ return true;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WebSocketMessageParser final : public MessageParser
+{
+public:
+ WebSocketMessageParser() : MessageParser() {}
+
+ WebSocketMessage ConsumeMessage();
+
+private:
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+
+ WebSocketMessage m_Message;
+};
+
+ParseMessageResult
+WebSocketMessageParser::OnParseMessage(MemoryView Msg)
+{
+ ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
+
+ const uint64_t PrevOffset = m_Stream.CurrentOffset();
+
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
+ {
+ const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
+
+ m_Stream.Write(Msg.Left(RemaingHeaderSize));
+ Msg += RemaingHeaderSize;
+
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
+ {
+ return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
+
+ const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
+
+ if (IsValidHeader == false)
+ {
+ OnReset();
+
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message header")};
+ }
+
+ if (m_Message.MessageSize() == 0)
+ {
+ return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
+ }
+
+ ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
+
+ if (Msg.IsEmpty() == false)
+ {
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
+ }
+
+ auto Status = ParseMessageStatus::kContinue;
+
+ if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
+ {
+ Status = ParseMessageStatus::kDone;
+
+ BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
+
+ CbPackage Pkg;
+ if (Pkg.TryLoad(Reader) == false)
+ {
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message")};
+ }
+
+ m_Message.SetBody(std::move(Pkg));
+ }
+
+ return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+}
+
+void
+WebSocketMessageParser::OnReset()
+{
+ m_Message = WebSocketMessage();
+}
+
+WebSocketMessage
+WebSocketMessageParser::ConsumeMessage()
+{
+ WebSocketMessage Msg = std::move(m_Message);
+ m_Message = WebSocketMessage();
+
+ return Msg;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsConnection : public std::enable_shared_from_this<WsConnection>
+{
+public:
+ WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
+ : m_Id(Id)
+ , m_Socket(std::move(Socket))
+ , m_StartTime(Clock::now())
+ , m_State()
+ {
+ }
+
+ ~WsConnection() = default;
+
+ std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
+
+ WebSocketId Id() const { return m_Id; }
+ asio::ip::tcp::socket& Socket() { return *m_Socket; }
+ TimePoint StartTime() const { return m_StartTime; }
+ WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
+ std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ WebSocketState Close();
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ std::mutex& WriteMutex() { return m_WriteMutex; }
+
+private:
+ WebSocketId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ TimePoint m_StartTime;
+ std::atomic_uint32_t m_State;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ std::mutex m_WriteMutex;
+};
+
+WebSocketState
+WsConnection::Close()
+{
+ const auto PrevState = SetState(WebSocketState::kDisconnected);
+
+ if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open())
+ {
+ m_Socket->close();
+ }
+
+ return PrevState;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsThreadPool
+{
+public:
+ WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {}
+ void Start(uint32_t ThreadCount);
+ void Stop();
+
+private:
+ asio::io_service& m_IoSvc;
+ std::vector<std::thread> m_Threads;
+ std::atomic_bool m_Running{false};
+};
+
+void
+WsThreadPool::Start(uint32_t ThreadCount)
+{
+ ZEN_ASSERT(m_Threads.empty());
+
+ ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount);
+
+ m_Running = true;
+
+ for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
+ {
+ m_Threads.emplace_back([this, ThreadId = Idx + 1] {
+ for (;;)
+ {
+ if (m_Running == false)
+ {
+ break;
+ }
+
+ try
+ {
+ m_IoSvc.run();
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what());
+ }
+ }
+
+ ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId);
+ });
+ }
+}
+
+void
+WsThreadPool::Stop()
+{
+ if (m_Running)
+ {
+ m_Running = false;
+
+ for (std::thread& Thread : m_Threads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_Threads.clear();
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsServer final : public WebSocketServer
+{
+public:
+ WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {}
+ virtual ~WsServer() { Shutdown(); }
+
+ virtual bool Run() override;
+ virtual void Shutdown() override;
+
+ virtual void RegisterService(WebSocketService& Service) override;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) override;
+ virtual void SendResponse(WebSocketMessage&& Response) override;
+
+private:
+ friend class WsConnection;
+
+ void AcceptConnection();
+ void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
+
+ void ReadMessage(std::shared_ptr<WsConnection> Connection);
+ void RouteMessage(WebSocketMessage&& Msg);
+ void SendMessage(WebSocketMessage&& Msg);
+
+ struct IdHasher
+ {
+ size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
+ };
+
+ using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
+ using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
+
+ WebSocketServerOptions m_Options;
+ asio::io_service m_IoSvc;
+ std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
+ std::unique_ptr<WsThreadPool> m_ThreadPool;
+ ConnectionMap m_Connections;
+ std::shared_mutex m_ConnMutex;
+ std::vector<WebSocketService*> m_Services;
+ RequestHandlerMap m_RequestHandlers;
+ NotificationHandlerMap m_NotificationHandlers;
+ std::atomic_bool m_Running{};
+};
+
+void
+WsServer::RegisterService(WebSocketService& Service)
+{
+ m_Services.push_back(&Service);
+
+ Service.Configure(*this);
+}
+
+bool
+WsServer::Run()
+{
+ static constexpr size_t ReceiveBufferSize = 256 << 10;
+ static constexpr size_t SendBufferSize = 256 << 10;
+
+ m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
+
+ m_Acceptor->set_option(asio::ip::v6_only(false));
+ m_Acceptor->set_option(asio::socket_base::reuse_address(true));
+ m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor->native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
+
+ asio::error_code Ec;
+ m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
+
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value());
+
+ return false;
+ }
+
+ m_Acceptor->listen();
+ m_Running = true;
+
+ ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port);
+
+ AcceptConnection();
+
+ m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc);
+ m_ThreadPool->Start(m_Options.ThreadCount);
+
+ return true;
+}
+
+void
+WsServer::Shutdown()
+{
+ if (m_Running)
+ {
+ ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down");
+
+ m_Running = false;
+
+ m_Acceptor->close();
+ m_Acceptor.reset();
+ m_IoSvc.stop();
+
+ m_ThreadPool->Stop();
+ }
+}
+
+void
+WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
+{
+ auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
+ Result.first->second.push_back(&Service);
+}
+
+void
+WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
+{
+ m_RequestHandlers[Key] = &Service;
+}
+
+void
+WsServer::SendNotification(WebSocketMessage&& Notification)
+{
+ ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
+
+ SendMessage(std::move(Notification));
+}
+void
+WsServer::SendResponse(WebSocketMessage&& Response)
+{
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
+
+ ZEN_ASSERT(Response.CorrelationId() != 0);
+
+ SendMessage(std::move(Response));
+}
+
+void
+WsServer::AcceptConnection()
+{
+ auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
+ asio::ip::tcp::socket& SocketRef = *Socket.get();
+
+ m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
+ if (m_Running)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message());
+ }
+ else
+ {
+ auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket));
+
+ ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ m_Connections[Connection->Id()] = Connection;
+ }
+
+ Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest));
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ }
+
+ AcceptConnection();
+ }
+ });
+}
+
+void
+WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
+{
+ if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value());
+ }
+ else
+ {
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value());
+ }
+ }
+
+ const WebSocketId Id = Connection->Id();
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ if (m_Connections.contains(Id))
+ {
+ m_Connections.erase(Id);
+ }
+ }
+}
+
+void
+WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
+{
+ Connection->ReadBuffer().prepare(64 << 10);
+
+ asio::async_read(
+ Connection->Socket(),
+ Connection->ReadBuffer(),
+ asio::transfer_at_least(1),
+ [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
+ if (ReadEc)
+ {
+ return CloseConnection(Connection, ReadEc);
+ }
+
+ switch (Connection->State())
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser());
+ asio::const_buffer Buffer = Connection->ReadBuffer().data();
+
+ ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+
+ Connection->ReadBuffer().consume(Result.ByteCount);
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return ReadMessage(Connection);
+ }
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Parser.IsUpgrade() == false)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
+
+ return async_write(Connection->Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ Connection->Parser()->Reset();
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ });
+ }
+
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+
+ std::string AcceptHash;
+ std::string Reason;
+ const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason);
+
+ if (ValidHandshake == false)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ Reason);
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
+
+ return async_write(Connection->Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ Connection->Parser()->Reset();
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ });
+ }
+
+ ExtendableStringBuilder<128> Sb;
+
+ Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: Upgrade\r\n"sv;
+
+ // TODO: Verify protocol
+ if (Parser.Headers().contains(http_header::SecWebSocketProtocol))
+ {
+ Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol]
+ << "\r\n";
+ }
+
+ Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
+ Sb << "\r\n"sv;
+
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "accepting handshake from connection '#{} {}'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ std::string Response = Sb.ToString();
+ Buffer = asio::buffer(Response);
+
+ async_write(Connection->Socket(),
+ Buffer,
+ [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) {
+ if (WriteEc)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ WriteEc.message());
+
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake ({}B) with connection '#{} {}' OK",
+ ByteCount,
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ Connection->SetParser(std::make_unique<WebSocketMessageParser>());
+ Connection->SetState(WebSocketState::kConnected);
+
+ ReadMessage(Connection);
+ });
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
+
+ uint64_t RemainingBytes = Connection->ReadBuffer().size();
+
+ while (RemainingBytes > 0)
+ {
+ MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Connection->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Connection->ReadBuffer().size();
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
+ }
+
+ WebSocketMessage Message = Parser.ConsumeMessage();
+ Parser.Reset();
+
+ Message.SetSocketId(Connection->Id());
+
+ RouteMessage(std::move(Message));
+ }
+
+ ReadMessage(Connection);
+ }
+ break;
+
+ default:
+ break;
+ };
+ });
+}
+
+void
+WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
+{
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kRequest:
+ case WebSocketMessageType::kStreamRequest:
+ {
+ CbObjectView Request = RoutedMessage.Body().GetObject();
+ std::string_view Method = Request["Method"].AsString();
+ bool Handled = false;
+ bool Error = false;
+ std::exception Exception;
+
+ if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
+ {
+ WebSocketService* Service = It->second;
+ ZEN_ASSERT(Service);
+
+ try
+ {
+ Handled = Service->HandleRequest(std::move(RoutedMessage));
+ }
+ catch (std::exception& Err)
+ {
+ Exception = std::move(Err);
+ Error = true;
+ }
+ }
+
+ if (Error || Handled == false)
+ {
+ std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
+
+ ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
+
+ CbObjectWriter Response;
+ Response << "Error"sv << ErrorText;
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
+ ResponseMsg.SetSocketId(RoutedMessage.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SendResponse(std::move(ResponseMsg));
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ CbObjectView Notification = RoutedMessage.Body().GetObject();
+ std::string_view Message = Notification["Message"].AsString();
+
+ if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
+ {
+ std::vector<WebSocketService*>& Handlers = It->second;
+
+ for (WebSocketService* Handler : Handlers)
+ {
+ Handler->HandleNotification(RoutedMessage);
+ }
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
+ }
+ }
+ break;
+
+ default:
+ break;
+ };
+}
+
+void
+WsServer::SendMessage(WebSocketMessage&& Msg)
+{
+ std::shared_ptr<WsConnection> Connection;
+
+ {
+ std::unique_lock _(m_ConnMutex);
+
+ if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
+ {
+ Connection = It->second;
+ }
+ }
+
+ if (Connection.get() == nullptr)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
+ return;
+ }
+
+ if (Connection.get() != nullptr)
+ {
+ BinaryWriter Writer;
+ Msg.Save(Writer);
+
+ ZEN_LOG_TRACE(LogWebSocket,
+ "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
+ ToString(Msg.MessageType()),
+ Connection->Id().Value(),
+ Msg.MessageSize(),
+ Msg.CorrelationId(),
+ NiceBytes(Writer.Size()));
+
+ {
+ ZEN_TRACE_CPU("WS::SendMessage");
+ std::unique_lock _(Connection->WriteMutex());
+ ZEN_TRACE_CPU("WS::WriteSocketData");
+ asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
+ }
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
+{
+public:
+ WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {}
+
+ virtual ~WsClient() { Disconnect(); }
+
+ std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
+
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
+ virtual void Disconnect() override;
+ virtual bool IsConnected() const override { return false; }
+ virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); }
+
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
+ virtual void OnNotification(NotificationCallback&& Cb) override;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
+
+private:
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ void TriggerEvent(WebSocketEvent Evt);
+ void ReadMessage();
+ void RouteMessage(WebSocketMessage&& RoutedMessage);
+
+ using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
+
+ asio::io_context& m_IoCtx;
+ WebSocketId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ EventCallback m_EventCallbacks[3];
+ NotificationCallback m_NotificationCallback;
+ PendingRequestMap m_PendingRequests;
+ std::mutex m_RequestMutex;
+ std::promise<bool> m_ConnectPromise;
+ std::atomic_uint32_t m_State;
+ std::string m_Host;
+ int16_t m_Port{};
+};
+
+std::future<bool>
+WsClient::Connect(const WebSocketConnectInfo& Info)
+{
+ if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
+ {
+ return m_ConnectPromise.get_future();
+ }
+
+ SetState(WebSocketState::kHandshaking);
+
+ try
+ {
+ asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port);
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol());
+
+ m_Socket->connect(Endpoint);
+
+ m_Host = m_Socket->remote_endpoint().address().to_string();
+ m_Port = Info.Port;
+
+ ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port);
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
+
+ SetState(WebSocketState::kError);
+ m_Socket.reset();
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ m_ConnectPromise.set_value(false);
+
+ return m_ConnectPromise.get_future();
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv;
+ Sb << "Host: " << Info.Host << "\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: upgrade\r\n"sv;
+ Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv;
+
+ if (Info.Protocols.empty() == false)
+ {
+ Sb << "Sec-WebSocket-Protocol: "sv;
+ for (size_t Idx = 0; const auto& Protocol : Info.Protocols)
+ {
+ if (Idx++)
+ {
+ Sb << ", ";
+ }
+ Sb << Protocol;
+ }
+ }
+
+ Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv;
+ Sb << "\r\n";
+
+ std::string HandshakeRequest = Sb.ToString();
+ asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
+
+ ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port);
+
+ m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
+ m_MsgParser->Reset();
+
+ async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ else
+ {
+ Self->ReadMessage();
+ }
+ });
+
+ return m_ConnectPromise.get_future();
+}
+
+void
+WsClient::Disconnect()
+{
+ if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
+ {
+ ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port);
+
+ if (m_Socket && m_Socket->is_open())
+ {
+ m_Socket->close();
+ m_Socket.reset();
+ }
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ for (auto& Kv : m_PendingRequests)
+ {
+ Kv.second.set_value(WebSocketMessage());
+ }
+
+ m_PendingRequests.clear();
+ }
+ }
+}
+
+std::future<WebSocketMessage>
+WsClient::SendRequest(WebSocketMessage&& Request)
+{
+ ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
+
+ BinaryWriter Writer;
+ Request.Save(Writer);
+
+ std::future<WebSocketMessage> FutureResponse;
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
+ ZEN_ASSERT(Result.second);
+
+ auto It = Result.first;
+ FutureResponse = It->second.get_future();
+ }
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ });
+
+ return FutureResponse;
+}
+
+void
+WsClient::OnNotification(NotificationCallback&& Cb)
+{
+ m_NotificationCallback = std::move(Cb);
+}
+
+void
+WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
+{
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
+}
+
+void
+WsClient::TriggerEvent(WebSocketEvent Evt)
+{
+ const uint32_t Index = static_cast<uint32_t>(Evt);
+
+ if (m_EventCallbacks[Index])
+ {
+ m_EventCallbacks[Index]();
+ }
+}
+
+void
+WsClient::ReadMessage()
+{
+ m_ReadBuffer.prepare(64 << 10);
+
+ async_read(*m_Socket,
+ m_ReadBuffer,
+ asio::transfer_at_least(1),
+ [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable {
+ const WebSocketState State = Self->State();
+
+ if (State == WebSocketState::kDisconnected)
+ {
+ return;
+ }
+
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message());
+
+ return Self->Disconnect();
+ }
+
+ switch (State)
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser());
+
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
+
+ ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Self->ReadBuffer().consume(size_t(Result.ByteCount));
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
+
+ Self->m_ConnectPromise.set_value(false);
+
+ return Self->Disconnect();
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return Self->ReadMessage();
+ }
+
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+
+ if (Parser.StatusCode() != 101)
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "handshake FAILED, status '{}', status code '{}'",
+ Parser.StatusText(),
+ Parser.StatusCode());
+
+ Self->m_ConnectPromise.set_value(false);
+
+ return Self->Disconnect();
+ }
+
+ ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText());
+
+ Self->SetParser(std::make_unique<WebSocketMessageParser>());
+ Self->SetState(WebSocketState::kConnected);
+ Self->ReadMessage();
+ Self->TriggerEvent(WebSocketEvent::kConnected);
+
+ Self->m_ConnectPromise.set_value(true);
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
+
+ uint64_t RemainingBytes = Self->ReadBuffer().size();
+
+ while (RemainingBytes > 0)
+ {
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Self->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Self->ReadBuffer().size();
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ Parser.Reset();
+ continue;
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
+ }
+
+ WebSocketMessage Message = Parser.ConsumeMessage();
+ Parser.Reset();
+
+ Self->RouteMessage(std::move(Message));
+ }
+
+ Self->ReadMessage();
+ }
+ break;
+
+ default:
+ break;
+ }
+ });
+}
+
+void
+WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
+{
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kResponse:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
+ {
+ It->second.set_value(std::move(RoutedMessage));
+ m_PendingRequests.erase(It);
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "route request message FAILED, reason 'unknown correlation ID ({})'",
+ RoutedMessage.CorrelationId());
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (m_NotificationCallback)
+ {
+ m_NotificationCallback(std::move(RoutedMessage));
+ }
+ }
+ break;
+
+ default:
+ ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType()));
+ break;
+ };
+}
+
+} // namespace zen::websocket
+
+namespace zen {
+
+std::atomic_uint32_t WebSocketId::NextId{1};
+
+bool
+WebSocketMessage::Header::IsValid() const
+{
+ return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
+ uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
+}
+
+std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
+
+void
+WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
+{
+ m_Header.MessageType = MessageType;
+}
+
+void
+WebSocketMessage::SetBody(CbPackage&& Body)
+{
+ m_Body = std::move(Body);
+}
+void
+WebSocketMessage::SetBody(CbObject&& Body)
+{
+ CbPackage Pkg;
+ Pkg.SetObject(Body);
+
+ SetBody(std::move(Pkg));
+}
+
+void
+WebSocketMessage::Save(BinaryWriter& Writer)
+{
+ Writer.Write(&m_Header, HeaderSize);
+
+ if (m_Body.has_value())
+ {
+ const CbObject& Obj = m_Body.value().GetObject();
+ MemoryView View = Obj.GetBuffer().GetView();
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
+ ZEN_ASSERT(ValidationResult == CbValidateError::None);
+
+ m_Body.value().Save(Writer);
+ }
+
+ if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
+ {
+ m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
+ }
+
+ m_Header.MessageSize = Writer.Size() - HeaderSize;
+
+ Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
+}
+
+bool
+WebSocketMessage::TryLoadHeader(MemoryView Memory)
+{
+ if (Memory.GetSize() < HeaderSize)
+ {
+ return false;
+ }
+
+ MutableMemoryView HeaderView(&m_Header, HeaderSize);
+
+ HeaderView.CopyFrom(Memory);
+
+ return m_Header.IsValid();
+}
+
+void
+WebSocketService::Configure(WebSocketServer& Server)
+{
+ ZEN_ASSERT(m_SocketServer == nullptr);
+
+ m_SocketServer = &Server;
+
+ RegisterHandlers(Server);
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete)
+{
+ WebSocketMessage Message;
+
+ Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(CorrelationId);
+ Message.SetSocketId(SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ SocketServer().SendResponse(std::move(Message));
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete)
+{
+ CbPackage Response;
+ Response.SetObject(std::move(StreamResponse));
+
+ SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete);
+}
+
+std::unique_ptr<WebSocketServer>
+WebSocketServer::Create(const WebSocketServerOptions& Options)
+{
+ return std::make_unique<websocket::WsServer>(Options);
+}
+
+std::shared_ptr<WebSocketClient>
+WebSocketClient::Create(asio::io_context& IoCtx)
+{
+ return std::make_shared<websocket::WsClient>(IoCtx);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
new file mode 100644
index 000000000..b0dbdbc79
--- /dev/null
+++ b/src/zenhttp/xmake.lua
@@ -0,0 +1,14 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zenhttp')
+ set_kind("static")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_files("httpsys.cpp", {unity_ignored=true})
+ add_includedirs("include", {public=true})
+ add_deps("zencore")
+ add_packages(
+ "vcpkg::gsl-lite",
+ "vcpkg::http-parser"
+ )
+ add_options("httpsys")
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
new file mode 100644
index 000000000..4bd6a5697
--- /dev/null
+++ b/src/zenhttp/zenhttp.cpp
@@ -0,0 +1,22 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/zenhttp.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zenhttp/httpclient.h>
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpshared.h>
+
+namespace zen {
+
+void
+zenhttp_forcelinktests()
+{
+ http_forcelink();
+ forcelink_httpshared();
+}
+
+} // namespace zen
+
+#endif