diff options
| author | Stefan Boberg <[email protected]> | 2021-10-14 19:07:14 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2021-10-14 19:07:14 +0200 |
| commit | 2b71d6a8d57c773bc7734b253a1ffd1e47162184 (patch) | |
| tree | c0c70f9f2f8b9dc895080aac9f7de1140c56ebf0 /zenhttp/httpasio.cpp | |
| parent | Merge branch 'main' of https://github.com/EpicGames/zen (diff) | |
| download | zen-2b71d6a8d57c773bc7734b253a1ffd1e47162184.tar.xz zen-2b71d6a8d57c773bc7734b253a1ffd1e47162184.zip | |
asio HTTP implementation (#23)
asio-based HTTP implementation
Diffstat (limited to 'zenhttp/httpasio.cpp')
| -rw-r--r-- | zenhttp/httpasio.cpp | 1125 |
1 files changed, 1125 insertions, 0 deletions
diff --git a/zenhttp/httpasio.cpp b/zenhttp/httpasio.cpp new file mode 100644 index 000000000..015d32633 --- /dev/null +++ b/zenhttp/httpasio.cpp @@ -0,0 +1,1125 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpasio.h" + +#include <zenhttp/httpserver.h> + +#include <conio.h> +#include <zencore/logging.h> + +#include <http_parser.h> +#include <asio.hpp> +#include <deque> + +namespace zen::asio_http { + +using namespace std::literals; +using namespace fmt::literals; + +struct HttpAcceptor; +struct HttpRequest; +struct HttpResponse; +struct HttpServerConnection; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); + +inline spdlog::logger& +InitLogger() +{ + spdlog::logger& Logger = logging::Get("asio"); +// Logger.set_level(spdlog::level::trace); + return Logger; +} + +inline spdlog::logger& +Log() +{ + static spdlog::logger& g_Logger = InitLogger(); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAsioServerImpl +{ +public: + HttpAsioServerImpl(); + ~HttpAsioServerImpl(); + + void Start(uint16_t Port, int ThreadCount); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpAsioServerRequest : public HttpServerRequest +{ +public: + HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpAsioServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + + using HttpServerRequest::WriteResponse; + + HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; + HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; + + asio_http::HttpRequest& m_Request; + IoBuffer m_PayloadBuffer; + std::string_view m_QueryString; + std::unique_ptr<HttpResponse> m_Response; +}; + +struct HttpRequest +{ + explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {} + + void Initialize(); + size_t ConsumeData(const char* InputData, size_t DataSize); + + void ResetState(); + + HttpVerb RequestVerb() const { return m_RequestVerb; } + bool IsKeepAlive() const { return m_KeepAlive; } + std::string_view Url() const { return std::string_view(m_Url, m_UrlLength); } + IoBuffer Body() { return m_BodyBuffer; } + + inline HttpContentType ContentType() + { + if (m_ContentTypeHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); + } + + inline HttpContentType AcceptType() + { + if (m_AcceptHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); + } + + Oid SessionId() const { return m_SessionId; } + int RequestId() const { return m_RequestId; } + +private: + struct HeaderEntry + { + std::string_view Name; + std::string_view Value; + }; + + HttpServerConnection& m_Connection; + char m_HeaderBuffer[512]; + char* m_HeaderCursor = m_HeaderBuffer; + char* m_Url = nullptr; + size_t m_UrlLength = 0; + char* m_CurrentHeaderName = nullptr; // Used while parsing headers + size_t m_CurrentHeaderNameLength = 0; + char* m_CurrentHeaderValue = nullptr; // Used while parsing headers + size_t m_CurrentHeaderValueLength = 0; + std::vector<HeaderEntry> m_Headers; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition; + http_parser m_Parser; + HttpVerb m_RequestVerb; + bool m_KeepAlive = false; + + void AppendInputBytes(const char* Data, size_t Bytes); + void AppendCurrentHeader(); + + int OnMessageBegin(); + int OnUrl(const char* Data, size_t Bytes); + int OnHeader(const char* Data, size_t Bytes); + int OnHeaderValue(const char* Data, size_t Bytes); + int OnHeadersComplete(); + int OnBody(const char* Data, size_t Bytes); + int OnMessageComplete(); + + static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); } + static http_parser_settings s_ParserSettings; + + void TerminateConnection(); +}; + +struct HttpResponse +{ +public: + HttpResponse() = default; + explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) + { + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { +#if 1 + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); +#else + IoBuffer TempBuffer = std::move(Buffer); + TempBuffer.MakeOwned(); + m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); +#endif + } + + uint64_t LocalDataSize = 0; + + m_AsioBuffers.push_back({}); // Placeholder for header + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + // + // this looks like it requires some custom asio plumbing however + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + else + { + // Send from memory + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + } + m_ContentLength = LocalDataSize; + + auto Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + } + + uint16_t ResponseCode() const { return m_ResponseCode; } + uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; } + + std::string_view GetHeaders() + { + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"; + } + + m_Headers << "Date: Mon, 11 Oct 2021 15:06:32 GMT\r\n\r\n"; // TODO: produce more believable data + + return m_Headers; + } + + void SuppressPayload() { m_AsioBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_DataBuffers; + std::vector<asio::const_buffer> m_AsioBuffers; + ExtendableStringBuilder<160> m_Headers; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpServerConnection +{ + HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); + ~HttpServerConnection(); + + void HandleNewRequest(); + void TerminateConnection(); + void HandleRequest(); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, + kWritingFinal, + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequest m_RequestData{*this}; + + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); + void OnError(); + + HttpAsioServerImpl& m_Server; + asio::streambuf m_RequestBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::atomic<uint32_t> m_RequestCounter{0}; + uint32_t m_ConnectionId = 0; + + RwLock m_ResponsesLock; + std::deque<std::unique_ptr<HttpResponse>> m_Responses; +}; + +std::atomic<uint32_t> g_ConnectionIdCounter{0}; + +HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +: m_Server(Server) +, m_Socket(std::move(Socket)) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) +{ + ZEN_TRACE("new connection #{}", m_ConnectionId); +} + +HttpServerConnection::~HttpServerConnection() +{ +} + +void +HttpServerConnection::HandleNewRequest() +{ + m_RequestData.Initialize(); + + EnqueueRead(); +} + +void +HttpServerConnection::TerminateConnection() +{ + m_RequestState = RequestState::kTerminated; + + std::error_code Ec; + m_Socket->close(Ec); +} + +void +HttpServerConnection::EnqueueRead() +{ + if (m_RequestState == RequestState::kInitialRead) + { + m_RequestState = RequestState::kReadingMore; + } + else + { + m_RequestState = RequestState::kInitialRead; + } + + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read(*m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(16), + [this](const asio::error_code& Ec, std::size_t ByteCount) { + if (Ec) + { + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + { + // Expected, just silently handle the condition + // + // if we get an I/O error on the boundary between two messages + // it should be fine to just not say anything + + ZEN_TRACE("(expected) socket read error: conn#{} '{}'", m_ConnectionId, Ec.message()); + } + else + { + ZEN_WARN("unexpected socket read error: conn#{} {}", m_ConnectionId, Ec.message()); + } + + delete this; + } + else + { + ZEN_TRACE("read: conn#:{} seq#:{} t:{} bytes:{}", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + GetCurrentThreadId(), + ByteCount); + + OnDataReceived(Ec, ByteCount); + } + }); +} + +void +HttpServerConnection::OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount) +{ + ZEN_UNUSED(ByteCount); + + if (Ec) + { + ZEN_ERROR("OnDataReceived Error: {}", Ec.message()); + + return OnError(); + } + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + + size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); + + m_RequestBuffer.consume(Result); + } + + switch (m_RequestState) + { + case RequestState::kDone: + case RequestState::kWritingFinal: + case RequestState::kTerminated: + break; + + default: + EnqueueRead(); + break; + } +} + +void +HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop) +{ + ZEN_UNUSED(ByteCount); + if (Ec) + { + ZEN_ERROR("OnResponseDataSent Error: {}", Ec.message()); + return OnError(); + } + else + { + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kDone; + + m_Socket->close(); + } + else if (Pop) + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.pop_front(); + } + } +} + +void +HttpServerConnection::OnError() +{ + m_Socket->close(); +} + +void +HttpServerConnection::HandleRequest() +{ + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kWritingFinal; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + + if (Ec) + { + ZEN_WARN("socket shutdown reported error: {}", Ec.message()); + } + } + else + { + m_RequestState = RequestState::kWriting; + } + + const uint32_t RequestNum = m_RequestCounter.load(std::memory_order_relaxed); + m_RequestCounter.fetch_add(1); + + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); + + ZEN_TRACE("Handling request: Conn#{} Req#{}", m_ConnectionId, RequestNum); + + Service->HandleRequest(Request); + + if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + auto ResponseBuffers = Response->AsioBuffers(); + + uint64_t ResponseLength = 0; + + for (auto& Buffer : ResponseBuffers) + { + ResponseLength += Buffer.size(); + } + + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.push_back(std::move(Response)); + } + + // TODO: should cork/uncork for Linux? + + asio::async_write(*m_Socket.get(), + ResponseBuffers, + asio::transfer_exactly(ResponseLength), + [this, RequestNum](const asio::error_code& Ec, std::size_t ByteCount) { + ZEN_TRACE("Response sent: Conn#{} Req#{} ({})", m_ConnectionId, RequestNum, NiceBytes(ByteCount)); + OnResponseDataSent(Ec, ByteCount, true); + }); + + return; + } + } + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Date: Mon, 11 Oct 2021 15:06:32 GMT\r\n\r\n"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Date: Mon, 11 Oct 2021 15:06:32 GMT\r\nConnection: close\r\n\r\n"sv; + } + + asio::async_write(*m_Socket.get(), asio::buffer(Response), [this](const asio::error_code& Ec, std::size_t ByteCount) { + OnResponseDataSent(Ec, ByteCount); + }); + } + else + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\nContent-Type: text/plain\r\nDate: Mon, 11 Oct 2021 15:06:32 GMT\r\n\r\nNo suitable route found"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\nContent-Type: text/plain\r\nDate: Mon, 11 Oct 2021 15:06:32 GMT\r\nConnection: close\r\n\r\nNo suitable route found"sv; + } + + asio::async_write(*m_Socket.get(), asio::buffer(Response), [this](const asio::error_code& Ec, std::size_t ByteCount) { + OnResponseDataSent(Ec, ByteCount); + }); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpRequest +// + +http_parser_settings HttpRequest::s_ParserSettings{ + .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, + .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, + .on_status = + [](http_parser* p, const char* Data, size_t ByteCount) { + ZEN_UNUSED(p, Data, ByteCount); + return 0; + }, + .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, + .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; + +void +HttpRequest::Initialize() +{ + http_parser_init(&m_Parser, HTTP_REQUEST); + m_Parser.data = this; + + ResetState(); +} + +size_t +HttpRequest::ConsumeData(const char* InputData, size_t DataSize) +{ + const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); + + if (m_Parser.http_errno) + { + ZEN_WARN("HTTP parser error {}", (uint32_t)m_Parser.http_errno); + + // TODO: we need to kill the connection since we're most likely + // out of sync and can't make progress + } + + return ConsumedBytes; +} + +void +HttpRequest::AppendInputBytes(const char* Data, size_t Bytes) +{ + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + + if (RemainingBufferSpace >= Bytes) + { + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + + return; + } + + // Terribad, but better than buffer overflow + TerminateConnection(); +} + +int +HttpRequest::OnUrl(const char* Data, size_t Bytes) +{ + if (!m_Url) + { + ZEN_ASSERT_SLOW(m_UrlLength == 0); + m_Url = m_HeaderCursor; + } + + AppendInputBytes(Data, Bytes); + m_UrlLength += Bytes; + + return 0; +} + +int +HttpRequest::OnHeader(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = m_HeaderCursor; + } + else if (m_CurrentHeaderName == nullptr) + { + m_CurrentHeaderName = m_HeaderCursor; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderNameLength += Bytes; + + return 0; +} + +void +HttpRequest::AppendCurrentHeader() +{ + std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); + std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); + + const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + + if (HeaderHash == HashContentLength) + { + m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashAccept) + { + m_AcceptHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashContentType) + { + m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashSession) + { + m_SessionId = Oid::FromHexString(HeaderValue); + } + else if (HeaderHash == HashRequest) + { + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + } + + m_Headers.emplace_back(HeaderName, HeaderValue); +} + +int +HttpRequest::OnHeaderValue(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength == 0) + { + m_CurrentHeaderValue = m_HeaderCursor; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderValueLength += Bytes; + + return 0; +} + +void +HttpRequest::TerminateConnection() +{ + m_Connection.TerminateConnection(); +} + +int +HttpRequest::OnHeadersComplete() +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + } + + if (m_ContentLengthHeaderIndex >= 0) + { + std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; + uint64_t ContentLength = 0; + std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); + + if (ContentLength) + { + m_BodyBuffer = IoBuffer(ContentLength); + } + else + { + m_BodyBuffer = {}; + } + + m_BodyBuffer.SetContentType(ContentType()); + + m_BodyPosition = 0; + } + + m_KeepAlive = !!http_should_keep_alive(&m_Parser); + + return 0; +} + +int +HttpRequest::OnBody(const char* Data, size_t Bytes) +{ + memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); + m_BodyPosition += Bytes; + + if (http_body_is_final(&m_Parser)) + { + if (m_BodyPosition != m_BodyBuffer.Size()) + { + ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + } + } + + return 0; +} + +void +HttpRequest::ResetState() +{ + m_HeaderCursor = m_HeaderBuffer; + m_CurrentHeaderName = nullptr; + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValue = nullptr; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = nullptr; + m_Url = nullptr; + m_UrlLength = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + + m_BodyBuffer = {}; + m_BodyPosition = 0; + + m_Headers.clear(); +} + +int +HttpRequest::OnMessageBegin() +{ + return 0; +} + +int +HttpRequest::OnMessageComplete() +{ + switch (m_Parser.method) + { + case HTTP_GET: + m_RequestVerb = HttpVerb::kGet; + break; + + case HTTP_POST: + m_RequestVerb = HttpVerb::kPost; + break; + + case HTTP_PUT: + m_RequestVerb = HttpVerb::kPut; + break; + + case HTTP_DELETE: + m_RequestVerb = HttpVerb::kDelete; + break; + + case HTTP_HEAD: + m_RequestVerb = HttpVerb::kHead; + break; + + case HTTP_COPY: + m_RequestVerb = HttpVerb::kCopy; + break; + + case HTTP_OPTIONS: + m_RequestVerb = HttpVerb::kOptions; + break; + } + + m_Connection.HandleRequest(); + + ResetState(); + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t Port) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService, asio::ip::tcp::endpoint(asio::ip::address_v4::any(), Port)) + { + } + + void Start() + { + m_Acceptor.listen(); + InitAccept(); + } + + void Stop() { m_IsStopped = true; } + + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + // TODO: Error condition - please handle and report properly + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop + + Socket->set_option(asio::ip::tcp::no_delay(true)); + + HttpServerConnection* Conn = new HttpServerConnection(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + + // note: the connection object is responsible for deleting itself + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + m_Acceptor.close(); + } + }); + } + +private: + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic<bool> m_IsStopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(PrefixLength); + m_Uri = Uri; + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + // + // TODO! + + m_AcceptType = Request.AcceptType(); +} + +HttpAsioServerRequest::~HttpAsioServerRequest() +{ +} + +Oid +HttpAsioServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpAsioServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpAsioServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((UINT16)ResponseCode, Empty); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(ContentType)); + m_Response->InitializeForPayload((UINT16)ResponseCode, Blobs); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerImpl::HttpAsioServerImpl() +{ +} + +HttpAsioServerImpl::~HttpAsioServerImpl() +{ +} + +void +HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) +{ + ZEN_ASSERT(ThreadCount > 0); + + m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); + m_Acceptor->Start(); + + for (int i = 0; i < ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, Index = i + 1] { + SetCurrentThreadName("asio worker {}"_format(Index)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); + } + }); + } +} + +void +HttpAsioServerImpl::Stop() +{ + m_Acceptor->Stop(); + m_IoService.stop(); + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +void +HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpAsioServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + if (Url.compare(0, SvcUrl.size(), SvcUrl) == 0) + { + return SvcEntry.Service; + } + } + + return nullptr; +} + +} // namespace zen::asio_http + +////////////////////////////////////////////////////////////////////////// + +namespace zen { +HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) +{ +} + +HttpAsioServer::~HttpAsioServer() +{ + m_Impl->Stop(); +} + +void +HttpAsioServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +void +HttpAsioServer::Initialize(int BasePort) +{ + m_BasePort = BasePort; + + m_Impl->Start(gsl::narrow<uint16_t>(m_BasePort), Max(std::thread::hardware_concurrency(), 8u)); +} + +void +HttpAsioServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); + } + + do + { + int WaitTimeout = -1; + + if (!TestMode) + { + WaitTimeout = 1000; + } + + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +} + +void +HttpAsioServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen |