// Copyright Epic Games, Inc. All Rights Reserved. #include "httpasio.h" #include #include #include #include #include ZEN_THIRD_PARTY_INCLUDES_START #if ZEN_PLATFORM_WINDOWS # include # include #endif #include #include ZEN_THIRD_PARTY_INCLUDES_END #define ASIO_VERBOSE_TRACE 0 #if ASIO_VERBOSE_TRACE # define ZEN_TRACE_VERBOSE ZEN_TRACE #else # define ZEN_TRACE_VERBOSE(fmtstr, ...) #endif namespace zen::asio_http { using namespace std::literals; struct HttpAcceptor; struct HttpRequest; struct HttpResponse; struct HttpServerConnection; static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); inline spdlog::logger& InitLogger() { spdlog::logger& Logger = logging::Get("asio"); // Logger.set_level(spdlog::level::trace); return Logger; } inline spdlog::logger& Log() { static spdlog::logger& g_Logger = InitLogger(); return g_Logger; } ////////////////////////////////////////////////////////////////////////// struct HttpAsioServerImpl { public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); int Start(uint16_t Port, int ThreadCount); void Stop(); void RegisterService(const char* UrlPath, HttpService& Service); HttpService* RouteRequest(std::string_view Url); asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr m_Acceptor; std::vector m_ThreadPool; struct ServiceEntry { std::string ServiceUrlPath; HttpService* Service; }; RwLock m_Lock; std::vector m_UriHandlers; }; /** * This is the class which request handlers use to interact with the server instance */ class HttpAsioServerRequest : public HttpServerRequest { public: HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer); ~HttpAsioServerRequest(); virtual Oid ParseSessionId() const override; virtual uint32_t ParseRequestId() const override; virtual IoBuffer ReadPayload() override; virtual void WriteResponse(HttpResponseCode ResponseCode) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) override; virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; virtual void WriteResponseAsync(std::function&& ContinuationHandler) override; using HttpServerRequest::WriteResponse; HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; asio_http::HttpRequest& m_Request; IoBuffer m_PayloadBuffer; std::unique_ptr m_Response; }; struct HttpRequest { explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {} void Initialize(); size_t ConsumeData(const char* InputData, size_t DataSize); void ResetState(); HttpVerb RequestVerb() const { return m_RequestVerb; } bool IsKeepAlive() const { return m_KeepAlive; } std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } IoBuffer Body() { return m_BodyBuffer; } inline HttpContentType ContentType() { if (m_ContentTypeHeaderIndex < 0) { return HttpContentType::kUnknownContentType; } return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); } inline HttpContentType AcceptType() { if (m_AcceptHeaderIndex < 0) { return HttpContentType::kUnknownContentType; } return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); } Oid SessionId() const { return m_SessionId; } int RequestId() const { return m_RequestId; } private: struct HeaderEntry { HeaderEntry() = default; HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} std::string_view Name; std::string_view Value; }; HttpServerConnection& m_Connection; char* m_HeaderCursor = m_HeaderBuffer; char* m_Url = nullptr; size_t m_UrlLength = 0; char* m_QueryString = nullptr; size_t m_QueryLength = 0; char* m_CurrentHeaderName = nullptr; // Used while parsing headers size_t m_CurrentHeaderNameLength = 0; char* m_CurrentHeaderValue = nullptr; // Used while parsing headers size_t m_CurrentHeaderValueLength = 0; std::vector m_Headers; int8_t m_ContentLengthHeaderIndex; int8_t m_AcceptHeaderIndex; int8_t m_ContentTypeHeaderIndex; HttpVerb m_RequestVerb; bool m_KeepAlive = false; bool m_Expect100Continue = false; int m_RequestId = -1; Oid m_SessionId{}; IoBuffer m_BodyBuffer; uint64_t m_BodyPosition = 0; http_parser m_Parser; char m_HeaderBuffer[1024]; std::string m_NormalizedUrl; void AppendInputBytes(const char* Data, size_t Bytes); void AppendCurrentHeader(); int OnMessageBegin(); int OnUrl(const char* Data, size_t Bytes); int OnHeader(const char* Data, size_t Bytes); int OnHeaderValue(const char* Data, size_t Bytes); int OnHeadersComplete(); int OnBody(const char* Data, size_t Bytes); int OnMessageComplete(); static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast(Parser->data); } static http_parser_settings s_ParserSettings; void TerminateConnection(); }; struct HttpResponse { public: HttpResponse() = default; explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} void InitializeForPayload(uint16_t ResponseCode, std::span BlobList) { m_ResponseCode = ResponseCode; const uint32_t ChunkCount = gsl::narrow(BlobList.size()); m_DataBuffers.reserve(ChunkCount); for (IoBuffer& Buffer : BlobList) { #if 1 m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); #else IoBuffer TempBuffer = std::move(Buffer); TempBuffer.MakeOwned(); m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); #endif } uint64_t LocalDataSize = 0; m_AsioBuffers.push_back({}); // Placeholder for header for (IoBuffer& Buffer : m_DataBuffers) { uint64_t BufferDataSize = Buffer.Size(); ZEN_ASSERT(BufferDataSize); LocalDataSize += BufferDataSize; IoBufferFileReference FileRef; if (Buffer.GetFileReference(/* out */ FileRef)) { // TODO: Use direct file transfer, via TransmitFile/sendfile // // this looks like it requires some custom asio plumbing however m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); } else { // Send from memory m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); } } m_ContentLength = LocalDataSize; auto Headers = GetHeaders(); m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); } uint16_t ResponseCode() const { return m_ResponseCode; } uint64_t ContentLength() const { return m_ContentLength; } const std::vector& AsioBuffers() const { return m_AsioBuffers; } std::string_view GetHeaders() { m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" << "Content-Length: " << ContentLength() << "\r\n"sv; if (!m_IsKeepAlive) { m_Headers << "Connection: close\r\n"sv; } m_Headers << "\r\n"sv; return m_Headers; } void SuppressPayload() { m_AsioBuffers.resize(1); } private: uint16_t m_ResponseCode = 0; bool m_IsKeepAlive = true; HttpContentType m_ContentType = HttpContentType::kBinary; uint64_t m_ContentLength = 0; std::vector m_DataBuffers; std::vector m_AsioBuffers; ExtendableStringBuilder<160> m_Headers; }; ////////////////////////////////////////////////////////////////////////// struct HttpServerConnection : std::enable_shared_from_this { HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr&& Socket); ~HttpServerConnection(); void HandleNewRequest(); void TerminateConnection(); void HandleRequest(); std::shared_ptr AsSharedPtr() { return shared_from_this(); } private: enum class RequestState { kInitialState, kInitialRead, kReadingMore, kWriting, kWritingFinal, kDone, kTerminated }; RequestState m_RequestState = RequestState::kInitialState; HttpRequest m_RequestData{*this}; void EnqueueRead(); void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); void OnError(); HttpAsioServerImpl& m_Server; asio::streambuf m_RequestBuffer; std::unique_ptr m_Socket; std::atomic m_RequestCounter{0}; uint32_t m_ConnectionId = 0; Ref m_PackageHandler; RwLock m_ResponsesLock; std::deque> m_Responses; }; std::atomic g_ConnectionIdCounter{0}; HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr&& Socket) : m_Server(Server) , m_Socket(std::move(Socket)) , m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) { ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); } HttpServerConnection::~HttpServerConnection() { ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); } void HttpServerConnection::HandleNewRequest() { m_RequestData.Initialize(); EnqueueRead(); } void HttpServerConnection::TerminateConnection() { m_RequestState = RequestState::kTerminated; std::error_code Ec; m_Socket->close(Ec); } void HttpServerConnection::EnqueueRead() { if (m_RequestState == RequestState::kInitialRead) { m_RequestState = RequestState::kReadingMore; } else { m_RequestState = RequestState::kInitialRead; } m_RequestBuffer.prepare(64 * 1024); asio::async_read(*m_Socket.get(), m_RequestBuffer, asio::transfer_at_least(1), [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); } void HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) { if (Ec) { if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) { ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message()); return; } else { ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); return OnError(); } } ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed), zen::GetCurrentThreadId(), NiceBytes(ByteCount)); while (m_RequestBuffer.size()) { const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); m_RequestBuffer.consume(Result); } switch (m_RequestState) { case RequestState::kDone: case RequestState::kWritingFinal: case RequestState::kTerminated: break; default: EnqueueRead(); break; } } void HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) { if (Ec) { ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); OnError(); } else { ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed), zen::GetCurrentThreadId(), NiceBytes(ByteCount)); if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kDone; m_Socket->close(); } else { if (Pop) { RwLock::ExclusiveLockScope _(m_ResponsesLock); m_Responses.pop_front(); } m_RequestCounter.fetch_add(1); } } } void HttpServerConnection::OnError() { m_Socket->close(); } void HttpServerConnection::HandleRequest() { if (!m_RequestData.IsKeepAlive()) { m_RequestState = RequestState::kWritingFinal; std::error_code Ec; m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); if (Ec) { ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); } } else { m_RequestState = RequestState::kWriting; } if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) { HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { try { Service->HandleRequest(Request); } catch (std::exception& ex) { ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); } } if (std::unique_ptr Response = std::move(Request.m_Response)) { // Transmit the response if (m_RequestData.RequestVerb() == HttpVerb::kHead) { Response->SuppressPayload(); } auto ResponseBuffers = Response->AsioBuffers(); uint64_t ResponseLength = 0; for (auto& Buffer : ResponseBuffers) { ResponseLength += Buffer.size(); } { RwLock::ExclusiveLockScope _(m_ResponsesLock); m_Responses.push_back(std::move(Response)); } // TODO: should cork/uncork for Linux? asio::async_write(*m_Socket.get(), ResponseBuffers, asio::transfer_exactly(ResponseLength), [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount, true); }); return; } } if (m_RequestData.RequestVerb() == HttpVerb::kHead) { std::string_view Response = "HTTP/1.1 404 NOT FOUND\r\n" "\r\n"sv; if (!m_RequestData.IsKeepAlive()) { Response = "HTTP/1.1 404 NOT FOUND\r\n" "Connection: close\r\n" "\r\n"sv; } asio::async_write( *m_Socket.get(), asio::buffer(Response), [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); } else { std::string_view Response = "HTTP/1.1 404 NOT FOUND\r\n" "Content-Length: 23\r\n" "Content-Type: text/plain\r\n" "\r\n" "No suitable route found"sv; if (!m_RequestData.IsKeepAlive()) { Response = "HTTP/1.1 404 NOT FOUND\r\n" "Content-Length: 23\r\n" "Content-Type: text/plain\r\n" "Connection: close\r\n" "\r\n" "No suitable route found"sv; } asio::async_write( *m_Socket.get(), asio::buffer(Response), [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); } } ////////////////////////////////////////////////////////////////////////// // // HttpRequest // http_parser_settings HttpRequest::s_ParserSettings{ .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, .on_status = [](http_parser* p, const char* Data, size_t ByteCount) { ZEN_UNUSED(p, Data, ByteCount); return 0; }, .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, .on_chunk_header{}, .on_chunk_complete{}}; void HttpRequest::Initialize() { http_parser_init(&m_Parser, HTTP_REQUEST); m_Parser.data = this; ResetState(); } size_t HttpRequest::ConsumeData(const char* InputData, size_t DataSize) { const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) { ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); TerminateConnection(); return DataSize; } return ConsumedBytes; } void HttpRequest::AppendInputBytes(const char* Data, size_t Bytes) { const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; if (RemainingBufferSpace >= Bytes) { memcpy(m_HeaderCursor, Data, Bytes); m_HeaderCursor += Bytes; return; } // Terribad, but better than buffer overflow TerminateConnection(); } int HttpRequest::OnUrl(const char* Data, size_t Bytes) { if (!m_Url) { ZEN_ASSERT_SLOW(m_UrlLength == 0); m_Url = m_HeaderCursor; } AppendInputBytes(Data, Bytes); m_UrlLength += Bytes; return 0; } int HttpRequest::OnHeader(const char* Data, size_t Bytes) { if (m_CurrentHeaderValueLength) { AppendCurrentHeader(); m_CurrentHeaderNameLength = 0; m_CurrentHeaderValueLength = 0; m_CurrentHeaderName = m_HeaderCursor; } else if (m_CurrentHeaderName == nullptr) { m_CurrentHeaderName = m_HeaderCursor; } memcpy(m_HeaderCursor, Data, Bytes); m_HeaderCursor += Bytes; m_CurrentHeaderNameLength += Bytes; return 0; } void HttpRequest::AppendCurrentHeader() { std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); if (HeaderHash == HashContentLength) { m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); } else if (HeaderHash == HashAccept) { m_AcceptHeaderIndex = (int8_t)m_Headers.size(); } else if (HeaderHash == HashContentType) { m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); } else if (HeaderHash == HashSession) { m_SessionId = Oid::FromHexString(HeaderValue); } else if (HeaderHash == HashRequest) { std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); } else if (HeaderHash == HashExpect) { if (HeaderValue == "100-continue"sv) { // We don't currently do anything with this m_Expect100Continue = true; } else { ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); } } m_Headers.emplace_back(HeaderName, HeaderValue); } int HttpRequest::OnHeaderValue(const char* Data, size_t Bytes) { if (m_CurrentHeaderValueLength == 0) { m_CurrentHeaderValue = m_HeaderCursor; } memcpy(m_HeaderCursor, Data, Bytes); m_HeaderCursor += Bytes; m_CurrentHeaderValueLength += Bytes; return 0; } void HttpRequest::TerminateConnection() { m_Connection.TerminateConnection(); } static void NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) { bool LastCharWasSeparator = false; for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) { const char UrlChar = Url[UrlIndex]; const bool IsSeparator = (UrlChar == '/'); if (IsSeparator && LastCharWasSeparator) { if (NormalizedUrl.empty()) { NormalizedUrl.reserve(UrlLength); NormalizedUrl.append(Url, UrlIndex); } if (!LastCharWasSeparator) { NormalizedUrl.push_back('/'); } } else if (!NormalizedUrl.empty()) { NormalizedUrl.push_back(UrlChar); } LastCharWasSeparator = IsSeparator; } } int HttpRequest::OnHeadersComplete() { if (m_CurrentHeaderValueLength) { AppendCurrentHeader(); } if (m_ContentLengthHeaderIndex >= 0) { std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; uint64_t ContentLength = 0; std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); if (ContentLength) { m_BodyBuffer = IoBuffer(ContentLength); } m_BodyBuffer.SetContentType(ContentType()); m_BodyPosition = 0; } m_KeepAlive = !!http_should_keep_alive(&m_Parser); switch (m_Parser.method) { case HTTP_GET: m_RequestVerb = HttpVerb::kGet; break; case HTTP_POST: m_RequestVerb = HttpVerb::kPost; break; case HTTP_PUT: m_RequestVerb = HttpVerb::kPut; break; case HTTP_DELETE: m_RequestVerb = HttpVerb::kDelete; break; case HTTP_HEAD: m_RequestVerb = HttpVerb::kHead; break; case HTTP_COPY: m_RequestVerb = HttpVerb::kCopy; break; case HTTP_OPTIONS: m_RequestVerb = HttpVerb::kOptions; break; default: ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); break; } std::string_view Url(m_Url, m_UrlLength); if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) { m_UrlLength = QuerySplit; m_QueryString = m_Url + QuerySplit + 1; m_QueryLength = Url.size() - QuerySplit - 1; } NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); return 0; } int HttpRequest::OnBody(const char* Data, size_t Bytes) { memcpy(reinterpret_cast(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); m_BodyPosition += Bytes; if (http_body_is_final(&m_Parser)) { if (m_BodyPosition != m_BodyBuffer.Size()) { ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); } } return 0; } void HttpRequest::ResetState() { m_HeaderCursor = m_HeaderBuffer; m_CurrentHeaderName = nullptr; m_CurrentHeaderNameLength = 0; m_CurrentHeaderValue = nullptr; m_CurrentHeaderValueLength = 0; m_CurrentHeaderName = nullptr; m_Url = nullptr; m_UrlLength = 0; m_QueryString = nullptr; m_QueryLength = 0; m_ContentLengthHeaderIndex = -1; m_AcceptHeaderIndex = -1; m_ContentTypeHeaderIndex = -1; m_Expect100Continue = false; m_BodyBuffer = {}; m_BodyPosition = 0; m_Headers.clear(); m_NormalizedUrl.clear(); } int HttpRequest::OnMessageBegin() { return 0; } int HttpRequest::OnMessageComplete() { m_Connection.HandleRequest(); ResetState(); return 0; } ////////////////////////////////////////////////////////////////////////// struct HttpAcceptor { HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t Port) : m_Server(Server) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) { m_Acceptor.set_option(asio::ip::v6_only(false)); m_Acceptor.set_option(asio::socket_base::reuse_address(true)); m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); asio::error_code BindErrorCode; m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), Port), BindErrorCode); if (BindErrorCode == asio::error::access_denied) { m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), 0)); } #if ZEN_PLATFORM_WINDOWS // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. // This must be used by both the client and server side, and is only effective in the absence of // Windows Filtering Platform (WFP) callouts which can be installed by security software. // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path SOCKET NativeSocket = m_Acceptor.native_handle(); int LoopbackOptionValue = 1; DWORD OptionNumberOfBytesReturned = 0; WSAIoctl(NativeSocket, SIO_LOOPBACK_FAST_PATH, &LoopbackOptionValue, sizeof(LoopbackOptionValue), NULL, 0, &OptionNumberOfBytesReturned, 0, 0); #endif m_Acceptor.listen(); } void Start() { m_Acceptor.listen(); InitAccept(); } void Stop() { m_IsStopped = true; } void InitAccept() { auto SocketPtr = std::make_unique(m_IoService); asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { if (Ec) { ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", m_Acceptor.local_endpoint().address().to_string(), m_Acceptor.local_endpoint().port(), Ec.message()); } else { // New connection established, pass socket ownership into connection object // and initiate request handling loop. The connection lifetime is // managed by the async read/write loop by passing the shared // reference to the callbacks. Socket->set_option(asio::ip::tcp::no_delay(true)); Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); auto Conn = std::make_shared(m_Server, std::move(Socket)); Conn->HandleNewRequest(); } if (!m_IsStopped.load()) { InitAccept(); } else { m_Acceptor.close(); } }); } int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } private: HttpAsioServerImpl& m_Server; asio::io_service& m_IoService; asio::ip::tcp::acceptor m_Acceptor; std::atomic m_IsStopped{false}; }; ////////////////////////////////////////////////////////////////////////// HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer) : m_Request(Request) , m_PayloadBuffer(std::move(PayloadBuffer)) { const int PrefixLength = Service.UriPrefixLength(); std::string_view Uri = Request.Url(); Uri.remove_prefix(std::min(PrefixLength, static_cast(Uri.size()))); m_Uri = Uri; m_QueryString = Request.QueryString(); m_Verb = Request.RequestVerb(); m_ContentLength = Request.Body().Size(); m_ContentType = Request.ContentType(); HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; // Parse any extension, to allow requesting a particular response encoding via the URL { std::string_view UriSuffix8{m_Uri}; const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); if (LastComponentIndex != std::string_view::npos) { UriSuffix8.remove_prefix(LastComponentIndex); } const size_t LastDotIndex = UriSuffix8.find_last_of('.'); if (LastDotIndex != std::string_view::npos) { UriSuffix8.remove_prefix(LastDotIndex + 1); AcceptContentType = ParseContentType(UriSuffix8); if (AcceptContentType != HttpContentType::kUnknownContentType) { m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); } } } // It an explicit content type extension was specified then we'll use that over any // Accept: header value that may be present if (AcceptContentType != HttpContentType::kUnknownContentType) { m_AcceptType = AcceptContentType; } else { m_AcceptType = Request.AcceptType(); } } HttpAsioServerRequest::~HttpAsioServerRequest() { } Oid HttpAsioServerRequest::ParseSessionId() const { return m_Request.SessionId(); } uint32_t HttpAsioServerRequest::ParseRequestId() const { return m_Request.RequestId(); } IoBuffer HttpAsioServerRequest::ReadPayload() { return m_PayloadBuffer; } void HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) { ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(HttpContentType::kBinary)); std::array Empty; m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); } void HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span Blobs) { ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType)); m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); } void HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) { ZEN_ASSERT(!m_Response); m_Response.reset(new HttpResponse(ContentType)); IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); std::array SingleBufferList({MessageBuffer}); m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); } void HttpAsioServerRequest::WriteResponseAsync(std::function&& ContinuationHandler) { ZEN_ASSERT(!m_Response); // Not one bit async, innit ContinuationHandler(*this); } ////////////////////////////////////////////////////////////////////////// HttpAsioServerImpl::HttpAsioServerImpl() { } HttpAsioServerImpl::~HttpAsioServerImpl() { } int HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) { ZEN_ASSERT(ThreadCount > 0); ZEN_INFO("starting asio http with {} service threads", ThreadCount); m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); m_Acceptor->Start(); for (int i = 0; i < ThreadCount; ++i) { m_ThreadPool.emplace_back([this, Index = i + 1] { SetCurrentThreadName(fmt::format("asio worker {}", Index)); try { m_IoService.run(); } catch (std::exception& e) { ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); } }); } ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); return m_Acceptor->GetAcceptPort(); } void HttpAsioServerImpl::Stop() { m_Acceptor->Stop(); m_IoService.stop(); for (auto& Thread : m_ThreadPool) { Thread.join(); } } void HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) { std::string_view UrlPath(InUrlPath); Service.SetUriPrefixLength(UrlPath.size()); if (!UrlPath.empty() && UrlPath.back() == '/') { UrlPath.remove_suffix(1); } RwLock::ExclusiveLockScope _(m_Lock); m_UriHandlers.push_back({std::string(UrlPath), &Service}); } HttpService* HttpAsioServerImpl::RouteRequest(std::string_view Url) { RwLock::SharedLockScope _(m_Lock); HttpService* CandidateService = nullptr; std::string::size_type CandidateMatchSize = 0; for (const ServiceEntry& SvcEntry : m_UriHandlers) { const std::string& SvcUrl = SvcEntry.ServiceUrlPath; const std::string::size_type SvcUrlSize = SvcUrl.size(); if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) { CandidateMatchSize = SvcUrl.size(); CandidateService = SvcEntry.Service; } } return CandidateService; } } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// namespace zen { HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique()) { ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest)); } HttpAsioServer::~HttpAsioServer() { m_Impl->Stop(); } void HttpAsioServer::RegisterService(HttpService& Service) { m_Impl->RegisterService(Service.BaseUri(), Service); } int HttpAsioServer::Initialize(int BasePort) { m_BasePort = m_Impl->Start(gsl::narrow(BasePort), Clamp(std::thread::hardware_concurrency(), 8u, 64u)); return m_BasePort; } void HttpAsioServer::Run(bool IsInteractive) { const bool TestMode = !IsInteractive; int WaitTimeout = -1; if (!TestMode) { WaitTimeout = 1000; } #if ZEN_PLATFORM_WINDOWS if (TestMode == false) { zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); } do { if (!TestMode && _kbhit() != 0) { char c = (char)_getch(); if (c == 27 || c == 'Q' || c == 'q') { RequestApplicationExit(0); } } m_ShutdownEvent.Wait(WaitTimeout); } while (!IsApplicationExitRequested()); #else if (TestMode == false) { zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); } do { m_ShutdownEvent.Wait(WaitTimeout); } while (!IsApplicationExitRequested()); #endif } void HttpAsioServer::RequestExit() { m_ShutdownEvent.Set(); } } // namespace zen