aboutsummaryrefslogtreecommitdiff
path: root/zenhttp
diff options
context:
space:
mode:
Diffstat (limited to 'zenhttp')
-rw-r--r--zenhttp/httpasio.cpp1372
-rw-r--r--zenhttp/httpasio.h36
-rw-r--r--zenhttp/httpclient.cpp176
-rw-r--r--zenhttp/httpnull.cpp83
-rw-r--r--zenhttp/httpnull.h29
-rw-r--r--zenhttp/httpserver.cpp885
-rw-r--r--zenhttp/httpshared.cpp809
-rw-r--r--zenhttp/httpsys.cpp1674
-rw-r--r--zenhttp/httpsys.h90
-rw-r--r--zenhttp/include/zenhttp/httpclient.h47
-rw-r--r--zenhttp/include/zenhttp/httpcommon.h181
-rw-r--r--zenhttp/include/zenhttp/httpserver.h315
-rw-r--r--zenhttp/include/zenhttp/httpshared.h163
-rw-r--r--zenhttp/include/zenhttp/websocket.h256
-rw-r--r--zenhttp/include/zenhttp/zenhttp.h21
-rw-r--r--zenhttp/iothreadpool.cpp49
-rw-r--r--zenhttp/iothreadpool.h37
-rw-r--r--zenhttp/websocketasio.cpp1613
-rw-r--r--zenhttp/xmake.lua14
-rw-r--r--zenhttp/zenhttp.cpp22
20 files changed, 0 insertions, 7872 deletions
diff --git a/zenhttp/httpasio.cpp b/zenhttp/httpasio.cpp
deleted file mode 100644
index 79b2c0a3d..000000000
--- a/zenhttp/httpasio.cpp
+++ /dev/null
@@ -1,1372 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "httpasio.h"
-
-#include <zencore/logging.h>
-#include <zenhttp/httpserver.h>
-
-#include <deque>
-#include <memory>
-#include <string_view>
-#include <vector>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#if ZEN_PLATFORM_WINDOWS
-# include <conio.h>
-# include <mstcpip.h>
-#endif
-#include <http_parser.h>
-#include <asio.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-#define ASIO_VERBOSE_TRACE 0
-
-#if ASIO_VERBOSE_TRACE
-# define ZEN_TRACE_VERBOSE ZEN_TRACE
-#else
-# define ZEN_TRACE_VERBOSE(fmtstr, ...)
-#endif
-
-namespace zen::asio_http {
-
-using namespace std::literals;
-
-struct HttpAcceptor;
-struct HttpRequest;
-struct HttpResponse;
-struct HttpServerConnection;
-
-static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
-static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
-static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
-static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
-static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
-static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
-static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
-
-inline spdlog::logger&
-InitLogger()
-{
- spdlog::logger& Logger = logging::Get("asio");
- // Logger.set_level(spdlog::level::trace);
- return Logger;
-}
-
-inline spdlog::logger&
-Log()
-{
- static spdlog::logger& g_Logger = InitLogger();
- return g_Logger;
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-struct HttpAsioServerImpl
-{
-public:
- HttpAsioServerImpl();
- ~HttpAsioServerImpl();
-
- int Start(uint16_t Port, int ThreadCount);
- void Stop();
- void RegisterService(const char* UrlPath, HttpService& Service);
- HttpService* RouteRequest(std::string_view Url);
-
- asio::io_service m_IoService;
- asio::io_service::work m_Work{m_IoService};
- std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
- std::vector<std::thread> m_ThreadPool;
-
- struct ServiceEntry
- {
- std::string ServiceUrlPath;
- HttpService* Service;
- };
-
- RwLock m_Lock;
- std::vector<ServiceEntry> m_UriHandlers;
-};
-
-/**
- * This is the class which request handlers use to interact with the server instance
- */
-
-class HttpAsioServerRequest : public HttpServerRequest
-{
-public:
- HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer);
- ~HttpAsioServerRequest();
-
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
-
- virtual IoBuffer ReadPayload() override;
- virtual void WriteResponse(HttpResponseCode ResponseCode) override;
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
- virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
- virtual bool TryGetRanges(HttpRanges& Ranges) override;
-
- using HttpServerRequest::WriteResponse;
-
- HttpAsioServerRequest(const HttpAsioServerRequest&) = delete;
- HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete;
-
- asio_http::HttpRequest& m_Request;
- IoBuffer m_PayloadBuffer;
- std::unique_ptr<HttpResponse> m_Response;
-};
-
-struct HttpRequest
-{
- explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {}
-
- void Initialize();
- size_t ConsumeData(const char* InputData, size_t DataSize);
- void ResetState();
-
- HttpVerb RequestVerb() const { return m_RequestVerb; }
- bool IsKeepAlive() const { return m_KeepAlive; }
- std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; }
- std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); }
- IoBuffer Body() { return m_BodyBuffer; }
-
- inline HttpContentType ContentType()
- {
- if (m_ContentTypeHeaderIndex < 0)
- {
- return HttpContentType::kUnknownContentType;
- }
-
- return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value);
- }
-
- inline HttpContentType AcceptType()
- {
- if (m_AcceptHeaderIndex < 0)
- {
- return HttpContentType::kUnknownContentType;
- }
-
- return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value);
- }
-
- Oid SessionId() const { return m_SessionId; }
- int RequestId() const { return m_RequestId; }
-
- std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); }
-
-private:
- struct HeaderEntry
- {
- HeaderEntry() = default;
-
- HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {}
-
- std::string_view Name;
- std::string_view Value;
- };
-
- HttpServerConnection& m_Connection;
- char* m_HeaderCursor = m_HeaderBuffer;
- char* m_Url = nullptr;
- size_t m_UrlLength = 0;
- char* m_QueryString = nullptr;
- size_t m_QueryLength = 0;
- char* m_CurrentHeaderName = nullptr; // Used while parsing headers
- size_t m_CurrentHeaderNameLength = 0;
- char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
- size_t m_CurrentHeaderValueLength = 0;
- std::vector<HeaderEntry> m_Headers;
- int8_t m_ContentLengthHeaderIndex;
- int8_t m_AcceptHeaderIndex;
- int8_t m_ContentTypeHeaderIndex;
- int8_t m_RangeHeaderIndex;
- HttpVerb m_RequestVerb;
- bool m_KeepAlive = false;
- bool m_Expect100Continue = false;
- int m_RequestId = -1;
- Oid m_SessionId{};
- IoBuffer m_BodyBuffer;
- uint64_t m_BodyPosition = 0;
- http_parser m_Parser;
- char m_HeaderBuffer[1024];
- std::string m_NormalizedUrl;
-
- void AppendCurrentHeader();
-
- int OnMessageBegin();
- int OnUrl(const char* Data, size_t Bytes);
- int OnHeader(const char* Data, size_t Bytes);
- int OnHeaderValue(const char* Data, size_t Bytes);
- int OnHeadersComplete();
- int OnBody(const char* Data, size_t Bytes);
- int OnMessageComplete();
-
- static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); }
- static http_parser_settings s_ParserSettings;
-};
-
-struct HttpResponse
-{
-public:
- HttpResponse() = default;
- explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {}
-
- void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
- {
- m_ResponseCode = ResponseCode;
- const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
-
- m_DataBuffers.reserve(ChunkCount);
-
- for (IoBuffer& Buffer : BlobList)
- {
-#if 1
- m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
-#else
- IoBuffer TempBuffer = std::move(Buffer);
- TempBuffer.MakeOwned();
- m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer));
-#endif
- }
-
- uint64_t LocalDataSize = 0;
-
- m_AsioBuffers.push_back({}); // Placeholder for header
-
- for (IoBuffer& Buffer : m_DataBuffers)
- {
- uint64_t BufferDataSize = Buffer.Size();
-
- ZEN_ASSERT(BufferDataSize);
-
- LocalDataSize += BufferDataSize;
-
- IoBufferFileReference FileRef;
- if (Buffer.GetFileReference(/* out */ FileRef))
- {
- // TODO: Use direct file transfer, via TransmitFile/sendfile
- //
- // this looks like it requires some custom asio plumbing however
-
- m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
- }
- else
- {
- // Send from memory
-
- m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
- }
- }
- m_ContentLength = LocalDataSize;
-
- auto Headers = GetHeaders();
- m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size());
- }
-
- uint16_t ResponseCode() const { return m_ResponseCode; }
- uint64_t ContentLength() const { return m_ContentLength; }
-
- const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; }
-
- std::string_view GetHeaders()
- {
- m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
- << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
- << "Content-Length: " << ContentLength() << "\r\n"sv;
-
- if (!m_IsKeepAlive)
- {
- m_Headers << "Connection: close\r\n"sv;
- }
-
- m_Headers << "\r\n"sv;
-
- return m_Headers;
- }
-
- void SuppressPayload() { m_AsioBuffers.resize(1); }
-
-private:
- uint16_t m_ResponseCode = 0;
- bool m_IsKeepAlive = true;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- uint64_t m_ContentLength = 0;
- std::vector<IoBuffer> m_DataBuffers;
- std::vector<asio::const_buffer> m_AsioBuffers;
- ExtendableStringBuilder<160> m_Headers;
-};
-
-//////////////////////////////////////////////////////////////////////////
-
-struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection>
-{
- HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
- ~HttpServerConnection();
-
- void HandleNewRequest();
- void TerminateConnection();
- void HandleRequest();
-
- std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
-
-private:
- enum class RequestState
- {
- kInitialState,
- kInitialRead,
- kReadingMore,
- kWriting,
- kWritingFinal,
- kDone,
- kTerminated
- };
-
- RequestState m_RequestState = RequestState::kInitialState;
- HttpRequest m_RequestData{*this};
-
- void EnqueueRead();
- void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
- void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false);
- void OnError();
-
- HttpAsioServerImpl& m_Server;
- asio::streambuf m_RequestBuffer;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::atomic<uint32_t> m_RequestCounter{0};
- uint32_t m_ConnectionId = 0;
- Ref<IHttpPackageHandler> m_PackageHandler;
-
- RwLock m_ResponsesLock;
- std::deque<std::unique_ptr<HttpResponse>> m_Responses;
-};
-
-std::atomic<uint32_t> g_ConnectionIdCounter{0};
-
-HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
-: m_Server(Server)
-, m_Socket(std::move(Socket))
-, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
-{
- ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
-}
-
-HttpServerConnection::~HttpServerConnection()
-{
- ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
-}
-
-void
-HttpServerConnection::HandleNewRequest()
-{
- m_RequestData.Initialize();
-
- EnqueueRead();
-}
-
-void
-HttpServerConnection::TerminateConnection()
-{
- m_RequestState = RequestState::kTerminated;
-
- std::error_code Ec;
- m_Socket->close(Ec);
-}
-
-void
-HttpServerConnection::EnqueueRead()
-{
- if (m_RequestState == RequestState::kInitialRead)
- {
- m_RequestState = RequestState::kReadingMore;
- }
- else
- {
- m_RequestState = RequestState::kInitialRead;
- }
-
- m_RequestBuffer.prepare(64 * 1024);
-
- asio::async_read(*m_Socket.get(),
- m_RequestBuffer,
- asio::transfer_at_least(1),
- [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
-}
-
-void
-HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
-{
- if (Ec)
- {
- if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead)
- {
- ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message());
- return;
- }
- else
- {
- ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
- return OnError();
- }
- }
-
- ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'",
- m_ConnectionId,
- m_RequestCounter.load(std::memory_order_relaxed),
- zen::GetCurrentThreadId(),
- NiceBytes(ByteCount));
-
- while (m_RequestBuffer.size())
- {
- const asio::const_buffer& InputBuffer = m_RequestBuffer.data();
-
- size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size());
- if (Result == ~0ull)
- {
- return OnError();
- }
-
- m_RequestBuffer.consume(Result);
- }
-
- switch (m_RequestState)
- {
- case RequestState::kDone:
- case RequestState::kWritingFinal:
- case RequestState::kTerminated:
- break;
-
- default:
- EnqueueRead();
- break;
- }
-}
-
-void
-HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop)
-{
- if (Ec)
- {
- ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
- OnError();
- }
- else
- {
- ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'",
- m_ConnectionId,
- m_RequestCounter.load(std::memory_order_relaxed),
- zen::GetCurrentThreadId(),
- NiceBytes(ByteCount));
-
- if (!m_RequestData.IsKeepAlive())
- {
- m_RequestState = RequestState::kDone;
-
- m_Socket->close();
- }
- else
- {
- if (Pop)
- {
- RwLock::ExclusiveLockScope _(m_ResponsesLock);
- m_Responses.pop_front();
- }
-
- m_RequestCounter.fetch_add(1);
- }
- }
-}
-
-void
-HttpServerConnection::OnError()
-{
- m_Socket->close();
-}
-
-void
-HttpServerConnection::HandleRequest()
-{
- if (!m_RequestData.IsKeepAlive())
- {
- m_RequestState = RequestState::kWritingFinal;
-
- std::error_code Ec;
- m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
-
- if (Ec)
- {
- ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
- }
- }
- else
- {
- m_RequestState = RequestState::kWriting;
- }
-
- if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
- {
- HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body());
-
- ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed));
-
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
- {
- try
- {
- Service->HandleRequest(Request);
- }
- catch (std::exception& ex)
- {
- ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
-
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
- }
- }
-
- if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
- {
- // Transmit the response
-
- if (m_RequestData.RequestVerb() == HttpVerb::kHead)
- {
- Response->SuppressPayload();
- }
-
- auto ResponseBuffers = Response->AsioBuffers();
-
- uint64_t ResponseLength = 0;
-
- for (auto& Buffer : ResponseBuffers)
- {
- ResponseLength += Buffer.size();
- }
-
- {
- RwLock::ExclusiveLockScope _(m_ResponsesLock);
- m_Responses.push_back(std::move(Response));
- }
-
- // TODO: should cork/uncork for Linux?
-
- asio::async_write(*m_Socket.get(),
- ResponseBuffers,
- asio::transfer_exactly(ResponseLength),
- [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) {
- Conn->OnResponseDataSent(Ec, ByteCount, true);
- });
-
- return;
- }
- }
-
- if (m_RequestData.RequestVerb() == HttpVerb::kHead)
- {
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "\r\n"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Connection: close\r\n"
- "\r\n"sv;
- }
-
- asio::async_write(
- *m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
- }
- else
- {
- std::string_view Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "\r\n"
- "No suitable route found"sv;
-
- if (!m_RequestData.IsKeepAlive())
- {
- Response =
- "HTTP/1.1 404 NOT FOUND\r\n"
- "Content-Length: 23\r\n"
- "Content-Type: text/plain\r\n"
- "Connection: close\r\n"
- "\r\n"
- "No suitable route found"sv;
- }
-
- asio::async_write(
- *m_Socket.get(),
- asio::buffer(Response),
- [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-//
-// HttpRequest
-//
-
-http_parser_settings HttpRequest::s_ParserSettings{
- .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
- .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
- .on_status =
- [](http_parser* p, const char* Data, size_t ByteCount) {
- ZEN_UNUSED(p, Data, ByteCount);
- return 0;
- },
- .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
- .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
- .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
- .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
- .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
- .on_chunk_header{},
- .on_chunk_complete{}};
-
-void
-HttpRequest::Initialize()
-{
- http_parser_init(&m_Parser, HTTP_REQUEST);
- m_Parser.data = this;
-
- ResetState();
-}
-
-size_t
-HttpRequest::ConsumeData(const char* InputData, size_t DataSize)
-{
- const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
-
- http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
-
- if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
- {
- ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
- return ~0ull;
- }
-
- return ConsumedBytes;
-}
-
-int
-HttpRequest::OnUrl(const char* Data, size_t Bytes)
-{
- if (!m_Url)
- {
- ZEN_ASSERT_SLOW(m_UrlLength == 0);
- m_Url = m_HeaderCursor;
- }
-
- const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
-
- if (RemainingBufferSpace < Bytes)
- {
- ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
- }
-
- memcpy(m_HeaderCursor, Data, Bytes);
- m_HeaderCursor += Bytes;
- m_UrlLength += Bytes;
-
- return 0;
-}
-
-int
-HttpRequest::OnHeader(const char* Data, size_t Bytes)
-{
- if (m_CurrentHeaderValueLength)
- {
- AppendCurrentHeader();
-
- m_CurrentHeaderNameLength = 0;
- m_CurrentHeaderValueLength = 0;
- m_CurrentHeaderName = m_HeaderCursor;
- }
- else if (m_CurrentHeaderName == nullptr)
- {
- m_CurrentHeaderName = m_HeaderCursor;
- }
-
- const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
- if (RemainingBufferSpace < Bytes)
- {
- ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
- }
-
- memcpy(m_HeaderCursor, Data, Bytes);
- m_HeaderCursor += Bytes;
- m_CurrentHeaderNameLength += Bytes;
-
- return 0;
-}
-
-void
-HttpRequest::AppendCurrentHeader()
-{
- std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength);
- std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength);
-
- const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
-
- if (HeaderHash == HashContentLength)
- {
- m_ContentLengthHeaderIndex = (int8_t)m_Headers.size();
- }
- else if (HeaderHash == HashAccept)
- {
- m_AcceptHeaderIndex = (int8_t)m_Headers.size();
- }
- else if (HeaderHash == HashContentType)
- {
- m_ContentTypeHeaderIndex = (int8_t)m_Headers.size();
- }
- else if (HeaderHash == HashSession)
- {
- m_SessionId = Oid::FromHexString(HeaderValue);
- }
- else if (HeaderHash == HashRequest)
- {
- std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
- }
- else if (HeaderHash == HashExpect)
- {
- if (HeaderValue == "100-continue"sv)
- {
- // We don't currently do anything with this
- m_Expect100Continue = true;
- }
- else
- {
- ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
- }
- }
- else if (HeaderHash == HashRange)
- {
- m_RangeHeaderIndex = (int8_t)m_Headers.size();
- }
-
- m_Headers.emplace_back(HeaderName, HeaderValue);
-}
-
-int
-HttpRequest::OnHeaderValue(const char* Data, size_t Bytes)
-{
- if (m_CurrentHeaderValueLength == 0)
- {
- m_CurrentHeaderValue = m_HeaderCursor;
- }
-
- const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
- if (RemainingBufferSpace < Bytes)
- {
- ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace);
- return 1;
- }
-
- memcpy(m_HeaderCursor, Data, Bytes);
- m_HeaderCursor += Bytes;
- m_CurrentHeaderValueLength += Bytes;
-
- return 0;
-}
-
-static void
-NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl)
-{
- bool LastCharWasSeparator = false;
- for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex)
- {
- const char UrlChar = Url[UrlIndex];
- const bool IsSeparator = (UrlChar == '/');
-
- if (IsSeparator && LastCharWasSeparator)
- {
- if (NormalizedUrl.empty())
- {
- NormalizedUrl.reserve(UrlLength);
- NormalizedUrl.append(Url, UrlIndex);
- }
-
- if (!LastCharWasSeparator)
- {
- NormalizedUrl.push_back('/');
- }
- }
- else if (!NormalizedUrl.empty())
- {
- NormalizedUrl.push_back(UrlChar);
- }
-
- LastCharWasSeparator = IsSeparator;
- }
-}
-
-int
-HttpRequest::OnHeadersComplete()
-{
- if (m_CurrentHeaderValueLength)
- {
- AppendCurrentHeader();
- }
-
- if (m_ContentLengthHeaderIndex >= 0)
- {
- std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value;
- uint64_t ContentLength = 0;
- std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength);
-
- if (ContentLength)
- {
- m_BodyBuffer = IoBuffer(ContentLength);
- }
-
- m_BodyBuffer.SetContentType(ContentType());
-
- m_BodyPosition = 0;
- }
-
- m_KeepAlive = !!http_should_keep_alive(&m_Parser);
-
- switch (m_Parser.method)
- {
- case HTTP_GET:
- m_RequestVerb = HttpVerb::kGet;
- break;
-
- case HTTP_POST:
- m_RequestVerb = HttpVerb::kPost;
- break;
-
- case HTTP_PUT:
- m_RequestVerb = HttpVerb::kPut;
- break;
-
- case HTTP_DELETE:
- m_RequestVerb = HttpVerb::kDelete;
- break;
-
- case HTTP_HEAD:
- m_RequestVerb = HttpVerb::kHead;
- break;
-
- case HTTP_COPY:
- m_RequestVerb = HttpVerb::kCopy;
- break;
-
- case HTTP_OPTIONS:
- m_RequestVerb = HttpVerb::kOptions;
- break;
-
- default:
- ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
- break;
- }
-
- std::string_view Url(m_Url, m_UrlLength);
-
- if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos)
- {
- m_UrlLength = QuerySplit;
- m_QueryString = m_Url + QuerySplit + 1;
- m_QueryLength = Url.size() - QuerySplit - 1;
- }
-
- NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl);
-
- return 0;
-}
-
-int
-HttpRequest::OnBody(const char* Data, size_t Bytes)
-{
- if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
- {
- ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
- (m_BodyPosition + Bytes) - m_BodyBuffer.Size());
- return 1;
- }
- memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
- m_BodyPosition += Bytes;
-
- if (http_body_is_final(&m_Parser))
- {
- if (m_BodyPosition != m_BodyBuffer.Size())
- {
- ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
- return 1;
- }
- }
-
- return 0;
-}
-
-void
-HttpRequest::ResetState()
-{
- m_HeaderCursor = m_HeaderBuffer;
- m_CurrentHeaderName = nullptr;
- m_CurrentHeaderNameLength = 0;
- m_CurrentHeaderValue = nullptr;
- m_CurrentHeaderValueLength = 0;
- m_CurrentHeaderName = nullptr;
- m_Url = nullptr;
- m_UrlLength = 0;
- m_QueryString = nullptr;
- m_QueryLength = 0;
- m_ContentLengthHeaderIndex = -1;
- m_AcceptHeaderIndex = -1;
- m_ContentTypeHeaderIndex = -1;
- m_RangeHeaderIndex = -1;
- m_Expect100Continue = false;
- m_BodyBuffer = {};
- m_BodyPosition = 0;
- m_Headers.clear();
- m_NormalizedUrl.clear();
-}
-
-int
-HttpRequest::OnMessageBegin()
-{
- return 0;
-}
-
-int
-HttpRequest::OnMessageComplete()
-{
- m_Connection.HandleRequest();
-
- ResetState();
-
- return 0;
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-struct HttpAcceptor
-{
- HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort)
- : m_Server(Server)
- , m_IoService(IoService)
- , m_Acceptor(m_IoService, asio::ip::tcp::v6())
- {
- m_Acceptor.set_option(asio::ip::v6_only(false));
-#if ZEN_PLATFORM_WINDOWS
- // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
- typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address;
- m_Acceptor.set_option(excluse_address(true));
-#else // ZEN_PLATFORM_WINDOWS
- m_Acceptor.set_option(asio::socket_base::reuse_address(false));
-#endif // ZEN_PLATFORM_WINDOWS
-
- m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- uint16_t EffectivePort = BasePort;
-
- asio::error_code BindErrorCode;
- m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
- // Sharing violation implies the port is being used by another process
- for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset)
- {
- EffectivePort = BasePort + (PortOffset * 100);
- m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
- }
- if (BindErrorCode == asio::error::access_denied)
- {
- EffectivePort = 0;
- m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
- }
- if (BindErrorCode)
- {
- ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message());
- }
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor.native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-#endif
- m_Acceptor.listen();
-
- ZEN_INFO("Started asio server at port '{}'", EffectivePort);
- }
-
- void Start()
- {
- m_Acceptor.listen();
- InitAccept();
- }
-
- void Stop() { m_IsStopped = true; }
-
- void InitAccept()
- {
- auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
- asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
-
- m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
- if (Ec)
- {
- ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
- m_Acceptor.local_endpoint().address().to_string(),
- m_Acceptor.local_endpoint().port(),
- Ec.message());
- }
- else
- {
- // New connection established, pass socket ownership into connection object
- // and initiate request handling loop. The connection lifetime is
- // managed by the async read/write loop by passing the shared
- // reference to the callbacks.
-
- Socket->set_option(asio::ip::tcp::no_delay(true));
- Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
- Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
-
- auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
- Conn->HandleNewRequest();
- }
-
- if (!m_IsStopped.load())
- {
- InitAccept();
- }
- else
- {
- m_Acceptor.close();
- }
- });
- }
-
- int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
-
-private:
- HttpAsioServerImpl& m_Server;
- asio::io_service& m_IoService;
- asio::ip::tcp::acceptor m_Acceptor;
- std::atomic<bool> m_IsStopped{false};
-};
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer)
-: m_Request(Request)
-, m_PayloadBuffer(std::move(PayloadBuffer))
-{
- const int PrefixLength = Service.UriPrefixLength();
-
- std::string_view Uri = Request.Url();
- Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size())));
- m_Uri = Uri;
- m_UriWithExtension = Uri;
- m_QueryString = Request.QueryString();
-
- m_Verb = Request.RequestVerb();
- m_ContentLength = Request.Body().Size();
- m_ContentType = Request.ContentType();
-
- HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
-
- // Parse any extension, to allow requesting a particular response encoding via the URL
-
- {
- std::string_view UriSuffix8{m_Uri};
-
- const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
-
- if (LastComponentIndex != std::string_view::npos)
- {
- UriSuffix8.remove_prefix(LastComponentIndex);
- }
-
- const size_t LastDotIndex = UriSuffix8.find_last_of('.');
-
- if (LastDotIndex != std::string_view::npos)
- {
- UriSuffix8.remove_prefix(LastDotIndex + 1);
-
- AcceptContentType = ParseContentType(UriSuffix8);
-
- if (AcceptContentType != HttpContentType::kUnknownContentType)
- {
- m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1));
- }
- }
- }
-
- // It an explicit content type extension was specified then we'll use that over any
- // Accept: header value that may be present
-
- if (AcceptContentType != HttpContentType::kUnknownContentType)
- {
- m_AcceptType = AcceptContentType;
- }
- else
- {
- m_AcceptType = Request.AcceptType();
- }
-}
-
-HttpAsioServerRequest::~HttpAsioServerRequest()
-{
-}
-
-Oid
-HttpAsioServerRequest::ParseSessionId() const
-{
- return m_Request.SessionId();
-}
-
-uint32_t
-HttpAsioServerRequest::ParseRequestId() const
-{
- return m_Request.RequestId();
-}
-
-IoBuffer
-HttpAsioServerRequest::ReadPayload()
-{
- return m_PayloadBuffer;
-}
-
-void
-HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
-{
- ZEN_ASSERT(!m_Response);
-
- m_Response.reset(new HttpResponse(HttpContentType::kBinary));
- std::array<IoBuffer, 0> Empty;
-
- m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
-}
-
-void
-HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
-{
- ZEN_ASSERT(!m_Response);
-
- m_Response.reset(new HttpResponse(ContentType));
- m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
-}
-
-void
-HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
-{
- ZEN_ASSERT(!m_Response);
- m_Response.reset(new HttpResponse(ContentType));
-
- IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
- std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
-
- m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList);
-}
-
-void
-HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
-{
- ZEN_ASSERT(!m_Response);
-
- // Not one bit async, innit
- ContinuationHandler(*this);
-}
-
-bool
-HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges)
-{
- return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges);
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpAsioServerImpl::HttpAsioServerImpl()
-{
-}
-
-HttpAsioServerImpl::~HttpAsioServerImpl()
-{
-}
-
-int
-HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount)
-{
- ZEN_ASSERT(ThreadCount > 0);
-
- ZEN_INFO("starting asio http with {} service threads", ThreadCount);
-
- m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port));
- m_Acceptor->Start();
-
- for (int i = 0; i < ThreadCount; ++i)
- {
- m_ThreadPool.emplace_back([this, Index = i + 1] {
- SetCurrentThreadName(fmt::format("asio worker {}", Index));
-
- try
- {
- m_IoService.run();
- }
- catch (std::exception& e)
- {
- ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what());
- }
- });
- }
-
- ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort());
-
- return m_Acceptor->GetAcceptPort();
-}
-
-void
-HttpAsioServerImpl::Stop()
-{
- m_Acceptor->Stop();
- m_IoService.stop();
- for (auto& Thread : m_ThreadPool)
- {
- Thread.join();
- }
-}
-
-void
-HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
-{
- std::string_view UrlPath(InUrlPath);
- Service.SetUriPrefixLength(UrlPath.size());
- if (!UrlPath.empty() && UrlPath.back() == '/')
- {
- UrlPath.remove_suffix(1);
- }
-
- RwLock::ExclusiveLockScope _(m_Lock);
- m_UriHandlers.push_back({std::string(UrlPath), &Service});
-}
-
-HttpService*
-HttpAsioServerImpl::RouteRequest(std::string_view Url)
-{
- RwLock::SharedLockScope _(m_Lock);
-
- HttpService* CandidateService = nullptr;
- std::string::size_type CandidateMatchSize = 0;
- for (const ServiceEntry& SvcEntry : m_UriHandlers)
- {
- const std::string& SvcUrl = SvcEntry.ServiceUrlPath;
- const std::string::size_type SvcUrlSize = SvcUrl.size();
- if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 &&
- ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/')))
- {
- CandidateMatchSize = SvcUrl.size();
- CandidateService = SvcEntry.Service;
- }
- }
-
- return CandidateService;
-}
-
-} // namespace zen::asio_http
-
-//////////////////////////////////////////////////////////////////////////
-
-namespace zen {
-HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
-{
- ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest));
-}
-
-HttpAsioServer::~HttpAsioServer()
-{
- try
- {
- m_Impl->Stop();
- }
- catch (std::exception& ex)
- {
- ZEN_WARN("Caught exception stopping http asio server: {}", ex.what());
- }
-}
-
-void
-HttpAsioServer::RegisterService(HttpService& Service)
-{
- m_Impl->RegisterService(Service.BaseUri(), Service);
-}
-
-int
-HttpAsioServer::Initialize(int BasePort)
-{
- m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u));
- return m_BasePort;
-}
-
-void
-HttpAsioServer::Run(bool IsInteractive)
-{
- const bool TestMode = !IsInteractive;
-
- int WaitTimeout = -1;
- if (!TestMode)
- {
- WaitTimeout = 1000;
- }
-
-#if ZEN_PLATFORM_WINDOWS
- if (TestMode == false)
- {
- zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit");
- }
-
- do
- {
- if (!TestMode && _kbhit() != 0)
- {
- char c = (char)_getch();
-
- if (c == 27 || c == 'Q' || c == 'q')
- {
- RequestApplicationExit(0);
- }
- }
-
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
-#else
- if (TestMode == false)
- {
- zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit");
- }
-
- do
- {
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
-#endif
-}
-
-void
-HttpAsioServer::RequestExit()
-{
- m_ShutdownEvent.Set();
-}
-
-} // namespace zen
diff --git a/zenhttp/httpasio.h b/zenhttp/httpasio.h
deleted file mode 100644
index 716145955..000000000
--- a/zenhttp/httpasio.h
+++ /dev/null
@@ -1,36 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/thread.h>
-#include <zenhttp/httpserver.h>
-
-#include <memory>
-
-namespace zen {
-
-namespace asio_http {
- struct HttpServerConnection;
- struct HttpAcceptor;
- struct HttpAsioServerImpl;
-} // namespace asio_http
-
-class HttpAsioServer : public HttpServer
-{
-public:
- HttpAsioServer();
- ~HttpAsioServer();
-
- virtual void RegisterService(HttpService& Service) override;
- virtual int Initialize(int BasePort) override;
- virtual void Run(bool IsInteractiveSession) override;
- virtual void RequestExit() override;
-
-private:
- Event m_ShutdownEvent;
- int m_BasePort = 0;
-
- std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl;
-};
-
-} // namespace zen
diff --git a/zenhttp/httpclient.cpp b/zenhttp/httpclient.cpp
deleted file mode 100644
index e6813d407..000000000
--- a/zenhttp/httpclient.cpp
+++ /dev/null
@@ -1,176 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/httpclient.h>
-#include <zenhttp/httpserver.h>
-
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinarypackage.h>
-#include <zencore/iobuffer.h>
-#include <zencore/logging.h>
-#include <zencore/session.h>
-#include <zencore/sharedbuffer.h>
-#include <zencore/stream.h>
-#include <zencore/testing.h>
-#include <zenhttp/httpshared.h>
-
-static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
-
-namespace zen {
-
-using namespace std::literals;
-
-HttpClient::Response
-FromCprResponse(cpr::Response& InResponse)
-{
- return {.StatusCode = int(InResponse.status_code)};
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri)
-{
- StringBuilder<32> SessionId;
- GetSessionId().ToString(SessionId);
- m_SessionId = SessionId;
-}
-
-HttpClient::~HttpClient()
-{
-}
-
-HttpClient::Response
-HttpClient::TransactPackage(std::string_view Url, CbPackage Package)
-{
- cpr::Session Sess;
- Sess.SetUrl(m_BaseUri + std::string(Url));
-
- // First, list of offered chunks for filtering on the server end
-
- std::vector<IoHash> AttachmentsToSend;
- std::span<const CbAttachment> Attachments = Package.GetAttachments();
-
- const uint32_t RequestId = ++HttpClientRequestIdCounter;
- auto RequestIdString = fmt::to_string(RequestId);
-
- if (Attachments.empty() == false)
- {
- CbObjectWriter Writer;
- Writer.BeginArray("offer");
-
- for (const CbAttachment& Attachment : Attachments)
- {
- IoHash Hash = Attachment.GetHash();
-
- Writer.AddHash(Hash);
- }
-
- Writer.EndArray();
-
- BinaryWriter MemWriter;
- Writer.Save(MemWriter);
-
- Sess.SetHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
- Sess.SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()});
-
- cpr::Response FilterResponse = Sess.Post();
-
- if (FilterResponse.status_code == 200)
- {
- IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size());
- CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer);
-
- for (auto& Entry : ResponseObject["need"])
- {
- ZEN_ASSERT(Entry.IsHash());
- AttachmentsToSend.push_back(Entry.AsHash());
- }
- }
- }
-
- // Prepare package for send
-
- CbPackage SendPackage;
- SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
-
- for (const IoHash& AttachmentCid : AttachmentsToSend)
- {
- const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
-
- if (Attachment)
- {
- SendPackage.AddAttachment(*Attachment);
- }
- else
- {
- // This should be an error -- server asked to have something we can't find
- }
- }
-
- // Transmit package payload
-
- CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
- SharedBuffer FlatMessage = Message.Flatten();
-
- Sess.SetHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
- Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()});
-
- cpr::Response FilterResponse = Sess.Post();
-
- if (!IsHttpSuccessCode(FilterResponse.status_code))
- {
- return FromCprResponse(FilterResponse);
- }
-
- IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size());
-
- if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end())
- {
- HttpContentType ContentType = ParseContentType(It->second);
-
- ResponseBuffer.SetContentType(ContentType);
- }
-
- return {.StatusCode = int(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
-}
-
-HttpClient::Response
-HttpClient::Put(std::string_view Url, IoBuffer Payload)
-{
- ZEN_UNUSED(Url);
- ZEN_UNUSED(Payload);
- return {};
-}
-
-HttpClient::Response
-HttpClient::Get(std::string_view Url)
-{
- ZEN_UNUSED(Url);
- return {};
-}
-
-HttpClient::Response
-HttpClient::Delete(std::string_view Url)
-{
- ZEN_UNUSED(Url);
- return {};
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-#if ZEN_WITH_TESTS
-
-TEST_CASE("httpclient")
-{
- using namespace std::literals;
-
- SUBCASE("client") {}
-}
-
-void
-httpclient_forcelink()
-{
-}
-
-#endif
-
-} // namespace zen
diff --git a/zenhttp/httpnull.cpp b/zenhttp/httpnull.cpp
deleted file mode 100644
index a6e1d3567..000000000
--- a/zenhttp/httpnull.cpp
+++ /dev/null
@@ -1,83 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "httpnull.h"
-
-#include <zencore/logging.h>
-
-#if ZEN_PLATFORM_WINDOWS
-# include <conio.h>
-#endif
-
-namespace zen {
-
-HttpNullServer::HttpNullServer()
-{
-}
-
-HttpNullServer::~HttpNullServer()
-{
-}
-
-void
-HttpNullServer::RegisterService(HttpService& Service)
-{
- ZEN_UNUSED(Service);
-}
-
-int
-HttpNullServer::Initialize(int BasePort)
-{
- return BasePort;
-}
-
-void
-HttpNullServer::Run(bool IsInteractiveSession)
-{
- const bool TestMode = !IsInteractiveSession;
-
- int WaitTimeout = -1;
- if (!TestMode)
- {
- WaitTimeout = 1000;
- }
-
-#if ZEN_PLATFORM_WINDOWS
- if (TestMode == false)
- {
- zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit");
- }
-
- do
- {
- if (!TestMode && _kbhit() != 0)
- {
- char c = (char)_getch();
-
- if (c == 27 || c == 'Q' || c == 'q')
- {
- RequestApplicationExit(0);
- }
- }
-
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
-#else
- if (TestMode == false)
- {
- zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit");
- }
-
- do
- {
- m_ShutdownEvent.Wait(WaitTimeout);
- } while (!IsApplicationExitRequested());
-#endif
-}
-
-void
-HttpNullServer::RequestExit()
-{
- m_ShutdownEvent.Set();
-}
-
-} // namespace zen
diff --git a/zenhttp/httpnull.h b/zenhttp/httpnull.h
deleted file mode 100644
index 74f021f6b..000000000
--- a/zenhttp/httpnull.h
+++ /dev/null
@@ -1,29 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/thread.h>
-#include <zenhttp/httpserver.h>
-
-namespace zen {
-
-/**
- * @brief Null implementation of "http" server. Does nothing
- */
-
-class HttpNullServer : public HttpServer
-{
-public:
- HttpNullServer();
- ~HttpNullServer();
-
- virtual void RegisterService(HttpService& Service) override;
- virtual int Initialize(int BasePort) override;
- virtual void Run(bool IsInteractiveSession) override;
- virtual void RequestExit() override;
-
-private:
- Event m_ShutdownEvent;
-};
-
-} // namespace zen
diff --git a/zenhttp/httpserver.cpp b/zenhttp/httpserver.cpp
deleted file mode 100644
index 671cbd319..000000000
--- a/zenhttp/httpserver.cpp
+++ /dev/null
@@ -1,885 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/httpserver.h>
-
-#include "httpasio.h"
-#include "httpnull.h"
-#include "httpsys.h"
-
-#include <zencore/compactbinary.h>
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinarypackage.h>
-#include <zencore/iobuffer.h>
-#include <zencore/logging.h>
-#include <zencore/refcount.h>
-#include <zencore/stream.h>
-#include <zencore/string.h>
-#include <zencore/testing.h>
-#include <zencore/thread.h>
-#include <zenhttp/httpshared.h>
-
-#include <charconv>
-#include <mutex>
-#include <span>
-#include <string_view>
-
-namespace zen {
-
-using namespace std::literals;
-
-std::string_view
-MapContentTypeToString(HttpContentType ContentType)
-{
- switch (ContentType)
- {
- default:
- case HttpContentType::kUnknownContentType:
- case HttpContentType::kBinary:
- return "application/octet-stream"sv;
-
- case HttpContentType::kText:
- return "text/plain"sv;
-
- case HttpContentType::kJSON:
- return "application/json"sv;
-
- case HttpContentType::kCbObject:
- return "application/x-ue-cb"sv;
-
- case HttpContentType::kCbPackage:
- return "application/x-ue-cbpkg"sv;
-
- case HttpContentType::kCbPackageOffer:
- return "application/x-ue-offer"sv;
-
- case HttpContentType::kCompressedBinary:
- return "application/x-ue-comp"sv;
-
- case HttpContentType::kYAML:
- return "text/yaml"sv;
-
- case HttpContentType::kHTML:
- return "text/html"sv;
-
- case HttpContentType::kJavaScript:
- return "application/javascript"sv;
-
- case HttpContentType::kCSS:
- return "text/css"sv;
-
- case HttpContentType::kPNG:
- return "image/png"sv;
-
- case HttpContentType::kIcon:
- return "image/x-icon"sv;
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-//
-// Note that in addition to MIME types we accept abbreviated versions, for
-// use in suffix parsing as well as for convenience when using curl
-
-static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv);
-static constinit uint32_t HashJson = HashStringDjb2("json"sv);
-static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv);
-static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv);
-static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv);
-static constinit uint32_t HashText = HashStringDjb2("text/plain"sv);
-static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv);
-static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv);
-static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv);
-static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv);
-static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv);
-static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv);
-static constinit uint32_t HashHtml = HashStringDjb2("html"sv);
-static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv);
-static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv);
-static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv);
-static constinit uint32_t HashCss = HashStringDjb2("css"sv);
-static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv);
-static constinit uint32_t HashPng = HashStringDjb2("png"sv);
-static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv);
-static constinit uint32_t HashIcon = HashStringDjb2("ico"sv);
-static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv);
-
-std::once_flag InitContentTypeLookup;
-
-struct HashedTypeEntry
-{
- uint32_t Hash;
- HttpContentType Type;
-} TypeHashTable[] = {
- // clang-format off
- {HashBinary, HttpContentType::kBinary},
- {HashApplicationCompactBinary, HttpContentType::kCbObject},
- {HashCompactBinary, HttpContentType::kCbObject},
- {HashCompactBinaryPackage, HttpContentType::kCbPackage},
- {HashCompactBinaryPackageShort, HttpContentType::kCbPackage},
- {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer},
- {HashJson, HttpContentType::kJSON},
- {HashApplicationJson, HttpContentType::kJSON},
- {HashYaml, HttpContentType::kYAML},
- {HashTextYaml, HttpContentType::kYAML},
- {HashText, HttpContentType::kText},
- {HashCompressedBinary, HttpContentType::kCompressedBinary},
- {HashHtml, HttpContentType::kHTML},
- {HashTextHtml, HttpContentType::kHTML},
- {HashJavaScript, HttpContentType::kJavaScript},
- {HashApplicationJavaScript, HttpContentType::kJavaScript},
- {HashCss, HttpContentType::kCSS},
- {HashTextCss, HttpContentType::kCSS},
- {HashPng, HttpContentType::kPNG},
- {HashImagePng, HttpContentType::kPNG},
- {HashIcon, HttpContentType::kIcon},
- {HashImageIcon, HttpContentType::kIcon},
- // clang-format on
-};
-
-HttpContentType
-ParseContentTypeImpl(const std::string_view& ContentTypeString)
-{
- if (!ContentTypeString.empty())
- {
- const uint32_t CtHash = HashStringDjb2(ContentTypeString);
-
- if (auto It = std::lower_bound(std::begin(TypeHashTable),
- std::end(TypeHashTable),
- CtHash,
- [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; });
- It != std::end(TypeHashTable))
- {
- if (It->Hash == CtHash)
- {
- return It->Type;
- }
- }
- }
-
- return HttpContentType::kUnknownContentType;
-}
-
-HttpContentType
-ParseContentTypeInit(const std::string_view& ContentTypeString)
-{
- std::call_once(InitContentTypeLookup, [] {
- std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) {
- return Lhs.Hash < Rhs.Hash;
- });
-
- // validate that there are no hash collisions
-
- uint32_t LastHash = 0;
-
- for (const auto& Item : TypeHashTable)
- {
- ZEN_ASSERT(LastHash != Item.Hash);
- LastHash = Item.Hash;
- }
- });
-
- ParseContentType = ParseContentTypeImpl;
-
- return ParseContentTypeImpl(ContentTypeString);
-}
-
-HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit;
-
-bool
-TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
-{
- if (RangeHeader.empty())
- {
- return false;
- }
-
- const size_t Count = Ranges.size();
-
- std::size_t UnitDelim = RangeHeader.find_first_of('=');
- if (UnitDelim == std::string_view::npos)
- {
- return false;
- }
-
- // only bytes for now
- std::string_view Unit = RangeHeader.substr(0, UnitDelim);
- if (Unit != "bytes"sv)
- {
- return false;
- }
-
- std::string_view Tokens = RangeHeader.substr(UnitDelim);
- while (!Tokens.empty())
- {
- // Skip =,
- Tokens = Tokens.substr(1);
-
- size_t Delim = Tokens.find_first_of(',');
- if (Delim == std::string_view::npos)
- {
- Delim = Tokens.length();
- }
-
- std::string_view Token = Tokens.substr(0, Delim);
- Tokens = Tokens.substr(Delim);
-
- Delim = Token.find_first_of('-');
- if (Delim == std::string_view::npos)
- {
- return false;
- }
-
- const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim));
- const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1));
-
- if (Start.has_value() && End.has_value() && End.value() > Start.value())
- {
- Ranges.push_back({.Start = Start.value(), .End = End.value()});
- }
- else if (Start)
- {
- Ranges.push_back({.Start = Start.value()});
- }
- else if (End)
- {
- Ranges.push_back({.End = End.value()});
- }
- }
-
- return Count != Ranges.size();
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-const std::string_view
-ToString(HttpVerb Verb)
-{
- switch (Verb)
- {
- case HttpVerb::kGet:
- return "GET"sv;
- case HttpVerb::kPut:
- return "PUT"sv;
- case HttpVerb::kPost:
- return "POST"sv;
- case HttpVerb::kDelete:
- return "DELETE"sv;
- case HttpVerb::kHead:
- return "HEAD"sv;
- case HttpVerb::kCopy:
- return "COPY"sv;
- case HttpVerb::kOptions:
- return "OPTIONS"sv;
- default:
- return "???"sv;
- }
-}
-
-std::string_view
-ReasonStringForHttpResultCode(int HttpCode)
-{
- switch (HttpCode)
- {
- // 1xx Informational
-
- case 100:
- return "Continue"sv;
- case 101:
- return "Switching Protocols"sv;
-
- // 2xx Success
-
- case 200:
- return "OK"sv;
- case 201:
- return "Created"sv;
- case 202:
- return "Accepted"sv;
- case 204:
- return "No Content"sv;
- case 205:
- return "Reset Content"sv;
- case 206:
- return "Partial Content"sv;
-
- // 3xx Redirection
-
- case 300:
- return "Multiple Choices"sv;
- case 301:
- return "Moved Permanently"sv;
- case 302:
- return "Found"sv;
- case 303:
- return "See Other"sv;
- case 304:
- return "Not Modified"sv;
- case 305:
- return "Use Proxy"sv;
- case 306:
- return "Switch Proxy"sv;
- case 307:
- return "Temporary Redirect"sv;
- case 308:
- return "Permanent Redirect"sv;
-
- // 4xx Client errors
-
- case 400:
- return "Bad Request"sv;
- case 401:
- return "Unauthorized"sv;
- case 402:
- return "Payment Required"sv;
- case 403:
- return "Forbidden"sv;
- case 404:
- return "Not Found"sv;
- case 405:
- return "Method Not Allowed"sv;
- case 406:
- return "Not Acceptable"sv;
- case 407:
- return "Proxy Authentication Required"sv;
- case 408:
- return "Request Timeout"sv;
- case 409:
- return "Conflict"sv;
- case 410:
- return "Gone"sv;
- case 411:
- return "Length Required"sv;
- case 412:
- return "Precondition Failed"sv;
- case 413:
- return "Payload Too Large"sv;
- case 414:
- return "URI Too Long"sv;
- case 415:
- return "Unsupported Media Type"sv;
- case 416:
- return "Range Not Satisifiable"sv;
- case 417:
- return "Expectation Failed"sv;
- case 418:
- return "I'm a teapot"sv;
- case 421:
- return "Misdirected Request"sv;
- case 422:
- return "Unprocessable Entity"sv;
- case 423:
- return "Locked"sv;
- case 424:
- return "Failed Dependency"sv;
- case 425:
- return "Too Early"sv;
- case 426:
- return "Upgrade Required"sv;
- case 428:
- return "Precondition Required"sv;
- case 429:
- return "Too Many Requests"sv;
- case 431:
- return "Request Header Fields Too Large"sv;
-
- // 5xx Server errors
-
- case 500:
- return "Internal Server Error"sv;
- case 501:
- return "Not Implemented"sv;
- case 502:
- return "Bad Gateway"sv;
- case 503:
- return "Service Unavailable"sv;
- case 504:
- return "Gateway Timeout"sv;
- case 505:
- return "HTTP Version Not Supported"sv;
- case 506:
- return "Variant Also Negotiates"sv;
- case 507:
- return "Insufficient Storage"sv;
- case 508:
- return "Loop Detected"sv;
- case 510:
- return "Not Extended"sv;
- case 511:
- return "Network Authentication Required"sv;
-
- default:
- return "Unknown Result"sv;
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-Ref<IHttpPackageHandler>
-HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
-{
- ZEN_UNUSED(HttpServiceRequest);
-
- return Ref<IHttpPackageHandler>();
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpServerRequest::HttpServerRequest()
-{
-}
-
-HttpServerRequest::~HttpServerRequest()
-{
-}
-
-void
-HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data)
-{
- std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data);
- return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers);
-}
-
-void
-HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data)
-{
- if (m_AcceptType == HttpContentType::kJSON)
- {
- ExtendableStringBuilder<1024> Sb;
- WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView());
- }
- else
- {
- SharedBuffer Buf = Data.GetBuffer();
- std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())};
- return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers);
- }
-}
-
-void
-HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array)
-{
- if (m_AcceptType == HttpContentType::kJSON)
- {
- ExtendableStringBuilder<1024> Sb;
- WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView());
- }
- else
- {
- SharedBuffer Buf = Array.GetBuffer();
- std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())};
- return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers);
- }
-}
-
-void
-HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString)
-{
- return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()});
-}
-
-void
-HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob)
-{
- std::array<IoBuffer, 1> Buffers{Blob};
- return WriteResponse(ResponseCode, ContentType, Buffers);
-}
-
-void
-HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload)
-{
- std::span<const SharedBuffer> Segments = Payload.GetSegments();
-
- std::vector<IoBuffer> Buffers;
-
- for (auto& Segment : Segments)
- {
- Buffers.push_back(Segment.AsIoBuffer());
- }
-
- WriteResponse(ResponseCode, ContentType, Buffers);
-}
-
-HttpServerRequest::QueryParams
-HttpServerRequest::GetQueryParams()
-{
- QueryParams Params;
-
- const std::string_view QStr = QueryString();
-
- const char* QueryIt = QStr.data();
- const char* QueryEnd = QueryIt + QStr.size();
-
- while (QueryIt != QueryEnd)
- {
- if (*QueryIt == '&')
- {
- ++QueryIt;
- continue;
- }
-
- size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt);
- const std::string_view Query{QueryIt, QueryLen};
-
- size_t DelimIndex = Query.find('&', 0);
-
- if (DelimIndex == std::string_view::npos)
- {
- DelimIndex = Query.size();
- }
-
- std::string_view ThisQuery{QueryIt, DelimIndex};
-
- size_t EqIndex = ThisQuery.find('=', 0);
-
- if (EqIndex != std::string_view::npos)
- {
- std::string_view Param{ThisQuery.data(), EqIndex};
- ThisQuery.remove_prefix(EqIndex + 1);
-
- Params.KvPairs.emplace_back(Param, ThisQuery);
- }
-
- QueryIt += DelimIndex;
- }
-
- return Params;
-}
-
-Oid
-HttpServerRequest::SessionId() const
-{
- if (m_Flags & kHaveSessionId)
- {
- return m_SessionId;
- }
-
- m_SessionId = ParseSessionId();
- m_Flags |= kHaveSessionId;
- return m_SessionId;
-}
-
-uint32_t
-HttpServerRequest::RequestId() const
-{
- if (m_Flags & kHaveRequestId)
- {
- return m_RequestId;
- }
-
- m_RequestId = ParseRequestId();
- m_Flags |= kHaveRequestId;
- return m_RequestId;
-}
-
-CbObject
-HttpServerRequest::ReadPayloadObject()
-{
- if (IoBuffer Payload = ReadPayload())
- {
- return LoadCompactBinaryObject(std::move(Payload));
- }
-
- return {};
-}
-
-CbPackage
-HttpServerRequest::ReadPayloadPackage()
-{
- if (IoBuffer Payload = ReadPayload())
- {
- return ParsePackageMessage(std::move(Payload));
- }
-
- return {};
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-void
-HttpRequestRouter::AddPattern(const char* Id, const char* Regex)
-{
- ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end());
-
- m_PatternMap.insert({Id, Regex});
-}
-
-void
-HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs)
-{
- ExtendableStringBuilder<128> ExpandedRegex;
- ProcessRegexSubstitutions(Regex, ExpandedRegex);
-
- m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex);
-}
-
-void
-HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex)
-{
- size_t RegexLen = strlen(Regex);
-
- for (size_t i = 0; i < RegexLen;)
- {
- bool matched = false;
-
- if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\')))
- {
- // Might have a pattern reference - find closing brace
-
- for (size_t j = i + 1; j < RegexLen; ++j)
- {
- if (Regex[j] == '}')
- {
- std::string Pattern(&Regex[i + 1], j - i - 1);
-
- if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end())
- {
- OutExpandedRegex.Append(it->second.c_str());
- }
- else
- {
- // Default to anything goes (or should this just be an error?)
-
- OutExpandedRegex.Append("(.+?)");
- }
-
- // skip ahead
- i = j + 1;
-
- matched = true;
-
- break;
- }
- }
- }
-
- if (!matched)
- {
- OutExpandedRegex.Append(Regex[i++]);
- }
- }
-}
-
-bool
-HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
-{
- const HttpVerb Verb = Request.RequestVerb();
-
- std::string_view Uri = Request.RelativeUri();
- HttpRouterRequest RouterRequest(Request);
-
- for (const auto& Handler : m_Handlers)
- {
- if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx))
- {
- Handler.Handler(RouterRequest);
-
- return true; // Route matched
- }
- }
-
- return false; // No route matched
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpRpcHandler::HttpRpcHandler()
-{
-}
-
-HttpRpcHandler::~HttpRpcHandler()
-{
-}
-
-void
-HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction)
-{
- ZEN_UNUSED(RpcId, HandlerFunction);
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-enum class HttpServerClass
-{
- kHttpAsio,
- kHttpSys,
- kHttpNull
-};
-
-// Implemented in httpsys.cpp
-Ref<HttpServer> CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads);
-
-Ref<HttpServer>
-CreateHttpServer(std::string_view ServerClass)
-{
- using namespace std::literals;
-
- HttpServerClass Class = HttpServerClass::kHttpNull;
-
-#if ZEN_WITH_HTTPSYS
- Class = HttpServerClass::kHttpSys;
-#elif 1
- Class = HttpServerClass::kHttpAsio;
-#endif
-
- if (ServerClass == "asio"sv)
- {
- Class = HttpServerClass::kHttpAsio;
- }
- else if (ServerClass == "httpsys"sv)
- {
- Class = HttpServerClass::kHttpSys;
- }
- else if (ServerClass == "null"sv)
- {
- Class = HttpServerClass::kHttpNull;
- }
-
- switch (Class)
- {
- default:
- case HttpServerClass::kHttpAsio:
- ZEN_INFO("using asio HTTP server implementation");
- return Ref<HttpServer>(new HttpAsioServer());
-
-#if ZEN_WITH_HTTPSYS
- case HttpServerClass::kHttpSys:
- ZEN_INFO("using http.sys server implementation");
- return Ref<HttpServer>(new HttpSysServer(std::thread::hardware_concurrency(), /* background worker threads */ 16));
-#endif
-
- case HttpServerClass::kHttpNull:
- ZEN_INFO("using null HTTP server implementation");
- return Ref<HttpServer>(new HttpNullServer);
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-bool
-HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef)
-{
- if (Request.RequestVerb() == HttpVerb::kPost)
- {
- if (Request.RequestContentType() == HttpContentType::kCbPackageOffer)
- {
- // The client is presenting us with a package attachments offer, we need
- // to filter it down to the list of attachments we need them to send in
- // the follow-up request
-
- PackageHandlerRef = Service.HandlePackageRequest(Request);
-
- if (PackageHandlerRef)
- {
- CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload());
-
- std::vector<IoHash> OfferCids;
-
- for (auto& CidEntry : OfferMessage["offer"])
- {
- if (!CidEntry.IsHash())
- {
- // Should yield bad request response?
-
- ZEN_WARN("found invalid entry in offer");
-
- continue;
- }
-
- OfferCids.push_back(CidEntry.AsHash());
- }
-
- ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size());
-
- PackageHandlerRef->FilterOffer(OfferCids);
-
- ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size());
-
- CbObjectWriter ResponseWriter;
- ResponseWriter.BeginArray("need");
-
- for (const IoHash& Cid : OfferCids)
- {
- ResponseWriter.AddHash(Cid);
- }
-
- ResponseWriter.EndArray();
-
- // Emit filter response
- Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
- return true;
- }
- }
- else if (Request.RequestContentType() == HttpContentType::kCbPackage)
- {
- // Process chunks in package request
-
- PackageHandlerRef = Service.HandlePackageRequest(Request);
-
- // TODO: this should really be done in a streaming fashion, currently this emulates
- // the intended flow from an API perspective
-
- if (PackageHandlerRef)
- {
- PackageHandlerRef->OnRequestBegin();
-
- auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer {
- return PackageHandlerRef->CreateTarget(Cid, Size);
- };
-
- CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer);
-
- PackageHandlerRef->OnRequestComplete();
- }
- }
- }
- return false;
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-#if ZEN_WITH_TESTS
-
-TEST_CASE("http.common")
-{
- using namespace std::literals;
-
- SUBCASE("router")
- {
- HttpRequestRouter r;
- r.AddPattern("a", "[[:alpha:]]+");
- r.RegisterRoute(
- "{a}",
- [&](auto) {},
- HttpVerb::kGet);
-
- // struct TestHttpServerRequest : public HttpServerRequest
- //{
- // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {}
- //};
-
- // TestHttpServerRequest req{};
- // r.HandleRequest(req);
- }
-
- SUBCASE("content-type")
- {
- for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i)
- {
- HttpContentType Ct{i};
-
- if (Ct != HttpContentType::kUnknownContentType)
- {
- CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct)));
- }
- }
- }
-}
-
-void
-http_forcelink()
-{
-}
-
-#endif
-
-} // namespace zen
diff --git a/zenhttp/httpshared.cpp b/zenhttp/httpshared.cpp
deleted file mode 100644
index 7aade56d2..000000000
--- a/zenhttp/httpshared.cpp
+++ /dev/null
@@ -1,809 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/httpshared.h>
-
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinarypackage.h>
-#include <zencore/compositebuffer.h>
-#include <zencore/filesystem.h>
-#include <zencore/fmtutils.h>
-#include <zencore/iobuffer.h>
-#include <zencore/iohash.h>
-#include <zencore/logging.h>
-#include <zencore/scopeguard.h>
-#include <zencore/stream.h>
-#include <zencore/testing.h>
-#include <zencore/testutils.h>
-
-#include <span>
-#include <vector>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <tsl/robin_map.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-namespace zen {
-
-const std::string_view HandlePrefix(":?#:");
-
-std::vector<IoBuffer>
-FormatPackageMessage(const CbPackage& Data, int TargetProcessPid)
-{
- return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessPid);
-}
-CompositeBuffer
-FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid)
-{
- return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessPid);
-}
-
-CompositeBuffer
-FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid)
-{
- std::vector<IoBuffer> Message = FormatPackageMessage(Data, Flags, TargetProcessPid);
-
- std::vector<SharedBuffer> Buffers;
-
- for (IoBuffer& Buf : Message)
- {
- Buffers.push_back(SharedBuffer(Buf));
- }
-
- return CompositeBuffer(std::move(Buffers));
-}
-
-std::vector<IoBuffer>
-FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid)
-{
- void* TargetProcessHandle = nullptr;
-#if ZEN_PLATFORM_WINDOWS
- std::vector<HANDLE> DuplicatedHandles;
- auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() {
- if (TargetProcessHandle == nullptr)
- {
- return;
- }
-
- for (HANDLE DuplicatedHandle : DuplicatedHandles)
- {
- HANDLE ClosingHandle;
- if (::DuplicateHandle((HANDLE)TargetProcessHandle,
- DuplicatedHandle,
- GetCurrentProcess(),
- &ClosingHandle,
- 0,
- FALSE,
- DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE)
- {
- ::CloseHandle(ClosingHandle);
- }
- }
- ::CloseHandle((HANDLE)TargetProcessHandle);
- TargetProcessHandle = nullptr;
- });
-
- if (EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && TargetProcessPid != 0)
- {
- TargetProcessHandle = OpenProcess(PROCESS_DUP_HANDLE, FALSE, TargetProcessPid);
- }
-#else
- ZEN_UNUSED(TargetProcessPid);
- void* DuplicatedHandles = nullptr;
-#endif // ZEN_PLATFORM_WINDOWS
-
- const std::span<const CbAttachment>& Attachments = Data.GetAttachments();
- std::vector<IoBuffer> ResponseBuffers;
-
- ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each
- // attachment is likely to consist of several buffers
-
- // Fixed size header
-
- CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())};
-
- ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr));
-
- // Attachment metadata array
-
- IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)};
- CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData());
-
- ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata
-
- // Root object
-
- IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer();
- ResponseBuffers.push_back(RootIoBuffer); // Root object
-
- *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()};
-
- // Attachment payloads
-
- auto MarshalLocal = [&AttachmentInfo, &ResponseBuffers](const std::string& Path8,
- CbAttachmentReferenceHeader& LocalRef,
- const IoHash& AttachmentHash,
- bool IsCompressed) {
- IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size());
-
- CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>();
- *RefHdr++ = LocalRef;
- memcpy(RefHdr, Path8.data(), Path8.size());
-
- *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(),
- .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef,
- .AttachmentHash = AttachmentHash};
-
- ResponseBuffers.push_back(std::move(RefBuffer));
- };
-
- tsl::robin_map<void*, std::string> FileNameMap;
-
- auto IsLocalRef = [&FileNameMap, &DuplicatedHandles](const CompositeBuffer& AttachmentBinary,
- bool DenyPartialLocalReferences,
- void* TargetProcessHandle,
- CbAttachmentReferenceHeader& LocalRef,
- std::string& Path8) -> bool {
- const SharedBuffer& Segment = AttachmentBinary.GetSegments().front();
- IoBufferFileReference Ref;
- const IoBuffer& SegmentBuffer = Segment.AsIoBuffer();
-
- if (!SegmentBuffer.GetFileReference(Ref))
- {
- return false;
- }
-
- if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile())
- {
- return false;
- }
-
- if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end())
- {
- Path8 = It->second;
- }
- else
- {
- bool UseFilePath = true;
-#if ZEN_PLATFORM_WINDOWS
- if (TargetProcessHandle != nullptr)
- {
- HANDLE TargetHandle = INVALID_HANDLE_VALUE;
- BOOL OK = ::DuplicateHandle(GetCurrentProcess(),
- Ref.FileHandle,
- (HANDLE)TargetProcessHandle,
- &TargetHandle,
- FILE_GENERIC_READ,
- FALSE,
- 0);
- if (OK)
- {
- DuplicatedHandles.push_back(TargetHandle);
- Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle));
- UseFilePath = false;
- }
- }
-#else // ZEN_PLATFORM_WINDOWS
- ZEN_UNUSED(TargetProcessHandle);
- // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to
- // deal with acceess rights etc.
-#endif // ZEN_PLATFORM_WINDOWS
- if (UseFilePath)
- {
- ExtendablePathBuilder<256> LocalRefFile;
- LocalRefFile.Append(std::filesystem::absolute(PathFromHandle(Ref.FileHandle)));
- Path8 = LocalRefFile.ToUtf8();
- }
- FileNameMap.insert_or_assign(Ref.FileHandle, Path8);
- }
-
- LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size());
- LocalRef.PayloadByteOffset = Ref.FileChunkOffset;
- LocalRef.PayloadByteSize = Ref.FileChunkSize;
-
- return true;
- };
-
- for (const CbAttachment& Attachment : Attachments)
- {
- if (Attachment.IsNull())
- {
- ZEN_NOT_IMPLEMENTED("Null attachments are not supported");
- }
- else if (CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary())
- {
- CompositeBuffer Compressed = AttachmentBuffer.GetCompressed();
- IoHash AttachmentHash = Attachment.GetHash();
-
- // If the data is either not backed by a file, or there are multiple
- // fragments then we cannot marshal it by local reference. We might
- // want/need to extend this in the future to allow multiple chunk
- // segments to be marshaled at once
-
- bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1);
- bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
- CbAttachmentReferenceHeader LocalRef;
- std::string Path8;
-
- if (MarshalByLocalRef)
- {
- MarshalByLocalRef = IsLocalRef(Compressed, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8);
- }
-
- if (MarshalByLocalRef)
- {
- const bool IsCompressed = true;
- bool IsHandle = false;
-#if ZEN_PLATFORM_WINDOWS
- IsHandle = Path8.starts_with(HandlePrefix);
-#endif
- MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed);
- ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize());
- }
- else
- {
- *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(),
- .Flags = CbAttachmentEntry::kIsCompressed,
- .AttachmentHash = AttachmentHash};
-
- for (const SharedBuffer& Segment : Compressed.GetSegments())
- {
- ResponseBuffers.push_back(Segment.AsIoBuffer());
- }
- }
- }
- else if (CbObject AttachmentObject = Attachment.AsObject())
- {
- IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer();
- ResponseBuffers.push_back(ObjIoBuffer);
-
- *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(),
- .Flags = CbAttachmentEntry::kIsObject,
- .AttachmentHash = Attachment.GetHash()};
- }
- else if (CompositeBuffer AttachmentBinary = Attachment.AsCompositeBinary())
- {
- IoHash AttachmentHash = Attachment.GetHash();
- bool MarshalByLocalRef =
- EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1);
- bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
-
- CbAttachmentReferenceHeader LocalRef;
- std::string Path8;
-
- if (MarshalByLocalRef)
- {
- MarshalByLocalRef = IsLocalRef(AttachmentBinary, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8);
- }
-
- if (MarshalByLocalRef)
- {
- const bool IsCompressed = false;
- bool IsHandle = false;
-#if ZEN_PLATFORM_WINDOWS
- IsHandle = Path8.starts_with(HandlePrefix);
-#endif
- MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed);
- ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize());
- }
- else
- {
- *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()};
-
- for (const SharedBuffer& Segment : AttachmentBinary.GetSegments())
- {
- ResponseBuffers.push_back(Segment.AsIoBuffer());
- }
- }
- }
- else
- {
- ZEN_NOT_IMPLEMENTED("Unknown attachment kind");
- }
- }
- FileNameMap.clear();
-#if ZEN_PLATFORM_WINDOWS
- DuplicatedHandles.clear();
-#endif // ZEN_PLATFORM_WINDOWS
-
- return ResponseBuffers;
-}
-
-bool
-IsPackageMessage(IoBuffer Payload)
-{
- if (!Payload)
- {
- return false;
- }
-
- BinaryReader Reader(Payload);
-
- CbPackageHeader Hdr;
- Reader.Read(&Hdr, sizeof Hdr);
-
- if (Hdr.HeaderMagic != kCbPkgMagic)
- {
- return false;
- }
-
- return true;
-}
-
-CbPackage
-ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
-{
- if (!Payload)
- {
- return {};
- }
-
- BinaryReader Reader(Payload);
-
- CbPackageHeader Hdr;
- Reader.Read(&Hdr, sizeof Hdr);
-
- if (Hdr.HeaderMagic != kCbPkgMagic)
- {
- throw std::runtime_error("invalid CbPackage header magic");
- }
-
- const uint32_t ChunkCount = Hdr.AttachmentCount + 1;
-
- std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]};
-
- Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount);
-
- CbPackage Package;
-
- std::vector<CbAttachment> Attachments;
- Attachments.reserve(ChunkCount); // Guessing here...
-
- tsl::robin_map<std::string, IoBuffer> PartialFileBuffers;
-
- // TODO: Throwing before this loop completes could result in leaking handles as we might not have picked up all the handles in the
- // message
- for (uint32_t i = 0; i < ChunkCount; ++i)
- {
- const CbAttachmentEntry& Entry = AttachmentEntries[i];
- const uint64_t AttachmentSize = Entry.PayloadSize;
-
- const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize);
- Reader.Skip(AttachmentSize);
-
- if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
- {
- // Marshal local reference - a "pointer" to the chunk backing file
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
-
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
- std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
-
- IoBuffer FullFileBuffer;
-
- std::filesystem::path Path(Utf8ToWide(PathView));
- if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end())
- {
- FullFileBuffer = It->second;
- }
- else
- {
- if (PathView.starts_with(HandlePrefix))
- {
-#if ZEN_PLATFORM_WINDOWS
- std::string_view HandleString(PathView.substr(HandlePrefix.length()));
- std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString));
- if (HandleNumber.has_value())
- {
- HANDLE FileHandle = HANDLE(HandleNumber.value());
- ULARGE_INTEGER liFileSize;
- liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart);
- if (liFileSize.LowPart != INVALID_FILE_SIZE)
- {
- FullFileBuffer = IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart));
- PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer);
- }
- }
-#else // ZEN_PLATFORM_WINDOWS
- // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes
- // and to deal with acceess rights etc.
- ZEN_ASSERT(false);
-#endif // ZEN_PLATFORM_WINDOWS
- }
- else
- {
- FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
- }
- }
-
- if (!FullFileBuffer)
- {
- // Unable to open chunk reference
- throw std::runtime_error(fmt::format("unable to resolve chunk #{} at '{}' (offset {}, size {})",
- i,
- Path,
- AttachRefHdr->PayloadByteOffset,
- AttachRefHdr->PayloadByteSize));
- }
-
- IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
- ? FullFileBuffer
- : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
-
- CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference)));
- if (!CompBuf)
- {
- throw std::runtime_error(fmt::format("invalid format for chunk #{} at '{}' (offset {}, size {})",
- i,
- Path,
- AttachRefHdr->PayloadByteOffset,
- AttachRefHdr->PayloadByteSize));
- }
- Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
- }
- else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
- {
- if (Entry.Flags & CbAttachmentEntry::kIsObject)
- {
- if (i == 0)
- {
- CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
- if (!CompBuf)
- {
- throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for CbObject", i));
- }
- // First payload is always a compact binary object
- Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf)));
- }
- else
- {
- ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported");
- }
- }
- else
- {
- CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
- if (!CompBuf)
- {
- throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for attachment", i));
- }
- Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
- }
- }
- else /* not compressed */
- {
- if (Entry.Flags & CbAttachmentEntry::kIsObject)
- {
- if (i == 0)
- {
- Package.SetObject(LoadCompactBinaryObject(AttachmentBuffer));
- }
- else
- {
- ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported");
- }
- }
- else
- {
- // Make a copy of the buffer so we attachements don't reference the entire payload
- IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize);
- ZEN_ASSERT(AttachmentBufferCopy);
- ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
- AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
-
- CbAttachment Attachment(SharedBuffer{AttachmentBufferCopy});
- Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
- }
- }
- }
- PartialFileBuffers.clear();
-
- Package.AddAttachments(Attachments);
-
- return Package;
-}
-
-bool
-ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage)
-{
- if (IsPackageMessage(Response))
- {
- OutPackage = ParsePackageMessage(Response);
- return true;
- }
- return OutPackage.TryLoad(Response);
-}
-
-CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
-{
-}
-
-CbPackageReader::~CbPackageReader()
-{
-}
-
-void
-CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer)
-{
- m_CreateBuffer = CreateBuffer;
-}
-
-uint64_t
-CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
-{
- ZEN_ASSERT(m_CurrentState != State::kReadingBuffers);
-
- switch (m_CurrentState)
- {
- case State::kInitialState:
- ZEN_ASSERT(Data == nullptr);
- m_CurrentState = State::kReadingHeader;
- return sizeof m_PackageHeader;
-
- case State::kReadingHeader:
- ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
- memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
- ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
- m_CurrentState = State::kReadingAttachmentEntries;
- m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
- return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
-
- case State::kReadingAttachmentEntries:
- ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
- memcpy(m_AttachmentEntries.data(), Data, DataBytes);
-
- for (CbAttachmentEntry& Entry : m_AttachmentEntries)
- {
- // This preallocates memory for payloads but note that for the local references
- // the caller will need to handle the payload differently (i.e it's a
- // CbAttachmentReferenceHeader not the actual payload)
-
- m_PayloadBuffers.push_back(IoBuffer{Entry.PayloadSize});
- }
-
- m_CurrentState = State::kReadingBuffers;
- return 0;
-
- default:
- ZEN_ASSERT(false);
- return 0;
- }
-}
-
-IoBuffer
-CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
-{
- // Marshal local reference - a "pointer" to the chunk backing file
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
-
- const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
- const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
-
- ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
-
- std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
-
- std::filesystem::path Path{PathView};
-
- IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
-
- if (!ChunkReference)
- {
- // Unable to open chunk reference
-
- throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})",
- PathToUtf8(Path),
- AttachRefHdr->PayloadByteOffset,
- AttachRefHdr->PayloadByteSize));
- }
-
- return ChunkReference;
-};
-
-void
-CbPackageReader::Finalize()
-{
- if (m_AttachmentEntries.empty())
- {
- return;
- }
-
- m_Attachments.reserve(m_AttachmentEntries.size() - 1);
-
- int CurrentAttachmentIndex = 0;
- for (CbAttachmentEntry& Entry : m_AttachmentEntries)
- {
- IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
-
- if (CurrentAttachmentIndex == 0)
- {
- // Root object
- if (Entry.Flags & CbAttachmentEntry::kIsObject)
- {
- if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
- {
- m_RootObject = LoadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer));
- }
- else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
- {
- IoHash RawHash;
- uint64_t RawSize;
- CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize);
- if (RawHash == Entry.AttachmentHash)
- {
- m_RootObject = LoadCompactBinaryObject(Compressed);
- }
- }
- else
- {
- m_RootObject = LoadCompactBinaryObject(std::move(AttachmentBuffer));
- }
- }
- else
- {
- throw std::runtime_error("missing or invalid root object");
- }
- }
- else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
- {
- IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer);
-
- if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
- {
- IoHash RawHash;
- uint64_t RawSize;
- CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize);
- if (RawHash == Entry.AttachmentHash)
- {
- m_Attachments.push_back(CbAttachment(Compressed, Entry.AttachmentHash));
- }
- }
- else
- {
- CompressedBuffer Compressed =
- CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None);
- m_Attachments.push_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash()));
- }
- }
-
- ++CurrentAttachmentIndex;
- }
-}
-
-/**
- ______________________ _____________________________
- \__ ___/\_ _____// _____/\__ ___/ _____/
- | | | __)_ \_____ \ | | \_____ \
- | | | \/ \ | | / \
- |____| /_______ /_______ / |____| /_______ /
- \/ \/ \/
- */
-
-#if ZEN_WITH_TESTS
-
-TEST_CASE("CbPackage.Serialization")
-{
- // Make a test package
-
- CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))};
- CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))};
-
- CbObjectWriter Cbo;
- Cbo.AddAttachment("abcd", Attach1);
- Cbo.AddAttachment("efgh", Attach2);
-
- CbPackage Pkg;
- Pkg.AddAttachment(Attach1);
- Pkg.AddAttachment(Attach2);
- Pkg.SetObject(Cbo.Save());
-
- SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten();
- const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
- uint64_t RemainingBytes = Buffer.GetSize();
-
- auto ConsumeBytes = [&](uint64_t ByteCount) {
- ZEN_ASSERT(ByteCount <= RemainingBytes);
- void* ReturnPtr = (void*)CursorPtr;
- CursorPtr += ByteCount;
- RemainingBytes -= ByteCount;
- return ReturnPtr;
- };
-
- auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
- ZEN_ASSERT(ByteCount <= RemainingBytes);
- memcpy(TargetBuffer, CursorPtr, ByteCount);
- CursorPtr += ByteCount;
- RemainingBytes -= ByteCount;
- };
-
- CbPackageReader Reader;
- uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
- uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
- NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
- auto Buffers = Reader.GetPayloadBuffers();
-
- for (auto& PayloadBuffer : Buffers)
- {
- CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
- }
-
- Reader.Finalize();
-}
-
-TEST_CASE("CbPackage.LocalRef")
-{
- ScopedTemporaryDirectory TempDir;
-
- auto Path1 = TempDir.Path() / "abcd";
- auto Path2 = TempDir.Path() / "efgh";
-
- {
- IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd"));
- IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh"));
-
- WriteFile(Path1, Buffer1);
- WriteFile(Path2, Buffer2);
- }
-
- // Make a test package
-
- IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1);
- IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2);
-
- CbAttachment Attach1{SharedBuffer(FileBuffer1)};
- CbAttachment Attach2{SharedBuffer(FileBuffer2)};
-
- CbObjectWriter Cbo;
- Cbo.AddAttachment("abcd", Attach1);
- Cbo.AddAttachment("efgh", Attach2);
-
- CbPackage Pkg;
- Pkg.AddAttachment(Attach1);
- Pkg.AddAttachment(Attach2);
- Pkg.SetObject(Cbo.Save());
-
- SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
- const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
- uint64_t RemainingBytes = Buffer.GetSize();
-
- auto ConsumeBytes = [&](uint64_t ByteCount) {
- ZEN_ASSERT(ByteCount <= RemainingBytes);
- void* ReturnPtr = (void*)CursorPtr;
- CursorPtr += ByteCount;
- RemainingBytes -= ByteCount;
- return ReturnPtr;
- };
-
- auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
- ZEN_ASSERT(ByteCount <= RemainingBytes);
- memcpy(TargetBuffer, CursorPtr, ByteCount);
- CursorPtr += ByteCount;
- RemainingBytes -= ByteCount;
- };
-
- CbPackageReader Reader;
- uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
- uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
- NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
- auto Buffers = Reader.GetPayloadBuffers();
-
- for (auto& PayloadBuffer : Buffers)
- {
- CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
- }
-
- Reader.Finalize();
-}
-
-void
-forcelink_httpshared()
-{
-}
-
-#endif
-
-} // namespace zen
diff --git a/zenhttp/httpsys.cpp b/zenhttp/httpsys.cpp
deleted file mode 100644
index c733d618d..000000000
--- a/zenhttp/httpsys.cpp
+++ /dev/null
@@ -1,1674 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "httpsys.h"
-
-#include <zencore/compactbinary.h>
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinarypackage.h>
-#include <zencore/except.h>
-#include <zencore/logging.h>
-#include <zencore/scopeguard.h>
-#include <zencore/string.h>
-#include <zencore/timer.h>
-#include <zenhttp/httpshared.h>
-
-#if ZEN_WITH_HTTPSYS
-
-# include <conio.h>
-# include <mstcpip.h>
-# pragma comment(lib, "httpapi.lib")
-
-std::wstring
-UTF8_to_UTF16(const char* InPtr)
-{
- std::wstring OutString;
- unsigned int Codepoint;
-
- while (*InPtr != 0)
- {
- unsigned char InChar = static_cast<unsigned char>(*InPtr);
-
- if (InChar <= 0x7f)
- Codepoint = InChar;
- else if (InChar <= 0xbf)
- Codepoint = (Codepoint << 6) | (InChar & 0x3f);
- else if (InChar <= 0xdf)
- Codepoint = InChar & 0x1f;
- else if (InChar <= 0xef)
- Codepoint = InChar & 0x0f;
- else
- Codepoint = InChar & 0x07;
-
- ++InPtr;
-
- if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff))
- {
- if (Codepoint > 0xffff)
- {
- OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10)));
- OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff)));
- }
- else if (Codepoint < 0xd800 || Codepoint >= 0xe000)
- {
- OutString.append(1, static_cast<wchar_t>(Codepoint));
- }
- }
- }
-
- return OutString;
-}
-
-namespace zen {
-
-using namespace std::literals;
-
-class HttpSysServer;
-class HttpSysTransaction;
-class HttpMessageResponseRequest;
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpVerb
-TranslateHttpVerb(HTTP_VERB ReqVerb)
-{
- switch (ReqVerb)
- {
- case HttpVerbOPTIONS:
- return HttpVerb::kOptions;
-
- case HttpVerbGET:
- return HttpVerb::kGet;
-
- case HttpVerbHEAD:
- return HttpVerb::kHead;
-
- case HttpVerbPOST:
- return HttpVerb::kPost;
-
- case HttpVerbPUT:
- return HttpVerb::kPut;
-
- case HttpVerbDELETE:
- return HttpVerb::kDelete;
-
- case HttpVerbCOPY:
- return HttpVerb::kCopy;
-
- default:
- // TODO: invalid request?
- return (HttpVerb)0;
- }
-}
-
-uint64_t
-GetContentLength(const HTTP_REQUEST* HttpRequest)
-{
- const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength];
- std::string_view cl(clh.pRawValue, clh.RawValueLength);
- uint64_t ContentLength = 0;
- std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength);
- return ContentLength;
-};
-
-HttpContentType
-GetContentType(const HTTP_REQUEST* HttpRequest)
-{
- const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType];
- return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
-};
-
-HttpContentType
-GetAcceptType(const HTTP_REQUEST* HttpRequest)
-{
- const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept];
- return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
-};
-
-/**
- * @brief Base class for any pending or active HTTP transactions
- */
-class HttpSysRequestHandler
-{
-public:
- explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {}
- virtual ~HttpSysRequestHandler() = default;
-
- virtual void IssueRequest(std::error_code& ErrorCode) = 0;
- virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0;
- HttpSysTransaction& Transaction() { return m_Transaction; }
-
- HttpSysRequestHandler(const HttpSysRequestHandler&) = delete;
- HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete;
-
-private:
- HttpSysTransaction& m_Transaction;
-};
-
-/**
- * This is the handler for the initial HTTP I/O request which will receive the headers
- * and however much of the remaining payload might fit in the embedded request buffer.
- *
- * It is also used to receive any entity body data relating to the request
- *
- */
-struct InitialRequestHandler : public HttpSysRequestHandler
-{
- inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; }
- inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; }
- inline bool IsInitialRequest() const { return m_IsInitialRequest; }
-
- InitialRequestHandler(HttpSysTransaction& InRequest);
- ~InitialRequestHandler();
-
- virtual void IssueRequest(std::error_code& ErrorCode) override final;
- virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
-
- bool m_IsInitialRequest = true;
- uint64_t m_CurrentPayloadOffset = 0;
- uint64_t m_ContentLength = ~uint64_t(0);
- IoBuffer m_PayloadBuffer;
- UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)];
-};
-
-/**
- * This is the class which request handlers use to interact with the server instance
- */
-
-class HttpSysServerRequest : public HttpServerRequest
-{
-public:
- HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer);
- ~HttpSysServerRequest() = default;
-
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
-
- virtual IoBuffer ReadPayload() override;
- virtual void WriteResponse(HttpResponseCode ResponseCode) override;
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
- virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
- virtual bool TryGetRanges(HttpRanges& Ranges) override;
-
- using HttpServerRequest::WriteResponse;
-
- HttpSysServerRequest(const HttpSysServerRequest&) = delete;
- HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
-
- HttpSysTransaction& m_HttpTx;
- HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
- IoBuffer m_PayloadBuffer;
- ExtendableStringBuilder<128> m_UriUtf8;
- ExtendableStringBuilder<128> m_QueryStringUtf8;
-};
-
-/** HTTP transaction
-
- There will be an instance of this per pending and in-flight HTTP transaction
-
- */
-class HttpSysTransaction final
-{
-public:
- HttpSysTransaction(HttpSysServer& Server);
- virtual ~HttpSysTransaction();
-
- enum class Status
- {
- kDone,
- kRequestPending
- };
-
- Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
-
- static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
- PVOID pContext /* HttpSysServer */,
- PVOID pOverlapped,
- ULONG IoResult,
- ULONG_PTR NumberOfBytesTransferred,
- PTP_IO Io);
-
- void IssueInitialRequest(std::error_code& ErrorCode);
- bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler);
-
- PTP_IO Iocp();
- HANDLE RequestQueueHandle();
- inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
- inline HttpSysServer& Server() { return m_HttpServer; }
- inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
-
- HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload);
-
- HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); }
-
-private:
- OVERLAPPED m_HttpOverlapped{};
- HttpSysServer& m_HttpServer;
-
- // Tracks which handler is due to handle the next I/O completion event
- HttpSysRequestHandler* m_CompletionHandler = nullptr;
- RwLock m_CompletionMutex;
- InitialRequestHandler m_InitialHttpHandler{*this};
- std::optional<HttpSysServerRequest> m_HandlerRequest;
- Ref<IHttpPackageHandler> m_PackageHandler;
-};
-
-/**
- * @brief HTTP request response I/O request handler
- *
- * Asynchronously streams out a response to an HTTP request via compound
- * responses from memory or directly from file
- */
-
-class HttpMessageResponseRequest : public HttpSysRequestHandler
-{
-public:
- HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode);
- HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message);
- HttpMessageResponseRequest(HttpSysTransaction& InRequest,
- uint16_t ResponseCode,
- HttpContentType ContentType,
- const void* Payload,
- size_t PayloadSize);
- HttpMessageResponseRequest(HttpSysTransaction& InRequest,
- uint16_t ResponseCode,
- HttpContentType ContentType,
- std::span<IoBuffer> Blobs);
- ~HttpMessageResponseRequest();
-
- virtual void IssueRequest(std::error_code& ErrorCode) override final;
- virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
- void SuppressResponseBody(); // typically used for HEAD requests
-
-private:
- std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks;
- uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
- uint16_t m_ResponseCode = 0;
- uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists
- uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
- bool m_IsInitialResponse = true;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- std::vector<IoBuffer> m_DataBuffers;
-
- void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
-};
-
-HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode)
-: HttpSysRequestHandler(InRequest)
-{
- std::array<IoBuffer, 0> EmptyBufferList;
-
- InitializeForPayload(ResponseCode, EmptyBufferList);
-}
-
-HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message)
-: HttpSysRequestHandler(InRequest)
-, m_ContentType(HttpContentType::kText)
-{
- IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size());
- std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
-
- InitializeForPayload(ResponseCode, SingleBufferList);
-}
-
-HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
- uint16_t ResponseCode,
- HttpContentType ContentType,
- const void* Payload,
- size_t PayloadSize)
-: HttpSysRequestHandler(InRequest)
-, m_ContentType(ContentType)
-{
- IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize);
- std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
-
- InitializeForPayload(ResponseCode, SingleBufferList);
-}
-
-HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
- uint16_t ResponseCode,
- HttpContentType ContentType,
- std::span<IoBuffer> BlobList)
-: HttpSysRequestHandler(InRequest)
-, m_ContentType(ContentType)
-{
- InitializeForPayload(ResponseCode, BlobList);
-}
-
-HttpMessageResponseRequest::~HttpMessageResponseRequest()
-{
-}
-
-void
-HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
-{
- const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
-
- m_HttpDataChunks.reserve(ChunkCount);
- m_DataBuffers.reserve(ChunkCount);
-
- for (IoBuffer& Buffer : BlobList)
- {
- m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
- }
-
- // Initialize the full array up front
-
- uint64_t LocalDataSize = 0;
-
- for (IoBuffer& Buffer : m_DataBuffers)
- {
- uint64_t BufferDataSize = Buffer.Size();
-
- ZEN_ASSERT(BufferDataSize);
-
- LocalDataSize += BufferDataSize;
-
- IoBufferFileReference FileRef;
- if (Buffer.GetFileReference(/* out */ FileRef))
- {
- // Use direct file transfer
-
- m_HttpDataChunks.push_back({});
- auto& Chunk = m_HttpDataChunks.back();
-
- Chunk.DataChunkType = HttpDataChunkFromFileHandle;
- Chunk.FromFileHandle.FileHandle = FileRef.FileHandle;
- Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset;
- Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize;
- }
- else
- {
- // Send from memory, need to make sure we chunk the buffer up since
- // the underlying data structure only accepts 32-bit chunk sizes for
- // memory chunks. When this happens the vector will be reallocated,
- // which is fine since this will be a pretty rare case and sending
- // the data is going to take a lot longer than a memory allocation :)
-
- const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data());
-
- while (BufferDataSize)
- {
- const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize));
-
- m_HttpDataChunks.push_back({});
- auto& Chunk = m_HttpDataChunks.back();
-
- Chunk.DataChunkType = HttpDataChunkFromMemory;
- Chunk.FromMemory.pBuffer = (void*)WriteCursor;
- Chunk.FromMemory.BufferLength = ThisChunkSize;
-
- BufferDataSize -= ThisChunkSize;
- WriteCursor += ThisChunkSize;
- }
- }
- }
-
- m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size());
- m_TotalDataSize = LocalDataSize;
-
- if (m_TotalDataSize == 0 && ResponseCode == 200)
- {
- // Some HTTP clients really don't like empty responses unless a 204 response is sent
- m_ResponseCode = uint16_t(HttpResponseCode::NoContent);
- }
- else
- {
- m_ResponseCode = ResponseCode;
- }
-}
-
-void
-HttpMessageResponseRequest::SuppressResponseBody()
-{
- m_RemainingChunkCount = 0;
- m_HttpDataChunks.clear();
- m_DataBuffers.clear();
-}
-
-HttpSysRequestHandler*
-HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
-{
- ZEN_UNUSED(NumberOfBytesTransferred);
-
- if (IoResult != NO_ERROR)
- {
- ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult));
-
- // if one transmit failed there's really no need to go on
- return nullptr;
- }
-
- if (m_RemainingChunkCount == 0)
- {
- return nullptr; // All done
- }
-
- return this;
-}
-
-void
-HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
-{
- HttpSysTransaction& Tx = Transaction();
- HTTP_REQUEST* const HttpReq = Tx.HttpRequest();
- PTP_IO const Iocp = Tx.Iocp();
-
- StartThreadpoolIo(Iocp);
-
- // Split payload into batches to play well with the underlying API
-
- const int MaxChunksPerCall = 9999;
-
- const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall);
- const int ThisRequestChunkOffset = m_NextDataChunkOffset;
-
- m_RemainingChunkCount -= ThisRequestChunkCount;
- m_NextDataChunkOffset += ThisRequestChunkCount;
-
- /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA?
-
- From the docs:
-
- This flag enables buffering of data in the kernel on a per-response basis. It should
- be used by an application doing synchronous I/O, or by a an application doing
- asynchronous I/O with no more than one send outstanding at a time.
-
- Applications using asynchronous I/O which may have more than one send outstanding at
- a time should not use this flag.
-
- When this flag is set, it should be used consistently in calls to the
- HttpSendHttpResponse function as well.
- */
-
- ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA;
-
- if (m_RemainingChunkCount)
- {
- // We need to make more calls to send the full amount of data
- SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
- }
-
- ULONG SendResult = 0;
-
- if (m_IsInitialResponse)
- {
- // Populate response structure
-
- HTTP_RESPONSE HttpResponse = {};
-
- HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount);
- HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset;
-
- // Server header
- //
- // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header
- //
- // This is controlled via a registry key 'DisableServerHeader', at:
- //
- // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters
- //
- // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether
- // (only the latter appears to do anything in my testing, on Windows 10).
- //
- // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header)
- //
-
- PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer];
- ServerHeader->pRawValue = "Zen";
- ServerHeader->RawValueLength = (USHORT)3;
-
- // Content-length header
-
- char ContentLengthString[32];
- _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10);
-
- PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength];
- ContentLengthHeader->pRawValue = ContentLengthString;
- ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString);
-
- // Content-type header
-
- PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
-
- std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
-
- ContentTypeHeader->pRawValue = ContentTypeString.data();
- ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
-
- std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
-
- HttpResponse.StatusCode = m_ResponseCode;
- HttpResponse.pReason = ReasonString.data();
- HttpResponse.ReasonLength = (USHORT)ReasonString.size();
-
- // Cache policy
-
- HTTP_CACHE_POLICY CachePolicy;
-
- CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
- CachePolicy.SecondsToLive = 0;
-
- // Initial response API call
-
- SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
- &HttpResponse,
- &CachePolicy,
- NULL,
- NULL,
- 0,
- Tx.Overlapped(),
- NULL);
-
- m_IsInitialResponse = false;
- }
- else
- {
- // Subsequent response API calls
-
- SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- SendFlags,
- (USHORT)ThisRequestChunkCount, // EntityChunkCount
- &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
- NULL, // BytesSent
- NULL, // Reserved1
- 0, // Reserved2
- Tx.Overlapped(), // Overlapped
- NULL // LogData
- );
- }
-
- if (SendResult == NO_ERROR)
- {
- // Synchronous completion, but the completion event will still be posted to IOCP
-
- ErrorCode.clear();
- }
- else if (SendResult == ERROR_IO_PENDING)
- {
- // Asynchronous completion, a completion notification will be posted to IOCP
-
- ErrorCode.clear();
- }
- else
- {
- // An error occurred, no completion will be posted to IOCP
-
- CancelThreadpoolIo(Iocp);
-
- ZEN_WARN("failed to send HTTP response (error: '{}'), request URL: '{}', request id: {}",
- GetSystemErrorAsString(SendResult),
- HttpReq->pRawUrl,
- HttpReq->RequestId);
-
- ErrorCode = MakeErrorCode(SendResult);
- }
-}
-
-/** HTTP completion handler for async work
-
- This is used to allow work to be taken off the request handler threads
- and to support posting responses asynchronously.
- */
-
-class HttpAsyncWorkRequest : public HttpSysRequestHandler
-{
-public:
- HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response);
- ~HttpAsyncWorkRequest();
-
- virtual void IssueRequest(std::error_code& ErrorCode) override final;
- virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
-
-private:
- struct AsyncWorkItem : public IWork
- {
- virtual void Execute() override;
-
- AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler)
- : Tx(InTx)
- , Handler(std::move(InHandler))
- {
- }
-
- HttpSysTransaction& Tx;
- std::function<void(HttpServerRequest&)> Handler;
- };
-
- Ref<AsyncWorkItem> m_WorkItem;
-};
-
-HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response)
-: HttpSysRequestHandler(Tx)
-{
- m_WorkItem = new AsyncWorkItem(Tx, std::move(Response));
-}
-
-HttpAsyncWorkRequest::~HttpAsyncWorkRequest()
-{
-}
-
-void
-HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode)
-{
- ErrorCode.clear();
-
- Transaction().Server().WorkPool().ScheduleWork(m_WorkItem);
-}
-
-HttpSysRequestHandler*
-HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
-{
- // This ought to not be called since there should be no outstanding I/O request
- // when this completion handler is active
-
- ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
-
- ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
-
- return this;
-}
-
-void
-HttpAsyncWorkRequest::AsyncWorkItem::Execute()
-{
- try
- {
- HttpSysServerRequest& ThisRequest = Tx.ServerRequest();
-
- ThisRequest.m_NextCompletionHandler = nullptr;
-
- Handler(ThisRequest);
-
- // TODO: should Handler be destroyed at this point to ensure there
- // are no outstanding references into state which could be
- // deleted asynchronously as a result of issuing the response?
-
- if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler)
- {
- return (void)Tx.IssueNextRequest(NextHandler);
- }
- else if (!ThisRequest.IsHandled())
- {
- return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv));
- }
- else
- {
- // "Handled" but no request handler? Shouldn't ever happen
- return (void)Tx.IssueNextRequest(
- new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv));
- }
- }
- catch (std::exception& Ex)
- {
- return (void)Tx.IssueNextRequest(
- new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what())));
- }
-}
-
-/**
- _________
- / _____/ ______________ __ ___________
- \_____ \_/ __ \_ __ \ \/ // __ \_ __ \
- / \ ___/| | \/\ /\ ___/| | \/
- /_______ /\___ >__| \_/ \___ >__|
- \/ \/ \/
-*/
-
-HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount)
-: m_Log(logging::Get("http"))
-, m_RequestLog(logging::Get("http_requests"))
-, m_ThreadPool(ThreadCount)
-, m_AsyncWorkPool(AsyncWorkThreadCount)
-{
- ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr);
-
- if (Result != NO_ERROR)
- {
- return;
- }
-
- m_IsHttpInitialized = true;
- m_IsOk = true;
-
- ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount);
-}
-
-HttpSysServer::~HttpSysServer()
-{
- if (m_IsHttpInitialized)
- {
- Cleanup();
-
- HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr);
- }
-}
-
-int
-HttpSysServer::InitializeServer(int BasePort)
-{
- using namespace std::literals;
-
- WideStringBuilder<64> WildcardUrlPath;
- WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
-
- m_IsOk = false;
-
- ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
-
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
-
- return BasePort;
- }
-
- Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
-
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
-
- return BasePort;
- }
-
- int EffectivePort = BasePort;
-
- Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
-
- // Sharing violation implies the port is being used by another process
- for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
- {
- EffectivePort = BasePort + (PortOffset * 100);
- WildcardUrlPath.Reset();
- WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv;
-
- Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
- }
-
- m_BaseUris.clear();
- if (Result == NO_ERROR)
- {
- m_BaseUris.push_back(WildcardUrlPath.c_str());
- }
- else if (Result == ERROR_ACCESS_DENIED)
- {
- // If we can't register the wildcard path, we fall back to local paths
- // This local paths allow requests originating locally to function, but will not allow
- // remote origin requests to function. This can be remedied by using netsh
- // during an install process to grant permissions to route public access to the appropriate
- // port for the current user. eg:
- // netsh http add urlacl url=http://*:1337/ user=<some_user>
-
- ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath));
-
- const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
-
- ULONG InternalResult = ERROR_SHARING_VIOLATION;
- for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
- {
- EffectivePort = BasePort + (PortOffset * 100);
-
- for (const std::u8string_view Host : Hosts)
- {
- WideStringBuilder<64> LocalUrlPath;
- LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
-
- InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
-
- if (InternalResult == NO_ERROR)
- {
- ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath));
-
- m_BaseUris.push_back(LocalUrlPath.c_str());
- }
- else
- {
- break;
- }
- }
- }
- }
-
- if (m_BaseUris.empty())
- {
- ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
-
- return BasePort;
- }
-
- HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
-
- Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
- /* Name */ nullptr,
- /* SecurityAttributes */ nullptr,
- /* Flags */ 0,
- &m_RequestQueueHandle);
-
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
-
- return EffectivePort;
- }
-
- HttpBindingInfo.Flags.Present = 1;
- HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle;
-
- Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo));
-
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
-
- return EffectivePort;
- }
-
- // Create I/O completion port
-
- std::error_code ErrorCode;
- m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
-
- if (ErrorCode)
- {
- ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
- }
- else
- {
- m_IsOk = true;
-
- ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
- }
-
- return EffectivePort;
-}
-
-void
-HttpSysServer::Cleanup()
-{
- ++m_IsShuttingDown;
-
- if (m_RequestQueueHandle)
- {
- HttpCloseRequestQueue(m_RequestQueueHandle);
- m_RequestQueueHandle = nullptr;
- }
-
- if (m_HttpUrlGroupId)
- {
- HttpCloseUrlGroup(m_HttpUrlGroupId);
- m_HttpUrlGroupId = 0;
- }
-
- if (m_HttpSessionId)
- {
- HttpCloseServerSession(m_HttpSessionId);
- m_HttpSessionId = 0;
- }
-}
-
-void
-HttpSysServer::StartServer()
-{
- const int InitialRequestCount = 32;
-
- for (int i = 0; i < InitialRequestCount; ++i)
- {
- IssueNewRequestMaybe();
- }
-}
-
-void
-HttpSysServer::Run(bool IsInteractive)
-{
- if (IsInteractive)
- {
- zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit");
- }
-
- do
- {
- // int WaitTimeout = -1;
- int WaitTimeout = 100;
-
- if (IsInteractive)
- {
- WaitTimeout = 1000;
-
- if (_kbhit() != 0)
- {
- char c = (char)_getch();
-
- if (c == 27 || c == 'Q' || c == 'q')
- {
- RequestApplicationExit(0);
- }
- }
- }
-
- m_ShutdownEvent.Wait(WaitTimeout);
- UpdateLofreqTimerValue();
- } while (!IsApplicationExitRequested());
-}
-
-void
-HttpSysServer::OnHandlingRequest()
-{
- if (--m_PendingRequests > m_MinPendingRequests)
- {
- // We have more than the minimum number of requests pending, just let someone else
- // enqueue new requests
- return;
- }
-
- IssueNewRequestMaybe();
-}
-
-void
-HttpSysServer::IssueNewRequestMaybe()
-{
- if (m_IsShuttingDown.load(std::memory_order::acquire))
- {
- return;
- }
-
- if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests)
- {
- return;
- }
-
- std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this);
-
- std::error_code ErrorCode;
- Request->IssueInitialRequest(ErrorCode);
-
- if (ErrorCode)
- {
- // No request was actually issued. What is the appropriate response?
-
- return;
- }
-
- // This may end up exceeding the MaxPendingRequests limit, but it's not
- // really a problem. I'm doing it this way mostly to avoid dealing with
- // exceptions here
- ++m_PendingRequests;
-
- Request.release();
-}
-
-void
-HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
-{
- if (UrlPath[0] == '/')
- {
- ++UrlPath;
- }
-
- const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
- Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */);
-
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
-
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
-
- return;
- }
- }
-}
-
-void
-HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
-{
- ZEN_UNUSED(Service);
-
- if (UrlPath[0] == '/')
- {
- ++UrlPath;
- }
-
- const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
-
- // Convert to wide string
-
- for (const std::wstring& BaseUri : m_BaseUris)
- {
- std::wstring Url16 = BaseUri + PathUtf16;
-
- ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
-
- if (Result != NO_ERROR)
- {
- ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
- }
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler)
-{
-}
-
-HttpSysTransaction::~HttpSysTransaction()
-{
-}
-
-PTP_IO
-HttpSysTransaction::Iocp()
-{
- return m_HttpServer.m_ThreadPool.Iocp();
-}
-
-HANDLE
-HttpSysTransaction::RequestQueueHandle()
-{
- return m_HttpServer.m_RequestQueueHandle;
-}
-
-void
-HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode)
-{
- m_InitialHttpHandler.IssueRequest(ErrorCode);
-}
-
-void
-HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
- PVOID pContext /* HttpSysServer */,
- PVOID pOverlapped,
- ULONG IoResult,
- ULONG_PTR NumberOfBytesTransferred,
- PTP_IO Io)
-{
- UNREFERENCED_PARAMETER(Io);
- UNREFERENCED_PARAMETER(Instance);
- UNREFERENCED_PARAMETER(pContext);
-
- // Note that for a given transaction we may be in this completion function on more
- // than one thread at any given moment. This means we need to be careful about what
- // happens in here
-
- HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
-
- if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
- {
- delete Transaction;
- }
-}
-
-bool
-HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler)
-{
- HttpSysRequestHandler* CurrentHandler = m_CompletionHandler;
- m_CompletionHandler = NewCompletionHandler;
-
- auto _ = MakeGuard([this, CurrentHandler] {
- if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler))
- {
- delete CurrentHandler;
- }
- });
-
- if (NewCompletionHandler == nullptr)
- {
- return false;
- }
-
- try
- {
- std::error_code ErrorCode;
- m_CompletionHandler->IssueRequest(ErrorCode);
-
- if (!ErrorCode)
- {
- return true;
- }
-
- ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message());
- }
- catch (std::exception& Ex)
- {
- ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what());
- }
-
- // something went wrong, no request is pending
- m_CompletionHandler = nullptr;
-
- return false;
-}
-
-HttpSysTransaction::Status
-HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
-{
- // We use this to ensure sequential execution of completion handlers
- // for any given transaction. It also ensures all member variables are
- // in a consistent state for the current thread
-
- RwLock::ExclusiveLockScope _(m_CompletionMutex);
-
- bool IsRequestPending = false;
-
- if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler)
- {
- if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest())
- {
- // Ensure we have a sufficient number of pending requests outstanding
- m_HttpServer.OnHandlingRequest();
- }
-
- auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred);
-
- IsRequestPending = IssueNextRequest(NewCompletionHandler);
- }
-
- // Ensure new requests are enqueued as necessary
- m_HttpServer.IssueNewRequestMaybe();
-
- if (IsRequestPending)
- {
- // There is another request pending on this transaction, so it needs to remain valid
- return Status::kRequestPending;
- }
-
- if (m_HttpServer.m_IsRequestLoggingEnabled)
- {
- if (m_HandlerRequest.has_value())
- {
- m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri());
- }
- }
-
- // Transaction done, caller should clean up (delete) this instance
- return Status::kDone;
-}
-
-HttpSysServerRequest&
-HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
-{
- HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
-
- // Default request handling
-
- if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
- {
- Service.HandleRequest(ThisRequest);
- }
-
- return ThisRequest;
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer)
-: m_HttpTx(Tx)
-, m_PayloadBuffer(std::move(PayloadBuffer))
-{
- const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest();
-
- const int PrefixLength = Service.UriPrefixLength();
- const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t);
-
- HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
-
- if (AbsPathLength >= PrefixLength)
- {
- // We convert the URI immediately because most of the code involved prefers to deal
- // with utf8. This is overhead which I'd prefer to avoid but for now we just have
- // to live with it
-
- WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)},
- m_UriUtf8);
-
- std::string_view UriSuffix8{m_UriUtf8};
-
- m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access
- m_Uri = UriSuffix8;
-
- const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
-
- if (LastComponentIndex != std::string_view::npos)
- {
- UriSuffix8.remove_prefix(LastComponentIndex);
- }
-
- const size_t LastDotIndex = UriSuffix8.find_last_of('.');
-
- if (LastDotIndex != std::string_view::npos)
- {
- UriSuffix8.remove_prefix(LastDotIndex + 1);
-
- AcceptContentType = ParseContentType(UriSuffix8);
- if (AcceptContentType != HttpContentType::kUnknownContentType)
- {
- m_Uri.remove_suffix(UriSuffix8.size() + 1);
- }
- }
- }
- else
- {
- m_UriUtf8.Reset();
- m_Uri = {};
- m_UriWithExtension = {};
- }
-
- if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
- {
- --QueryStringLength; // We skip the leading question mark
-
- WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8);
- }
- else
- {
- m_QueryStringUtf8.Reset();
- }
-
- m_QueryString = std::string_view(m_QueryStringUtf8);
- m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb);
- m_ContentLength = GetContentLength(HttpRequestPtr);
- m_ContentType = GetContentType(HttpRequestPtr);
-
- // It an explicit content type extension was specified then we'll use that over any
- // Accept: header value that may be present
-
- if (AcceptContentType != HttpContentType::kUnknownContentType)
- {
- m_AcceptType = AcceptContentType;
- }
- else
- {
- m_AcceptType = GetAcceptType(HttpRequestPtr);
- }
-
- if (m_Verb == HttpVerb::kHead)
- {
- SetSuppressResponseBody();
- }
-}
-
-Oid
-HttpSysServerRequest::ParseSessionId() const
-{
- const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
-
- for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i)
- {
- HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i];
- std::string_view HeaderName{Header.pName, Header.NameLength};
-
- if (HeaderName == "UE-Session"sv)
- {
- if (Header.RawValueLength == Oid::StringLength)
- {
- return Oid::FromHexString({Header.pRawValue, Header.RawValueLength});
- }
- }
- }
-
- return {};
-}
-
-uint32_t
-HttpSysServerRequest::ParseRequestId() const
-{
- const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
-
- for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i)
- {
- HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i];
- std::string_view HeaderName{Header.pName, Header.NameLength};
-
- if (HeaderName == "UE-Request"sv)
- {
- std::string_view RequestValue{Header.pRawValue, Header.RawValueLength};
- uint32_t RequestId = 0;
- std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId);
- return RequestId;
- }
- }
-
- return 0;
-}
-
-IoBuffer
-HttpSysServerRequest::ReadPayload()
-{
- return m_PayloadBuffer;
-}
-
-void
-HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
-{
- ZEN_ASSERT(IsHandled() == false);
-
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
-
- if (SuppressBody())
- {
- Response->SuppressResponseBody();
- }
-
- m_NextCompletionHandler = Response;
-
- SetIsHandled();
-}
-
-void
-HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
-{
- ZEN_ASSERT(IsHandled() == false);
-
- auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
-
- if (SuppressBody())
- {
- Response->SuppressResponseBody();
- }
-
- m_NextCompletionHandler = Response;
-
- SetIsHandled();
-}
-
-void
-HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
-{
- ZEN_ASSERT(IsHandled() == false);
-
- auto Response =
- new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size());
-
- if (SuppressBody())
- {
- Response->SuppressResponseBody();
- }
-
- m_NextCompletionHandler = Response;
-
- SetIsHandled();
-}
-
-void
-HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
-{
- if (m_HttpTx.Server().IsAsyncResponseEnabled())
- {
- m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler));
- }
- else
- {
- ContinuationHandler(m_HttpTx.ServerRequest());
- }
-}
-
-bool
-HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges)
-{
- HTTP_REQUEST* Req = m_HttpTx.HttpRequest();
- const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange];
-
- return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges);
-}
-
-//////////////////////////////////////////////////////////////////////////
-
-InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest)
-{
-}
-
-InitialRequestHandler::~InitialRequestHandler()
-{
-}
-
-void
-InitialRequestHandler::IssueRequest(std::error_code& ErrorCode)
-{
- HttpSysTransaction& Tx = Transaction();
- PTP_IO Iocp = Tx.Iocp();
- HTTP_REQUEST* HttpReq = Tx.HttpRequest();
-
- StartThreadpoolIo(Iocp);
-
- ULONG HttpApiResult;
-
- if (IsInitialRequest())
- {
- HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(),
- HTTP_NULL_ID,
- HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
- HttpReq,
- RequestBufferSize(),
- NULL,
- Tx.Overlapped());
- }
- else
- {
- // The http.sys team recommends limiting the size to 128KB
- static const uint64_t kMaxBytesPerApiCall = 128 * 1024;
-
- uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset;
- const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall);
- void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset;
-
- HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(),
- HttpReq->RequestId,
- 0, /* Flags */
- BufferWriteCursor,
- gsl::narrow<ULONG>(BytesToReadThisCall),
- nullptr, // BytesReturned
- Tx.Overlapped());
- }
-
- if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR)
- {
- CancelThreadpoolIo(Iocp);
-
- ErrorCode = MakeErrorCode(HttpApiResult);
-
- ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message());
-
- return;
- }
-
- ErrorCode.clear();
-}
-
-HttpSysRequestHandler*
-InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
-{
- auto _ = MakeGuard([&] { m_IsInitialRequest = false; });
-
- switch (IoResult)
- {
- default:
- case ERROR_OPERATION_ABORTED:
- return nullptr;
-
- case ERROR_MORE_DATA: // Insufficient buffer space
- case NO_ERROR:
- break;
- }
-
- // Route request
-
- try
- {
- HTTP_REQUEST* HttpReq = HttpRequest();
-
-# if 0
- for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
- {
- auto& ReqInfo = HttpReq->pRequestInfo[i];
-
- switch (ReqInfo.InfoType)
- {
- case HttpRequestInfoTypeRequestTiming:
- {
- const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
-
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeAuth:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeChannelBind:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslProtocol:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBindingDraft:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeSslTokenBinding:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV0:
- {
- const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
-
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeRequestSizing:
- {
- const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
- ZEN_INFO("");
- }
- break;
- case HttpRequestInfoTypeQuicStats:
- ZEN_INFO("");
- break;
- case HttpRequestInfoTypeTcpInfoV1:
- {
- const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
-
- ZEN_INFO("");
- }
- break;
- }
- }
-# endif
-
- if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
- {
- if (m_IsInitialRequest)
- {
- m_ContentLength = GetContentLength(HttpReq);
- const HttpContentType ContentType = GetContentType(HttpReq);
-
- if (m_ContentLength)
- {
- // Handle initial chunk read by copying any payload which has already been copied
- // into our embedded request buffer
-
- m_PayloadBuffer = IoBuffer(m_ContentLength);
- m_PayloadBuffer.SetContentType(ContentType);
-
- uint64_t BytesToRead = m_ContentLength;
- uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData());
- uint8_t* BufferWriteCursor = BufferBase;
-
- const int EntityChunkCount = HttpReq->EntityChunkCount;
-
- for (int i = 0; i < EntityChunkCount; ++i)
- {
- HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i];
-
- ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory);
-
- const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength;
-
- ZEN_ASSERT(BufferLength <= BytesToRead);
-
- memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength);
-
- BufferWriteCursor += BufferLength;
- BytesToRead -= BufferLength;
- }
-
- m_CurrentPayloadOffset = BufferWriteCursor - BufferBase;
- }
- }
- else
- {
- m_CurrentPayloadOffset += NumberOfBytesTransferred;
- }
-
- if (m_CurrentPayloadOffset != m_ContentLength)
- {
- // Body not complete, issue another read request to receive more body data
- return this;
- }
-
- // Request body received completely
-
- m_PayloadBuffer.MakeImmutable();
-
- HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer);
-
- if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler)
- {
- return Response;
- }
-
- if (!ThisRequest.IsHandled())
- {
- return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
- }
- }
-
- // Unable to route
- return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
- }
- catch (std::exception& ex)
- {
- ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
-
- return new HttpMessageResponseRequest(Transaction(), 500, ex.what());
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-//
-// HttpServer interface implementation
-//
-
-int
-HttpSysServer::Initialize(int BasePort)
-{
- int EffectivePort = InitializeServer(BasePort);
- StartServer();
- return EffectivePort;
-}
-
-void
-HttpSysServer::RequestExit()
-{
- m_ShutdownEvent.Set();
-}
-void
-HttpSysServer::RegisterService(HttpService& Service)
-{
- RegisterService(Service.BaseUri(), Service);
-}
-
-Ref<HttpServer>
-CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads)
-{
- return Ref<HttpServer>(new HttpSysServer(Concurrency, BackgroundWorkerThreads));
-}
-
-} // namespace zen
-#endif
diff --git a/zenhttp/httpsys.h b/zenhttp/httpsys.h
deleted file mode 100644
index d6bd34890..000000000
--- a/zenhttp/httpsys.h
+++ /dev/null
@@ -1,90 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zenhttp/httpserver.h>
-
-#ifndef ZEN_WITH_HTTPSYS
-# if ZEN_PLATFORM_WINDOWS
-# define ZEN_WITH_HTTPSYS 1
-# else
-# define ZEN_WITH_HTTPSYS 0
-# endif
-#endif
-
-#if ZEN_WITH_HTTPSYS
-# define _WINSOCKAPI_
-# include <zencore/windows.h>
-# include <zencore/workthreadpool.h>
-# include "iothreadpool.h"
-
-# include <http.h>
-
-namespace spdlog {
-class logger;
-}
-
-namespace zen {
-
-/**
- * @brief Windows implementation of HTTP server based on http.sys
- *
- * This requires elevation to function
- */
-class HttpSysServer : public HttpServer
-{
- friend class HttpSysTransaction;
-
-public:
- explicit HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount);
- ~HttpSysServer();
-
- // HttpServer interface implementation
-
- virtual int Initialize(int BasePort) override;
- virtual void Run(bool TestMode) override;
- virtual void RequestExit() override;
- virtual void RegisterService(HttpService& Service) override;
-
- WorkerThreadPool& WorkPool() { return m_AsyncWorkPool; }
-
- inline bool IsOk() const { return m_IsOk; }
- inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
-
-private:
- int InitializeServer(int BasePort);
- void Cleanup();
-
- void StartServer();
- void OnHandlingRequest();
- void IssueNewRequestMaybe();
-
- void RegisterService(const char* Endpoint, HttpService& Service);
- void UnregisterService(const char* Endpoint, HttpService& Service);
-
-private:
- spdlog::logger& m_Log;
- spdlog::logger& m_RequestLog;
- spdlog::logger& Log() { return m_Log; }
-
- bool m_IsOk = false;
- bool m_IsHttpInitialized = false;
- bool m_IsRequestLoggingEnabled = false;
- bool m_IsAsyncResponseEnabled = true;
-
- WinIoThreadPool m_ThreadPool;
- WorkerThreadPool m_AsyncWorkPool;
-
- std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
- HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
- HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
- HANDLE m_RequestQueueHandle = 0;
- std::atomic_int32_t m_PendingRequests{0};
- std::atomic_int32_t m_IsShuttingDown{0};
- int32_t m_MinPendingRequests = 16;
- int32_t m_MaxPendingRequests = 128;
- Event m_ShutdownEvent;
-};
-
-} // namespace zen
-#endif
diff --git a/zenhttp/include/zenhttp/httpclient.h b/zenhttp/include/zenhttp/httpclient.h
deleted file mode 100644
index 8316a9b9f..000000000
--- a/zenhttp/include/zenhttp/httpclient.h
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include "zenhttp.h"
-
-#include <zencore/iobuffer.h>
-#include <zencore/uid.h>
-#include <zenhttp/httpcommon.h>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <cpr/cpr.h>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-namespace zen {
-
-class CbPackage;
-
-/** HTTP client implementation for Zen use cases
-
- Currently simple and synchronous, should become lean and asynchronous
- */
-class HttpClient
-{
-public:
- HttpClient(std::string_view BaseUri);
- ~HttpClient();
-
- struct Response
- {
- int StatusCode = 0;
- IoBuffer ResponsePayload; // Note: this also includes the content type
- };
-
- [[nodiscard]] Response Put(std::string_view Url, IoBuffer Payload);
- [[nodiscard]] Response Get(std::string_view Url);
- [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package);
- [[nodiscard]] Response Delete(std::string_view Url);
-
-private:
- std::string m_BaseUri;
- std::string m_SessionId;
-};
-
-} // namespace zen
-
-void httpclient_forcelink(); // internal
diff --git a/zenhttp/include/zenhttp/httpcommon.h b/zenhttp/include/zenhttp/httpcommon.h
deleted file mode 100644
index 19fda8db4..000000000
--- a/zenhttp/include/zenhttp/httpcommon.h
+++ /dev/null
@@ -1,181 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/iobuffer.h>
-
-#include <string_view>
-
-#include <gsl/gsl-lite.hpp>
-
-namespace zen {
-
-using HttpContentType = ZenContentType;
-
-class IoBuffer;
-class CbObject;
-class CbPackage;
-class StringBuilderBase;
-
-struct HttpRange
-{
- uint32_t Start = ~uint32_t(0);
- uint32_t End = ~uint32_t(0);
-};
-
-using HttpRanges = std::vector<HttpRange>;
-
-std::string_view MapContentTypeToString(HttpContentType ContentType);
-extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString);
-std::string_view ReasonStringForHttpResultCode(int HttpCode);
-bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges);
-
-[[nodiscard]] inline bool
-IsHttpSuccessCode(int HttpCode)
-{
- return (HttpCode >= 200) && (HttpCode < 300);
-}
-
-enum class HttpVerb : uint8_t
-{
- kGet = 1 << 0,
- kPut = 1 << 1,
- kPost = 1 << 2,
- kDelete = 1 << 3,
- kHead = 1 << 4,
- kCopy = 1 << 5,
- kOptions = 1 << 6
-};
-
-gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb);
-
-const std::string_view ToString(HttpVerb Verb);
-
-enum class HttpResponseCode
-{
- // 1xx - Informational
-
- Continue = 100, //!< Indicates that the initial part of a request has been received and has not yet been rejected by the server.
- SwitchingProtocols = 101, //!< Indicates that the server understands and is willing to comply with the client's request, via the
- //!< Upgrade header field, for a change in the application protocol being used on this connection.
- Processing = 102, //!< Is an interim response used to inform the client that the server has accepted the complete request, but has not
- //!< yet completed it.
- EarlyHints = 103, //!< Indicates to the client that the server is likely to send a final response with the header fields included in
- //!< the informational response.
-
- // 2xx - Successful
-
- OK = 200, //!< Indicates that the request has succeeded.
- Created = 201, //!< Indicates that the request has been fulfilled and has resulted in one or more new resources being created.
- Accepted = 202, //!< Indicates that the request has been accepted for processing, but the processing has not been completed.
- NonAuthoritativeInformation = 203, //!< Indicates that the request was successful but the enclosed payload has been modified from that
- //!< of the origin server's 200 (OK) response by a transforming proxy.
- NoContent = 204, //!< Indicates that the server has successfully fulfilled the request and that there is no additional content to send
- //!< in the response payload body.
- ResetContent = 205, //!< Indicates that the server has fulfilled the request and desires that the user agent reset the \"document
- //!< view\", which caused the request to be sent, to its original state as received from the origin server.
- PartialContent = 206, //!< Indicates that the server is successfully fulfilling a range request for the target resource by transferring
- //!< one or more parts of the selected representation that correspond to the satisfiable ranges found in the
- //!< requests's Range header field.
- MultiStatus = 207, //!< Provides status for multiple independent operations.
- AlreadyReported = 208, //!< Used inside a DAV:propstat response element to avoid enumerating the internal members of multiple bindings
- //!< to the same collection repeatedly. [RFC 5842]
- IMUsed = 226, //!< The server has fulfilled a GET request for the resource, and the response is a representation of the result of one
- //!< or more instance-manipulations applied to the current instance.
-
- // 3xx - Redirection
-
- MultipleChoices = 300, //!< Indicates that the target resource has more than one representation, each with its own more specific
- //!< identifier, and information about the alternatives is being provided so that the user (or user agent) can
- //!< select a preferred representation by redirecting its request to one or more of those identifiers.
- MovedPermanently = 301, //!< Indicates that the target resource has been assigned a new permanent URI and any future references to this
- //!< resource ought to use one of the enclosed URIs.
- Found = 302, //!< Indicates that the target resource resides temporarily under a different URI.
- SeeOther = 303, //!< Indicates that the server is redirecting the user agent to a different resource, as indicated by a URI in the
- //!< Location header field, that is intended to provide an indirect response to the original request.
- NotModified = 304, //!< Indicates that a conditional GET request has been received and would have resulted in a 200 (OK) response if it
- //!< were not for the fact that the condition has evaluated to false.
- UseProxy = 305, //!< \deprecated \parblock Due to security concerns regarding in-band configuration of a proxy. \endparblock
- //!< The requested resource MUST be accessed through the proxy given by the Location field.
- TemporaryRedirect = 307, //!< Indicates that the target resource resides temporarily under a different URI and the user agent MUST NOT
- //!< change the request method if it performs an automatic redirection to that URI.
- PermanentRedirect = 308, //!< The target resource has been assigned a new permanent URI and any future references to this resource
- //!< ought to use one of the enclosed URIs. [...] This status code is similar to 301 Moved Permanently
- //!< (Section 7.3.2 of rfc7231), except that it does not allow rewriting the request method from POST to GET.
-
- // 4xx - Client Error
- BadRequest = 400, //!< Indicates that the server cannot or will not process the request because the received syntax is invalid,
- //!< nonsensical, or exceeds some limitation on what the server is willing to process.
- Unauthorized = 401, //!< Indicates that the request has not been applied because it lacks valid authentication credentials for the
- //!< target resource.
- PaymentRequired = 402, //!< *Reserved*
- Forbidden = 403, //!< Indicates that the server understood the request but refuses to authorize it.
- NotFound = 404, //!< Indicates that the origin server did not find a current representation for the target resource or is not willing
- //!< to disclose that one exists.
- MethodNotAllowed = 405, //!< Indicates that the method specified in the request-line is known by the origin server but not supported by
- //!< the target resource.
- NotAcceptable = 406, //!< Indicates that the target resource does not have a current representation that would be acceptable to the
- //!< user agent, according to the proactive negotiation header fields received in the request, and the server is
- //!< unwilling to supply a default representation.
- ProxyAuthenticationRequired =
- 407, //!< Is similar to 401 (Unauthorized), but indicates that the client needs to authenticate itself in order to use a proxy.
- RequestTimeout =
- 408, //!< Indicates that the server did not receive a complete request message within the time that it was prepared to wait.
- Conflict = 409, //!< Indicates that the request could not be completed due to a conflict with the current state of the resource.
- Gone = 410, //!< Indicates that access to the target resource is no longer available at the origin server and that this condition is
- //!< likely to be permanent.
- LengthRequired = 411, //!< Indicates that the server refuses to accept the request without a defined Content-Length.
- PreconditionFailed =
- 412, //!< Indicates that one or more preconditions given in the request header fields evaluated to false when tested on the server.
- PayloadTooLarge = 413, //!< Indicates that the server is refusing to process a request because the request payload is larger than the
- //!< server is willing or able to process.
- URITooLong = 414, //!< Indicates that the server is refusing to service the request because the request-target is longer than the
- //!< server is willing to interpret.
- UnsupportedMediaType = 415, //!< Indicates that the origin server is refusing to service the request because the payload is in a format
- //!< not supported by the target resource for this method.
- RangeNotSatisfiable = 416, //!< Indicates that none of the ranges in the request's Range header field overlap the current extent of the
- //!< selected resource or that the set of ranges requested has been rejected due to invalid ranges or an
- //!< excessive request of small or overlapping ranges.
- ExpectationFailed = 417, //!< Indicates that the expectation given in the request's Expect header field could not be met by at least
- //!< one of the inbound servers.
- ImATeapot = 418, //!< Any attempt to brew coffee with a teapot should result in the error code 418 I'm a teapot.
- UnprocessableEntity = 422, //!< Means the server understands the content type of the request entity (hence a 415(Unsupported Media
- //!< Type) status code is inappropriate), and the syntax of the request entity is correct (thus a 400 (Bad
- //!< Request) status code is inappropriate) but was unable to process the contained instructions.
- Locked = 423, //!< Means the source or destination resource of a method is locked.
- FailedDependency = 424, //!< Means that the method could not be performed on the resource because the requested action depended on
- //!< another action and that action failed.
- UpgradeRequired = 426, //!< Indicates that the server refuses to perform the request using the current protocol but might be willing to
- //!< do so after the client upgrades to a different protocol.
- PreconditionRequired = 428, //!< Indicates that the origin server requires the request to be conditional.
- TooManyRequests = 429, //!< Indicates that the user has sent too many requests in a given amount of time (\"rate limiting\").
- RequestHeaderFieldsTooLarge =
- 431, //!< Indicates that the server is unwilling to process the request because its header fields are too large.
- UnavailableForLegalReasons =
- 451, //!< This status code indicates that the server is denying access to the resource in response to a legal demand.
-
- // 5xx - Server Error
-
- InternalServerError =
- 500, //!< Indicates that the server encountered an unexpected condition that prevented it from fulfilling the request.
- NotImplemented = 501, //!< Indicates that the server does not support the functionality required to fulfill the request.
- BadGateway = 502, //!< Indicates that the server, while acting as a gateway or proxy, received an invalid response from an inbound
- //!< server it accessed while attempting to fulfill the request.
- ServiceUnavailable = 503, //!< Indicates that the server is currently unable to handle the request due to a temporary overload or
- //!< scheduled maintenance, which will likely be alleviated after some delay.
- GatewayTimeout = 504, //!< Indicates that the server, while acting as a gateway or proxy, did not receive a timely response from an
- //!< upstream server it needed to access in order to complete the request.
- HTTPVersionNotSupported = 505, //!< Indicates that the server does not support, or refuses to support, the protocol version that was
- //!< used in the request message.
- VariantAlsoNegotiates =
- 506, //!< Indicates that the server has an internal configuration error: the chosen variant resource is configured to engage in
- //!< transparent content negotiation itself, and is therefore not a proper end point in the negotiation process.
- InsufficientStorage = 507, //!< Means the method could not be performed on the resource because the server is unable to store the
- //!< representation needed to successfully complete the request.
- LoopDetected = 508, //!< Indicates that the server terminated an operation because it encountered an infinite loop while processing a
- //!< request with "Depth: infinity". [RFC 5842]
- NotExtended = 510, //!< The policy for accessing the resource has not been met in the request. [RFC 2774]
- NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access.
-};
-
-} // namespace zen
diff --git a/zenhttp/include/zenhttp/httpserver.h b/zenhttp/include/zenhttp/httpserver.h
deleted file mode 100644
index 3b9fa50b4..000000000
--- a/zenhttp/include/zenhttp/httpserver.h
+++ /dev/null
@@ -1,315 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include "zenhttp.h"
-
-#include <zencore/compactbinary.h>
-#include <zencore/enumflags.h>
-#include <zencore/iobuffer.h>
-#include <zencore/iohash.h>
-#include <zencore/refcount.h>
-#include <zencore/string.h>
-#include <zencore/uid.h>
-#include <zenhttp/httpcommon.h>
-
-#include <functional>
-#include <gsl/gsl-lite.hpp>
-#include <list>
-#include <map>
-#include <regex>
-#include <span>
-#include <unordered_map>
-
-namespace zen {
-
-/** HTTP Server Request
- */
-class HttpServerRequest
-{
-public:
- HttpServerRequest();
- ~HttpServerRequest();
-
- // Synchronous operations
-
- [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix
- [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
- [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; }
-
- struct QueryParams
- {
- std::vector<std::pair<std::string_view, std::string_view>> KvPairs;
-
- std::string_view GetValue(std::string_view ParamName) const
- {
- for (const auto& Kv : KvPairs)
- {
- const std::string_view& Key = Kv.first;
-
- if (Key.size() == ParamName.size())
- {
- if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size()))
- {
- return Kv.second;
- }
- }
- }
-
- return std::string_view();
- }
- };
-
- virtual bool TryGetRanges(HttpRanges&) { return false; }
-
- QueryParams GetQueryParams();
-
- inline HttpVerb RequestVerb() const { return m_Verb; }
- inline HttpContentType RequestContentType() { return m_ContentType; }
- inline HttpContentType AcceptContentType() { return m_AcceptType; }
-
- inline uint64_t ContentLength() const { return m_ContentLength; }
- Oid SessionId() const;
- uint32_t RequestId() const;
-
- inline bool IsHandled() const { return !!(m_Flags & kIsHandled); }
- inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); }
- inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; }
-
- /** Read POST/PUT payload for request body, which is always available without delay
- */
- virtual IoBuffer ReadPayload() = 0;
-
- ZENCORE_API CbObject ReadPayloadObject();
- ZENCORE_API CbPackage ReadPayloadPackage();
-
- /** Respond with payload
-
- No data will have been sent when any of these functions return. Instead, the response will be transmitted
- asynchronously, after returning from a request handler function.
-
- Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be
- moved into our response handler array where they are kept alive, in order to reduce ref-counting storms
- */
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0;
- virtual void WriteResponse(HttpResponseCode ResponseCode) = 0;
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0;
- virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload);
-
- void WriteResponse(HttpResponseCode ResponseCode, CbObject Data);
- void WriteResponse(HttpResponseCode ResponseCode, CbArray Array);
- void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package);
- void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString);
- void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob);
-
- virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0;
-
-protected:
- enum
- {
- kIsHandled = 1 << 0,
- kSuppressBody = 1 << 1,
- kHaveRequestId = 1 << 2,
- kHaveSessionId = 1 << 3,
- };
-
- mutable uint32_t m_Flags = 0;
- HttpVerb m_Verb = HttpVerb::kGet;
- HttpContentType m_ContentType = HttpContentType::kBinary;
- HttpContentType m_AcceptType = HttpContentType::kUnknownContentType;
- uint64_t m_ContentLength = ~0ull;
- std::string_view m_Uri;
- std::string_view m_UriWithExtension;
- std::string_view m_QueryString;
- mutable uint32_t m_RequestId = ~uint32_t(0);
- mutable Oid m_SessionId = Oid::Zero;
-
- inline void SetIsHandled() { m_Flags |= kIsHandled; }
-
- virtual Oid ParseSessionId() const = 0;
- virtual uint32_t ParseRequestId() const = 0;
-};
-
-class IHttpPackageHandler : public RefCounted
-{
-public:
- virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0;
- virtual void OnRequestBegin() = 0;
- virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0;
- virtual void OnRequestComplete() = 0;
-};
-
-/**
- * Base class for implementing an HTTP "service"
- *
- * A service exposes one or more endpoints with a certain URI prefix
- *
- */
-
-class HttpService
-{
-public:
- HttpService() = default;
- virtual ~HttpService() = default;
-
- virtual const char* BaseUri() const = 0;
- virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
- virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
-
- // Internals
-
- inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; }
- inline int UriPrefixLength() const { return m_UriPrefixLength; }
-
-private:
- int m_UriPrefixLength = 0;
-};
-
-/** HTTP server
- *
- * Implements the main event loop to service HTTP requests, and handles routing
- * requests to the appropriate handler as registered via RegisterService
- */
-class HttpServer : public RefCounted
-{
-public:
- virtual void RegisterService(HttpService& Service) = 0;
- virtual int Initialize(int BasePort) = 0;
- virtual void Run(bool IsInteractiveSession) = 0;
- virtual void RequestExit() = 0;
-};
-
-Ref<HttpServer> CreateHttpServer(std::string_view ServerClass);
-
-//////////////////////////////////////////////////////////////////////////
-
-class HttpRouterRequest
-{
-public:
- HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
-
- ZENCORE_API std::string GetCapture(uint32_t Index) const;
- inline HttpServerRequest& ServerRequest() { return m_HttpRequest; }
-
-private:
- using MatchResults_t = std::match_results<std::string_view::const_iterator>;
-
- HttpServerRequest& m_HttpRequest;
- MatchResults_t m_Match;
-
- friend class HttpRequestRouter;
-};
-
-inline std::string
-HttpRouterRequest::GetCapture(uint32_t Index) const
-{
- ZEN_ASSERT(Index < m_Match.size());
-
- return m_Match[Index];
-}
-
-/** HTTP request router helper
- *
- * This helper class allows a service implementer to register one or more
- * endpoints using pattern matching (currently using regex matching)
- *
- * This is intended to be initialized once only, there is no thread
- * safety so you can absolutely not add or remove endpoints once the handler
- * goes live
- */
-
-class HttpRequestRouter
-{
-public:
- typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t;
-
- /**
- * @brief Add pattern which can be referenced by name, commonly used for URL components
- * @param Id String used to identify patterns for replacement
- * @param Regex String which will replace the Id string in any registered URL paths
- */
- void AddPattern(const char* Id, const char* Regex);
-
- /**
- * @brief Register a an endpoint handler for the given route
- * @param Regex Regular expression used to match the handler to a request. This may
- * contain pattern aliases registered via AddPattern
- * @param HandlerFunc Handler function to call for any matching request
- * @param SupportedVerbs Supported HTTP verbs for this handler
- */
- void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs);
-
- void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex);
-
- /**
- * @brief HTTP request handling function - this should be called to route the
- * request to a registered handler
- * @param Request Request to route to a handler
- * @return Function returns true if the request was routed successfully
- */
- bool HandleRequest(zen::HttpServerRequest& Request);
-
-private:
- struct HandlerEntry
- {
- HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern)
- : RegEx(Regex, std::regex::icase | std::regex::ECMAScript)
- , Verbs(SupportedVerbs)
- , Handler(std::move(Handler))
- , Pattern(Pattern)
- {
- }
-
- ~HandlerEntry() = default;
-
- std::regex RegEx;
- HttpVerb Verbs;
- HandlerFunc_t Handler;
- const char* Pattern;
-
- private:
- HandlerEntry& operator=(const HandlerEntry&) = delete;
- HandlerEntry(const HandlerEntry&) = delete;
- };
-
- std::list<HandlerEntry> m_Handlers;
- std::unordered_map<std::string, std::string> m_PatternMap;
-};
-
-/** HTTP RPC request helper
- */
-
-class RpcResult
-{
- RpcResult(CbObject Result) : m_Result(std::move(Result)) {}
-
-private:
- CbObject m_Result;
-};
-
-class HttpRpcHandler
-{
-public:
- HttpRpcHandler();
- ~HttpRpcHandler();
-
- HttpRpcHandler(const HttpRpcHandler&) = delete;
- HttpRpcHandler operator=(const HttpRpcHandler&) = delete;
-
- void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction);
-
-private:
- struct RpcFunction
- {
- std::function<void(CbObject& RpcArgs)> Function;
- std::string Identifier;
- };
-
- std::map<std::string, RpcFunction> m_Functions;
-};
-
-bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
-
-void http_forcelink(); // internal
-
-} // namespace zen
diff --git a/zenhttp/include/zenhttp/httpshared.h b/zenhttp/include/zenhttp/httpshared.h
deleted file mode 100644
index d335572c5..000000000
--- a/zenhttp/include/zenhttp/httpshared.h
+++ /dev/null
@@ -1,163 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/compactbinarypackage.h>
-#include <zencore/iobuffer.h>
-#include <zencore/iohash.h>
-
-#include <functional>
-#include <gsl/gsl-lite.hpp>
-
-namespace zen {
-
-class IoBuffer;
-class CbPackage;
-class CompositeBuffer;
-
-/** _____ _ _____ _
- / ____| | | __ \ | |
- | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___
- | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \
- | |____| |_) | | | (_| | (__| < (_| | (_| | __/
- \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___|
- __/ |
- |___/
-
- Structures and code related to handling CbPackage transactions
-
- CbPackage instances are marshaled across the wire using a distinct message
- format. We don't use the CbPackage serialization format provided by the
- CbPackage implementation itself since that does not provide much flexibility
- in how the attachment payloads are transmitted. The scheme below separates
- metadata cleanly from payloads and this enables us to more efficiently
- transmit them either via sendfile/TransmitFile like mechanisms, or by
- reference/memory mapping in the local case.
- */
-
-struct CbPackageHeader
-{
- uint32_t HeaderMagic;
- uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document?
- uint32_t Reserved1;
- uint32_t Reserved2;
-};
-
-static_assert(sizeof(CbPackageHeader) == 16);
-
-enum : uint32_t
-{
- kCbPkgMagic = 0xaa77aacc
-};
-
-struct CbAttachmentEntry
-{
- uint64_t PayloadSize; // Size of the associated payload data in the message
- uint32_t Flags; // See flags below
- IoHash AttachmentHash; // Content Id for the attachment
-
- enum
- {
- kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format
- kIsObject = (1u << 1), // Is compact binary object
- kIsError = (1u << 2), // Is error (compact binary formatted) object
- kIsLocalRef = (1u << 3), // Is "local reference"
- };
-};
-
-struct CbAttachmentReferenceHeader
-{
- uint64_t PayloadByteOffset = 0;
- uint64_t PayloadByteSize = ~0u;
- uint16_t AbsolutePathLength = 0;
-
- // This header will be followed by UTF8 encoded absolute path to backing file
-};
-
-static_assert(sizeof(CbAttachmentEntry) == 32);
-
-enum class FormatFlags
-{
- kDefault = 0,
- kAllowLocalReferences = (1u << 0),
- kDenyPartialLocalReferences = (1u << 1)
-};
-
-gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags);
-
-enum class RpcAcceptOptions : uint16_t
-{
- kNone = 0,
- kAllowLocalReferences = (1u << 0),
- kAllowPartialLocalReferences = (1u << 1)
-};
-
-gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
-
-std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0);
-CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0);
-CbPackage ParsePackageMessage(
- IoBuffer Payload,
- std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
- return IoBuffer{Size};
- });
-bool IsPackageMessage(IoBuffer Payload);
-
-bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
-
-std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, int TargetProcessPid = 0);
-CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid = 0);
-
-/** Streaming reader for compact binary packages
-
- The goal is to ultimately support zero-copy I/O, but for now there'll be some
- copying involved on some platforms at least.
-
- This approach to deserializing CbPackage data is more efficient than
- `ParsePackageMessage` since it does not require the entire message to
- be resident in a memory buffer
-
- */
-class CbPackageReader
-{
-public:
- CbPackageReader();
- ~CbPackageReader();
-
- void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
-
- /** Process compact binary package data stream
-
- The data stream must be in the serialization format produced by FormatPackageMessage
-
- \return How many bytes must be fed to this function in the next call
- */
- uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes);
-
- void Finalize();
- const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; }
- CbObject GetRootObject() { return m_RootObject; }
- std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; }
-
-private:
- enum class State
- {
- kInitialState,
- kReadingHeader,
- kReadingAttachmentEntries,
- kReadingBuffers
- } m_CurrentState = State::kInitialState;
-
- std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
- std::vector<IoBuffer> m_PayloadBuffers;
- std::vector<CbAttachmentEntry> m_AttachmentEntries;
- std::vector<CbAttachment> m_Attachments;
- CbObject m_RootObject;
- CbPackageHeader m_PackageHeader;
-
- IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer);
-};
-
-void forcelink_httpshared();
-
-} // namespace zen
diff --git a/zenhttp/include/zenhttp/websocket.h b/zenhttp/include/zenhttp/websocket.h
deleted file mode 100644
index adca7e988..000000000
--- a/zenhttp/include/zenhttp/websocket.h
+++ /dev/null
@@ -1,256 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zencore/compactbinarypackage.h>
-#include <zencore/memory.h>
-
-#include <compare>
-#include <functional>
-#include <future>
-#include <memory>
-#include <optional>
-
-#pragma once
-
-namespace asio {
-class io_context;
-}
-
-namespace zen {
-
-class BinaryWriter;
-
-/**
- * A unique socket ID.
- */
-class WebSocketId
-{
- static std::atomic_uint32_t NextId;
-
-public:
- WebSocketId() = default;
-
- uint32_t Value() const { return m_Value; }
-
- auto operator<=>(const WebSocketId&) const = default;
-
- static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); }
-
-private:
- WebSocketId(uint32_t Value) : m_Value(Value) {}
-
- uint32_t m_Value{};
-};
-
-/**
- * Type of web socket message.
- */
-enum class WebSocketMessageType : uint8_t
-{
- kInvalid,
- kNotification,
- kRequest,
- kStreamRequest,
- kResponse,
- kStreamResponse,
- kStreamCompleteResponse,
- kCount
-};
-
-inline std::string_view
-ToString(WebSocketMessageType Type)
-{
- switch (Type)
- {
- case WebSocketMessageType::kInvalid:
- return std::string_view("Invalid");
- case WebSocketMessageType::kNotification:
- return std::string_view("Notification");
- case WebSocketMessageType::kRequest:
- return std::string_view("Request");
- case WebSocketMessageType::kStreamRequest:
- return std::string_view("StreamRequest");
- case WebSocketMessageType::kResponse:
- return std::string_view("Response");
- case WebSocketMessageType::kStreamResponse:
- return std::string_view("StreamResponse");
- case WebSocketMessageType::kStreamCompleteResponse:
- return std::string_view("StreamCompleteResponse");
- default:
- return std::string_view("Unknown");
- };
-}
-
-/**
- * Web socket message.
- */
-class WebSocketMessage
-{
- struct Header
- {
- static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh
-
- uint64_t MessageSize{};
- uint32_t Magic{ExpectedMagic};
- uint32_t CorrelationId{};
- uint32_t StatusCode{200u};
- WebSocketMessageType MessageType{};
- uint8_t Reserved[3] = {0};
-
- bool IsValid() const;
- };
-
- static_assert(sizeof(Header) == 24);
-
- static std::atomic_uint32_t NextCorrelationId;
-
-public:
- static constexpr size_t HeaderSize = sizeof(Header);
-
- WebSocketMessage() = default;
-
- WebSocketId SocketId() const { return m_SocketId; }
- void SetSocketId(WebSocketId Id) { m_SocketId = Id; }
- uint64_t MessageSize() const { return m_Header.MessageSize; }
- void SetMessageType(WebSocketMessageType MessageType);
- void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; }
- uint32_t CorrelationId() const { return m_Header.CorrelationId; }
- uint32_t StatusCode() const { return m_Header.StatusCode; }
- void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; }
- WebSocketMessageType MessageType() const { return m_Header.MessageType; }
-
- const CbPackage& Body() const { return m_Body.value(); }
- void SetBody(CbPackage&& Body);
- void SetBody(CbObject&& Body);
- bool HasBody() const { return m_Body.has_value(); }
-
- void Save(BinaryWriter& Writer);
- bool TryLoadHeader(MemoryView Memory);
-
- bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; }
-
-private:
- Header m_Header{};
- WebSocketId m_SocketId{};
- std::optional<CbPackage> m_Body;
-};
-
-class WebSocketServer;
-
-/**
- * Base class for handling web socket requests and notifications from connected client(s).
- */
-class WebSocketService
-{
-public:
- virtual ~WebSocketService() = default;
-
- void Configure(WebSocketServer& Server);
-
- virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); }
- virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); }
-
-protected:
- WebSocketService() = default;
-
- virtual void RegisterHandlers(WebSocketServer& Server) = 0;
- void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete);
- void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete);
-
- WebSocketServer& SocketServer()
- {
- ZEN_ASSERT(m_SocketServer);
- return *m_SocketServer;
- }
-
-private:
- WebSocketServer* m_SocketServer{};
-};
-
-/**
- * Server options.
- */
-struct WebSocketServerOptions
-{
- uint16_t Port = 2337;
- uint32_t ThreadCount = 1;
-};
-
-/**
- * The web socket server manages client connections and routing of requests and notifications.
- */
-class WebSocketServer
-{
-public:
- virtual ~WebSocketServer() = default;
-
- virtual bool Run() = 0;
- virtual void Shutdown() = 0;
-
- virtual void RegisterService(WebSocketService& Service) = 0;
- virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0;
- virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0;
-
- virtual void SendNotification(WebSocketMessage&& Notification) = 0;
- virtual void SendResponse(WebSocketMessage&& Response) = 0;
-
- static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options);
-};
-
-/**
- * The state of the web socket.
- */
-enum class WebSocketState : uint32_t
-{
- kNone,
- kHandshaking,
- kConnected,
- kDisconnected,
- kError
-};
-
-/**
- * Type of web socket client event.
- */
-enum class WebSocketEvent : uint32_t
-{
- kConnected,
- kDisconnected,
- kError
-};
-
-/**
- * Web socket client connection info.
- */
-struct WebSocketConnectInfo
-{
- std::string Host;
- int16_t Port{8848};
- std::string Endpoint;
- std::vector<std::string> Protocols;
- uint16_t Version{13};
-};
-
-/**
- * A connection to a web socket server for sending requests and listening for notifications.
- */
-class WebSocketClient
-{
-public:
- using EventCallback = std::function<void()>;
- using NotificationCallback = std::function<void(WebSocketMessage&&)>;
-
- virtual ~WebSocketClient() = default;
-
- virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0;
- virtual void Disconnect() = 0;
- virtual bool IsConnected() const = 0;
- virtual WebSocketState State() const = 0;
-
- virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0;
- virtual void OnNotification(NotificationCallback&& Cb) = 0;
- virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0;
-
- static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx);
-};
-
-} // namespace zen
diff --git a/zenhttp/include/zenhttp/zenhttp.h b/zenhttp/include/zenhttp/zenhttp.h
deleted file mode 100644
index 59c64b31f..000000000
--- a/zenhttp/include/zenhttp/zenhttp.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/zencore.h>
-
-#ifndef ZEN_WITH_HTTPSYS
-# if ZEN_PLATFORM_WINDOWS
-# define ZEN_WITH_HTTPSYS 1
-# else
-# define ZEN_WITH_HTTPSYS 0
-# endif
-#endif
-
-#define ZENHTTP_API // Placeholder to allow DLL configs in the future
-
-namespace zen {
-
-ZENHTTP_API void zenhttp_forcelinktests();
-
-}
diff --git a/zenhttp/iothreadpool.cpp b/zenhttp/iothreadpool.cpp
deleted file mode 100644
index 6087e69ec..000000000
--- a/zenhttp/iothreadpool.cpp
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include "iothreadpool.h"
-
-#include <zencore/except.h>
-
-#if ZEN_PLATFORM_WINDOWS
-
-namespace zen {
-
-WinIoThreadPool::WinIoThreadPool(int InThreadCount)
-{
- // Thread pool setup
-
- m_ThreadPool = CreateThreadpool(NULL);
-
- SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount);
- SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2);
-
- InitializeThreadpoolEnvironment(&m_CallbackEnvironment);
-
- m_CleanupGroup = CreateThreadpoolCleanupGroup();
-
- SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool);
-
- SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL);
-}
-
-WinIoThreadPool::~WinIoThreadPool()
-{
- CloseThreadpool(m_ThreadPool);
-}
-
-void
-WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode)
-{
- ZEN_ASSERT(!m_ThreadPoolIo);
-
- m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment);
-
- if (!m_ThreadPoolIo)
- {
- ErrorCode = MakeErrorCodeFromLastError();
- }
-}
-
-} // namespace zen
-
-#endif
diff --git a/zenhttp/iothreadpool.h b/zenhttp/iothreadpool.h
deleted file mode 100644
index 8333964c3..000000000
--- a/zenhttp/iothreadpool.h
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#pragma once
-
-#include <zencore/zencore.h>
-
-#if ZEN_PLATFORM_WINDOWS
-# include <zencore/windows.h>
-
-# include <system_error>
-
-namespace zen {
-
-//////////////////////////////////////////////////////////////////////////
-//
-// Thread pool. Implemented in terms of Windows thread pool right now, will
-// need a cross-platform implementation eventually
-//
-
-class WinIoThreadPool
-{
-public:
- WinIoThreadPool(int InThreadCount);
- ~WinIoThreadPool();
-
- void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode);
- inline PTP_IO Iocp() const { return m_ThreadPoolIo; }
-
-private:
- PTP_POOL m_ThreadPool = nullptr;
- PTP_CLEANUP_GROUP m_CleanupGroup = nullptr;
- PTP_IO m_ThreadPoolIo = nullptr;
- TP_CALLBACK_ENVIRON m_CallbackEnvironment;
-};
-
-} // namespace zen
-#endif
diff --git a/zenhttp/websocketasio.cpp b/zenhttp/websocketasio.cpp
deleted file mode 100644
index bbe7e1ad8..000000000
--- a/zenhttp/websocketasio.cpp
+++ /dev/null
@@ -1,1613 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/websocket.h>
-
-#include <zencore/base64.h>
-#include <zencore/compactbinarybuilder.h>
-#include <zencore/compactbinaryvalidation.h>
-#include <zencore/intmath.h>
-#include <zencore/iobuffer.h>
-#include <zencore/logging.h>
-#include <zencore/memory.h>
-#include <zencore/sha1.h>
-#include <zencore/stream.h>
-#include <zencore/string.h>
-#include <zencore/trace.h>
-
-#include <chrono>
-#include <optional>
-#include <shared_mutex>
-#include <span>
-#include <system_error>
-#include <thread>
-
-ZEN_THIRD_PARTY_INCLUDES_START
-#include <fmt/format.h>
-#include <http_parser.h>
-#include <asio.hpp>
-ZEN_THIRD_PARTY_INCLUDES_END
-
-#if ZEN_PLATFORM_WINDOWS
-# include <mstcpip.h>
-#endif
-
-namespace zen::websocket {
-
-using namespace std::literals;
-
-ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
-
-ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv);
-
-using Clock = std::chrono::steady_clock;
-using TimePoint = Clock::time_point;
-
-///////////////////////////////////////////////////////////////////////////////
-namespace http_header {
- static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv;
- static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv;
- static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv;
- static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv;
- static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv;
- static constexpr std::string_view Upgrade = "Upgrade"sv;
-} // namespace http_header
-
-///////////////////////////////////////////////////////////////////////////////
-enum class ParseMessageStatus : uint32_t
-{
- kError,
- kContinue,
- kDone,
-};
-
-struct ParseMessageResult
-{
- ParseMessageStatus Status{};
- size_t ByteCount{};
- std::optional<std::string> Reason;
-};
-
-class MessageParser
-{
-public:
- virtual ~MessageParser() = default;
-
- ParseMessageResult ParseMessage(MemoryView Msg);
- void Reset();
-
-protected:
- MessageParser() = default;
-
- virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
- virtual void OnReset() = 0;
-
- BinaryWriter m_Stream;
-};
-
-ParseMessageResult
-MessageParser::ParseMessage(MemoryView Msg)
-{
- return OnParseMessage(Msg);
-}
-
-void
-MessageParser::Reset()
-{
- OnReset();
-
- m_Stream.Reset();
-}
-
-///////////////////////////////////////////////////////////////////////////////
-enum class HttpMessageParserType
-{
- kRequest,
- kResponse,
- kBoth
-};
-
-class HttpMessageParser final : public MessageParser
-{
-public:
- using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
-
- HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); }
-
- virtual ~HttpMessageParser() = default;
-
- int32_t StatusCode() const { return m_Parser.status_code; }
- bool IsUpgrade() const { return m_Parser.upgrade != 0; }
- HttpHeaders& Headers() { return m_Headers; }
- MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); }
-
- std::string_view StatusText() const
- {
- return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size);
- }
-
- bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
-
-private:
- void Initialize();
- virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
- virtual void OnReset() override;
- int OnMessageBegin();
- int OnUrl(MemoryView Url);
- int OnStatus(MemoryView Status);
- int OnHeaderField(MemoryView HeaderField);
- int OnHeaderValue(MemoryView HeaderValue);
- int OnHeadersComplete();
- int OnBody(MemoryView Body);
- int OnMessageComplete();
-
- struct StreamEntry
- {
- uint64_t Offset{};
- uint64_t Size{};
- };
-
- struct HeaderStreamEntry
- {
- StreamEntry Field{};
- StreamEntry Value{};
- };
-
- HttpMessageParserType m_Type;
- http_parser m_Parser;
- StreamEntry m_UrlEntry;
- StreamEntry m_StatusEntry;
- StreamEntry m_BodyEntry;
- HeaderStreamEntry m_CurrentHeader;
- std::vector<HeaderStreamEntry> m_HeaderEntries;
- HttpHeaders m_Headers;
- bool m_IsMsgComplete{false};
-
- static http_parser_settings ParserSettings;
-};
-
-http_parser_settings HttpMessageParser::ParserSettings = {
- .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); },
-
- .on_url = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); },
-
- .on_status = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); },
-
- .on_header_field = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); },
-
- .on_header_value = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); },
-
- .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); },
-
- .on_body = [](http_parser* P,
- const char* Data,
- size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); },
-
- .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
-
-void
-HttpMessageParser::Initialize()
-{
- http_parser_init(&m_Parser,
- m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
- : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
- : HTTP_BOTH);
- m_Parser.data = this;
-
- m_UrlEntry = {};
- m_StatusEntry = {};
- m_CurrentHeader = {};
- m_BodyEntry = {};
-
- m_IsMsgComplete = false;
-
- m_HeaderEntries.clear();
-}
-
-ParseMessageResult
-HttpMessageParser::OnParseMessage(MemoryView Msg)
-{
- const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize());
-
- auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
-
- if (m_Parser.http_errno != 0)
- {
- Status = ParseMessageStatus::kError;
- }
-
- return {.Status = Status, .ByteCount = uint64_t(ByteCount)};
-}
-
-void
-HttpMessageParser::OnReset()
-{
- Initialize();
-}
-
-int
-HttpMessageParser::OnMessageBegin()
-{
- ZEN_ASSERT(m_IsMsgComplete == false);
- ZEN_ASSERT(m_HeaderEntries.empty());
- ZEN_ASSERT(m_Headers.empty());
-
- return 0;
-}
-
-int
-HttpMessageParser::OnStatus(MemoryView Status)
-{
- m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()};
-
- m_Stream.Write(Status);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnUrl(MemoryView Url)
-{
- m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()};
-
- m_Stream.Write(Url);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnHeaderField(MemoryView HeaderField)
-{
- if (m_CurrentHeader.Value.Size > 0)
- {
- m_HeaderEntries.push_back(m_CurrentHeader);
- m_CurrentHeader = {};
- }
-
- if (m_CurrentHeader.Field.Size == 0)
- {
- m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset();
- }
-
- m_CurrentHeader.Field.Size += HeaderField.GetSize();
-
- m_Stream.Write(HeaderField);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnHeaderValue(MemoryView HeaderValue)
-{
- if (m_CurrentHeader.Value.Size == 0)
- {
- m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset();
- }
-
- m_CurrentHeader.Value.Size += HeaderValue.GetSize();
-
- m_Stream.Write(HeaderValue);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnHeadersComplete()
-{
- if (m_CurrentHeader.Value.Size > 0)
- {
- m_HeaderEntries.push_back(m_CurrentHeader);
- m_CurrentHeader = {};
- }
-
- m_Headers.clear();
- m_Headers.reserve(m_HeaderEntries.size());
-
- const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data());
-
- for (const auto& Entry : m_HeaderEntries)
- {
- auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size);
- auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size);
-
- m_Headers.try_emplace(std::move(Field), std::move(Value));
- }
-
- return 0;
-}
-
-int
-HttpMessageParser::OnBody(MemoryView Body)
-{
- m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()};
-
- m_Stream.Write(Body);
-
- return 0;
-}
-
-int
-HttpMessageParser::OnMessageComplete()
-{
- m_IsMsgComplete = true;
-
- return 0;
-}
-
-bool
-HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason)
-{
- static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
-
- OutAcceptHash = std::string();
-
- if (m_Headers.contains(http_header::SecWebSocketKey) == false)
- {
- OutReason = "Missing header Sec-WebSocket-Key";
- return false;
- }
-
- if (m_Headers.contains(http_header::Upgrade) == false)
- {
- OutReason = "Missing header Upgrade";
- return false;
- }
-
- ExtendableStringBuilder<128> Sb;
- Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid;
-
- SHA1Stream HashStream;
- HashStream.Append(Sb.Data(), Sb.Size());
-
- SHA1 Hash = HashStream.GetHash();
-
- OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
- Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data());
-
- return true;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WebSocketMessageParser final : public MessageParser
-{
-public:
- WebSocketMessageParser() : MessageParser() {}
-
- WebSocketMessage ConsumeMessage();
-
-private:
- virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
- virtual void OnReset() override;
-
- WebSocketMessage m_Message;
-};
-
-ParseMessageResult
-WebSocketMessageParser::OnParseMessage(MemoryView Msg)
-{
- ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
-
- const uint64_t PrevOffset = m_Stream.CurrentOffset();
-
- if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
- {
- const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
-
- m_Stream.Write(Msg.Left(RemaingHeaderSize));
- Msg += RemaingHeaderSize;
-
- if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
- {
- return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
- }
-
- const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
-
- if (IsValidHeader == false)
- {
- OnReset();
-
- return {.Status = ParseMessageStatus::kError,
- .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
- .Reason = std::string("Invalid websocket message header")};
- }
-
- if (m_Message.MessageSize() == 0)
- {
- return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
- }
- }
-
- ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
-
- if (Msg.IsEmpty() == false)
- {
- const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
- m_Stream.Write(Msg.Left(RemaingMessageSize));
- }
-
- auto Status = ParseMessageStatus::kContinue;
-
- if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
- {
- Status = ParseMessageStatus::kDone;
-
- BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
-
- CbPackage Pkg;
- if (Pkg.TryLoad(Reader) == false)
- {
- return {.Status = ParseMessageStatus::kError,
- .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
- .Reason = std::string("Invalid websocket message")};
- }
-
- m_Message.SetBody(std::move(Pkg));
- }
-
- return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
-}
-
-void
-WebSocketMessageParser::OnReset()
-{
- m_Message = WebSocketMessage();
-}
-
-WebSocketMessage
-WebSocketMessageParser::ConsumeMessage()
-{
- WebSocketMessage Msg = std::move(m_Message);
- m_Message = WebSocketMessage();
-
- return Msg;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsConnection : public std::enable_shared_from_this<WsConnection>
-{
-public:
- WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
- : m_Id(Id)
- , m_Socket(std::move(Socket))
- , m_StartTime(Clock::now())
- , m_State()
- {
- }
-
- ~WsConnection() = default;
-
- std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
-
- WebSocketId Id() const { return m_Id; }
- asio::ip::tcp::socket& Socket() { return *m_Socket; }
- TimePoint StartTime() const { return m_StartTime; }
- WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
- std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); }
- asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
- WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
- WebSocketState Close();
- MessageParser* Parser() { return m_MsgParser.get(); }
- void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
- std::mutex& WriteMutex() { return m_WriteMutex; }
-
-private:
- WebSocketId m_Id;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- TimePoint m_StartTime;
- std::atomic_uint32_t m_State;
- std::unique_ptr<MessageParser> m_MsgParser;
- asio::streambuf m_ReadBuffer;
- std::mutex m_WriteMutex;
-};
-
-WebSocketState
-WsConnection::Close()
-{
- const auto PrevState = SetState(WebSocketState::kDisconnected);
-
- if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open())
- {
- m_Socket->close();
- }
-
- return PrevState;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsThreadPool
-{
-public:
- WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {}
- void Start(uint32_t ThreadCount);
- void Stop();
-
-private:
- asio::io_service& m_IoSvc;
- std::vector<std::thread> m_Threads;
- std::atomic_bool m_Running{false};
-};
-
-void
-WsThreadPool::Start(uint32_t ThreadCount)
-{
- ZEN_ASSERT(m_Threads.empty());
-
- ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount);
-
- m_Running = true;
-
- for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
- {
- m_Threads.emplace_back([this, ThreadId = Idx + 1] {
- for (;;)
- {
- if (m_Running == false)
- {
- break;
- }
-
- try
- {
- m_IoSvc.run();
- }
- catch (std::exception& Err)
- {
- ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what());
- }
- }
-
- ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId);
- });
- }
-}
-
-void
-WsThreadPool::Stop()
-{
- if (m_Running)
- {
- m_Running = false;
-
- for (std::thread& Thread : m_Threads)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
-
- m_Threads.clear();
- }
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsServer final : public WebSocketServer
-{
-public:
- WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {}
- virtual ~WsServer() { Shutdown(); }
-
- virtual bool Run() override;
- virtual void Shutdown() override;
-
- virtual void RegisterService(WebSocketService& Service) override;
- virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
- virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
-
- virtual void SendNotification(WebSocketMessage&& Notification) override;
- virtual void SendResponse(WebSocketMessage&& Response) override;
-
-private:
- friend class WsConnection;
-
- void AcceptConnection();
- void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
-
- void ReadMessage(std::shared_ptr<WsConnection> Connection);
- void RouteMessage(WebSocketMessage&& Msg);
- void SendMessage(WebSocketMessage&& Msg);
-
- struct IdHasher
- {
- size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
- };
-
- using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
- using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
- using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
-
- WebSocketServerOptions m_Options;
- asio::io_service m_IoSvc;
- std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
- std::unique_ptr<WsThreadPool> m_ThreadPool;
- ConnectionMap m_Connections;
- std::shared_mutex m_ConnMutex;
- std::vector<WebSocketService*> m_Services;
- RequestHandlerMap m_RequestHandlers;
- NotificationHandlerMap m_NotificationHandlers;
- std::atomic_bool m_Running{};
-};
-
-void
-WsServer::RegisterService(WebSocketService& Service)
-{
- m_Services.push_back(&Service);
-
- Service.Configure(*this);
-}
-
-bool
-WsServer::Run()
-{
- static constexpr size_t ReceiveBufferSize = 256 << 10;
- static constexpr size_t SendBufferSize = 256 << 10;
-
- m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
-
- m_Acceptor->set_option(asio::ip::v6_only(false));
- m_Acceptor->set_option(asio::socket_base::reuse_address(true));
- m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
- m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
- m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
-
-#if ZEN_PLATFORM_WINDOWS
- // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
- // This must be used by both the client and server side, and is only effective in the absence of
- // Windows Filtering Platform (WFP) callouts which can be installed by security software.
- // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
- SOCKET NativeSocket = m_Acceptor->native_handle();
- int LoopbackOptionValue = 1;
- DWORD OptionNumberOfBytesReturned = 0;
- WSAIoctl(NativeSocket,
- SIO_LOOPBACK_FAST_PATH,
- &LoopbackOptionValue,
- sizeof(LoopbackOptionValue),
- NULL,
- 0,
- &OptionNumberOfBytesReturned,
- 0,
- 0);
-#endif
-
- asio::error_code Ec;
- m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
-
- if (Ec)
- {
- ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value());
-
- return false;
- }
-
- m_Acceptor->listen();
- m_Running = true;
-
- ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port);
-
- AcceptConnection();
-
- m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc);
- m_ThreadPool->Start(m_Options.ThreadCount);
-
- return true;
-}
-
-void
-WsServer::Shutdown()
-{
- if (m_Running)
- {
- ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down");
-
- m_Running = false;
-
- m_Acceptor->close();
- m_Acceptor.reset();
- m_IoSvc.stop();
-
- m_ThreadPool->Stop();
- }
-}
-
-void
-WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
-{
- auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
- Result.first->second.push_back(&Service);
-}
-
-void
-WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
-{
- m_RequestHandlers[Key] = &Service;
-}
-
-void
-WsServer::SendNotification(WebSocketMessage&& Notification)
-{
- ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
-
- SendMessage(std::move(Notification));
-}
-void
-WsServer::SendResponse(WebSocketMessage&& Response)
-{
- ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
- Response.MessageType() == WebSocketMessageType::kStreamResponse ||
- Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
-
- ZEN_ASSERT(Response.CorrelationId() != 0);
-
- SendMessage(std::move(Response));
-}
-
-void
-WsServer::AcceptConnection()
-{
- auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
- asio::ip::tcp::socket& SocketRef = *Socket.get();
-
- m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
- if (m_Running)
- {
- if (Ec)
- {
- ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message());
- }
- else
- {
- auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket));
-
- ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
-
- {
- std::unique_lock _(m_ConnMutex);
- m_Connections[Connection->Id()] = Connection;
- }
-
- Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest));
- Connection->SetState(WebSocketState::kHandshaking);
-
- ReadMessage(Connection);
- }
-
- AcceptConnection();
- }
- });
-}
-
-void
-WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
-{
- if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected)
- {
- if (Ec)
- {
- ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value());
- }
- else
- {
- ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value());
- }
- }
-
- const WebSocketId Id = Connection->Id();
-
- {
- std::unique_lock _(m_ConnMutex);
- if (m_Connections.contains(Id))
- {
- m_Connections.erase(Id);
- }
- }
-}
-
-void
-WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
-{
- Connection->ReadBuffer().prepare(64 << 10);
-
- asio::async_read(
- Connection->Socket(),
- Connection->ReadBuffer(),
- asio::transfer_at_least(1),
- [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
- if (ReadEc)
- {
- return CloseConnection(Connection, ReadEc);
- }
-
- switch (Connection->State())
- {
- case WebSocketState::kHandshaking:
- {
- HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser());
- asio::const_buffer Buffer = Connection->ReadBuffer().data();
-
- ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
-
- Connection->ReadBuffer().consume(Result.ByteCount);
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- return ReadMessage(Connection);
- }
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWebSocket,
- "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- return CloseConnection(Connection, std::error_code());
- }
-
- if (Parser.IsUpgrade() == false)
- {
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
-
- return async_write(Connection->Socket(),
- asio::buffer(UpgradeRequiredResponse),
- [this, Connection](const asio::error_code& WriteEc, std::size_t) {
- if (WriteEc)
- {
- return CloseConnection(Connection, WriteEc);
- }
-
- Connection->Parser()->Reset();
- Connection->SetState(WebSocketState::kHandshaking);
-
- ReadMessage(Connection);
- });
- }
-
- ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
-
- std::string AcceptHash;
- std::string Reason;
- const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason);
-
- if (ValidHandshake == false)
- {
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '{}' FAILED, reason '{}'",
- Connection->Id().Value(),
- Reason);
-
- constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
-
- return async_write(Connection->Socket(),
- asio::buffer(UpgradeRequiredResponse),
- [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
- if (WriteEc)
- {
- return CloseConnection(Connection, WriteEc);
- }
-
- Connection->Parser()->Reset();
- Connection->SetState(WebSocketState::kHandshaking);
-
- ReadMessage(Connection);
- });
- }
-
- ExtendableStringBuilder<128> Sb;
-
- Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
- Sb << "Upgrade: websocket\r\n"sv;
- Sb << "Connection: Upgrade\r\n"sv;
-
- // TODO: Verify protocol
- if (Parser.Headers().contains(http_header::SecWebSocketProtocol))
- {
- Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol]
- << "\r\n";
- }
-
- Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
- Sb << "\r\n"sv;
-
- ZEN_LOG_DEBUG(LogWebSocket,
- "accepting handshake from connection '#{} {}'",
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- std::string Response = Sb.ToString();
- Buffer = asio::buffer(Response);
-
- async_write(Connection->Socket(),
- Buffer,
- [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) {
- if (WriteEc)
- {
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake with connection '{}' FAILED, reason '{}'",
- Connection->Id().Value(),
- WriteEc.message());
-
- return CloseConnection(Connection, WriteEc);
- }
-
- ZEN_LOG_DEBUG(LogWebSocket,
- "handshake ({}B) with connection '#{} {}' OK",
- ByteCount,
- Connection->Id().Value(),
- Connection->RemoteAddr());
-
- Connection->SetParser(std::make_unique<WebSocketMessageParser>());
- Connection->SetState(WebSocketState::kConnected);
-
- ReadMessage(Connection);
- });
- }
- break;
-
- case WebSocketState::kConnected:
- {
- WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
-
- uint64_t RemainingBytes = Connection->ReadBuffer().size();
-
- while (RemainingBytes > 0)
- {
- MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- Connection->ReadBuffer().consume(Result.ByteCount);
- RemainingBytes = Connection->ReadBuffer().size();
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
-
- return CloseConnection(Connection, std::error_code());
- }
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(RemainingBytes == 0);
- continue;
- }
-
- WebSocketMessage Message = Parser.ConsumeMessage();
- Parser.Reset();
-
- Message.SetSocketId(Connection->Id());
-
- RouteMessage(std::move(Message));
- }
-
- ReadMessage(Connection);
- }
- break;
-
- default:
- break;
- };
- });
-}
-
-void
-WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
-{
- switch (RoutedMessage.MessageType())
- {
- case WebSocketMessageType::kRequest:
- case WebSocketMessageType::kStreamRequest:
- {
- CbObjectView Request = RoutedMessage.Body().GetObject();
- std::string_view Method = Request["Method"].AsString();
- bool Handled = false;
- bool Error = false;
- std::exception Exception;
-
- if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
- {
- WebSocketService* Service = It->second;
- ZEN_ASSERT(Service);
-
- try
- {
- Handled = Service->HandleRequest(std::move(RoutedMessage));
- }
- catch (std::exception& Err)
- {
- Exception = std::move(Err);
- Error = true;
- }
- }
-
- if (Error || Handled == false)
- {
- std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
-
- ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
-
- CbObjectWriter Response;
- Response << "Error"sv << ErrorText;
-
- WebSocketMessage ResponseMsg;
- ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
- ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
- ResponseMsg.SetSocketId(RoutedMessage.SocketId());
- ResponseMsg.SetBody(Response.Save());
-
- SendResponse(std::move(ResponseMsg));
- }
- }
- break;
-
- case WebSocketMessageType::kNotification:
- {
- CbObjectView Notification = RoutedMessage.Body().GetObject();
- std::string_view Message = Notification["Message"].AsString();
-
- if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
- {
- std::vector<WebSocketService*>& Handlers = It->second;
-
- for (WebSocketService* Handler : Handlers)
- {
- Handler->HandleNotification(RoutedMessage);
- }
- }
- else
- {
- ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
- }
- }
- break;
-
- default:
- break;
- };
-}
-
-void
-WsServer::SendMessage(WebSocketMessage&& Msg)
-{
- std::shared_ptr<WsConnection> Connection;
-
- {
- std::unique_lock _(m_ConnMutex);
-
- if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
- {
- Connection = It->second;
- }
- }
-
- if (Connection.get() == nullptr)
- {
- ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
- return;
- }
-
- if (Connection.get() != nullptr)
- {
- BinaryWriter Writer;
- Msg.Save(Writer);
-
- ZEN_LOG_TRACE(LogWebSocket,
- "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
- ToString(Msg.MessageType()),
- Connection->Id().Value(),
- Msg.MessageSize(),
- Msg.CorrelationId(),
- NiceBytes(Writer.Size()));
-
- {
- ZEN_TRACE_CPU("WS::SendMessage");
- std::unique_lock _(Connection->WriteMutex());
- ZEN_TRACE_CPU("WS::WriteSocketData");
- asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
- }
- }
-}
-
-///////////////////////////////////////////////////////////////////////////////
-class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
-{
-public:
- WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {}
-
- virtual ~WsClient() { Disconnect(); }
-
- std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
-
- virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
- virtual void Disconnect() override;
- virtual bool IsConnected() const override { return false; }
- virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); }
-
- virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
- virtual void OnNotification(NotificationCallback&& Cb) override;
- virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
-
-private:
- WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
- MessageParser* Parser() { return m_MsgParser.get(); }
- void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
- asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
- void TriggerEvent(WebSocketEvent Evt);
- void ReadMessage();
- void RouteMessage(WebSocketMessage&& RoutedMessage);
-
- using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
-
- asio::io_context& m_IoCtx;
- WebSocketId m_Id;
- std::unique_ptr<asio::ip::tcp::socket> m_Socket;
- std::unique_ptr<MessageParser> m_MsgParser;
- asio::streambuf m_ReadBuffer;
- EventCallback m_EventCallbacks[3];
- NotificationCallback m_NotificationCallback;
- PendingRequestMap m_PendingRequests;
- std::mutex m_RequestMutex;
- std::promise<bool> m_ConnectPromise;
- std::atomic_uint32_t m_State;
- std::string m_Host;
- int16_t m_Port{};
-};
-
-std::future<bool>
-WsClient::Connect(const WebSocketConnectInfo& Info)
-{
- if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
- {
- return m_ConnectPromise.get_future();
- }
-
- SetState(WebSocketState::kHandshaking);
-
- try
- {
- asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port);
- m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol());
-
- m_Socket->connect(Endpoint);
-
- m_Host = m_Socket->remote_endpoint().address().to_string();
- m_Port = Info.Port;
-
- ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port);
- }
- catch (std::exception& Err)
- {
- ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
-
- SetState(WebSocketState::kError);
- m_Socket.reset();
-
- TriggerEvent(WebSocketEvent::kDisconnected);
-
- m_ConnectPromise.set_value(false);
-
- return m_ConnectPromise.get_future();
- }
-
- ExtendableStringBuilder<128> Sb;
- Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv;
- Sb << "Host: " << Info.Host << "\r\n"sv;
- Sb << "Upgrade: websocket\r\n"sv;
- Sb << "Connection: upgrade\r\n"sv;
- Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv;
-
- if (Info.Protocols.empty() == false)
- {
- Sb << "Sec-WebSocket-Protocol: "sv;
- for (size_t Idx = 0; const auto& Protocol : Info.Protocols)
- {
- if (Idx++)
- {
- Sb << ", ";
- }
- Sb << Protocol;
- }
- }
-
- Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv;
- Sb << "\r\n";
-
- std::string HandshakeRequest = Sb.ToString();
- asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
-
- ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port);
-
- m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
- m_MsgParser->Reset();
-
- async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
- if (Ec)
- {
- ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message());
-
- Self->Disconnect();
- }
- else
- {
- Self->ReadMessage();
- }
- });
-
- return m_ConnectPromise.get_future();
-}
-
-void
-WsClient::Disconnect()
-{
- if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
- {
- ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port);
-
- if (m_Socket && m_Socket->is_open())
- {
- m_Socket->close();
- m_Socket.reset();
- }
-
- TriggerEvent(WebSocketEvent::kDisconnected);
-
- {
- std::unique_lock _(m_RequestMutex);
-
- for (auto& Kv : m_PendingRequests)
- {
- Kv.second.set_value(WebSocketMessage());
- }
-
- m_PendingRequests.clear();
- }
- }
-}
-
-std::future<WebSocketMessage>
-WsClient::SendRequest(WebSocketMessage&& Request)
-{
- ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
-
- BinaryWriter Writer;
- Request.Save(Writer);
-
- std::future<WebSocketMessage> FutureResponse;
-
- {
- std::unique_lock _(m_RequestMutex);
-
- auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
- ZEN_ASSERT(Result.second);
-
- auto It = Result.first;
- FutureResponse = It->second.get_future();
- }
-
- IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
-
- async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
- if (Ec)
- {
- ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
-
- Self->Disconnect();
- }
- });
-
- return FutureResponse;
-}
-
-void
-WsClient::OnNotification(NotificationCallback&& Cb)
-{
- m_NotificationCallback = std::move(Cb);
-}
-
-void
-WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
-{
- m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
-}
-
-void
-WsClient::TriggerEvent(WebSocketEvent Evt)
-{
- const uint32_t Index = static_cast<uint32_t>(Evt);
-
- if (m_EventCallbacks[Index])
- {
- m_EventCallbacks[Index]();
- }
-}
-
-void
-WsClient::ReadMessage()
-{
- m_ReadBuffer.prepare(64 << 10);
-
- async_read(*m_Socket,
- m_ReadBuffer,
- asio::transfer_at_least(1),
- [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable {
- const WebSocketState State = Self->State();
-
- if (State == WebSocketState::kDisconnected)
- {
- return;
- }
-
- if (Ec)
- {
- ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message());
-
- return Self->Disconnect();
- }
-
- switch (State)
- {
- case WebSocketState::kHandshaking:
- {
- HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser());
-
- MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
-
- ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- Self->ReadBuffer().consume(size_t(Result.ByteCount));
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
-
- Self->m_ConnectPromise.set_value(false);
-
- return Self->Disconnect();
- }
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- return Self->ReadMessage();
- }
-
- ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
-
- if (Parser.StatusCode() != 101)
- {
- ZEN_LOG_WARN(LogWsClient,
- "handshake FAILED, status '{}', status code '{}'",
- Parser.StatusText(),
- Parser.StatusCode());
-
- Self->m_ConnectPromise.set_value(false);
-
- return Self->Disconnect();
- }
-
- ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText());
-
- Self->SetParser(std::make_unique<WebSocketMessageParser>());
- Self->SetState(WebSocketState::kConnected);
- Self->ReadMessage();
- Self->TriggerEvent(WebSocketEvent::kConnected);
-
- Self->m_ConnectPromise.set_value(true);
- }
- break;
-
- case WebSocketState::kConnected:
- {
- WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
-
- uint64_t RemainingBytes = Self->ReadBuffer().size();
-
- while (RemainingBytes > 0)
- {
- MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
- const ParseMessageResult Result = Parser.ParseMessage(MessageData);
-
- Self->ReadBuffer().consume(Result.ByteCount);
- RemainingBytes = Self->ReadBuffer().size();
-
- if (Result.Status == ParseMessageStatus::kError)
- {
- ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
-
- Parser.Reset();
- continue;
- }
-
- if (Result.Status == ParseMessageStatus::kContinue)
- {
- ZEN_ASSERT(RemainingBytes == 0);
- continue;
- }
-
- WebSocketMessage Message = Parser.ConsumeMessage();
- Parser.Reset();
-
- Self->RouteMessage(std::move(Message));
- }
-
- Self->ReadMessage();
- }
- break;
-
- default:
- break;
- }
- });
-}
-
-void
-WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
-{
- switch (RoutedMessage.MessageType())
- {
- case WebSocketMessageType::kResponse:
- {
- std::unique_lock _(m_RequestMutex);
-
- if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
- {
- It->second.set_value(std::move(RoutedMessage));
- m_PendingRequests.erase(It);
- }
- else
- {
- ZEN_LOG_WARN(LogWsClient,
- "route request message FAILED, reason 'unknown correlation ID ({})'",
- RoutedMessage.CorrelationId());
- }
- }
- break;
-
- case WebSocketMessageType::kNotification:
- {
- std::unique_lock _(m_RequestMutex);
-
- if (m_NotificationCallback)
- {
- m_NotificationCallback(std::move(RoutedMessage));
- }
- }
- break;
-
- default:
- ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType()));
- break;
- };
-}
-
-} // namespace zen::websocket
-
-namespace zen {
-
-std::atomic_uint32_t WebSocketId::NextId{1};
-
-bool
-WebSocketMessage::Header::IsValid() const
-{
- return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
- uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
-}
-
-std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
-
-void
-WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
-{
- m_Header.MessageType = MessageType;
-}
-
-void
-WebSocketMessage::SetBody(CbPackage&& Body)
-{
- m_Body = std::move(Body);
-}
-void
-WebSocketMessage::SetBody(CbObject&& Body)
-{
- CbPackage Pkg;
- Pkg.SetObject(Body);
-
- SetBody(std::move(Pkg));
-}
-
-void
-WebSocketMessage::Save(BinaryWriter& Writer)
-{
- Writer.Write(&m_Header, HeaderSize);
-
- if (m_Body.has_value())
- {
- const CbObject& Obj = m_Body.value().GetObject();
- MemoryView View = Obj.GetBuffer().GetView();
-
- const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
- ZEN_ASSERT(ValidationResult == CbValidateError::None);
-
- m_Body.value().Save(Writer);
- }
-
- if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
- {
- m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
- }
-
- m_Header.MessageSize = Writer.Size() - HeaderSize;
-
- Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
-}
-
-bool
-WebSocketMessage::TryLoadHeader(MemoryView Memory)
-{
- if (Memory.GetSize() < HeaderSize)
- {
- return false;
- }
-
- MutableMemoryView HeaderView(&m_Header, HeaderSize);
-
- HeaderView.CopyFrom(Memory);
-
- return m_Header.IsValid();
-}
-
-void
-WebSocketService::Configure(WebSocketServer& Server)
-{
- ZEN_ASSERT(m_SocketServer == nullptr);
-
- m_SocketServer = &Server;
-
- RegisterHandlers(Server);
-}
-
-void
-WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete)
-{
- WebSocketMessage Message;
-
- Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse);
- Message.SetCorrelationId(CorrelationId);
- Message.SetSocketId(SocketId);
- Message.SetBody(std::move(StreamResponse));
-
- SocketServer().SendResponse(std::move(Message));
-}
-
-void
-WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete)
-{
- CbPackage Response;
- Response.SetObject(std::move(StreamResponse));
-
- SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete);
-}
-
-std::unique_ptr<WebSocketServer>
-WebSocketServer::Create(const WebSocketServerOptions& Options)
-{
- return std::make_unique<websocket::WsServer>(Options);
-}
-
-std::shared_ptr<WebSocketClient>
-WebSocketClient::Create(asio::io_context& IoCtx)
-{
- return std::make_shared<websocket::WsClient>(IoCtx);
-}
-
-} // namespace zen
diff --git a/zenhttp/xmake.lua b/zenhttp/xmake.lua
deleted file mode 100644
index b0dbdbc79..000000000
--- a/zenhttp/xmake.lua
+++ /dev/null
@@ -1,14 +0,0 @@
--- Copyright Epic Games, Inc. All Rights Reserved.
-
-target('zenhttp')
- set_kind("static")
- add_headerfiles("**.h")
- add_files("**.cpp")
- add_files("httpsys.cpp", {unity_ignored=true})
- add_includedirs("include", {public=true})
- add_deps("zencore")
- add_packages(
- "vcpkg::gsl-lite",
- "vcpkg::http-parser"
- )
- add_options("httpsys")
diff --git a/zenhttp/zenhttp.cpp b/zenhttp/zenhttp.cpp
deleted file mode 100644
index 4bd6a5697..000000000
--- a/zenhttp/zenhttp.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright Epic Games, Inc. All Rights Reserved.
-
-#include <zenhttp/zenhttp.h>
-
-#if ZEN_WITH_TESTS
-
-# include <zenhttp/httpclient.h>
-# include <zenhttp/httpserver.h>
-# include <zenhttp/httpshared.h>
-
-namespace zen {
-
-void
-zenhttp_forcelinktests()
-{
- http_forcelink();
- forcelink_httpshared();
-}
-
-} // namespace zen
-
-#endif