diff options
Diffstat (limited to 'src/zenhttp')
| -rw-r--r-- | src/zenhttp/httpasio.cpp | 1372 | ||||
| -rw-r--r-- | src/zenhttp/httpasio.h | 36 | ||||
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 176 | ||||
| -rw-r--r-- | src/zenhttp/httpnull.cpp | 83 | ||||
| -rw-r--r-- | src/zenhttp/httpnull.h | 29 | ||||
| -rw-r--r-- | src/zenhttp/httpserver.cpp | 885 | ||||
| -rw-r--r-- | src/zenhttp/httpshared.cpp | 809 | ||||
| -rw-r--r-- | src/zenhttp/httpsys.cpp | 1674 | ||||
| -rw-r--r-- | src/zenhttp/httpsys.h | 90 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpclient.h | 47 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpcommon.h | 181 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpserver.h | 315 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpshared.h | 163 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/websocket.h | 256 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/zenhttp.h | 21 | ||||
| -rw-r--r-- | src/zenhttp/iothreadpool.cpp | 49 | ||||
| -rw-r--r-- | src/zenhttp/iothreadpool.h | 37 | ||||
| -rw-r--r-- | src/zenhttp/websocketasio.cpp | 1613 | ||||
| -rw-r--r-- | src/zenhttp/xmake.lua | 14 | ||||
| -rw-r--r-- | src/zenhttp/zenhttp.cpp | 22 |
20 files changed, 7872 insertions, 0 deletions
diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp new file mode 100644 index 000000000..79b2c0a3d --- /dev/null +++ b/src/zenhttp/httpasio.cpp @@ -0,0 +1,1372 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpasio.h" + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +#include <deque> +#include <memory> +#include <string_view> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# include <mstcpip.h> +#endif +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#define ASIO_VERBOSE_TRACE 0 + +#if ASIO_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +#else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +#endif + +namespace zen::asio_http { + +using namespace std::literals; + +struct HttpAcceptor; +struct HttpRequest; +struct HttpResponse; +struct HttpServerConnection; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +inline spdlog::logger& +InitLogger() +{ + spdlog::logger& Logger = logging::Get("asio"); + // Logger.set_level(spdlog::level::trace); + return Logger; +} + +inline spdlog::logger& +Log() +{ + static spdlog::logger& g_Logger = InitLogger(); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAsioServerImpl +{ +public: + HttpAsioServerImpl(); + ~HttpAsioServerImpl(); + + int Start(uint16_t Port, int ThreadCount); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpAsioServerRequest : public HttpServerRequest +{ +public: + HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpAsioServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; + HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; + + asio_http::HttpRequest& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpResponse> m_Response; +}; + +struct HttpRequest +{ + explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {} + + void Initialize(); + size_t ConsumeData(const char* InputData, size_t DataSize); + void ResetState(); + + HttpVerb RequestVerb() const { return m_RequestVerb; } + bool IsKeepAlive() const { return m_KeepAlive; } + std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } + std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } + IoBuffer Body() { return m_BodyBuffer; } + + inline HttpContentType ContentType() + { + if (m_ContentTypeHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); + } + + inline HttpContentType AcceptType() + { + if (m_AcceptHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); + } + + Oid SessionId() const { return m_SessionId; } + int RequestId() const { return m_RequestId; } + + std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } + +private: + struct HeaderEntry + { + HeaderEntry() = default; + + HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} + + std::string_view Name; + std::string_view Value; + }; + + HttpServerConnection& m_Connection; + char* m_HeaderCursor = m_HeaderBuffer; + char* m_Url = nullptr; + size_t m_UrlLength = 0; + char* m_QueryString = nullptr; + size_t m_QueryLength = 0; + char* m_CurrentHeaderName = nullptr; // Used while parsing headers + size_t m_CurrentHeaderNameLength = 0; + char* m_CurrentHeaderValue = nullptr; // Used while parsing headers + size_t m_CurrentHeaderValueLength = 0; + std::vector<HeaderEntry> m_Headers; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int8_t m_RangeHeaderIndex; + HttpVerb m_RequestVerb; + bool m_KeepAlive = false; + bool m_Expect100Continue = false; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition = 0; + http_parser m_Parser; + char m_HeaderBuffer[1024]; + std::string m_NormalizedUrl; + + void AppendCurrentHeader(); + + int OnMessageBegin(); + int OnUrl(const char* Data, size_t Bytes); + int OnHeader(const char* Data, size_t Bytes); + int OnHeaderValue(const char* Data, size_t Bytes); + int OnHeadersComplete(); + int OnBody(const char* Data, size_t Bytes); + int OnMessageComplete(); + + static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); } + static http_parser_settings s_ParserSettings; +}; + +struct HttpResponse +{ +public: + HttpResponse() = default; + explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) + { + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { +#if 1 + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); +#else + IoBuffer TempBuffer = std::move(Buffer); + TempBuffer.MakeOwned(); + m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); +#endif + } + + uint64_t LocalDataSize = 0; + + m_AsioBuffers.push_back({}); // Placeholder for header + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + // + // this looks like it requires some custom asio plumbing however + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + else + { + // Send from memory + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + } + m_ContentLength = LocalDataSize; + + auto Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + } + + uint16_t ResponseCode() const { return m_ResponseCode; } + uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; } + + std::string_view GetHeaders() + { + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; + } + + void SuppressPayload() { m_AsioBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_DataBuffers; + std::vector<asio::const_buffer> m_AsioBuffers; + ExtendableStringBuilder<160> m_Headers; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection> +{ + HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); + ~HttpServerConnection(); + + void HandleNewRequest(); + void TerminateConnection(); + void HandleRequest(); + + std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, + kWritingFinal, + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequest m_RequestData{*this}; + + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); + void OnError(); + + HttpAsioServerImpl& m_Server; + asio::streambuf m_RequestBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::atomic<uint32_t> m_RequestCounter{0}; + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + RwLock m_ResponsesLock; + std::deque<std::unique_ptr<HttpResponse>> m_Responses; +}; + +std::atomic<uint32_t> g_ConnectionIdCounter{0}; + +HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +: m_Server(Server) +, m_Socket(std::move(Socket)) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) +{ + ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); +} + +HttpServerConnection::~HttpServerConnection() +{ + ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); +} + +void +HttpServerConnection::HandleNewRequest() +{ + m_RequestData.Initialize(); + + EnqueueRead(); +} + +void +HttpServerConnection::TerminateConnection() +{ + m_RequestState = RequestState::kTerminated; + + std::error_code Ec; + m_Socket->close(Ec); +} + +void +HttpServerConnection::EnqueueRead() +{ + if (m_RequestState == RequestState::kInitialRead) + { + m_RequestState = RequestState::kReadingMore; + } + else + { + m_RequestState = RequestState::kInitialRead; + } + + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read(*m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); +} + +void +HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + { + ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message()); + return; + } + else + { + ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); + return OnError(); + } + } + + ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + + size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); + if (Result == ~0ull) + { + return OnError(); + } + + m_RequestBuffer.consume(Result); + } + + switch (m_RequestState) + { + case RequestState::kDone: + case RequestState::kWritingFinal: + case RequestState::kTerminated: + break; + + default: + EnqueueRead(); + break; + } +} + +void +HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) +{ + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); + OnError(); + } + else + { + ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kDone; + + m_Socket->close(); + } + else + { + if (Pop) + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.pop_front(); + } + + m_RequestCounter.fetch_add(1); + } + } +} + +void +HttpServerConnection::OnError() +{ + m_Socket->close(); +} + +void +HttpServerConnection::HandleRequest() +{ + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kWritingFinal; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + } + else + { + m_RequestState = RequestState::kWriting; + } + + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); + + ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + auto ResponseBuffers = Response->AsioBuffers(); + + uint64_t ResponseLength = 0; + + for (auto& Buffer : ResponseBuffers) + { + ResponseLength += Buffer.size(); + } + + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.push_back(std::move(Response)); + } + + // TODO: should cork/uncork for Linux? + + asio::async_write(*m_Socket.get(), + ResponseBuffers, + asio::transfer_exactly(ResponseLength), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, true); + }); + + return; + } + } + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } + else + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpRequest +// + +http_parser_settings HttpRequest::s_ParserSettings{ + .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, + .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, + .on_status = + [](http_parser* p, const char* Data, size_t ByteCount) { + ZEN_UNUSED(p, Data, ByteCount); + return 0; + }, + .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, + .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; + +void +HttpRequest::Initialize() +{ + http_parser_init(&m_Parser, HTTP_REQUEST); + m_Parser.data = this; + + ResetState(); +} + +size_t +HttpRequest::ConsumeData(const char* InputData, size_t DataSize) +{ + const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); + + http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); + + if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + { + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); + return ~0ull; + } + + return ConsumedBytes; +} + +int +HttpRequest::OnUrl(const char* Data, size_t Bytes) +{ + if (!m_Url) + { + ZEN_ASSERT_SLOW(m_UrlLength == 0); + m_Url = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_UrlLength += Bytes; + + return 0; +} + +int +HttpRequest::OnHeader(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = m_HeaderCursor; + } + else if (m_CurrentHeaderName == nullptr) + { + m_CurrentHeaderName = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderNameLength += Bytes; + + return 0; +} + +void +HttpRequest::AppendCurrentHeader() +{ + std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); + std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); + + const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + + if (HeaderHash == HashContentLength) + { + m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashAccept) + { + m_AcceptHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashContentType) + { + m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashSession) + { + m_SessionId = Oid::FromHexString(HeaderValue); + } + else if (HeaderHash == HashRequest) + { + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + } + else if (HeaderHash == HashExpect) + { + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + } + else if (HeaderHash == HashRange) + { + m_RangeHeaderIndex = (int8_t)m_Headers.size(); + } + + m_Headers.emplace_back(HeaderName, HeaderValue); +} + +int +HttpRequest::OnHeaderValue(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength == 0) + { + m_CurrentHeaderValue = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderValueLength += Bytes; + + return 0; +} + +static void +NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) +{ + bool LastCharWasSeparator = false; + for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) + { + const char UrlChar = Url[UrlIndex]; + const bool IsSeparator = (UrlChar == '/'); + + if (IsSeparator && LastCharWasSeparator) + { + if (NormalizedUrl.empty()) + { + NormalizedUrl.reserve(UrlLength); + NormalizedUrl.append(Url, UrlIndex); + } + + if (!LastCharWasSeparator) + { + NormalizedUrl.push_back('/'); + } + } + else if (!NormalizedUrl.empty()) + { + NormalizedUrl.push_back(UrlChar); + } + + LastCharWasSeparator = IsSeparator; + } +} + +int +HttpRequest::OnHeadersComplete() +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + } + + if (m_ContentLengthHeaderIndex >= 0) + { + std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; + uint64_t ContentLength = 0; + std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); + + if (ContentLength) + { + m_BodyBuffer = IoBuffer(ContentLength); + } + + m_BodyBuffer.SetContentType(ContentType()); + + m_BodyPosition = 0; + } + + m_KeepAlive = !!http_should_keep_alive(&m_Parser); + + switch (m_Parser.method) + { + case HTTP_GET: + m_RequestVerb = HttpVerb::kGet; + break; + + case HTTP_POST: + m_RequestVerb = HttpVerb::kPost; + break; + + case HTTP_PUT: + m_RequestVerb = HttpVerb::kPut; + break; + + case HTTP_DELETE: + m_RequestVerb = HttpVerb::kDelete; + break; + + case HTTP_HEAD: + m_RequestVerb = HttpVerb::kHead; + break; + + case HTTP_COPY: + m_RequestVerb = HttpVerb::kCopy; + break; + + case HTTP_OPTIONS: + m_RequestVerb = HttpVerb::kOptions; + break; + + default: + ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + break; + } + + std::string_view Url(m_Url, m_UrlLength); + + if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) + { + m_UrlLength = QuerySplit; + m_QueryString = m_Url + QuerySplit + 1; + m_QueryLength = Url.size() - QuerySplit - 1; + } + + NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); + + return 0; +} + +int +HttpRequest::OnBody(const char* Data, size_t Bytes) +{ + if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + { + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); + return 1; + } + memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); + m_BodyPosition += Bytes; + + if (http_body_is_final(&m_Parser)) + { + if (m_BodyPosition != m_BodyBuffer.Size()) + { + ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + return 1; + } + } + + return 0; +} + +void +HttpRequest::ResetState() +{ + m_HeaderCursor = m_HeaderBuffer; + m_CurrentHeaderName = nullptr; + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValue = nullptr; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = nullptr; + m_Url = nullptr; + m_UrlLength = 0; + m_QueryString = nullptr; + m_QueryLength = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; + m_Headers.clear(); + m_NormalizedUrl.clear(); +} + +int +HttpRequest::OnMessageBegin() +{ + return 0; +} + +int +HttpRequest::OnMessageComplete() +{ + m_Connection.HandleRequest(); + + ResetState(); + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService, asio::ip::tcp::v6()) + { + m_Acceptor.set_option(asio::ip::v6_only(false)); +#if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address; + m_Acceptor.set_option(excluse_address(true)); +#else // ZEN_PLATFORM_WINDOWS + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +#endif // ZEN_PLATFORM_WINDOWS + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + m_Acceptor.listen(); + + ZEN_INFO("Started asio server at port '{}'", EffectivePort); + } + + void Start() + { + m_Acceptor.listen(); + InitAccept(); + } + + void Stop() { m_IsStopped = true; } + + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + m_Acceptor.close(); + } + }); + } + + int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic<bool> m_IsStopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpAsioServerRequest::~HttpAsioServerRequest() +{ +} + +Oid +HttpAsioServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpAsioServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpAsioServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerImpl::HttpAsioServerImpl() +{ +} + +HttpAsioServerImpl::~HttpAsioServerImpl() +{ +} + +int +HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) +{ + ZEN_ASSERT(ThreadCount > 0); + + ZEN_INFO("starting asio http with {} service threads", ThreadCount); + + m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); + m_Acceptor->Start(); + + for (int i = 0; i < ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, Index = i + 1] { + SetCurrentThreadName(fmt::format("asio worker {}", Index)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); + } + }); + } + + ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); + + return m_Acceptor->GetAcceptPort(); +} + +void +HttpAsioServerImpl::Stop() +{ + m_Acceptor->Stop(); + m_IoService.stop(); + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +void +HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpAsioServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +} // namespace zen::asio_http + +////////////////////////////////////////////////////////////////////////// + +namespace zen { +HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) +{ + ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest)); +} + +HttpAsioServer::~HttpAsioServer() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); + } +} + +void +HttpAsioServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +int +HttpAsioServer::Initialize(int BasePort) +{ + m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u)); + return m_BasePort; +} + +void +HttpAsioServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpAsioServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h new file mode 100644 index 000000000..716145955 --- /dev/null +++ b/src/zenhttp/httpasio.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +#include <memory> + +namespace zen { + +namespace asio_http { + struct HttpServerConnection; + struct HttpAcceptor; + struct HttpAsioServerImpl; +} // namespace asio_http + +class HttpAsioServer : public HttpServer +{ +public: + HttpAsioServer(); + ~HttpAsioServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; + + std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp new file mode 100644 index 000000000..e6813d407 --- /dev/null +++ b/src/zenhttp/httpclient.cpp @@ -0,0 +1,176 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclient.h> +#include <zenhttp/httpserver.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/sharedbuffer.h> +#include <zencore/stream.h> +#include <zencore/testing.h> +#include <zenhttp/httpshared.h> + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +namespace zen { + +using namespace std::literals; + +HttpClient::Response +FromCprResponse(cpr::Response& InResponse) +{ + return {.StatusCode = int(InResponse.status_code)}; +} + +////////////////////////////////////////////////////////////////////////// + +HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri) +{ + StringBuilder<32> SessionId; + GetSessionId().ToString(SessionId); + m_SessionId = SessionId; +} + +HttpClient::~HttpClient() +{ +} + +HttpClient::Response +HttpClient::TransactPackage(std::string_view Url, CbPackage Package) +{ + cpr::Session Sess; + Sess.SetUrl(m_BaseUri + std::string(Url)); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + IoHash Hash = Attachment.GetHash(); + + Writer.AddHash(Hash); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + Sess.SetHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); + Sess.SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer); + + for (auto& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess.SetHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); + Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return FromCprResponse(FilterResponse); + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = int(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; +} + +HttpClient::Response +HttpClient::Put(std::string_view Url, IoBuffer Payload) +{ + ZEN_UNUSED(Url); + ZEN_UNUSED(Payload); + return {}; +} + +HttpClient::Response +HttpClient::Get(std::string_view Url) +{ + ZEN_UNUSED(Url); + return {}; +} + +HttpClient::Response +HttpClient::Delete(std::string_view Url) +{ + ZEN_UNUSED(Url); + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("httpclient") +{ + using namespace std::literals; + + SUBCASE("client") {} +} + +void +httpclient_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpnull.cpp b/src/zenhttp/httpnull.cpp new file mode 100644 index 000000000..a6e1d3567 --- /dev/null +++ b/src/zenhttp/httpnull.cpp @@ -0,0 +1,83 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpnull.h" + +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +#endif + +namespace zen { + +HttpNullServer::HttpNullServer() +{ +} + +HttpNullServer::~HttpNullServer() +{ +} + +void +HttpNullServer::RegisterService(HttpService& Service) +{ + ZEN_UNUSED(Service); +} + +int +HttpNullServer::Initialize(int BasePort) +{ + return BasePort; +} + +void +HttpNullServer::Run(bool IsInteractiveSession) +{ + const bool TestMode = !IsInteractiveSession; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpNullServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen diff --git a/src/zenhttp/httpnull.h b/src/zenhttp/httpnull.h new file mode 100644 index 000000000..74f021f6b --- /dev/null +++ b/src/zenhttp/httpnull.h @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +namespace zen { + +/** + * @brief Null implementation of "http" server. Does nothing + */ + +class HttpNullServer : public HttpServer +{ +public: + HttpNullServer(); + ~HttpNullServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + +private: + Event m_ShutdownEvent; +}; + +} // namespace zen diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp new file mode 100644 index 000000000..671cbd319 --- /dev/null +++ b/src/zenhttp/httpserver.cpp @@ -0,0 +1,885 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpserver.h> + +#include "httpasio.h" +#include "httpnull.h" +#include "httpsys.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/thread.h> +#include <zenhttp/httpshared.h> + +#include <charconv> +#include <mutex> +#include <span> +#include <string_view> + +namespace zen { + +using namespace std::literals; + +std::string_view +MapContentTypeToString(HttpContentType ContentType) +{ + switch (ContentType) + { + default: + case HttpContentType::kUnknownContentType: + case HttpContentType::kBinary: + return "application/octet-stream"sv; + + case HttpContentType::kText: + return "text/plain"sv; + + case HttpContentType::kJSON: + return "application/json"sv; + + case HttpContentType::kCbObject: + return "application/x-ue-cb"sv; + + case HttpContentType::kCbPackage: + return "application/x-ue-cbpkg"sv; + + case HttpContentType::kCbPackageOffer: + return "application/x-ue-offer"sv; + + case HttpContentType::kCompressedBinary: + return "application/x-ue-comp"sv; + + case HttpContentType::kYAML: + return "text/yaml"sv; + + case HttpContentType::kHTML: + return "text/html"sv; + + case HttpContentType::kJavaScript: + return "application/javascript"sv; + + case HttpContentType::kCSS: + return "text/css"sv; + + case HttpContentType::kPNG: + return "image/png"sv; + + case HttpContentType::kIcon: + return "image/x-icon"sv; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Note that in addition to MIME types we accept abbreviated versions, for +// use in suffix parsing as well as for convenience when using curl + +static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); +static constinit uint32_t HashJson = HashStringDjb2("json"sv); +static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv); +static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv); +static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv); +static constinit uint32_t HashText = HashStringDjb2("text/plain"sv); +static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv); +static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv); +static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); +static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv); +static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv); +static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv); +static constinit uint32_t HashHtml = HashStringDjb2("html"sv); +static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv); +static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv); +static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv); +static constinit uint32_t HashCss = HashStringDjb2("css"sv); +static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv); +static constinit uint32_t HashPng = HashStringDjb2("png"sv); +static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv); +static constinit uint32_t HashIcon = HashStringDjb2("ico"sv); +static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv); + +std::once_flag InitContentTypeLookup; + +struct HashedTypeEntry +{ + uint32_t Hash; + HttpContentType Type; +} TypeHashTable[] = { + // clang-format off + {HashBinary, HttpContentType::kBinary}, + {HashApplicationCompactBinary, HttpContentType::kCbObject}, + {HashCompactBinary, HttpContentType::kCbObject}, + {HashCompactBinaryPackage, HttpContentType::kCbPackage}, + {HashCompactBinaryPackageShort, HttpContentType::kCbPackage}, + {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer}, + {HashJson, HttpContentType::kJSON}, + {HashApplicationJson, HttpContentType::kJSON}, + {HashYaml, HttpContentType::kYAML}, + {HashTextYaml, HttpContentType::kYAML}, + {HashText, HttpContentType::kText}, + {HashCompressedBinary, HttpContentType::kCompressedBinary}, + {HashHtml, HttpContentType::kHTML}, + {HashTextHtml, HttpContentType::kHTML}, + {HashJavaScript, HttpContentType::kJavaScript}, + {HashApplicationJavaScript, HttpContentType::kJavaScript}, + {HashCss, HttpContentType::kCSS}, + {HashTextCss, HttpContentType::kCSS}, + {HashPng, HttpContentType::kPNG}, + {HashImagePng, HttpContentType::kPNG}, + {HashIcon, HttpContentType::kIcon}, + {HashImageIcon, HttpContentType::kIcon}, + // clang-format on +}; + +HttpContentType +ParseContentTypeImpl(const std::string_view& ContentTypeString) +{ + if (!ContentTypeString.empty()) + { + const uint32_t CtHash = HashStringDjb2(ContentTypeString); + + if (auto It = std::lower_bound(std::begin(TypeHashTable), + std::end(TypeHashTable), + CtHash, + [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; }); + It != std::end(TypeHashTable)) + { + if (It->Hash == CtHash) + { + return It->Type; + } + } + } + + return HttpContentType::kUnknownContentType; +} + +HttpContentType +ParseContentTypeInit(const std::string_view& ContentTypeString) +{ + std::call_once(InitContentTypeLookup, [] { + std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) { + return Lhs.Hash < Rhs.Hash; + }); + + // validate that there are no hash collisions + + uint32_t LastHash = 0; + + for (const auto& Item : TypeHashTable) + { + ZEN_ASSERT(LastHash != Item.Hash); + LastHash = Item.Hash; + } + }); + + ParseContentType = ParseContentTypeImpl; + + return ParseContentTypeImpl(ContentTypeString); +} + +HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit; + +bool +TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) +{ + if (RangeHeader.empty()) + { + return false; + } + + const size_t Count = Ranges.size(); + + std::size_t UnitDelim = RangeHeader.find_first_of('='); + if (UnitDelim == std::string_view::npos) + { + return false; + } + + // only bytes for now + std::string_view Unit = RangeHeader.substr(0, UnitDelim); + if (Unit != "bytes"sv) + { + return false; + } + + std::string_view Tokens = RangeHeader.substr(UnitDelim); + while (!Tokens.empty()) + { + // Skip =, + Tokens = Tokens.substr(1); + + size_t Delim = Tokens.find_first_of(','); + if (Delim == std::string_view::npos) + { + Delim = Tokens.length(); + } + + std::string_view Token = Tokens.substr(0, Delim); + Tokens = Tokens.substr(Delim); + + Delim = Token.find_first_of('-'); + if (Delim == std::string_view::npos) + { + return false; + } + + const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim)); + const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1)); + + if (Start.has_value() && End.has_value() && End.value() > Start.value()) + { + Ranges.push_back({.Start = Start.value(), .End = End.value()}); + } + else if (Start) + { + Ranges.push_back({.Start = Start.value()}); + } + else if (End) + { + Ranges.push_back({.End = End.value()}); + } + } + + return Count != Ranges.size(); +} + +////////////////////////////////////////////////////////////////////////// + +const std::string_view +ToString(HttpVerb Verb) +{ + switch (Verb) + { + case HttpVerb::kGet: + return "GET"sv; + case HttpVerb::kPut: + return "PUT"sv; + case HttpVerb::kPost: + return "POST"sv; + case HttpVerb::kDelete: + return "DELETE"sv; + case HttpVerb::kHead: + return "HEAD"sv; + case HttpVerb::kCopy: + return "COPY"sv; + case HttpVerb::kOptions: + return "OPTIONS"sv; + default: + return "???"sv; + } +} + +std::string_view +ReasonStringForHttpResultCode(int HttpCode) +{ + switch (HttpCode) + { + // 1xx Informational + + case 100: + return "Continue"sv; + case 101: + return "Switching Protocols"sv; + + // 2xx Success + + case 200: + return "OK"sv; + case 201: + return "Created"sv; + case 202: + return "Accepted"sv; + case 204: + return "No Content"sv; + case 205: + return "Reset Content"sv; + case 206: + return "Partial Content"sv; + + // 3xx Redirection + + case 300: + return "Multiple Choices"sv; + case 301: + return "Moved Permanently"sv; + case 302: + return "Found"sv; + case 303: + return "See Other"sv; + case 304: + return "Not Modified"sv; + case 305: + return "Use Proxy"sv; + case 306: + return "Switch Proxy"sv; + case 307: + return "Temporary Redirect"sv; + case 308: + return "Permanent Redirect"sv; + + // 4xx Client errors + + case 400: + return "Bad Request"sv; + case 401: + return "Unauthorized"sv; + case 402: + return "Payment Required"sv; + case 403: + return "Forbidden"sv; + case 404: + return "Not Found"sv; + case 405: + return "Method Not Allowed"sv; + case 406: + return "Not Acceptable"sv; + case 407: + return "Proxy Authentication Required"sv; + case 408: + return "Request Timeout"sv; + case 409: + return "Conflict"sv; + case 410: + return "Gone"sv; + case 411: + return "Length Required"sv; + case 412: + return "Precondition Failed"sv; + case 413: + return "Payload Too Large"sv; + case 414: + return "URI Too Long"sv; + case 415: + return "Unsupported Media Type"sv; + case 416: + return "Range Not Satisifiable"sv; + case 417: + return "Expectation Failed"sv; + case 418: + return "I'm a teapot"sv; + case 421: + return "Misdirected Request"sv; + case 422: + return "Unprocessable Entity"sv; + case 423: + return "Locked"sv; + case 424: + return "Failed Dependency"sv; + case 425: + return "Too Early"sv; + case 426: + return "Upgrade Required"sv; + case 428: + return "Precondition Required"sv; + case 429: + return "Too Many Requests"sv; + case 431: + return "Request Header Fields Too Large"sv; + + // 5xx Server errors + + case 500: + return "Internal Server Error"sv; + case 501: + return "Not Implemented"sv; + case 502: + return "Bad Gateway"sv; + case 503: + return "Service Unavailable"sv; + case 504: + return "Gateway Timeout"sv; + case 505: + return "HTTP Version Not Supported"sv; + case 506: + return "Variant Also Negotiates"sv; + case 507: + return "Insufficient Storage"sv; + case 508: + return "Loop Detected"sv; + case 510: + return "Not Extended"sv; + case 511: + return "Network Authentication Required"sv; + + default: + return "Unknown Result"sv; + } +} + +////////////////////////////////////////////////////////////////////////// + +Ref<IHttpPackageHandler> +HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) +{ + ZEN_UNUSED(HttpServiceRequest); + + return Ref<IHttpPackageHandler>(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpServerRequest::HttpServerRequest() +{ +} + +HttpServerRequest::~HttpServerRequest() +{ +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data) +{ + std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data); + return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data) +{ + if (m_AcceptType == HttpContentType::kJSON) + { + ExtendableStringBuilder<1024> Sb; + WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView()); + } + else + { + SharedBuffer Buf = Data.GetBuffer(); + std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; + return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); + } +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array) +{ + if (m_AcceptType == HttpContentType::kJSON) + { + ExtendableStringBuilder<1024> Sb; + WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView()); + } + else + { + SharedBuffer Buf = Array.GetBuffer(); + std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; + return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); + } +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString) +{ + return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()}); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob) +{ + std::array<IoBuffer, 1> Buffers{Blob}; + return WriteResponse(ResponseCode, ContentType, Buffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) +{ + std::span<const SharedBuffer> Segments = Payload.GetSegments(); + + std::vector<IoBuffer> Buffers; + + for (auto& Segment : Segments) + { + Buffers.push_back(Segment.AsIoBuffer()); + } + + WriteResponse(ResponseCode, ContentType, Buffers); +} + +HttpServerRequest::QueryParams +HttpServerRequest::GetQueryParams() +{ + QueryParams Params; + + const std::string_view QStr = QueryString(); + + const char* QueryIt = QStr.data(); + const char* QueryEnd = QueryIt + QStr.size(); + + while (QueryIt != QueryEnd) + { + if (*QueryIt == '&') + { + ++QueryIt; + continue; + } + + size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt); + const std::string_view Query{QueryIt, QueryLen}; + + size_t DelimIndex = Query.find('&', 0); + + if (DelimIndex == std::string_view::npos) + { + DelimIndex = Query.size(); + } + + std::string_view ThisQuery{QueryIt, DelimIndex}; + + size_t EqIndex = ThisQuery.find('=', 0); + + if (EqIndex != std::string_view::npos) + { + std::string_view Param{ThisQuery.data(), EqIndex}; + ThisQuery.remove_prefix(EqIndex + 1); + + Params.KvPairs.emplace_back(Param, ThisQuery); + } + + QueryIt += DelimIndex; + } + + return Params; +} + +Oid +HttpServerRequest::SessionId() const +{ + if (m_Flags & kHaveSessionId) + { + return m_SessionId; + } + + m_SessionId = ParseSessionId(); + m_Flags |= kHaveSessionId; + return m_SessionId; +} + +uint32_t +HttpServerRequest::RequestId() const +{ + if (m_Flags & kHaveRequestId) + { + return m_RequestId; + } + + m_RequestId = ParseRequestId(); + m_Flags |= kHaveRequestId; + return m_RequestId; +} + +CbObject +HttpServerRequest::ReadPayloadObject() +{ + if (IoBuffer Payload = ReadPayload()) + { + return LoadCompactBinaryObject(std::move(Payload)); + } + + return {}; +} + +CbPackage +HttpServerRequest::ReadPayloadPackage() +{ + if (IoBuffer Payload = ReadPayload()) + { + return ParsePackageMessage(std::move(Payload)); + } + + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpRequestRouter::AddPattern(const char* Id, const char* Regex) +{ + ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); + + m_PatternMap.insert({Id, Regex}); +} + +void +HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) +{ + ExtendableStringBuilder<128> ExpandedRegex; + ProcessRegexSubstitutions(Regex, ExpandedRegex); + + m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); +} + +void +HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) +{ + size_t RegexLen = strlen(Regex); + + for (size_t i = 0; i < RegexLen;) + { + bool matched = false; + + if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) + { + // Might have a pattern reference - find closing brace + + for (size_t j = i + 1; j < RegexLen; ++j) + { + if (Regex[j] == '}') + { + std::string Pattern(&Regex[i + 1], j - i - 1); + + if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) + { + OutExpandedRegex.Append(it->second.c_str()); + } + else + { + // Default to anything goes (or should this just be an error?) + + OutExpandedRegex.Append("(.+?)"); + } + + // skip ahead + i = j + 1; + + matched = true; + + break; + } + } + } + + if (!matched) + { + OutExpandedRegex.Append(Regex[i++]); + } + } +} + +bool +HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) +{ + const HttpVerb Verb = Request.RequestVerb(); + + std::string_view Uri = Request.RelativeUri(); + HttpRouterRequest RouterRequest(Request); + + for (const auto& Handler : m_Handlers) + { + if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) + { + Handler.Handler(RouterRequest); + + return true; // Route matched + } + } + + return false; // No route matched +} + +////////////////////////////////////////////////////////////////////////// + +HttpRpcHandler::HttpRpcHandler() +{ +} + +HttpRpcHandler::~HttpRpcHandler() +{ +} + +void +HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction) +{ + ZEN_UNUSED(RpcId, HandlerFunction); +} + +////////////////////////////////////////////////////////////////////////// + +enum class HttpServerClass +{ + kHttpAsio, + kHttpSys, + kHttpNull +}; + +// Implemented in httpsys.cpp +Ref<HttpServer> CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads); + +Ref<HttpServer> +CreateHttpServer(std::string_view ServerClass) +{ + using namespace std::literals; + + HttpServerClass Class = HttpServerClass::kHttpNull; + +#if ZEN_WITH_HTTPSYS + Class = HttpServerClass::kHttpSys; +#elif 1 + Class = HttpServerClass::kHttpAsio; +#endif + + if (ServerClass == "asio"sv) + { + Class = HttpServerClass::kHttpAsio; + } + else if (ServerClass == "httpsys"sv) + { + Class = HttpServerClass::kHttpSys; + } + else if (ServerClass == "null"sv) + { + Class = HttpServerClass::kHttpNull; + } + + switch (Class) + { + default: + case HttpServerClass::kHttpAsio: + ZEN_INFO("using asio HTTP server implementation"); + return Ref<HttpServer>(new HttpAsioServer()); + +#if ZEN_WITH_HTTPSYS + case HttpServerClass::kHttpSys: + ZEN_INFO("using http.sys server implementation"); + return Ref<HttpServer>(new HttpSysServer(std::thread::hardware_concurrency(), /* background worker threads */ 16)); +#endif + + case HttpServerClass::kHttpNull: + ZEN_INFO("using null HTTP server implementation"); + return Ref<HttpServer>(new HttpNullServer); + } +} + +////////////////////////////////////////////////////////////////////////// + +bool +HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef) +{ + if (Request.RequestVerb() == HttpVerb::kPost) + { + if (Request.RequestContentType() == HttpContentType::kCbPackageOffer) + { + // The client is presenting us with a package attachments offer, we need + // to filter it down to the list of attachments we need them to send in + // the follow-up request + + PackageHandlerRef = Service.HandlePackageRequest(Request); + + if (PackageHandlerRef) + { + CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload()); + + std::vector<IoHash> OfferCids; + + for (auto& CidEntry : OfferMessage["offer"]) + { + if (!CidEntry.IsHash()) + { + // Should yield bad request response? + + ZEN_WARN("found invalid entry in offer"); + + continue; + } + + OfferCids.push_back(CidEntry.AsHash()); + } + + ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size()); + + PackageHandlerRef->FilterOffer(OfferCids); + + ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size()); + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + + for (const IoHash& Cid : OfferCids) + { + ResponseWriter.AddHash(Cid); + } + + ResponseWriter.EndArray(); + + // Emit filter response + Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + return true; + } + } + else if (Request.RequestContentType() == HttpContentType::kCbPackage) + { + // Process chunks in package request + + PackageHandlerRef = Service.HandlePackageRequest(Request); + + // TODO: this should really be done in a streaming fashion, currently this emulates + // the intended flow from an API perspective + + if (PackageHandlerRef) + { + PackageHandlerRef->OnRequestBegin(); + + auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer { + return PackageHandlerRef->CreateTarget(Cid, Size); + }; + + CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); + + PackageHandlerRef->OnRequestComplete(); + } + } + } + return false; +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("http.common") +{ + using namespace std::literals; + + SUBCASE("router") + { + HttpRequestRouter r; + r.AddPattern("a", "[[:alpha:]]+"); + r.RegisterRoute( + "{a}", + [&](auto) {}, + HttpVerb::kGet); + + // struct TestHttpServerRequest : public HttpServerRequest + //{ + // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {} + //}; + + // TestHttpServerRequest req{}; + // r.HandleRequest(req); + } + + SUBCASE("content-type") + { + for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i) + { + HttpContentType Ct{i}; + + if (Ct != HttpContentType::kUnknownContentType) + { + CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct))); + } + } + } +} + +void +http_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpshared.cpp b/src/zenhttp/httpshared.cpp new file mode 100644 index 000000000..7aade56d2 --- /dev/null +++ b/src/zenhttp/httpshared.cpp @@ -0,0 +1,809 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpshared.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compositebuffer.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> + +#include <span> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +const std::string_view HandlePrefix(":?#:"); + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data, int TargetProcessPid) +{ + return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessPid); +} +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid) +{ + return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessPid); +} + +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid) +{ + std::vector<IoBuffer> Message = FormatPackageMessage(Data, Flags, TargetProcessPid); + + std::vector<SharedBuffer> Buffers; + + for (IoBuffer& Buf : Message) + { + Buffers.push_back(SharedBuffer(Buf)); + } + + return CompositeBuffer(std::move(Buffers)); +} + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid) +{ + void* TargetProcessHandle = nullptr; +#if ZEN_PLATFORM_WINDOWS + std::vector<HANDLE> DuplicatedHandles; + auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() { + if (TargetProcessHandle == nullptr) + { + return; + } + + for (HANDLE DuplicatedHandle : DuplicatedHandles) + { + HANDLE ClosingHandle; + if (::DuplicateHandle((HANDLE)TargetProcessHandle, + DuplicatedHandle, + GetCurrentProcess(), + &ClosingHandle, + 0, + FALSE, + DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE) + { + ::CloseHandle(ClosingHandle); + } + } + ::CloseHandle((HANDLE)TargetProcessHandle); + TargetProcessHandle = nullptr; + }); + + if (EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && TargetProcessPid != 0) + { + TargetProcessHandle = OpenProcess(PROCESS_DUP_HANDLE, FALSE, TargetProcessPid); + } +#else + ZEN_UNUSED(TargetProcessPid); + void* DuplicatedHandles = nullptr; +#endif // ZEN_PLATFORM_WINDOWS + + const std::span<const CbAttachment>& Attachments = Data.GetAttachments(); + std::vector<IoBuffer> ResponseBuffers; + + ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each + // attachment is likely to consist of several buffers + + // Fixed size header + + CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())}; + + ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr)); + + // Attachment metadata array + + IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)}; + CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData()); + + ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata + + // Root object + + IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer(); + ResponseBuffers.push_back(RootIoBuffer); // Root object + + *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()}; + + // Attachment payloads + + auto MarshalLocal = [&AttachmentInfo, &ResponseBuffers](const std::string& Path8, + CbAttachmentReferenceHeader& LocalRef, + const IoHash& AttachmentHash, + bool IsCompressed) { + IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size()); + + CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>(); + *RefHdr++ = LocalRef; + memcpy(RefHdr, Path8.data(), Path8.size()); + + *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(), + .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef, + .AttachmentHash = AttachmentHash}; + + ResponseBuffers.push_back(std::move(RefBuffer)); + }; + + tsl::robin_map<void*, std::string> FileNameMap; + + auto IsLocalRef = [&FileNameMap, &DuplicatedHandles](const CompositeBuffer& AttachmentBinary, + bool DenyPartialLocalReferences, + void* TargetProcessHandle, + CbAttachmentReferenceHeader& LocalRef, + std::string& Path8) -> bool { + const SharedBuffer& Segment = AttachmentBinary.GetSegments().front(); + IoBufferFileReference Ref; + const IoBuffer& SegmentBuffer = Segment.AsIoBuffer(); + + if (!SegmentBuffer.GetFileReference(Ref)) + { + return false; + } + + if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile()) + { + return false; + } + + if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end()) + { + Path8 = It->second; + } + else + { + bool UseFilePath = true; +#if ZEN_PLATFORM_WINDOWS + if (TargetProcessHandle != nullptr) + { + HANDLE TargetHandle = INVALID_HANDLE_VALUE; + BOOL OK = ::DuplicateHandle(GetCurrentProcess(), + Ref.FileHandle, + (HANDLE)TargetProcessHandle, + &TargetHandle, + FILE_GENERIC_READ, + FALSE, + 0); + if (OK) + { + DuplicatedHandles.push_back(TargetHandle); + Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle)); + UseFilePath = false; + } + } +#else // ZEN_PLATFORM_WINDOWS + ZEN_UNUSED(TargetProcessHandle); + // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to + // deal with acceess rights etc. +#endif // ZEN_PLATFORM_WINDOWS + if (UseFilePath) + { + ExtendablePathBuilder<256> LocalRefFile; + LocalRefFile.Append(std::filesystem::absolute(PathFromHandle(Ref.FileHandle))); + Path8 = LocalRefFile.ToUtf8(); + } + FileNameMap.insert_or_assign(Ref.FileHandle, Path8); + } + + LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size()); + LocalRef.PayloadByteOffset = Ref.FileChunkOffset; + LocalRef.PayloadByteSize = Ref.FileChunkSize; + + return true; + }; + + for (const CbAttachment& Attachment : Attachments) + { + if (Attachment.IsNull()) + { + ZEN_NOT_IMPLEMENTED("Null attachments are not supported"); + } + else if (CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary()) + { + CompositeBuffer Compressed = AttachmentBuffer.GetCompressed(); + IoHash AttachmentHash = Attachment.GetHash(); + + // If the data is either not backed by a file, or there are multiple + // fragments then we cannot marshal it by local reference. We might + // want/need to extend this in the future to allow multiple chunk + // segments to be marshaled at once + + bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1); + bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); + CbAttachmentReferenceHeader LocalRef; + std::string Path8; + + if (MarshalByLocalRef) + { + MarshalByLocalRef = IsLocalRef(Compressed, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8); + } + + if (MarshalByLocalRef) + { + const bool IsCompressed = true; + bool IsHandle = false; +#if ZEN_PLATFORM_WINDOWS + IsHandle = Path8.starts_with(HandlePrefix); +#endif + MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed); + ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize()); + } + else + { + *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(), + .Flags = CbAttachmentEntry::kIsCompressed, + .AttachmentHash = AttachmentHash}; + + for (const SharedBuffer& Segment : Compressed.GetSegments()) + { + ResponseBuffers.push_back(Segment.AsIoBuffer()); + } + } + } + else if (CbObject AttachmentObject = Attachment.AsObject()) + { + IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer(); + ResponseBuffers.push_back(ObjIoBuffer); + + *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(), + .Flags = CbAttachmentEntry::kIsObject, + .AttachmentHash = Attachment.GetHash()}; + } + else if (CompositeBuffer AttachmentBinary = Attachment.AsCompositeBinary()) + { + IoHash AttachmentHash = Attachment.GetHash(); + bool MarshalByLocalRef = + EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1); + bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); + + CbAttachmentReferenceHeader LocalRef; + std::string Path8; + + if (MarshalByLocalRef) + { + MarshalByLocalRef = IsLocalRef(AttachmentBinary, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8); + } + + if (MarshalByLocalRef) + { + const bool IsCompressed = false; + bool IsHandle = false; +#if ZEN_PLATFORM_WINDOWS + IsHandle = Path8.starts_with(HandlePrefix); +#endif + MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed); + ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize()); + } + else + { + *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()}; + + for (const SharedBuffer& Segment : AttachmentBinary.GetSegments()) + { + ResponseBuffers.push_back(Segment.AsIoBuffer()); + } + } + } + else + { + ZEN_NOT_IMPLEMENTED("Unknown attachment kind"); + } + } + FileNameMap.clear(); +#if ZEN_PLATFORM_WINDOWS + DuplicatedHandles.clear(); +#endif // ZEN_PLATFORM_WINDOWS + + return ResponseBuffers; +} + +bool +IsPackageMessage(IoBuffer Payload) +{ + if (!Payload) + { + return false; + } + + BinaryReader Reader(Payload); + + CbPackageHeader Hdr; + Reader.Read(&Hdr, sizeof Hdr); + + if (Hdr.HeaderMagic != kCbPkgMagic) + { + return false; + } + + return true; +} + +CbPackage +ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +{ + if (!Payload) + { + return {}; + } + + BinaryReader Reader(Payload); + + CbPackageHeader Hdr; + Reader.Read(&Hdr, sizeof Hdr); + + if (Hdr.HeaderMagic != kCbPkgMagic) + { + throw std::runtime_error("invalid CbPackage header magic"); + } + + const uint32_t ChunkCount = Hdr.AttachmentCount + 1; + + std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]}; + + Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount); + + CbPackage Package; + + std::vector<CbAttachment> Attachments; + Attachments.reserve(ChunkCount); // Guessing here... + + tsl::robin_map<std::string, IoBuffer> PartialFileBuffers; + + // TODO: Throwing before this loop completes could result in leaking handles as we might not have picked up all the handles in the + // message + for (uint32_t i = 0; i < ChunkCount; ++i) + { + const CbAttachmentEntry& Entry = AttachmentEntries[i]; + const uint64_t AttachmentSize = Entry.PayloadSize; + + const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize); + Reader.Skip(AttachmentSize); + + if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + // Marshal local reference - a "pointer" to the chunk backing file + + ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + + ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + + IoBuffer FullFileBuffer; + + std::filesystem::path Path(Utf8ToWide(PathView)); + if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end()) + { + FullFileBuffer = It->second; + } + else + { + if (PathView.starts_with(HandlePrefix)) + { +#if ZEN_PLATFORM_WINDOWS + std::string_view HandleString(PathView.substr(HandlePrefix.length())); + std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString)); + if (HandleNumber.has_value()) + { + HANDLE FileHandle = HANDLE(HandleNumber.value()); + ULARGE_INTEGER liFileSize; + liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart); + if (liFileSize.LowPart != INVALID_FILE_SIZE) + { + FullFileBuffer = IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart)); + PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer); + } + } +#else // ZEN_PLATFORM_WINDOWS + // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes + // and to deal with acceess rights etc. + ZEN_ASSERT(false); +#endif // ZEN_PLATFORM_WINDOWS + } + else + { + FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; + } + } + + if (!FullFileBuffer) + { + // Unable to open chunk reference + throw std::runtime_error(fmt::format("unable to resolve chunk #{} at '{}' (offset {}, size {})", + i, + Path, + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + ? FullFileBuffer + : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); + + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference))); + if (!CompBuf) + { + throw std::runtime_error(fmt::format("invalid format for chunk #{} at '{}' (offset {}, size {})", + i, + Path, + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); + } + else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (i == 0) + { + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); + if (!CompBuf) + { + throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for CbObject", i)); + } + // First payload is always a compact binary object + Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf))); + } + else + { + ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported"); + } + } + else + { + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); + if (!CompBuf) + { + throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for attachment", i)); + } + Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); + } + } + else /* not compressed */ + { + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (i == 0) + { + Package.SetObject(LoadCompactBinaryObject(AttachmentBuffer)); + } + else + { + ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported"); + } + } + else + { + // Make a copy of the buffer so we attachements don't reference the entire payload + IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); + ZEN_ASSERT(AttachmentBufferCopy); + ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); + AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); + + CbAttachment Attachment(SharedBuffer{AttachmentBufferCopy}); + Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); + } + } + } + PartialFileBuffers.clear(); + + Package.AddAttachments(Attachments); + + return Package; +} + +bool +ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage) +{ + if (IsPackageMessage(Response)) + { + OutPackage = ParsePackageMessage(Response); + return true; + } + return OutPackage.TryLoad(Response); +} + +CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +{ +} + +CbPackageReader::~CbPackageReader() +{ +} + +void +CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer) +{ + m_CreateBuffer = CreateBuffer; +} + +uint64_t +CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) +{ + ZEN_ASSERT(m_CurrentState != State::kReadingBuffers); + + switch (m_CurrentState) + { + case State::kInitialState: + ZEN_ASSERT(Data == nullptr); + m_CurrentState = State::kReadingHeader; + return sizeof m_PackageHeader; + + case State::kReadingHeader: + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); + return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + + case State::kReadingAttachmentEntries: + ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); + memcpy(m_AttachmentEntries.data(), Data, DataBytes); + + for (CbAttachmentEntry& Entry : m_AttachmentEntries) + { + // This preallocates memory for payloads but note that for the local references + // the caller will need to handle the payload differently (i.e it's a + // CbAttachmentReferenceHeader not the actual payload) + + m_PayloadBuffers.push_back(IoBuffer{Entry.PayloadSize}); + } + + m_CurrentState = State::kReadingBuffers; + return 0; + + default: + ZEN_ASSERT(false); + return 0; + } +} + +IoBuffer +CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) +{ + // Marshal local reference - a "pointer" to the chunk backing file + + ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); + + ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + + std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + + std::filesystem::path Path{PathView}; + + IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); + + if (!ChunkReference) + { + // Unable to open chunk reference + + throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + + return ChunkReference; +}; + +void +CbPackageReader::Finalize() +{ + if (m_AttachmentEntries.empty()) + { + return; + } + + m_Attachments.reserve(m_AttachmentEntries.size() - 1); + + int CurrentAttachmentIndex = 0; + for (CbAttachmentEntry& Entry : m_AttachmentEntries) + { + IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + + if (CurrentAttachmentIndex == 0) + { + // Root object + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + m_RootObject = LoadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer)); + } + else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize); + if (RawHash == Entry.AttachmentHash) + { + m_RootObject = LoadCompactBinaryObject(Compressed); + } + } + else + { + m_RootObject = LoadCompactBinaryObject(std::move(AttachmentBuffer)); + } + } + else + { + throw std::runtime_error("missing or invalid root object"); + } + } + else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer); + + if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize); + if (RawHash == Entry.AttachmentHash) + { + m_Attachments.push_back(CbAttachment(Compressed, Entry.AttachmentHash)); + } + } + else + { + CompressedBuffer Compressed = + CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None); + m_Attachments.push_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash())); + } + } + + ++CurrentAttachmentIndex; + } +} + +/** + ______________________ _____________________________ + \__ ___/\_ _____// _____/\__ ___/ _____/ + | | | __)_ \_____ \ | | \_____ \ + | | | \/ \ | | / \ + |____| /_______ /_______ / |____| /_______ / + \/ \/ \/ + */ + +#if ZEN_WITH_TESTS + +TEST_CASE("CbPackage.Serialization") +{ + // Make a test package + + CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))}; + CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.LocalRef") +{ + ScopedTemporaryDirectory TempDir; + + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + { + IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd")); + IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh")); + + WriteFile(Path1, Buffer1); + WriteFile(Path2, Buffer2); + } + + // Make a test package + + IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1); + IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2); + + CbAttachment Attach1{SharedBuffer(FileBuffer1)}; + CbAttachment Attach2{SharedBuffer(FileBuffer2)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +void +forcelink_httpshared() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpsys.cpp b/src/zenhttp/httpsys.cpp new file mode 100644 index 000000000..c733d618d --- /dev/null +++ b/src/zenhttp/httpsys.cpp @@ -0,0 +1,1674 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsys.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/timer.h> +#include <zenhttp/httpshared.h> + +#if ZEN_WITH_HTTPSYS + +# include <conio.h> +# include <mstcpip.h> +# pragma comment(lib, "httpapi.lib") + +std::wstring +UTF8_to_UTF16(const char* InPtr) +{ + std::wstring OutString; + unsigned int Codepoint; + + while (*InPtr != 0) + { + unsigned char InChar = static_cast<unsigned char>(*InPtr); + + if (InChar <= 0x7f) + Codepoint = InChar; + else if (InChar <= 0xbf) + Codepoint = (Codepoint << 6) | (InChar & 0x3f); + else if (InChar <= 0xdf) + Codepoint = InChar & 0x1f; + else if (InChar <= 0xef) + Codepoint = InChar & 0x0f; + else + Codepoint = InChar & 0x07; + + ++InPtr; + + if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) + { + if (Codepoint > 0xffff) + { + OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10))); + OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff))); + } + else if (Codepoint < 0xd800 || Codepoint >= 0xe000) + { + OutString.append(1, static_cast<wchar_t>(Codepoint)); + } + } + } + + return OutString; +} + +namespace zen { + +using namespace std::literals; + +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + +////////////////////////////////////////////////////////////////////////// + +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +/** + * @brief Base class for any pending or active HTTP transactions + */ +class HttpSysRequestHandler +{ +public: + explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} + virtual ~HttpSysRequestHandler() = default; + + virtual void IssueRequest(std::error_code& ErrorCode) = 0; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; + HttpSysTransaction& Transaction() { return m_Transaction; } + + HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; + HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; + +private: + HttpSysTransaction& m_Transaction; +}; + +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer. + * + * It is also used to receive any entity body data relating to the request + * + */ +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } + + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpSysServerRequest() = default; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpSysServerRequest(const HttpSysServerRequest&) = delete; + HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; + + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; +}; + +/** HTTP transaction + + There will be an instance of this per pending and in-flight HTTP transaction + + */ +class HttpSysTransaction final +{ +public: + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); + + enum class Status + { + kDone, + kRequestPending + }; + + Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + + static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io); + + void IssueInitialRequest(std::error_code& ErrorCode); + bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); + + PTP_IO Iocp(); + HANDLE RequestQueueHandle(); + inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline HttpSysServer& Server() { return m_HttpServer; } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + + HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); + + HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } + +private: + OVERLAPPED m_HttpOverlapped{}; + HttpSysServer& m_HttpServer; + + // Tracks which handler is due to handle the next I/O completion event + HttpSysRequestHandler* m_CompletionHandler = nullptr; + RwLock m_CompletionMutex; + InitialRequestHandler m_InitialHttpHandler{*this}; + std::optional<HttpSysServerRequest> m_HandlerRequest; + Ref<IHttpPackageHandler> m_PackageHandler; +}; + +/** + * @brief HTTP request response I/O request handler + * + * Asynchronously streams out a response to an HTTP request via compound + * responses from memory or directly from file + */ + +class HttpMessageResponseRequest : public HttpSysRequestHandler +{ +public: + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> Blobs); + ~HttpMessageResponseRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + void SuppressResponseBody(); // typically used for HEAD requests + +private: + std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::vector<IoBuffer> m_DataBuffers; + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); +}; + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) +{ + std::array<IoBuffer, 0> EmptyBufferList; + + InitializeForPayload(ResponseCode, EmptyBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) +: HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> BlobList) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + InitializeForPayload(ResponseCode, BlobList); +} + +HttpMessageResponseRequest::~HttpMessageResponseRequest() +{ +} + +void +HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_HttpDataChunks.reserve(ChunkCount); + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + + // Initialize the full array up front + + uint64_t LocalDataSize = 0; + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // Use direct file transfer + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); + + while (BufferDataSize) + { + const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; + + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } + } + } + + m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); + m_TotalDataSize = LocalDataSize; + + if (m_TotalDataSize == 0 && ResponseCode == 200) + { + // Some HTTP clients really don't like empty responses unless a 204 response is sent + m_ResponseCode = uint16_t(HttpResponseCode::NoContent); + } + else + { + m_ResponseCode = ResponseCode; + } +} + +void +HttpMessageResponseRequest::SuppressResponseBody() +{ + m_RemainingChunkCount = 0; + m_HttpDataChunks.clear(); + m_DataBuffers.clear(); +} + +HttpSysRequestHandler* +HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + if (IoResult != NO_ERROR) + { + ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); + + // if one transmit failed there's really no need to go on + return nullptr; + } + + if (m_RemainingChunkCount == 0) + { + return nullptr; // All done + } + + return this; +} + +void +HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) +{ + HttpSysTransaction& Tx = Transaction(); + HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); + PTP_IO const Iocp = Tx.Iocp(); + + StartThreadpoolIo(Iocp); + + // Split payload into batches to play well with the underlying API + + const int MaxChunksPerCall = 9999; + + const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); + const int ThisRequestChunkOffset = m_NextDataChunkOffset; + + m_RemainingChunkCount -= ThisRequestChunkCount; + m_NextDataChunkOffset += ThisRequestChunkCount; + + /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? + + From the docs: + + This flag enables buffering of data in the kernel on a per-response basis. It should + be used by an application doing synchronous I/O, or by a an application doing + asynchronous I/O with no more than one send outstanding at a time. + + Applications using asynchronous I/O which may have more than one send outstanding at + a time should not use this flag. + + When this flag is set, it should be used consistently in calls to the + HttpSendHttpResponse function as well. + */ + + ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + + if (m_RemainingChunkCount) + { + // We need to make more calls to send the full amount of data + SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + } + + ULONG SendResult = 0; + + if (m_IsInitialResponse) + { + // Populate response structure + + HTTP_RESPONSE HttpResponse = {}; + + HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); + HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; + + // Server header + // + // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header + // + // This is controlled via a registry key 'DisableServerHeader', at: + // + // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters + // + // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether + // (only the latter appears to do anything in my testing, on Windows 10). + // + // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) + // + + PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; + ServerHeader->pRawValue = "Zen"; + ServerHeader->RawValueLength = (USHORT)3; + + // Content-length header + + char ContentLengthString[32]; + _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); + + PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; + ContentLengthHeader->pRawValue = ContentLengthString; + ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); + + // Content-type header + + PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; + + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); + + HttpResponse.StatusCode = m_ResponseCode; + HttpResponse.pReason = ReasonString.data(); + HttpResponse.ReasonLength = (USHORT)ReasonString.size(); + + // Cache policy + + HTTP_CACHE_POLICY CachePolicy; + + CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.SecondsToLive = 0; + + // Initial response API call + + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + &HttpResponse, + &CachePolicy, + NULL, + NULL, + 0, + Tx.Overlapped(), + NULL); + + m_IsInitialResponse = false; + } + else + { + // Subsequent response API calls + + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + (USHORT)ThisRequestChunkCount, // EntityChunkCount + &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); + } + + if (SendResult == NO_ERROR) + { + // Synchronous completion, but the completion event will still be posted to IOCP + + ErrorCode.clear(); + } + else if (SendResult == ERROR_IO_PENDING) + { + // Asynchronous completion, a completion notification will be posted to IOCP + + ErrorCode.clear(); + } + else + { + // An error occurred, no completion will be posted to IOCP + + CancelThreadpoolIo(Iocp); + + ZEN_WARN("failed to send HTTP response (error: '{}'), request URL: '{}', request id: {}", + GetSystemErrorAsString(SendResult), + HttpReq->pRawUrl, + HttpReq->RequestId); + + ErrorCode = MakeErrorCode(SendResult); + } +} + +/** HTTP completion handler for async work + + This is used to allow work to be taken off the request handler threads + and to support posting responses asynchronously. + */ + +class HttpAsyncWorkRequest : public HttpSysRequestHandler +{ +public: + HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response); + ~HttpAsyncWorkRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + +private: + struct AsyncWorkItem : public IWork + { + virtual void Execute() override; + + AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler) + : Tx(InTx) + , Handler(std::move(InHandler)) + { + } + + HttpSysTransaction& Tx; + std::function<void(HttpServerRequest&)> Handler; + }; + + Ref<AsyncWorkItem> m_WorkItem; +}; + +HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response) +: HttpSysRequestHandler(Tx) +{ + m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); +} + +HttpAsyncWorkRequest::~HttpAsyncWorkRequest() +{ +} + +void +HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) +{ + ErrorCode.clear(); + + Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); +} + +HttpSysRequestHandler* +HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // This ought to not be called since there should be no outstanding I/O request + // when this completion handler is active + + ZEN_UNUSED(IoResult, NumberOfBytesTransferred); + + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + + return this; +} + +void +HttpAsyncWorkRequest::AsyncWorkItem::Execute() +{ + try + { + HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); + + ThisRequest.m_NextCompletionHandler = nullptr; + + Handler(ThisRequest); + + // TODO: should Handler be destroyed at this point to ensure there + // are no outstanding references into state which could be + // deleted asynchronously as a result of issuing the response? + + if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) + { + return (void)Tx.IssueNextRequest(NextHandler); + } + else if (!ThisRequest.IsHandled()) + { + return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); + } + else + { + // "Handled" but no request handler? Shouldn't ever happen + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); + } + } + catch (std::exception& Ex) + { + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); + } +} + +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ + +HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount) +: m_Log(logging::Get("http")) +, m_RequestLog(logging::Get("http_requests")) +, m_ThreadPool(ThreadCount) +, m_AsyncWorkPool(AsyncWorkThreadCount) +{ + ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); + + if (Result != NO_ERROR) + { + return; + } + + m_IsHttpInitialized = true; + m_IsOk = true; + + ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount); +} + +HttpSysServer::~HttpSysServer() +{ + if (m_IsHttpInitialized) + { + Cleanup(); + + HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); + } +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + using namespace std::literals; + + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + m_IsOk = false; + + ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + int EffectivePort = BasePort; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + // Sharing violation implies the port is being used by another process + for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + WildcardUrlPath.Reset(); + WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + m_BaseUris.clear(); + if (Result == NO_ERROR) + { + m_BaseUris.push_back(WildcardUrlPath.c_str()); + } + else if (Result == ERROR_ACCESS_DENIED) + { + // If we can't register the wildcard path, we fall back to local paths + // This local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh + // during an install process to grant permissions to route public access to the appropriate + // port for the current user. eg: + // netsh http add urlacl url=http://*:1337/ user=<some_user> + + ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + ULONG InternalResult = ERROR_SHARING_VIOLATION; + for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrlPath; + LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; + + InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (InternalResult == NO_ERROR) + { + ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); + + m_BaseUris.push_back(LocalUrlPath.c_str()); + } + else + { + break; + } + } + } + } + + if (m_BaseUris.empty()) + { + ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; + + Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ nullptr, + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + HttpBindingInfo.Flags.Present = 1; + HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; + + Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + // Create I/O completion port + + std::error_code ErrorCode; + m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); + + if (ErrorCode) + { + ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + } + else + { + m_IsOk = true; + + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + + return EffectivePort; +} + +void +HttpSysServer::Cleanup() +{ + ++m_IsShuttingDown; + + if (m_RequestQueueHandle) + { + HttpCloseRequestQueue(m_RequestQueueHandle); + m_RequestQueueHandle = nullptr; + } + + if (m_HttpUrlGroupId) + { + HttpCloseUrlGroup(m_HttpUrlGroupId); + m_HttpUrlGroupId = 0; + } + + if (m_HttpSessionId) + { + HttpCloseServerSession(m_HttpSessionId); + m_HttpSessionId = 0; + } +} + +void +HttpSysServer::StartServer() +{ + const int InitialRequestCount = 32; + + for (int i = 0; i < InitialRequestCount; ++i) + { + IssueNewRequestMaybe(); + } +} + +void +HttpSysServer::Run(bool IsInteractive) +{ + if (IsInteractive) + { + zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); + } + + do + { + // int WaitTimeout = -1; + int WaitTimeout = 100; + + if (IsInteractive) + { + WaitTimeout = 1000; + + if (_kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + UpdateLofreqTimerValue(); + } while (!IsApplicationExitRequested()); +} + +void +HttpSysServer::OnHandlingRequest() +{ + if (--m_PendingRequests > m_MinPendingRequests) + { + // We have more than the minimum number of requests pending, just let someone else + // enqueue new requests + return; + } + + IssueNewRequestMaybe(); +} + +void +HttpSysServer::IssueNewRequestMaybe() +{ + if (m_IsShuttingDown.load(std::memory_order::acquire)) + { + return; + } + + if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) + { + return; + } + + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); + + std::error_code ErrorCode; + Request->IssueInitialRequest(ErrorCode); + + if (ErrorCode) + { + // No request was actually issued. What is the appropriate response? + + return; + } + + // This may end up exceeding the MaxPendingRequests limit, but it's not + // really a problem. I'm doing it this way mostly to avoid dealing with + // exceptions here + ++m_PendingRequests; + + Request.release(); +} + +void +HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) +{ + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + + return; + } + } +} + +void +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) +{ + ZEN_UNUSED(Service); + + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) +{ +} + +HttpSysTransaction::~HttpSysTransaction() +{ +} + +PTP_IO +HttpSysTransaction::Iocp() +{ + return m_HttpServer.m_ThreadPool.Iocp(); +} + +HANDLE +HttpSysTransaction::RequestQueueHandle() +{ + return m_HttpServer.m_RequestQueueHandle; +} + +void +HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) +{ + m_InitialHttpHandler.IssueRequest(ErrorCode); +} + +void +HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io) +{ + UNREFERENCED_PARAMETER(Io); + UNREFERENCED_PARAMETER(Instance); + UNREFERENCED_PARAMETER(pContext); + + // Note that for a given transaction we may be in this completion function on more + // than one thread at any given moment. This means we need to be careful about what + // happens in here + + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) + { + delete Transaction; + } +} + +bool +HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) +{ + HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; + m_CompletionHandler = NewCompletionHandler; + + auto _ = MakeGuard([this, CurrentHandler] { + if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) + { + delete CurrentHandler; + } + }); + + if (NewCompletionHandler == nullptr) + { + return false; + } + + try + { + std::error_code ErrorCode; + m_CompletionHandler->IssueRequest(ErrorCode); + + if (!ErrorCode) + { + return true; + } + + ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); + } + catch (std::exception& Ex) + { + ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); + } + + // something went wrong, no request is pending + m_CompletionHandler = nullptr; + + return false; +} + +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // We use this to ensure sequential execution of completion handlers + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + + RwLock::ExclusiveLockScope _(m_CompletionMutex); + + bool IsRequestPending = false; + + if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) + { + if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) + { + // Ensure we have a sufficient number of pending requests outstanding + m_HttpServer.OnHandlingRequest(); + } + + auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); + + IsRequestPending = IssueNextRequest(NewCompletionHandler); + } + + // Ensure new requests are enqueued as necessary + m_HttpServer.IssueNewRequestMaybe(); + + if (IsRequestPending) + { + // There is another request pending on this transaction, so it needs to remain valid + return Status::kRequestPending; + } + + if (m_HttpServer.m_IsRequestLoggingEnabled) + { + if (m_HandlerRequest.has_value()) + { + m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); + } + } + + // Transaction done, caller should clean up (delete) this instance + return Status::kDone; +} + +HttpSysServerRequest& +HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) +{ + HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + + // Default request handling + + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + + return ThisRequest; +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This is overhead which I'd prefer to avoid but for now we just have + // to live with it + + WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_UriUtf8); + + std::string_view UriSuffix8{m_UriUtf8}; + + m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access + m_Uri = UriSuffix8; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(UriSuffix8.size() + 1); + } + } + } + else + { + m_UriUtf8.Reset(); + m_Uri = {}; + m_UriWithExtension = {}; + } + + if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; // We skip the leading question mark + + WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_QueryString = std::string_view(m_QueryStringUtf8); + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = GetAcceptType(HttpRequestPtr); + } + + if (m_Verb == HttpVerb::kHead) + { + SetSuppressResponseBody(); + } +} + +Oid +HttpSysServerRequest::ParseSessionId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Session"sv) + { + if (Header.RawValueLength == Oid::StringLength) + { + return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + } + } + } + + return {}; +} + +uint32_t +HttpSysServerRequest::ParseRequestId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Request"sv) + { + std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; + uint32_t RequestId = 0; + std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); + return RequestId; + } + } + + return 0; +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + if (m_HttpTx.Server().IsAsyncResponseEnabled()) + { + m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); + } + else + { + ContinuationHandler(m_HttpTx.ServerRequest()); + } +} + +bool +HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; + + return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + +void +InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) +{ + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + + StartThreadpoolIo(Iocp); + + ULONG HttpApiResult; + + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + // The http.sys team recommends limiting the size to 128KB + static const uint64_t kMaxBytesPerApiCall = 128 * 1024; + + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow<ULONG>(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } + + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) + { + CancelThreadpoolIo(Iocp); + + ErrorCode = MakeErrorCode(HttpApiResult); + + ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); + + return; + } + + ErrorCode.clear(); +} + +HttpSysRequestHandler* +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + + switch (IoResult) + { + default: + case ERROR_OPERATION_ABORTED: + return nullptr; + + case ERROR_MORE_DATA: // Insufficient buffer space + case NO_ERROR: + break; + } + + // Route request + + try + { + HTTP_REQUEST* HttpReq = HttpRequest(); + +# if 0 + for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + { + auto& ReqInfo = HttpReq->pRequestInfo[i]; + + switch (ReqInfo.InfoType) + { + case HttpRequestInfoTypeRequestTiming: + { + const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeAuth: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeChannelBind: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslProtocol: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBindingDraft: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBinding: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV0: + { + const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeRequestSizing: + { + const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeQuicStats: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV1: + { + const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + } + } +# endif + + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) + { + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + const HttpContentType ContentType = GetContentType(HttpReq); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); + + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } + + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else + { + m_CurrentPayloadOffset += NumberOfBytesTransferred; + } + + if (m_CurrentPayloadOffset != m_ContentLength) + { + // Body not complete, issue another read request to receive more body data + return this; + } + + // Request body received completely + + m_PayloadBuffer.MakeImmutable(); + + HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); + + if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) + { + return Response; + } + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + } + + // Unable to route + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + + return new HttpMessageResponseRequest(Transaction(), 500, ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpServer interface implementation +// + +int +HttpSysServer::Initialize(int BasePort) +{ + int EffectivePort = InitializeServer(BasePort); + StartServer(); + return EffectivePort; +} + +void +HttpSysServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} +void +HttpSysServer::RegisterService(HttpService& Service) +{ + RegisterService(Service.BaseUri(), Service); +} + +Ref<HttpServer> +CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads) +{ + return Ref<HttpServer>(new HttpSysServer(Concurrency, BackgroundWorkerThreads)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/httpsys.h b/src/zenhttp/httpsys.h new file mode 100644 index 000000000..d6bd34890 --- /dev/null +++ b/src/zenhttp/httpsys.h @@ -0,0 +1,90 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <zencore/workthreadpool.h> +# include "iothreadpool.h" + +# include <http.h> + +namespace spdlog { +class logger; +} + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual int Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + + WorkerThreadPool& WorkPool() { return m_AsyncWorkPool; } + + inline bool IsOk() const { return m_IsOk; } + inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + +private: + int InitializeServer(int BasePort); + void Cleanup(); + + void StartServer(); + void OnHandlingRequest(); + void IssueNewRequestMaybe(); + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + spdlog::logger& m_Log; + spdlog::logger& m_RequestLog; + spdlog::logger& Log() { return m_Log; } + + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + bool m_IsRequestLoggingEnabled = false; + bool m_IsAsyncResponseEnabled = true; + + WinIoThreadPool m_ThreadPool; + WorkerThreadPool m_AsyncWorkPool; + + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic_int32_t m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; +}; + +} // namespace zen +#endif diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h new file mode 100644 index 000000000..8316a9b9f --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -0,0 +1,47 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zencore/iobuffer.h> +#include <zencore/uid.h> +#include <zenhttp/httpcommon.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CbPackage; + +/** HTTP client implementation for Zen use cases + + Currently simple and synchronous, should become lean and asynchronous + */ +class HttpClient +{ +public: + HttpClient(std::string_view BaseUri); + ~HttpClient(); + + struct Response + { + int StatusCode = 0; + IoBuffer ResponsePayload; // Note: this also includes the content type + }; + + [[nodiscard]] Response Put(std::string_view Url, IoBuffer Payload); + [[nodiscard]] Response Get(std::string_view Url); + [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package); + [[nodiscard]] Response Delete(std::string_view Url); + +private: + std::string m_BaseUri; + std::string m_SessionId; +}; + +} // namespace zen + +void httpclient_forcelink(); // internal diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h new file mode 100644 index 000000000..19fda8db4 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -0,0 +1,181 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> + +#include <string_view> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +using HttpContentType = ZenContentType; + +class IoBuffer; +class CbObject; +class CbPackage; +class StringBuilderBase; + +struct HttpRange +{ + uint32_t Start = ~uint32_t(0); + uint32_t End = ~uint32_t(0); +}; + +using HttpRanges = std::vector<HttpRange>; + +std::string_view MapContentTypeToString(HttpContentType ContentType); +extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString); +std::string_view ReasonStringForHttpResultCode(int HttpCode); +bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges); + +[[nodiscard]] inline bool +IsHttpSuccessCode(int HttpCode) +{ + return (HttpCode >= 200) && (HttpCode < 300); +} + +enum class HttpVerb : uint8_t +{ + kGet = 1 << 0, + kPut = 1 << 1, + kPost = 1 << 2, + kDelete = 1 << 3, + kHead = 1 << 4, + kCopy = 1 << 5, + kOptions = 1 << 6 +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb); + +const std::string_view ToString(HttpVerb Verb); + +enum class HttpResponseCode +{ + // 1xx - Informational + + Continue = 100, //!< Indicates that the initial part of a request has been received and has not yet been rejected by the server. + SwitchingProtocols = 101, //!< Indicates that the server understands and is willing to comply with the client's request, via the + //!< Upgrade header field, for a change in the application protocol being used on this connection. + Processing = 102, //!< Is an interim response used to inform the client that the server has accepted the complete request, but has not + //!< yet completed it. + EarlyHints = 103, //!< Indicates to the client that the server is likely to send a final response with the header fields included in + //!< the informational response. + + // 2xx - Successful + + OK = 200, //!< Indicates that the request has succeeded. + Created = 201, //!< Indicates that the request has been fulfilled and has resulted in one or more new resources being created. + Accepted = 202, //!< Indicates that the request has been accepted for processing, but the processing has not been completed. + NonAuthoritativeInformation = 203, //!< Indicates that the request was successful but the enclosed payload has been modified from that + //!< of the origin server's 200 (OK) response by a transforming proxy. + NoContent = 204, //!< Indicates that the server has successfully fulfilled the request and that there is no additional content to send + //!< in the response payload body. + ResetContent = 205, //!< Indicates that the server has fulfilled the request and desires that the user agent reset the \"document + //!< view\", which caused the request to be sent, to its original state as received from the origin server. + PartialContent = 206, //!< Indicates that the server is successfully fulfilling a range request for the target resource by transferring + //!< one or more parts of the selected representation that correspond to the satisfiable ranges found in the + //!< requests's Range header field. + MultiStatus = 207, //!< Provides status for multiple independent operations. + AlreadyReported = 208, //!< Used inside a DAV:propstat response element to avoid enumerating the internal members of multiple bindings + //!< to the same collection repeatedly. [RFC 5842] + IMUsed = 226, //!< The server has fulfilled a GET request for the resource, and the response is a representation of the result of one + //!< or more instance-manipulations applied to the current instance. + + // 3xx - Redirection + + MultipleChoices = 300, //!< Indicates that the target resource has more than one representation, each with its own more specific + //!< identifier, and information about the alternatives is being provided so that the user (or user agent) can + //!< select a preferred representation by redirecting its request to one or more of those identifiers. + MovedPermanently = 301, //!< Indicates that the target resource has been assigned a new permanent URI and any future references to this + //!< resource ought to use one of the enclosed URIs. + Found = 302, //!< Indicates that the target resource resides temporarily under a different URI. + SeeOther = 303, //!< Indicates that the server is redirecting the user agent to a different resource, as indicated by a URI in the + //!< Location header field, that is intended to provide an indirect response to the original request. + NotModified = 304, //!< Indicates that a conditional GET request has been received and would have resulted in a 200 (OK) response if it + //!< were not for the fact that the condition has evaluated to false. + UseProxy = 305, //!< \deprecated \parblock Due to security concerns regarding in-band configuration of a proxy. \endparblock + //!< The requested resource MUST be accessed through the proxy given by the Location field. + TemporaryRedirect = 307, //!< Indicates that the target resource resides temporarily under a different URI and the user agent MUST NOT + //!< change the request method if it performs an automatic redirection to that URI. + PermanentRedirect = 308, //!< The target resource has been assigned a new permanent URI and any future references to this resource + //!< ought to use one of the enclosed URIs. [...] This status code is similar to 301 Moved Permanently + //!< (Section 7.3.2 of rfc7231), except that it does not allow rewriting the request method from POST to GET. + + // 4xx - Client Error + BadRequest = 400, //!< Indicates that the server cannot or will not process the request because the received syntax is invalid, + //!< nonsensical, or exceeds some limitation on what the server is willing to process. + Unauthorized = 401, //!< Indicates that the request has not been applied because it lacks valid authentication credentials for the + //!< target resource. + PaymentRequired = 402, //!< *Reserved* + Forbidden = 403, //!< Indicates that the server understood the request but refuses to authorize it. + NotFound = 404, //!< Indicates that the origin server did not find a current representation for the target resource or is not willing + //!< to disclose that one exists. + MethodNotAllowed = 405, //!< Indicates that the method specified in the request-line is known by the origin server but not supported by + //!< the target resource. + NotAcceptable = 406, //!< Indicates that the target resource does not have a current representation that would be acceptable to the + //!< user agent, according to the proactive negotiation header fields received in the request, and the server is + //!< unwilling to supply a default representation. + ProxyAuthenticationRequired = + 407, //!< Is similar to 401 (Unauthorized), but indicates that the client needs to authenticate itself in order to use a proxy. + RequestTimeout = + 408, //!< Indicates that the server did not receive a complete request message within the time that it was prepared to wait. + Conflict = 409, //!< Indicates that the request could not be completed due to a conflict with the current state of the resource. + Gone = 410, //!< Indicates that access to the target resource is no longer available at the origin server and that this condition is + //!< likely to be permanent. + LengthRequired = 411, //!< Indicates that the server refuses to accept the request without a defined Content-Length. + PreconditionFailed = + 412, //!< Indicates that one or more preconditions given in the request header fields evaluated to false when tested on the server. + PayloadTooLarge = 413, //!< Indicates that the server is refusing to process a request because the request payload is larger than the + //!< server is willing or able to process. + URITooLong = 414, //!< Indicates that the server is refusing to service the request because the request-target is longer than the + //!< server is willing to interpret. + UnsupportedMediaType = 415, //!< Indicates that the origin server is refusing to service the request because the payload is in a format + //!< not supported by the target resource for this method. + RangeNotSatisfiable = 416, //!< Indicates that none of the ranges in the request's Range header field overlap the current extent of the + //!< selected resource or that the set of ranges requested has been rejected due to invalid ranges or an + //!< excessive request of small or overlapping ranges. + ExpectationFailed = 417, //!< Indicates that the expectation given in the request's Expect header field could not be met by at least + //!< one of the inbound servers. + ImATeapot = 418, //!< Any attempt to brew coffee with a teapot should result in the error code 418 I'm a teapot. + UnprocessableEntity = 422, //!< Means the server understands the content type of the request entity (hence a 415(Unsupported Media + //!< Type) status code is inappropriate), and the syntax of the request entity is correct (thus a 400 (Bad + //!< Request) status code is inappropriate) but was unable to process the contained instructions. + Locked = 423, //!< Means the source or destination resource of a method is locked. + FailedDependency = 424, //!< Means that the method could not be performed on the resource because the requested action depended on + //!< another action and that action failed. + UpgradeRequired = 426, //!< Indicates that the server refuses to perform the request using the current protocol but might be willing to + //!< do so after the client upgrades to a different protocol. + PreconditionRequired = 428, //!< Indicates that the origin server requires the request to be conditional. + TooManyRequests = 429, //!< Indicates that the user has sent too many requests in a given amount of time (\"rate limiting\"). + RequestHeaderFieldsTooLarge = + 431, //!< Indicates that the server is unwilling to process the request because its header fields are too large. + UnavailableForLegalReasons = + 451, //!< This status code indicates that the server is denying access to the resource in response to a legal demand. + + // 5xx - Server Error + + InternalServerError = + 500, //!< Indicates that the server encountered an unexpected condition that prevented it from fulfilling the request. + NotImplemented = 501, //!< Indicates that the server does not support the functionality required to fulfill the request. + BadGateway = 502, //!< Indicates that the server, while acting as a gateway or proxy, received an invalid response from an inbound + //!< server it accessed while attempting to fulfill the request. + ServiceUnavailable = 503, //!< Indicates that the server is currently unable to handle the request due to a temporary overload or + //!< scheduled maintenance, which will likely be alleviated after some delay. + GatewayTimeout = 504, //!< Indicates that the server, while acting as a gateway or proxy, did not receive a timely response from an + //!< upstream server it needed to access in order to complete the request. + HTTPVersionNotSupported = 505, //!< Indicates that the server does not support, or refuses to support, the protocol version that was + //!< used in the request message. + VariantAlsoNegotiates = + 506, //!< Indicates that the server has an internal configuration error: the chosen variant resource is configured to engage in + //!< transparent content negotiation itself, and is therefore not a proper end point in the negotiation process. + InsufficientStorage = 507, //!< Means the method could not be performed on the resource because the server is unable to store the + //!< representation needed to successfully complete the request. + LoopDetected = 508, //!< Indicates that the server terminated an operation because it encountered an infinite loop while processing a + //!< request with "Depth: infinity". [RFC 5842] + NotExtended = 510, //!< The policy for accessing the resource has not been met in the request. [RFC 2774] + NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access. +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h new file mode 100644 index 000000000..3b9fa50b4 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -0,0 +1,315 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zencore/compactbinary.h> +#include <zencore/enumflags.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/refcount.h> +#include <zencore/string.h> +#include <zencore/uid.h> +#include <zenhttp/httpcommon.h> + +#include <functional> +#include <gsl/gsl-lite.hpp> +#include <list> +#include <map> +#include <regex> +#include <span> +#include <unordered_map> + +namespace zen { + +/** HTTP Server Request + */ +class HttpServerRequest +{ +public: + HttpServerRequest(); + ~HttpServerRequest(); + + // Synchronous operations + + [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix + [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } + [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } + + struct QueryParams + { + std::vector<std::pair<std::string_view, std::string_view>> KvPairs; + + std::string_view GetValue(std::string_view ParamName) const + { + for (const auto& Kv : KvPairs) + { + const std::string_view& Key = Kv.first; + + if (Key.size() == ParamName.size()) + { + if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size())) + { + return Kv.second; + } + } + } + + return std::string_view(); + } + }; + + virtual bool TryGetRanges(HttpRanges&) { return false; } + + QueryParams GetQueryParams(); + + inline HttpVerb RequestVerb() const { return m_Verb; } + inline HttpContentType RequestContentType() { return m_ContentType; } + inline HttpContentType AcceptContentType() { return m_AcceptType; } + + inline uint64_t ContentLength() const { return m_ContentLength; } + Oid SessionId() const; + uint32_t RequestId() const; + + inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } + inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } + inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + + /** Read POST/PUT payload for request body, which is always available without delay + */ + virtual IoBuffer ReadPayload() = 0; + + ZENCORE_API CbObject ReadPayloadObject(); + ZENCORE_API CbPackage ReadPayloadPackage(); + + /** Respond with payload + + No data will have been sent when any of these functions return. Instead, the response will be transmitted + asynchronously, after returning from a request handler function. + + Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be + moved into our response handler array where they are kept alive, in order to reduce ref-counting storms + */ + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); + + void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); + void WriteResponse(HttpResponseCode ResponseCode, CbArray Array); + void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); + void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); + void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); + + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0; + +protected: + enum + { + kIsHandled = 1 << 0, + kSuppressBody = 1 << 1, + kHaveRequestId = 1 << 2, + kHaveSessionId = 1 << 3, + }; + + mutable uint32_t m_Flags = 0; + HttpVerb m_Verb = HttpVerb::kGet; + HttpContentType m_ContentType = HttpContentType::kBinary; + HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; + uint64_t m_ContentLength = ~0ull; + std::string_view m_Uri; + std::string_view m_UriWithExtension; + std::string_view m_QueryString; + mutable uint32_t m_RequestId = ~uint32_t(0); + mutable Oid m_SessionId = Oid::Zero; + + inline void SetIsHandled() { m_Flags |= kIsHandled; } + + virtual Oid ParseSessionId() const = 0; + virtual uint32_t ParseRequestId() const = 0; +}; + +class IHttpPackageHandler : public RefCounted +{ +public: + virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0; + virtual void OnRequestBegin() = 0; + virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0; + virtual void OnRequestComplete() = 0; +}; + +/** + * Base class for implementing an HTTP "service" + * + * A service exposes one or more endpoints with a certain URI prefix + * + */ + +class HttpService +{ +public: + HttpService() = default; + virtual ~HttpService() = default; + + virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + // Internals + + inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; } + inline int UriPrefixLength() const { return m_UriPrefixLength; } + +private: + int m_UriPrefixLength = 0; +}; + +/** HTTP server + * + * Implements the main event loop to service HTTP requests, and handles routing + * requests to the appropriate handler as registered via RegisterService + */ +class HttpServer : public RefCounted +{ +public: + virtual void RegisterService(HttpService& Service) = 0; + virtual int Initialize(int BasePort) = 0; + virtual void Run(bool IsInteractiveSession) = 0; + virtual void RequestExit() = 0; +}; + +Ref<HttpServer> CreateHttpServer(std::string_view ServerClass); + +////////////////////////////////////////////////////////////////////////// + +class HttpRouterRequest +{ +public: + HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + + ZENCORE_API std::string GetCapture(uint32_t Index) const; + inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } + +private: + using MatchResults_t = std::match_results<std::string_view::const_iterator>; + + HttpServerRequest& m_HttpRequest; + MatchResults_t m_Match; + + friend class HttpRequestRouter; +}; + +inline std::string +HttpRouterRequest::GetCapture(uint32_t Index) const +{ + ZEN_ASSERT(Index < m_Match.size()); + + return m_Match[Index]; +} + +/** HTTP request router helper + * + * This helper class allows a service implementer to register one or more + * endpoints using pattern matching (currently using regex matching) + * + * This is intended to be initialized once only, there is no thread + * safety so you can absolutely not add or remove endpoints once the handler + * goes live + */ + +class HttpRequestRouter +{ +public: + typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; + + /** + * @brief Add pattern which can be referenced by name, commonly used for URL components + * @param Id String used to identify patterns for replacement + * @param Regex String which will replace the Id string in any registered URL paths + */ + void AddPattern(const char* Id, const char* Regex); + + /** + * @brief Register a an endpoint handler for the given route + * @param Regex Regular expression used to match the handler to a request. This may + * contain pattern aliases registered via AddPattern + * @param HandlerFunc Handler function to call for any matching request + * @param SupportedVerbs Supported HTTP verbs for this handler + */ + void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); + + void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); + + /** + * @brief HTTP request handling function - this should be called to route the + * request to a registered handler + * @param Request Request to route to a handler + * @return Function returns true if the request was routed successfully + */ + bool HandleRequest(zen::HttpServerRequest& Request); + +private: + struct HandlerEntry + { + HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) + : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) + , Verbs(SupportedVerbs) + , Handler(std::move(Handler)) + , Pattern(Pattern) + { + } + + ~HandlerEntry() = default; + + std::regex RegEx; + HttpVerb Verbs; + HandlerFunc_t Handler; + const char* Pattern; + + private: + HandlerEntry& operator=(const HandlerEntry&) = delete; + HandlerEntry(const HandlerEntry&) = delete; + }; + + std::list<HandlerEntry> m_Handlers; + std::unordered_map<std::string, std::string> m_PatternMap; +}; + +/** HTTP RPC request helper + */ + +class RpcResult +{ + RpcResult(CbObject Result) : m_Result(std::move(Result)) {} + +private: + CbObject m_Result; +}; + +class HttpRpcHandler +{ +public: + HttpRpcHandler(); + ~HttpRpcHandler(); + + HttpRpcHandler(const HttpRpcHandler&) = delete; + HttpRpcHandler operator=(const HttpRpcHandler&) = delete; + + void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction); + +private: + struct RpcFunction + { + std::function<void(CbObject& RpcArgs)> Function; + std::string Identifier; + }; + + std::map<std::string, RpcFunction> m_Functions; +}; + +bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); + +void http_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpshared.h b/src/zenhttp/include/zenhttp/httpshared.h new file mode 100644 index 000000000..d335572c5 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpshared.h @@ -0,0 +1,163 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> + +#include <functional> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +class IoBuffer; +class CbPackage; +class CompositeBuffer; + +/** _____ _ _____ _ + / ____| | | __ \ | | + | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___ + | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \ + | |____| |_) | | | (_| | (__| < (_| | (_| | __/ + \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___| + __/ | + |___/ + + Structures and code related to handling CbPackage transactions + + CbPackage instances are marshaled across the wire using a distinct message + format. We don't use the CbPackage serialization format provided by the + CbPackage implementation itself since that does not provide much flexibility + in how the attachment payloads are transmitted. The scheme below separates + metadata cleanly from payloads and this enables us to more efficiently + transmit them either via sendfile/TransmitFile like mechanisms, or by + reference/memory mapping in the local case. + */ + +struct CbPackageHeader +{ + uint32_t HeaderMagic; + uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document? + uint32_t Reserved1; + uint32_t Reserved2; +}; + +static_assert(sizeof(CbPackageHeader) == 16); + +enum : uint32_t +{ + kCbPkgMagic = 0xaa77aacc +}; + +struct CbAttachmentEntry +{ + uint64_t PayloadSize; // Size of the associated payload data in the message + uint32_t Flags; // See flags below + IoHash AttachmentHash; // Content Id for the attachment + + enum + { + kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format + kIsObject = (1u << 1), // Is compact binary object + kIsError = (1u << 2), // Is error (compact binary formatted) object + kIsLocalRef = (1u << 3), // Is "local reference" + }; +}; + +struct CbAttachmentReferenceHeader +{ + uint64_t PayloadByteOffset = 0; + uint64_t PayloadByteSize = ~0u; + uint16_t AbsolutePathLength = 0; + + // This header will be followed by UTF8 encoded absolute path to backing file +}; + +static_assert(sizeof(CbAttachmentEntry) == 32); + +enum class FormatFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), + kDenyPartialLocalReferences = (1u << 1) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags); + +enum class RpcAcceptOptions : uint16_t +{ + kNone = 0, + kAllowLocalReferences = (1u << 0), + kAllowPartialLocalReferences = (1u << 1) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0); +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + return IoBuffer{Size}; + }); +bool IsPackageMessage(IoBuffer Payload); + +bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, int TargetProcessPid = 0); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid = 0); + +/** Streaming reader for compact binary packages + + The goal is to ultimately support zero-copy I/O, but for now there'll be some + copying involved on some platforms at least. + + This approach to deserializing CbPackage data is more efficient than + `ParsePackageMessage` since it does not require the entire message to + be resident in a memory buffer + + */ +class CbPackageReader +{ +public: + CbPackageReader(); + ~CbPackageReader(); + + void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + + /** Process compact binary package data stream + + The data stream must be in the serialization format produced by FormatPackageMessage + + \return How many bytes must be fed to this function in the next call + */ + uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes); + + void Finalize(); + const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; } + CbObject GetRootObject() { return m_RootObject; } + std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; } + +private: + enum class State + { + kInitialState, + kReadingHeader, + kReadingAttachmentEntries, + kReadingBuffers + } m_CurrentState = State::kInitialState; + + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; + std::vector<IoBuffer> m_PayloadBuffers; + std::vector<CbAttachmentEntry> m_AttachmentEntries; + std::vector<CbAttachment> m_Attachments; + CbObject m_RootObject; + CbPackageHeader m_PackageHeader; + + IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer); +}; + +void forcelink_httpshared(); + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..adca7e988 --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,256 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compactbinarypackage.h> +#include <zencore/memory.h> + +#include <compare> +#include <functional> +#include <future> +#include <memory> +#include <optional> + +#pragma once + +namespace asio { +class io_context; +} + +namespace zen { + +class BinaryWriter; + +/** + * A unique socket ID. + */ +class WebSocketId +{ + static std::atomic_uint32_t NextId; + +public: + WebSocketId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const WebSocketId&) const = default; + + static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); } + +private: + WebSocketId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +/** + * Type of web socket message. + */ +enum class WebSocketMessageType : uint8_t +{ + kInvalid, + kNotification, + kRequest, + kStreamRequest, + kResponse, + kStreamResponse, + kStreamCompleteResponse, + kCount +}; + +inline std::string_view +ToString(WebSocketMessageType Type) +{ + switch (Type) + { + case WebSocketMessageType::kInvalid: + return std::string_view("Invalid"); + case WebSocketMessageType::kNotification: + return std::string_view("Notification"); + case WebSocketMessageType::kRequest: + return std::string_view("Request"); + case WebSocketMessageType::kStreamRequest: + return std::string_view("StreamRequest"); + case WebSocketMessageType::kResponse: + return std::string_view("Response"); + case WebSocketMessageType::kStreamResponse: + return std::string_view("StreamResponse"); + case WebSocketMessageType::kStreamCompleteResponse: + return std::string_view("StreamCompleteResponse"); + default: + return std::string_view("Unknown"); + }; +} + +/** + * Web socket message. + */ +class WebSocketMessage +{ + struct Header + { + static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh + + uint64_t MessageSize{}; + uint32_t Magic{ExpectedMagic}; + uint32_t CorrelationId{}; + uint32_t StatusCode{200u}; + WebSocketMessageType MessageType{}; + uint8_t Reserved[3] = {0}; + + bool IsValid() const; + }; + + static_assert(sizeof(Header) == 24); + + static std::atomic_uint32_t NextCorrelationId; + +public: + static constexpr size_t HeaderSize = sizeof(Header); + + WebSocketMessage() = default; + + WebSocketId SocketId() const { return m_SocketId; } + void SetSocketId(WebSocketId Id) { m_SocketId = Id; } + uint64_t MessageSize() const { return m_Header.MessageSize; } + void SetMessageType(WebSocketMessageType MessageType); + void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } + uint32_t CorrelationId() const { return m_Header.CorrelationId; } + uint32_t StatusCode() const { return m_Header.StatusCode; } + void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; } + WebSocketMessageType MessageType() const { return m_Header.MessageType; } + + const CbPackage& Body() const { return m_Body.value(); } + void SetBody(CbPackage&& Body); + void SetBody(CbObject&& Body); + bool HasBody() const { return m_Body.has_value(); } + + void Save(BinaryWriter& Writer); + bool TryLoadHeader(MemoryView Memory); + + bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; } + +private: + Header m_Header{}; + WebSocketId m_SocketId{}; + std::optional<CbPackage> m_Body; +}; + +class WebSocketServer; + +/** + * Base class for handling web socket requests and notifications from connected client(s). + */ +class WebSocketService +{ +public: + virtual ~WebSocketService() = default; + + void Configure(WebSocketServer& Server); + + virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); } + virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); } + +protected: + WebSocketService() = default; + + virtual void RegisterHandlers(WebSocketServer& Server) = 0; + void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete); + void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete); + + WebSocketServer& SocketServer() + { + ZEN_ASSERT(m_SocketServer); + return *m_SocketServer; + } + +private: + WebSocketServer* m_SocketServer{}; +}; + +/** + * Server options. + */ +struct WebSocketServerOptions +{ + uint16_t Port = 2337; + uint32_t ThreadCount = 1; +}; + +/** + * The web socket server manages client connections and routing of requests and notifications. + */ +class WebSocketServer +{ +public: + virtual ~WebSocketServer() = default; + + virtual bool Run() = 0; + virtual void Shutdown() = 0; + + virtual void RegisterService(WebSocketService& Service) = 0; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0; + + virtual void SendNotification(WebSocketMessage&& Notification) = 0; + virtual void SendResponse(WebSocketMessage&& Response) = 0; + + static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options); +}; + +/** + * The state of the web socket. + */ +enum class WebSocketState : uint32_t +{ + kNone, + kHandshaking, + kConnected, + kDisconnected, + kError +}; + +/** + * Type of web socket client event. + */ +enum class WebSocketEvent : uint32_t +{ + kConnected, + kDisconnected, + kError +}; + +/** + * Web socket client connection info. + */ +struct WebSocketConnectInfo +{ + std::string Host; + int16_t Port{8848}; + std::string Endpoint; + std::vector<std::string> Protocols; + uint16_t Version{13}; +}; + +/** + * A connection to a web socket server for sending requests and listening for notifications. + */ +class WebSocketClient +{ +public: + using EventCallback = std::function<void()>; + using NotificationCallback = std::function<void(WebSocketMessage&&)>; + + virtual ~WebSocketClient() = default; + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0; + virtual void Disconnect() = 0; + virtual bool IsConnected() const = 0; + virtual WebSocketState State() const = 0; + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0; + virtual void OnNotification(NotificationCallback&& Cb) = 0; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0; + + static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx); +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/zenhttp.h b/src/zenhttp/include/zenhttp/zenhttp.h new file mode 100644 index 000000000..59c64b31f --- /dev/null +++ b/src/zenhttp/include/zenhttp/zenhttp.h @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +#define ZENHTTP_API // Placeholder to allow DLL configs in the future + +namespace zen { + +ZENHTTP_API void zenhttp_forcelinktests(); + +} diff --git a/src/zenhttp/iothreadpool.cpp b/src/zenhttp/iothreadpool.cpp new file mode 100644 index 000000000..6087e69ec --- /dev/null +++ b/src/zenhttp/iothreadpool.cpp @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "iothreadpool.h" + +#include <zencore/except.h> + +#if ZEN_PLATFORM_WINDOWS + +namespace zen { + +WinIoThreadPool::WinIoThreadPool(int InThreadCount) +{ + // Thread pool setup + + m_ThreadPool = CreateThreadpool(NULL); + + SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); + SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2); + + InitializeThreadpoolEnvironment(&m_CallbackEnvironment); + + m_CleanupGroup = CreateThreadpoolCleanupGroup(); + + SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); + + SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); +} + +WinIoThreadPool::~WinIoThreadPool() +{ + CloseThreadpool(m_ThreadPool); +} + +void +WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) +{ + ZEN_ASSERT(!m_ThreadPoolIo); + + m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); + + if (!m_ThreadPoolIo) + { + ErrorCode = MakeErrorCodeFromLastError(); + } +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/iothreadpool.h b/src/zenhttp/iothreadpool.h new file mode 100644 index 000000000..8333964c3 --- /dev/null +++ b/src/zenhttp/iothreadpool.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> + +# include <system_error> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Thread pool. Implemented in terms of Windows thread pool right now, will +// need a cross-platform implementation eventually +// + +class WinIoThreadPool +{ +public: + WinIoThreadPool(int InThreadCount); + ~WinIoThreadPool(); + + void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); + inline PTP_IO Iocp() const { return m_ThreadPoolIo; } + +private: + PTP_POOL m_ThreadPool = nullptr; + PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; + PTP_IO m_ThreadPoolIo = nullptr; + TP_CALLBACK_ENVIRON m_CallbackEnvironment; +}; + +} // namespace zen +#endif diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp new file mode 100644 index 000000000..bbe7e1ad8 --- /dev/null +++ b/src/zenhttp/websocketasio.cpp @@ -0,0 +1,1613 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocket.h> + +#include <zencore/base64.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/sha1.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/trace.h> + +#include <chrono> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <mstcpip.h> +#endif + +namespace zen::websocket { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +enum class ParseMessageStatus : uint32_t +{ + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional<std::string> Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + BinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Reset(); +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const + { + return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); + } + + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + void Initialize(); + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry + { + uint64_t Offset{}; + uint64_t Size{}; + }; + + struct HeaderStreamEntry + { + StreamEntry Field{}; + StreamEntry Value{}; + }; + + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector<HeaderStreamEntry> m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; + + static http_parser_settings ParserSettings; +}; + +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; + +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) + { + Status = ParseMessageStatus::kError; + } + + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} + +void +HttpMessageParser::OnReset() +{ + Initialize(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} + +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; + + m_Stream.Write(Url); + + return 0; +} + +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + if (m_CurrentHeader.Field.Size == 0) + { + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) + { + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); + + m_Stream.Write(HeaderValue); + + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); + + const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); + + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } + + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) +{ + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); + + if (m_Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return false; + } + + if (m_Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////// +class WebSocketMessageParser final : public MessageParser +{ +public: + WebSocketMessageParser() : MessageParser() {} + + WebSocketMessage ConsumeMessage(); + +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + + WebSocketMessage m_Message; +}; + +ParseMessageResult +WebSocketMessageParser::OnParseMessage(MemoryView Msg) +{ + ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); + + const uint64_t PrevOffset = m_Stream.CurrentOffset(); + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingHeaderSize)); + Msg += RemaingHeaderSize; + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); + + if (IsValidHeader == false) + { + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; + } + + if (m_Message.MessageSize() == 0) + { + return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + } + + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); + + if (Msg.IsEmpty() == false) + { + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + m_Stream.Write(Msg.Left(RemaingMessageSize)); + } + + auto Status = ParseMessageStatus::kContinue; + + if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) + { + Status = ParseMessageStatus::kDone; + + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); + + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } + + m_Message.SetBody(std::move(Pkg)); + } + + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; +} + +void +WebSocketMessageParser::OnReset() +{ + m_Message = WebSocketMessage(); +} + +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() +{ + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); + + return Msg; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this<WsConnection> +{ +public: + WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_State() + { + } + + ~WsConnection() = default; + + std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } + + WebSocketId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } + std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + WebSocketState Close(); + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + std::mutex& WriteMutex() { return m_WriteMutex; } + +private: + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + TimePoint m_StartTime; + std::atomic_uint32_t m_State; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + std::mutex m_WriteMutex; +}; + +WebSocketState +WsConnection::Close() +{ + const auto PrevState = SetState(WebSocketState::kDisconnected); + + if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector<std::thread> m_Threads; + std::atomic_bool m_Running{false}; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); + + m_Running = true; + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + for (;;) + { + if (m_Running == false) + { + break; + } + + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + } + + ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + if (m_Running) + { + m_Running = false; + + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} + virtual ~WsServer() { Shutdown(); } + + virtual bool Run() override; + virtual void Shutdown() override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); + + void ReadMessage(std::shared_ptr<WsConnection> Connection); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); + + struct IdHasher + { + size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; + using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; + + WebSocketServerOptions m_Options; + asio::io_service m_IoSvc; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::unique_ptr<WsThreadPool> m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::vector<WebSocketService*> m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; + std::atomic_bool m_Running{}; +}; + +void +WsServer::RegisterService(WebSocketService& Service) +{ + m_Services.push_back(&Service); + + Service.Configure(*this); +} + +bool +WsServer::Run() +{ + static constexpr size_t ReceiveBufferSize = 256 << 10; + static constexpr size_t SendBufferSize = 256 << 10; + + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor->native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); + m_ThreadPool->Start(m_Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) +{ + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); + Result.first->second.push_back(&Service); +} + +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} + +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); + + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || + Response.MessageType() == WebSocketMessageType::kStreamResponse || + Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); + + ZEN_ASSERT(Response.CorrelationId() != 0); + + SendMessage(std::move(Response)); +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (m_Running) + { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); + } + else + { + auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); + + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[Connection->Id()] = Connection; + } + + Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + } + + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) +{ + if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value()); + } + else + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); + } + } + + const WebSocketId Id = Connection->Id(); + + { + std::unique_lock _(m_ConnMutex); + if (m_Connections.contains(Id)) + { + m_Connections.erase(Id); + } + } +} + +void +WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) +{ + Connection->ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection->Socket(), + Connection->ReadBuffer(), + asio::transfer_at_least(1), + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + switch (Connection->State()) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); + asio::const_buffer Buffer = Connection->ReadBuffer().data(); + + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + Connection->ReadBuffer().consume(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) + { + return ReadMessage(Connection); + } + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Parser.IsUpgrade() == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + std::string AcceptHash; + std::string Reason; + const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); + + if (ValidHandshake == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] + << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + ZEN_LOG_DEBUG(LogWebSocket, + "accepting handshake from connection '#{} {}'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + std::string Response = Sb.ToString(); + Buffer = asio::buffer(Response); + + async_write(Connection->Socket(), + Buffer, + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { + if (WriteEc) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + WriteEc.message()); + + return CloseConnection(Connection, WriteEc); + } + + ZEN_LOG_DEBUG(LogWebSocket, + "handshake ({}B) with connection '#{} {}' OK", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); + + Connection->SetParser(std::make_unique<WebSocketMessageParser>()); + Connection->SetState(WebSocketState::kConnected); + + ReadMessage(Connection); + }); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); + + uint64_t RemainingBytes = Connection->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } + + ReadMessage(Connection); + } + break; + + default: + break; + }; + }); +} + +void +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + case WebSocketMessageType::kStreamRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector<WebSocketService*>& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + + default: + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; + + { + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) + { + Connection = It->second; + } + } + + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + + ZEN_LOG_TRACE(LogWebSocket, + "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", + ToString(Msg.MessageType()), + Connection->Id().Value(), + Msg.MessageSize(), + Msg.CorrelationId(), + NiceBytes(Writer.Size())); + + { + ZEN_TRACE_CPU("WS::SendMessage"); + std::unique_lock _(Connection->WriteMutex()); + ZEN_TRACE_CPU("WS::WriteSocketData"); + asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> +{ +public: + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} + + virtual ~WsClient() { Disconnect(); } + + std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const override { return false; } + virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); } + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; + +private: + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + void TriggerEvent(WebSocketEvent Evt); + void ReadMessage(); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; + + asio::io_context& m_IoCtx; + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + EventCallback m_EventCallbacks[3]; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise<bool> m_ConnectPromise; + std::atomic_uint32_t m_State; + std::string m_Host; + int16_t m_Port{}; +}; + +std::future<bool> +WsClient::Connect(const WebSocketConnectInfo& Info) +{ + if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) + { + return m_ConnectPromise.get_future(); + } + + SetState(WebSocketState::kHandshaking); + + try + { + asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); + + m_Socket->connect(Endpoint); + + m_Host = m_Socket->remote_endpoint().address().to_string(); + m_Port = Info.Port; + + ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); + } + catch (std::exception& Err) + { + ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + + SetState(WebSocketState::kError); + m_Socket.reset(); + + TriggerEvent(WebSocketEvent::kDisconnected); + + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); + } + + ExtendableStringBuilder<128> Sb; + Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; + Sb << "Host: " << Info.Host << "\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: upgrade\r\n"sv; + Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; + + if (Info.Protocols.empty() == false) + { + Sb << "Sec-WebSocket-Protocol: "sv; + for (size_t Idx = 0; const auto& Protocol : Info.Protocols) + { + if (Idx++) + { + Sb << ", "; + } + Sb << Protocol; + } + } + + Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; + Sb << "\r\n"; + + std::string HandshakeRequest = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(HandshakeRequest); + + ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); + + m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); + m_MsgParser->Reset(); + + async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + else + { + Self->ReadMessage(); + } + }); + + return m_ConnectPromise.get_future(); +} + +void +WsClient::Disconnect() +{ + if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) + { + ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); + + if (m_Socket && m_Socket->is_open()) + { + m_Socket->close(); + m_Socket.reset(); + } + + TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } + } +} + +std::future<WebSocketMessage> +WsClient::SendRequest(WebSocketMessage&& Request) +{ + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + + BinaryWriter Writer; + Request.Save(Writer); + + std::future<WebSocketMessage> FutureResponse; + + { + std::unique_lock _(m_RequestMutex); + + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); + ZEN_ASSERT(Result.second); + + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + }); + + return FutureResponse; +} + +void +WsClient::OnNotification(NotificationCallback&& Cb) +{ + m_NotificationCallback = std::move(Cb); +} + +void +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +{ + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); +} + +void +WsClient::TriggerEvent(WebSocketEvent Evt) +{ + const uint32_t Index = static_cast<uint32_t>(Evt); + + if (m_EventCallbacks[Index]) + { + m_EventCallbacks[Index](); + } +} + +void +WsClient::ReadMessage() +{ + m_ReadBuffer.prepare(64 << 10); + + async_read(*m_Socket, + m_ReadBuffer, + asio::transfer_at_least(1), + [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { + const WebSocketState State = Self->State(); + + if (State == WebSocketState::kDisconnected) + { + return; + } + + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); + + return Self->Disconnect(); + } + + switch (State) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); + + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + + ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(size_t(Result.ByteCount)); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + return Self->ReadMessage(); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + if (Parser.StatusCode() != 101) + { + ZEN_LOG_WARN(LogWsClient, + "handshake FAILED, status '{}', status code '{}'", + Parser.StatusText(), + Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); + + Self->SetParser(std::make_unique<WebSocketMessageParser>()); + Self->SetState(WebSocketState::kConnected); + Self->ReadMessage(); + Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); + + uint64_t RemainingBytes = Self->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + Parser.Reset(); + continue; + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Self->RouteMessage(std::move(Message)); + } + + Self->ReadMessage(); + } + break; + + default: + break; + } + }); +} + +void +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + + default: + ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType())); + break; + }; +} + +} // namespace zen::websocket + +namespace zen { + +std::atomic_uint32_t WebSocketId::NextId{1}; + +bool +WebSocketMessage::Header::IsValid() const +{ + return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && + uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); +} + +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + +void +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) +{ + m_Header.MessageType = MessageType; +} + +void +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) +{ + CbPackage Pkg; + Pkg.SetObject(Body); + + SetBody(std::move(Pkg)); +} + +void +WebSocketMessage::Save(BinaryWriter& Writer) +{ + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + const CbObject& Obj = m_Body.value().GetObject(); + MemoryView View = Obj.GetBuffer().GetView(); + + const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); + ZEN_ASSERT(ValidationResult == CbValidateError::None); + + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); +} + +bool +WebSocketMessage::TryLoadHeader(MemoryView Memory) +{ + if (Memory.GetSize() < HeaderSize) + { + return false; + } + + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); + + return m_Header.IsValid(); +} + +void +WebSocketService::Configure(WebSocketServer& Server) +{ + ZEN_ASSERT(m_SocketServer == nullptr); + + m_SocketServer = &Server; + + RegisterHandlers(Server); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) +{ + WebSocketMessage Message; + + Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(CorrelationId); + Message.SetSocketId(SocketId); + Message.SetBody(std::move(StreamResponse)); + + SocketServer().SendResponse(std::move(Message)); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) +{ + CbPackage Response; + Response.SetObject(std::move(StreamResponse)); + + SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); +} + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create(const WebSocketServerOptions& Options) +{ + return std::make_unique<websocket::WsServer>(Options); +} + +std::shared_ptr<WebSocketClient> +WebSocketClient::Create(asio::io_context& IoCtx) +{ + return std::make_shared<websocket::WsClient>(IoCtx); +} + +} // namespace zen diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua new file mode 100644 index 000000000..b0dbdbc79 --- /dev/null +++ b/src/zenhttp/xmake.lua @@ -0,0 +1,14 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenhttp') + set_kind("static") + add_headerfiles("**.h") + add_files("**.cpp") + add_files("httpsys.cpp", {unity_ignored=true}) + add_includedirs("include", {public=true}) + add_deps("zencore") + add_packages( + "vcpkg::gsl-lite", + "vcpkg::http-parser" + ) + add_options("httpsys") diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp new file mode 100644 index 000000000..4bd6a5697 --- /dev/null +++ b/src/zenhttp/zenhttp.cpp @@ -0,0 +1,22 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/zenhttp.h> + +#if ZEN_WITH_TESTS + +# include <zenhttp/httpclient.h> +# include <zenhttp/httpserver.h> +# include <zenhttp/httpshared.h> + +namespace zen { + +void +zenhttp_forcelinktests() +{ + http_forcelink(); + forcelink_httpshared(); +} + +} // namespace zen + +#endif |