aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/httpasio.cpp
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src/zenhttp/httpasio.cpp
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src/zenhttp/httpasio.cpp')
-rw-r--r--src/zenhttp/httpasio.cpp1372
1 files changed, 1372 insertions, 0 deletions
diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp
new file mode 100644
index 000000000..79b2c0a3d
--- /dev/null
+++ b/src/zenhttp/httpasio.cpp
@@ -0,0 +1,1372 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpasio.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <deque>
+#include <memory>
+#include <string_view>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+# include <mstcpip.h>
+#endif
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#define ASIO_VERBOSE_TRACE 0
+
+#if ASIO_VERBOSE_TRACE
+# define ZEN_TRACE_VERBOSE ZEN_TRACE
+#else
+# define ZEN_TRACE_VERBOSE(fmtstr, ...)
+#endif
+
+namespace zen::asio_http {
+
+using namespace std::literals;
+
+struct HttpAcceptor;
+struct HttpRequest;
+struct HttpResponse;
+struct HttpServerConnection;
+
+static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+
+inline spdlog::logger&
+InitLogger()
+{
+ spdlog::logger& Logger = logging::Get("asio");
+ // Logger.set_level(spdlog::level::trace);
+ return Logger;
+}
+
+inline spdlog::logger&
+Log()
+{
+ static spdlog::logger& g_Logger = InitLogger();
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAsioServerImpl
+{
+public:
+ HttpAsioServerImpl();
+ ~HttpAsioServerImpl();
+
+ int Start(uint16_t Port, int ThreadCount);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ HttpService* RouteRequest(std::string_view Url);
+
+ asio::io_service m_IoService;
+ asio::io_service::work m_Work{m_IoService};
+ std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
+ std::vector<std::thread> m_ThreadPool;
+
+ struct ServiceEntry
+ {
+ std::string ServiceUrlPath;
+ HttpService* Service;
+ };
+
+ RwLock m_Lock;
+ std::vector<ServiceEntry> m_UriHandlers;
+};
+
+/**
+ * This is the class which request handlers use to interact with the server instance
+ */
+
+class HttpAsioServerRequest : public HttpServerRequest
+{
+public:
+ HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer);
+ ~HttpAsioServerRequest();
+
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
+
+ virtual IoBuffer ReadPayload() override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
+ virtual bool TryGetRanges(HttpRanges& Ranges) override;
+
+ using HttpServerRequest::WriteResponse;
+
+ HttpAsioServerRequest(const HttpAsioServerRequest&) = delete;
+ HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete;
+
+ asio_http::HttpRequest& m_Request;
+ IoBuffer m_PayloadBuffer;
+ std::unique_ptr<HttpResponse> m_Response;
+};
+
+struct HttpRequest
+{
+ explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {}
+
+ void Initialize();
+ size_t ConsumeData(const char* InputData, size_t DataSize);
+ void ResetState();
+
+ HttpVerb RequestVerb() const { return m_RequestVerb; }
+ bool IsKeepAlive() const { return m_KeepAlive; }
+ std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; }
+ std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); }
+ IoBuffer Body() { return m_BodyBuffer; }
+
+ inline HttpContentType ContentType()
+ {
+ if (m_ContentTypeHeaderIndex < 0)
+ {
+ return HttpContentType::kUnknownContentType;
+ }
+
+ return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value);
+ }
+
+ inline HttpContentType AcceptType()
+ {
+ if (m_AcceptHeaderIndex < 0)
+ {
+ return HttpContentType::kUnknownContentType;
+ }
+
+ return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value);
+ }
+
+ Oid SessionId() const { return m_SessionId; }
+ int RequestId() const { return m_RequestId; }
+
+ std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); }
+
+private:
+ struct HeaderEntry
+ {
+ HeaderEntry() = default;
+
+ HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {}
+
+ std::string_view Name;
+ std::string_view Value;
+ };
+
+ HttpServerConnection& m_Connection;
+ char* m_HeaderCursor = m_HeaderBuffer;
+ char* m_Url = nullptr;
+ size_t m_UrlLength = 0;
+ char* m_QueryString = nullptr;
+ size_t m_QueryLength = 0;
+ char* m_CurrentHeaderName = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderNameLength = 0;
+ char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderValueLength = 0;
+ std::vector<HeaderEntry> m_Headers;
+ int8_t m_ContentLengthHeaderIndex;
+ int8_t m_AcceptHeaderIndex;
+ int8_t m_ContentTypeHeaderIndex;
+ int8_t m_RangeHeaderIndex;
+ HttpVerb m_RequestVerb;
+ bool m_KeepAlive = false;
+ bool m_Expect100Continue = false;
+ int m_RequestId = -1;
+ Oid m_SessionId{};
+ IoBuffer m_BodyBuffer;
+ uint64_t m_BodyPosition = 0;
+ http_parser m_Parser;
+ char m_HeaderBuffer[1024];
+ std::string m_NormalizedUrl;
+
+ void AppendCurrentHeader();
+
+ int OnMessageBegin();
+ int OnUrl(const char* Data, size_t Bytes);
+ int OnHeader(const char* Data, size_t Bytes);
+ int OnHeaderValue(const char* Data, size_t Bytes);
+ int OnHeadersComplete();
+ int OnBody(const char* Data, size_t Bytes);
+ int OnMessageComplete();
+
+ static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); }
+ static http_parser_settings s_ParserSettings;
+};
+
+struct HttpResponse
+{
+public:
+ HttpResponse() = default;
+ explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {}
+
+ void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
+ {
+ m_ResponseCode = ResponseCode;
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
+
+ m_DataBuffers.reserve(ChunkCount);
+
+ for (IoBuffer& Buffer : BlobList)
+ {
+#if 1
+ m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+#else
+ IoBuffer TempBuffer = std::move(Buffer);
+ TempBuffer.MakeOwned();
+ m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer));
+#endif
+ }
+
+ uint64_t LocalDataSize = 0;
+
+ m_AsioBuffers.push_back({}); // Placeholder for header
+
+ for (IoBuffer& Buffer : m_DataBuffers)
+ {
+ uint64_t BufferDataSize = Buffer.Size();
+
+ ZEN_ASSERT(BufferDataSize);
+
+ LocalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
+ {
+ // TODO: Use direct file transfer, via TransmitFile/sendfile
+ //
+ // this looks like it requires some custom asio plumbing however
+
+ m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
+ }
+ else
+ {
+ // Send from memory
+
+ m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
+ }
+ }
+ m_ContentLength = LocalDataSize;
+
+ auto Headers = GetHeaders();
+ m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size());
+ }
+
+ uint16_t ResponseCode() const { return m_ResponseCode; }
+ uint64_t ContentLength() const { return m_ContentLength; }
+
+ const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; }
+
+ std::string_view GetHeaders()
+ {
+ m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
+ << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Length: " << ContentLength() << "\r\n"sv;
+
+ if (!m_IsKeepAlive)
+ {
+ m_Headers << "Connection: close\r\n"sv;
+ }
+
+ m_Headers << "\r\n"sv;
+
+ return m_Headers;
+ }
+
+ void SuppressPayload() { m_AsioBuffers.resize(1); }
+
+private:
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
+ std::vector<IoBuffer> m_DataBuffers;
+ std::vector<asio::const_buffer> m_AsioBuffers;
+ ExtendableStringBuilder<160> m_Headers;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection>
+{
+ HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
+ ~HttpServerConnection();
+
+ void HandleNewRequest();
+ void TerminateConnection();
+ void HandleRequest();
+
+ std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+
+private:
+ enum class RequestState
+ {
+ kInitialState,
+ kInitialRead,
+ kReadingMore,
+ kWriting,
+ kWritingFinal,
+ kDone,
+ kTerminated
+ };
+
+ RequestState m_RequestState = RequestState::kInitialState;
+ HttpRequest m_RequestData{*this};
+
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false);
+ void OnError();
+
+ HttpAsioServerImpl& m_Server;
+ asio::streambuf m_RequestBuffer;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::atomic<uint32_t> m_RequestCounter{0};
+ uint32_t m_ConnectionId = 0;
+ Ref<IHttpPackageHandler> m_PackageHandler;
+
+ RwLock m_ResponsesLock;
+ std::deque<std::unique_ptr<HttpResponse>> m_Responses;
+};
+
+std::atomic<uint32_t> g_ConnectionIdCounter{0};
+
+HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+: m_Server(Server)
+, m_Socket(std::move(Socket))
+, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
+{
+ ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
+}
+
+HttpServerConnection::~HttpServerConnection()
+{
+ ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
+}
+
+void
+HttpServerConnection::HandleNewRequest()
+{
+ m_RequestData.Initialize();
+
+ EnqueueRead();
+}
+
+void
+HttpServerConnection::TerminateConnection()
+{
+ m_RequestState = RequestState::kTerminated;
+
+ std::error_code Ec;
+ m_Socket->close(Ec);
+}
+
+void
+HttpServerConnection::EnqueueRead()
+{
+ if (m_RequestState == RequestState::kInitialRead)
+ {
+ m_RequestState = RequestState::kReadingMore;
+ }
+ else
+ {
+ m_RequestState = RequestState::kInitialRead;
+ }
+
+ m_RequestBuffer.prepare(64 * 1024);
+
+ asio::async_read(*m_Socket.get(),
+ m_RequestBuffer,
+ asio::transfer_at_least(1),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
+}
+
+void
+HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead)
+ {
+ ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ return;
+ }
+ else
+ {
+ ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ return OnError();
+ }
+ }
+
+ ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'",
+ m_ConnectionId,
+ m_RequestCounter.load(std::memory_order_relaxed),
+ zen::GetCurrentThreadId(),
+ NiceBytes(ByteCount));
+
+ while (m_RequestBuffer.size())
+ {
+ const asio::const_buffer& InputBuffer = m_RequestBuffer.data();
+
+ size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size());
+ if (Result == ~0ull)
+ {
+ return OnError();
+ }
+
+ m_RequestBuffer.consume(Result);
+ }
+
+ switch (m_RequestState)
+ {
+ case RequestState::kDone:
+ case RequestState::kWritingFinal:
+ case RequestState::kTerminated:
+ break;
+
+ default:
+ EnqueueRead();
+ break;
+ }
+}
+
+void
+HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop)
+{
+ if (Ec)
+ {
+ ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ OnError();
+ }
+ else
+ {
+ ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'",
+ m_ConnectionId,
+ m_RequestCounter.load(std::memory_order_relaxed),
+ zen::GetCurrentThreadId(),
+ NiceBytes(ByteCount));
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ m_RequestState = RequestState::kDone;
+
+ m_Socket->close();
+ }
+ else
+ {
+ if (Pop)
+ {
+ RwLock::ExclusiveLockScope _(m_ResponsesLock);
+ m_Responses.pop_front();
+ }
+
+ m_RequestCounter.fetch_add(1);
+ }
+ }
+}
+
+void
+HttpServerConnection::OnError()
+{
+ m_Socket->close();
+}
+
+void
+HttpServerConnection::HandleRequest()
+{
+ if (!m_RequestData.IsKeepAlive())
+ {
+ m_RequestState = RequestState::kWritingFinal;
+
+ std::error_code Ec;
+ m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ m_RequestState = RequestState::kWriting;
+ }
+
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body());
+
+ ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed));
+
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ {
+ try
+ {
+ Service->HandleRequest(Request);
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
+
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
+ }
+
+ if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
+ {
+ // Transmit the response
+
+ if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ Response->SuppressPayload();
+ }
+
+ auto ResponseBuffers = Response->AsioBuffers();
+
+ uint64_t ResponseLength = 0;
+
+ for (auto& Buffer : ResponseBuffers)
+ {
+ ResponseLength += Buffer.size();
+ }
+
+ {
+ RwLock::ExclusiveLockScope _(m_ResponsesLock);
+ m_Responses.push_back(std::move(Response));
+ }
+
+ // TODO: should cork/uncork for Linux?
+
+ asio::async_write(*m_Socket.get(),
+ ResponseBuffers,
+ asio::transfer_exactly(ResponseLength),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, true);
+ });
+
+ return;
+ }
+ }
+
+ if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ std::string_view Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "\r\n"sv;
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Connection: close\r\n"
+ "\r\n"sv;
+ }
+
+ asio::async_write(
+ *m_Socket.get(),
+ asio::buffer(Response),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
+ }
+ else
+ {
+ std::string_view Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "Connection: close\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+ }
+
+ asio::async_write(
+ *m_Socket.get(),
+ asio::buffer(Response),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// HttpRequest
+//
+
+http_parser_settings HttpRequest::s_ParserSettings{
+ .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
+ .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
+ .on_status =
+ [](http_parser* p, const char* Data, size_t ByteCount) {
+ ZEN_UNUSED(p, Data, ByteCount);
+ return 0;
+ },
+ .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
+ .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
+ .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
+ .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
+ .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
+ .on_chunk_header{},
+ .on_chunk_complete{}};
+
+void
+HttpRequest::Initialize()
+{
+ http_parser_init(&m_Parser, HTTP_REQUEST);
+ m_Parser.data = this;
+
+ ResetState();
+}
+
+size_t
+HttpRequest::ConsumeData(const char* InputData, size_t DataSize)
+{
+ const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
+
+ http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
+
+ if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
+ {
+ ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
+ return ~0ull;
+ }
+
+ return ConsumedBytes;
+}
+
+int
+HttpRequest::OnUrl(const char* Data, size_t Bytes)
+{
+ if (!m_Url)
+ {
+ ZEN_ASSERT_SLOW(m_UrlLength == 0);
+ m_Url = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_UrlLength += Bytes;
+
+ return 0;
+}
+
+int
+HttpRequest::OnHeader(const char* Data, size_t Bytes)
+{
+ if (m_CurrentHeaderValueLength)
+ {
+ AppendCurrentHeader();
+
+ m_CurrentHeaderNameLength = 0;
+ m_CurrentHeaderValueLength = 0;
+ m_CurrentHeaderName = m_HeaderCursor;
+ }
+ else if (m_CurrentHeaderName == nullptr)
+ {
+ m_CurrentHeaderName = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_CurrentHeaderNameLength += Bytes;
+
+ return 0;
+}
+
+void
+HttpRequest::AppendCurrentHeader()
+{
+ std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength);
+ std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength);
+
+ const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
+
+ if (HeaderHash == HashContentLength)
+ {
+ m_ContentLengthHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashAccept)
+ {
+ m_AcceptHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashContentType)
+ {
+ m_ContentTypeHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashSession)
+ {
+ m_SessionId = Oid::FromHexString(HeaderValue);
+ }
+ else if (HeaderHash == HashRequest)
+ {
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ }
+ else if (HeaderHash == HashExpect)
+ {
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ }
+ else if (HeaderHash == HashRange)
+ {
+ m_RangeHeaderIndex = (int8_t)m_Headers.size();
+ }
+
+ m_Headers.emplace_back(HeaderName, HeaderValue);
+}
+
+int
+HttpRequest::OnHeaderValue(const char* Data, size_t Bytes)
+{
+ if (m_CurrentHeaderValueLength == 0)
+ {
+ m_CurrentHeaderValue = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_CurrentHeaderValueLength += Bytes;
+
+ return 0;
+}
+
+static void
+NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl)
+{
+ bool LastCharWasSeparator = false;
+ for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex)
+ {
+ const char UrlChar = Url[UrlIndex];
+ const bool IsSeparator = (UrlChar == '/');
+
+ if (IsSeparator && LastCharWasSeparator)
+ {
+ if (NormalizedUrl.empty())
+ {
+ NormalizedUrl.reserve(UrlLength);
+ NormalizedUrl.append(Url, UrlIndex);
+ }
+
+ if (!LastCharWasSeparator)
+ {
+ NormalizedUrl.push_back('/');
+ }
+ }
+ else if (!NormalizedUrl.empty())
+ {
+ NormalizedUrl.push_back(UrlChar);
+ }
+
+ LastCharWasSeparator = IsSeparator;
+ }
+}
+
+int
+HttpRequest::OnHeadersComplete()
+{
+ if (m_CurrentHeaderValueLength)
+ {
+ AppendCurrentHeader();
+ }
+
+ if (m_ContentLengthHeaderIndex >= 0)
+ {
+ std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value;
+ uint64_t ContentLength = 0;
+ std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength);
+
+ if (ContentLength)
+ {
+ m_BodyBuffer = IoBuffer(ContentLength);
+ }
+
+ m_BodyBuffer.SetContentType(ContentType());
+
+ m_BodyPosition = 0;
+ }
+
+ m_KeepAlive = !!http_should_keep_alive(&m_Parser);
+
+ switch (m_Parser.method)
+ {
+ case HTTP_GET:
+ m_RequestVerb = HttpVerb::kGet;
+ break;
+
+ case HTTP_POST:
+ m_RequestVerb = HttpVerb::kPost;
+ break;
+
+ case HTTP_PUT:
+ m_RequestVerb = HttpVerb::kPut;
+ break;
+
+ case HTTP_DELETE:
+ m_RequestVerb = HttpVerb::kDelete;
+ break;
+
+ case HTTP_HEAD:
+ m_RequestVerb = HttpVerb::kHead;
+ break;
+
+ case HTTP_COPY:
+ m_RequestVerb = HttpVerb::kCopy;
+ break;
+
+ case HTTP_OPTIONS:
+ m_RequestVerb = HttpVerb::kOptions;
+ break;
+
+ default:
+ ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
+ break;
+ }
+
+ std::string_view Url(m_Url, m_UrlLength);
+
+ if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos)
+ {
+ m_UrlLength = QuerySplit;
+ m_QueryString = m_Url + QuerySplit + 1;
+ m_QueryLength = Url.size() - QuerySplit - 1;
+ }
+
+ NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl);
+
+ return 0;
+}
+
+int
+HttpRequest::OnBody(const char* Data, size_t Bytes)
+{
+ if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
+ {
+ ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
+ (m_BodyPosition + Bytes) - m_BodyBuffer.Size());
+ return 1;
+ }
+ memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
+ m_BodyPosition += Bytes;
+
+ if (http_body_is_final(&m_Parser))
+ {
+ if (m_BodyPosition != m_BodyBuffer.Size())
+ {
+ ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+void
+HttpRequest::ResetState()
+{
+ m_HeaderCursor = m_HeaderBuffer;
+ m_CurrentHeaderName = nullptr;
+ m_CurrentHeaderNameLength = 0;
+ m_CurrentHeaderValue = nullptr;
+ m_CurrentHeaderValueLength = 0;
+ m_CurrentHeaderName = nullptr;
+ m_Url = nullptr;
+ m_UrlLength = 0;
+ m_QueryString = nullptr;
+ m_QueryLength = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
+ m_Headers.clear();
+ m_NormalizedUrl.clear();
+}
+
+int
+HttpRequest::OnMessageBegin()
+{
+ return 0;
+}
+
+int
+HttpRequest::OnMessageComplete()
+{
+ m_Connection.HandleRequest();
+
+ ResetState();
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService, asio::ip::tcp::v6())
+ {
+ m_Acceptor.set_option(asio::ip::v6_only(false));
+#if ZEN_PLATFORM_WINDOWS
+ // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
+ typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address;
+ m_Acceptor.set_option(excluse_address(true));
+#else // ZEN_PLATFORM_WINDOWS
+ m_Acceptor.set_option(asio::socket_base::reuse_address(false));
+#endif // ZEN_PLATFORM_WINDOWS
+
+ m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ uint16_t EffectivePort = BasePort;
+
+ asio::error_code BindErrorCode;
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ // Sharing violation implies the port is being used by another process
+ for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode == asio::error::access_denied)
+ {
+ EffectivePort = 0;
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode)
+ {
+ ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message());
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
+ m_Acceptor.listen();
+
+ ZEN_INFO("Started asio server at port '{}'", EffectivePort);
+ }
+
+ void Start()
+ {
+ m_Acceptor.listen();
+ InitAccept();
+ }
+
+ void Stop() { m_IsStopped = true; }
+
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
+ asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
+ m_Acceptor.local_endpoint().address().to_string(),
+ m_Acceptor.local_endpoint().port(),
+ Ec.message());
+ }
+ else
+ {
+ // New connection established, pass socket ownership into connection object
+ // and initiate request handling loop. The connection lifetime is
+ // managed by the async read/write loop by passing the shared
+ // reference to the callbacks.
+
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ m_Acceptor.close();
+ }
+ });
+ }
+
+ int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
+
+private:
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+ asio::ip::tcp::acceptor m_Acceptor;
+ std::atomic<bool> m_IsStopped{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer)
+: m_Request(Request)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const int PrefixLength = Service.UriPrefixLength();
+
+ std::string_view Uri = Request.Url();
+ Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size())));
+ m_Uri = Uri;
+ m_UriWithExtension = Uri;
+ m_QueryString = Request.QueryString();
+
+ m_Verb = Request.RequestVerb();
+ m_ContentLength = Request.Body().Size();
+ m_ContentType = Request.ContentType();
+
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
+ // Parse any extension, to allow requesting a particular response encoding via the URL
+
+ {
+ std::string_view UriSuffix8{m_Uri};
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1));
+ }
+ }
+ }
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = Request.AcceptType();
+ }
+}
+
+HttpAsioServerRequest::~HttpAsioServerRequest()
+{
+}
+
+Oid
+HttpAsioServerRequest::ParseSessionId() const
+{
+ return m_Request.SessionId();
+}
+
+uint32_t
+HttpAsioServerRequest::ParseRequestId() const
+{
+ return m_Request.RequestId();
+}
+
+IoBuffer
+HttpAsioServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpResponse(HttpContentType::kBinary));
+ std::array<IoBuffer, 0> Empty;
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpResponse(ContentType));
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(!m_Response);
+ m_Response.reset(new HttpResponse(ContentType));
+
+ IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList);
+}
+
+void
+HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ ZEN_ASSERT(!m_Response);
+
+ // Not one bit async, innit
+ ContinuationHandler(*this);
+}
+
+bool
+HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges)
+{
+ return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpAsioServerImpl::HttpAsioServerImpl()
+{
+}
+
+HttpAsioServerImpl::~HttpAsioServerImpl()
+{
+}
+
+int
+HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount)
+{
+ ZEN_ASSERT(ThreadCount > 0);
+
+ ZEN_INFO("starting asio http with {} service threads", ThreadCount);
+
+ m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port));
+ m_Acceptor->Start();
+
+ for (int i = 0; i < ThreadCount; ++i)
+ {
+ m_ThreadPool.emplace_back([this, Index = i + 1] {
+ SetCurrentThreadName(fmt::format("asio worker {}", Index));
+
+ try
+ {
+ m_IoService.run();
+ }
+ catch (std::exception& e)
+ {
+ ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what());
+ }
+ });
+ }
+
+ ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort());
+
+ return m_Acceptor->GetAcceptPort();
+}
+
+void
+HttpAsioServerImpl::Stop()
+{
+ m_Acceptor->Stop();
+ m_IoService.stop();
+ for (auto& Thread : m_ThreadPool)
+ {
+ Thread.join();
+ }
+}
+
+void
+HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
+{
+ std::string_view UrlPath(InUrlPath);
+ Service.SetUriPrefixLength(UrlPath.size());
+ if (!UrlPath.empty() && UrlPath.back() == '/')
+ {
+ UrlPath.remove_suffix(1);
+ }
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_UriHandlers.push_back({std::string(UrlPath), &Service});
+}
+
+HttpService*
+HttpAsioServerImpl::RouteRequest(std::string_view Url)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ HttpService* CandidateService = nullptr;
+ std::string::size_type CandidateMatchSize = 0;
+ for (const ServiceEntry& SvcEntry : m_UriHandlers)
+ {
+ const std::string& SvcUrl = SvcEntry.ServiceUrlPath;
+ const std::string::size_type SvcUrlSize = SvcUrl.size();
+ if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 &&
+ ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/')))
+ {
+ CandidateMatchSize = SvcUrl.size();
+ CandidateService = SvcEntry.Service;
+ }
+ }
+
+ return CandidateService;
+}
+
+} // namespace zen::asio_http
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
+{
+ ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest));
+}
+
+HttpAsioServer::~HttpAsioServer()
+{
+ try
+ {
+ m_Impl->Stop();
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("Caught exception stopping http asio server: {}", ex.what());
+ }
+}
+
+void
+HttpAsioServer::RegisterService(HttpService& Service)
+{
+ m_Impl->RegisterService(Service.BaseUri(), Service);
+}
+
+int
+HttpAsioServer::Initialize(int BasePort)
+{
+ m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u));
+ return m_BasePort;
+}
+
+void
+HttpAsioServer::Run(bool IsInteractive)
+{
+ const bool TestMode = !IsInteractive;
+
+ int WaitTimeout = -1;
+ if (!TestMode)
+ {
+ WaitTimeout = 1000;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit");
+ }
+
+ do
+ {
+ if (!TestMode && _kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#else
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit");
+ }
+
+ do
+ {
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#endif
+}
+
+void
+HttpAsioServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+
+} // namespace zen