aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-10-10 13:30:07 +0200
committerGitHub <[email protected]>2023-10-10 13:30:07 +0200
commit7905d0d21af95c6016768d7a8a81dd9204b34d24 (patch)
tree076329f5fad1c33503ee611f6399853da2490567 /src
parentcache reference tracking (#455) (diff)
downloadzen-7905d0d21af95c6016768d7a8a81dd9204b34d24.tar.xz
zen-7905d0d21af95c6016768d7a8a81dd9204b34d24.zip
experimental pluggable transport support (#436)
this change adds a `--http=plugin` mode where we support pluggable transports. Currently this defaults to a barebones blocking winsock implementation but there is also support for dynamic loading of transport plugins, which will be further developed in the near future.
Diffstat (limited to 'src')
-rw-r--r--src/plugins/winsock/winsock.cpp350
-rw-r--r--src/plugins/winsock/xmake.lua18
-rw-r--r--src/zen/xmake.lua2
-rw-r--r--src/zenhttp/dlltransport.cpp251
-rw-r--r--src/zenhttp/dlltransport.h37
-rw-r--r--src/zenhttp/httpasio.cpp7
-rw-r--r--src/zenhttp/httpasio.h2
-rw-r--r--src/zenhttp/httpparser.cpp7
-rw-r--r--src/zenhttp/httpparser.h63
-rw-r--r--src/zenhttp/httpplugin.cpp781
-rw-r--r--src/zenhttp/httpserver.cpp35
-rw-r--r--src/zenhttp/include/zenhttp/httpplugin.h49
-rw-r--r--src/zenhttp/include/zenhttp/transportplugin.h97
-rw-r--r--src/zenhttp/winsocktransport.cpp367
-rw-r--r--src/zenhttp/winsocktransport.h15
15 files changed, 2039 insertions, 42 deletions
diff --git a/src/plugins/winsock/winsock.cpp b/src/plugins/winsock/winsock.cpp
new file mode 100644
index 000000000..dca1fdbe7
--- /dev/null
+++ b/src/plugins/winsock/winsock.cpp
@@ -0,0 +1,350 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <inttypes.h>
+#include <atomic>
+#include <exception>
+#include <future>
+#include <memory>
+#include <thread>
+
+#include <zencore/refcount.h>
+#include <zencore/zencore.h>
+
+#ifndef _WIN32_WINNT
+# define _WIN32_WINNT 0x0A00
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <WS2tcpip.h>
+#include <WinSock2.h>
+#include <windows.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <transportplugin.h>
+
+//////////////////////////////////////////////////////////////////////////
+
+class SocketTransportPlugin : public TransportPluginInterface, zen::RefCounted
+{
+public:
+ SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount);
+ ~SocketTransportPlugin();
+
+ // TransportPluginInterface implementation
+
+ virtual uint32_t AddRef() const override;
+ virtual uint32_t Release() const override;
+ virtual void Initialize(TransportServerInterface* ServerInterface) override;
+ virtual void Shutdown() override;
+ virtual bool IsAvailable() override;
+
+private:
+ TransportServerInterface* m_ServerInterface = nullptr;
+ bool m_IsOk = true;
+ uint16_t m_BasePort = 0;
+ int m_ThreadCount = 0;
+
+ SOCKET m_ListenSocket{};
+ std::thread m_AcceptThread;
+ std::atomic_flag m_KeepRunning;
+ std::vector<std::future<void>> m_Connections;
+};
+
+struct SocketTransportConnection : public TransportConnectionInterface
+{
+public:
+ SocketTransportConnection();
+ ~SocketTransportConnection();
+
+ void Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket);
+ void HandleConnection();
+
+ // TransportConnectionInterface implementation
+
+ virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override;
+ virtual void Shutdown(bool Receive, bool Transmit) override;
+ virtual void CloseConnection() override;
+
+private:
+ zen::Ref<TransportServerConnectionHandler> m_ConnectionHandler;
+ SOCKET m_ClientSocket{};
+ bool m_IsTerminated = false;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+SocketTransportConnection::SocketTransportConnection()
+{
+}
+
+SocketTransportConnection::~SocketTransportConnection()
+{
+}
+
+void
+SocketTransportConnection::Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket)
+{
+ // ZEN_ASSERT(!m_ConnectionHandler);
+
+ m_ConnectionHandler = ServerConnection;
+ m_ClientSocket = ClientSocket;
+}
+
+void
+SocketTransportConnection::HandleConnection()
+{
+ // ZEN_ASSERT(m_ConnectionHandler);
+
+ const int InputBufferSize = 64 * 1024;
+ std::unique_ptr<uint8_t[]> InputBuffer{new uint8_t[64 * 1024]};
+
+ do
+ {
+ const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer.get(), InputBufferSize, /* flags */ 0);
+
+ if (RecvBytes == 0)
+ {
+ // Connection closed
+ return CloseConnection();
+ }
+ else if (RecvBytes < 0)
+ {
+ // Error
+ return CloseConnection();
+ }
+
+ m_ConnectionHandler->OnBytesRead(InputBuffer.get(), RecvBytes);
+ } while (m_ClientSocket);
+}
+
+void
+SocketTransportConnection::CloseConnection()
+{
+ if (m_IsTerminated)
+ {
+ return;
+ }
+
+ // ZEN_ASSERT(m_ClientSocket);
+ m_IsTerminated = true;
+
+ shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data
+
+ closesocket(m_ClientSocket);
+ m_ClientSocket = 0;
+}
+
+int64_t
+SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize)
+{
+ const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer);
+ int64_t TotalBytesSent = 0;
+
+ while (DataSize)
+ {
+ const int MaxBlockSize = 128 * 1024;
+ const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize;
+ const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0);
+
+ if (SentBytes < 0)
+ {
+ // Error
+ return SentBytes;
+ }
+
+ BufferCursor += SentBytes;
+ DataSize -= SentBytes;
+ TotalBytesSent += SentBytes;
+ }
+
+ return TotalBytesSent;
+}
+
+void
+SocketTransportConnection::Shutdown(bool Receive, bool Transmit)
+{
+ if (Receive)
+ {
+ if (Transmit)
+ {
+ shutdown(m_ClientSocket, SD_BOTH);
+ }
+ else
+ {
+ shutdown(m_ClientSocket, SD_RECEIVE);
+ }
+ }
+ else if (Transmit)
+ {
+ shutdown(m_ClientSocket, SD_SEND);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount)
+: m_BasePort(BasePort)
+, m_ThreadCount(ThreadCount != 0 ? ThreadCount : std::max(std::thread::hardware_concurrency(), 8u))
+{
+#if ZEN_PLATFORM_WINDOWS
+ WSADATA wsaData;
+ if (int Result = WSAStartup(0x202, &wsaData); Result != 0)
+ {
+ m_IsOk = false;
+ WSACleanup();
+ }
+#endif
+}
+
+SocketTransportPlugin::~SocketTransportPlugin()
+{
+ Shutdown();
+
+#if ZEN_PLATFORM_WINDOWS
+ if (m_IsOk)
+ {
+ WSACleanup();
+ }
+#endif
+}
+
+uint32_t
+SocketTransportPlugin::AddRef() const
+{
+ return RefCounted::AddRef();
+}
+
+uint32_t
+SocketTransportPlugin::Release() const
+{
+ return RefCounted::Release();
+}
+
+void
+SocketTransportPlugin::Initialize(TransportServerInterface* ServerInterface)
+{
+ uint16_t Port = m_BasePort;
+
+ m_ServerInterface = ServerInterface;
+ m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0);
+
+ if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET)
+ {
+ throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()),
+ "socket creation failed in HTTP plugin server init");
+ }
+
+ sockaddr_in6 Server{};
+ Server.sin6_family = AF_INET6;
+ Server.sin6_port = htons(Port);
+ Server.sin6_addr = in6addr_any;
+
+ if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR)
+ {
+ throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()), "bind call failed in HTTP plugin server init");
+ }
+
+ if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR)
+ {
+ throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()),
+ "listen call failed in HTTP plugin server init");
+ }
+
+ m_KeepRunning.test_and_set();
+
+ m_AcceptThread = std::thread([&] {
+ // SetCurrentThreadName("http_plugin_acceptor");
+
+ // ZEN_INFO("HTTP plugin server waiting for connections");
+
+ do
+ {
+ if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR)
+ {
+ int Flag = 1;
+ setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag));
+
+ // Handle new connection
+ SocketTransportConnection* Connection = new SocketTransportConnection();
+ TransportServerConnectionHandler* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)};
+ Connection->Initialize(ConnectionInterface, ClientSocket);
+
+ m_Connections.push_back(std::async(std::launch::async, [Connection] {
+ try
+ {
+ Connection->HandleConnection();
+ }
+ catch (std::exception&)
+ {
+ // ZEN_WARN("exception caught in connection loop: {}", Ex.what());
+ }
+
+ delete Connection;
+ }));
+ }
+ else
+ {
+ }
+ } while (m_KeepRunning.test());
+
+ // ZEN_INFO("HTTP plugin server accept thread exit");
+ });
+}
+
+void
+SocketTransportPlugin::Shutdown()
+{
+ // TODO: all pending/ongoing work should be drained here as well
+
+ m_KeepRunning.clear();
+
+ closesocket(m_ListenSocket);
+ m_ListenSocket = 0;
+
+ if (m_AcceptThread.joinable())
+ {
+ m_AcceptThread.join();
+ }
+}
+
+bool
+SocketTransportPlugin::IsAvailable()
+{
+ return true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+TransportPluginInterface*
+CreateTransportPluginInterface()
+{
+ return new SocketTransportPlugin(1337, 8);
+}
+
+BOOL WINAPI
+DllMain([[maybe_unused]] HINSTANCE hinstDLL, // handle to DLL module
+ DWORD fdwReason, // reason for calling function
+ LPVOID lpvReserved) // reserved
+{
+ // Perform actions based on the reason for calling.
+ switch (fdwReason)
+ {
+ case DLL_PROCESS_ATTACH:
+ break;
+
+ case DLL_THREAD_ATTACH:
+ break;
+
+ case DLL_THREAD_DETACH:
+ break;
+
+ case DLL_PROCESS_DETACH:
+ if (lpvReserved != nullptr)
+ {
+ break; // do not do cleanup if process termination scenario
+ }
+ break;
+ }
+
+ return TRUE;
+}
diff --git a/src/plugins/winsock/xmake.lua b/src/plugins/winsock/xmake.lua
new file mode 100644
index 000000000..a4ef02a98
--- /dev/null
+++ b/src/plugins/winsock/xmake.lua
@@ -0,0 +1,18 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("winsock")
+ set_kind("shared")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs(".", "../../zenhttp/include/zenhttp", "../../zencore/include")
+ set_symbols("debug")
+
+ add_cxxflags("/showIncludes")
+
+ if is_mode("release") then
+ set_optimize("fastest")
+ end
+
+ if is_plat("windows") then
+ add_links("Ws2_32")
+ end
diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua
index b83999efc..fef48e7bc 100644
--- a/src/zen/xmake.lua
+++ b/src/zen/xmake.lua
@@ -17,7 +17,7 @@ target("zen")
add_files("zen.rc")
add_ldflags("/subsystem:console,5.02")
add_ldflags("/LTCG")
- add_ldflags("crypt32.lib", "wldap32.lib", "Ws2_32.lib")
+ add_links("crypt32", "wldap32", "Ws2_32")
end
if is_plat("macosx") then
diff --git a/src/zenhttp/dlltransport.cpp b/src/zenhttp/dlltransport.cpp
new file mode 100644
index 000000000..0bd5e3720
--- /dev/null
+++ b/src/zenhttp/dlltransport.cpp
@@ -0,0 +1,251 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "dlltransport.h"
+
+#include <zencore/except.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+
+#include <exception>
+#include <thread>
+#include <vector>
+
+#if ZEN_WITH_PLUGINS
+
+namespace zen {
+
+struct DllTransportConnection : public TransportConnectionInterface
+{
+public:
+ DllTransportConnection();
+ ~DllTransportConnection();
+
+ void Initialize(TransportServerConnectionHandler& ServerConnection);
+ void HandleConnection();
+
+ // TransportConnectionInterface
+
+ virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override;
+ virtual void Shutdown(bool Receive, bool Transmit) override;
+ virtual void CloseConnection() override;
+
+private:
+ Ref<TransportServerConnectionHandler> m_ConnectionHandler;
+ bool m_IsTerminated = false;
+};
+
+DllTransportConnection::DllTransportConnection()
+{
+}
+
+DllTransportConnection::~DllTransportConnection()
+{
+}
+
+void
+DllTransportConnection::Initialize(TransportServerConnectionHandler& ServerConnection)
+{
+ m_ConnectionHandler = &ServerConnection; // TODO: this is awkward
+}
+
+void
+DllTransportConnection::HandleConnection()
+{
+}
+
+void
+DllTransportConnection::CloseConnection()
+{
+ if (m_IsTerminated)
+ {
+ return;
+ }
+
+ m_IsTerminated = true;
+}
+
+int64_t
+DllTransportConnection::WriteBytes(const void* Buffer, size_t DataSize)
+{
+ ZEN_UNUSED(Buffer, DataSize);
+ return DataSize;
+}
+
+void
+DllTransportConnection::Shutdown(bool Receive, bool Transmit)
+{
+ ZEN_UNUSED(Receive, Transmit);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct LoadedDll
+{
+ std::string Name;
+ std::filesystem::path LoadedFromPath;
+ Ref<TransportPluginInterface> Plugin;
+};
+
+class DllTransportPluginImpl
+{
+public:
+ DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount);
+ ~DllTransportPluginImpl();
+
+ uint16_t Start(TransportServerInterface* ServerInterface);
+ void Stop();
+ bool IsAvailable();
+ void LoadDll(std::string_view Name);
+
+private:
+ TransportServerInterface* m_ServerInterface = nullptr;
+ RwLock m_Lock;
+ std::vector<LoadedDll> m_Transports;
+ uint16_t m_BasePort = 0;
+ int m_ThreadCount = 0;
+};
+
+DllTransportPluginImpl::DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount)
+: m_BasePort(BasePort)
+, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u))
+{
+}
+
+DllTransportPluginImpl::~DllTransportPluginImpl()
+{
+}
+
+uint16_t
+DllTransportPluginImpl::Start(TransportServerInterface* ServerIface)
+{
+ m_ServerInterface = ServerIface;
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (LoadedDll& Transport : m_Transports)
+ {
+ try
+ {
+ Transport.Plugin->Initialize(ServerIface);
+ }
+ catch (const std::exception&)
+ {
+ // TODO: report
+ }
+ }
+
+ return m_BasePort;
+}
+
+void
+DllTransportPluginImpl::Stop()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (LoadedDll& Transport : m_Transports)
+ {
+ try
+ {
+ Transport.Plugin->Shutdown();
+ }
+ catch (const std::exception&)
+ {
+ // TODO: report
+ }
+ }
+}
+
+bool
+DllTransportPluginImpl::IsAvailable()
+{
+ return true;
+}
+
+void
+DllTransportPluginImpl::LoadDll(std::string_view Name)
+{
+ ExtendableStringBuilder<128> DllPath;
+ DllPath << Name << ".dll";
+ HMODULE DllHandle = LoadLibraryA(DllPath.c_str());
+
+ if (!DllHandle)
+ {
+ std::error_code Ec = MakeErrorCodeFromLastError();
+
+ throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath));
+ }
+
+ TransportPluginInterface* CreateTransportPluginInterface();
+
+ PfnCreateTransportPluginInterface CreatePlugin =
+ (PfnCreateTransportPluginInterface)GetProcAddress(DllHandle, "CreateTransportPluginInterface");
+
+ if (!CreatePlugin)
+ {
+ std::error_code Ec = MakeErrorCodeFromLastError();
+
+ FreeLibrary(DllHandle);
+
+ throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath));
+ }
+
+ LoadedDll NewDll;
+
+ NewDll.Name = Name;
+ NewDll.LoadedFromPath = DllPath.c_str();
+ NewDll.Plugin = CreatePlugin();
+
+ m_Transports.emplace_back(std::move(NewDll));
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+DllTransportPlugin::DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount)
+: m_Impl(std::make_unique<DllTransportPluginImpl>(BasePort, ThreadCount))
+{
+}
+
+DllTransportPlugin::~DllTransportPlugin()
+{
+ m_Impl->Stop();
+}
+
+uint32_t
+DllTransportPlugin::AddRef() const
+{
+ return RefCounted::AddRef();
+}
+
+uint32_t
+DllTransportPlugin::Release() const
+{
+ return RefCounted::Release();
+}
+
+void
+DllTransportPlugin::Initialize(TransportServerInterface* ServerInterface)
+{
+ m_Impl->Start(ServerInterface);
+}
+
+void
+DllTransportPlugin::Shutdown()
+{
+ m_Impl->Stop();
+}
+
+bool
+DllTransportPlugin::IsAvailable()
+{
+ return m_Impl->IsAvailable();
+}
+
+void
+DllTransportPlugin::LoadDll(std::string_view Name)
+{
+ return m_Impl->LoadDll(Name);
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/dlltransport.h b/src/zenhttp/dlltransport.h
new file mode 100644
index 000000000..b13bab804
--- /dev/null
+++ b/src/zenhttp/dlltransport.h
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpplugin.h>
+
+#if ZEN_WITH_PLUGINS
+
+namespace zen {
+
+class DllTransportPluginImpl;
+
+/** Transport plugin which supports dynamic loading of external transport
+ * provider modules
+ */
+class DllTransportPlugin : public TransportPluginInterface, RefCounted
+{
+public:
+ DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount);
+ ~DllTransportPlugin();
+
+ virtual uint32_t AddRef() const override;
+ virtual uint32_t Release() const override;
+
+ virtual void Initialize(TransportServerInterface* ServerInterface) override;
+ virtual void Shutdown() override;
+ virtual bool IsAvailable() override;
+
+ void LoadDll(std::string_view Name);
+
+private:
+ std::unique_ptr<DllTransportPluginImpl> m_Impl;
+};
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp
index 702ca11fd..562f75e3d 100644
--- a/src/zenhttp/httpasio.cpp
+++ b/src/zenhttp/httpasio.cpp
@@ -214,7 +214,7 @@ private:
//////////////////////////////////////////////////////////////////////////
-struct HttpServerConnection : public HttpConnectionBase, std::enable_shared_from_this<HttpServerConnection>
+struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection>
{
HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
~HttpServerConnection();
@@ -223,10 +223,11 @@ struct HttpServerConnection : public HttpConnectionBase, std::enable_shared_from
// HttpConnectionBase implementation
- virtual void HandleNewRequest() override;
virtual void TerminateConnection() override;
virtual void HandleRequest() override;
+ void HandleNewRequest();
+
private:
enum class RequestState
{
@@ -276,8 +277,6 @@ HttpServerConnection::~HttpServerConnection()
void
HttpServerConnection::HandleNewRequest()
{
- m_RequestData.Initialize();
-
EnqueueRead();
}
diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h
index 81aadfc23..e8d13a57f 100644
--- a/src/zenhttp/httpasio.h
+++ b/src/zenhttp/httpasio.h
@@ -10,8 +10,6 @@
namespace zen {
namespace asio_http {
- struct HttpServerConnection;
- struct HttpAcceptor;
struct HttpAsioServerImpl;
} // namespace asio_http
diff --git a/src/zenhttp/httpparser.cpp b/src/zenhttp/httpparser.cpp
index ebfe36227..6b987151a 100644
--- a/src/zenhttp/httpparser.cpp
+++ b/src/zenhttp/httpparser.cpp
@@ -38,8 +38,7 @@ http_parser_settings HttpRequestParser::s_ParserSettings{
.on_chunk_header{},
.on_chunk_complete{}};
-void
-HttpRequestParser::Initialize()
+HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection)
{
http_parser_init(&m_Parser, HTTP_REQUEST);
m_Parser.data = this;
@@ -47,6 +46,10 @@ HttpRequestParser::Initialize()
ResetState();
}
+HttpRequestParser::~HttpRequestParser()
+{
+}
+
size_t
HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize)
{
diff --git a/src/zenhttp/httpparser.h b/src/zenhttp/httpparser.h
index cce51fcca..219ac351d 100644
--- a/src/zenhttp/httpparser.h
+++ b/src/zenhttp/httpparser.h
@@ -11,20 +11,19 @@ ZEN_THIRD_PARTY_INCLUDES_END
namespace zen {
-class HttpConnectionBase
+class HttpRequestParserCallbacks
{
public:
- virtual ~HttpConnectionBase() = default;
- virtual void HandleNewRequest() = 0;
- virtual void TerminateConnection() = 0;
- virtual void HandleRequest() = 0;
+ virtual ~HttpRequestParserCallbacks() = default;
+ virtual void HandleRequest() = 0;
+ virtual void TerminateConnection() = 0;
};
struct HttpRequestParser
{
- explicit HttpRequestParser(HttpConnectionBase& Connection) : m_Connection(Connection) {}
+ explicit HttpRequestParser(HttpRequestParserCallbacks& Connection);
+ ~HttpRequestParser();
- void Initialize();
size_t ConsumeData(const char* InputData, size_t DataSize);
void ResetState();
@@ -70,31 +69,31 @@ private:
std::string_view Value;
};
- HttpConnectionBase& m_Connection;
- char* m_HeaderCursor = m_HeaderBuffer;
- char* m_Url = nullptr;
- size_t m_UrlLength = 0;
- char* m_QueryString = nullptr;
- size_t m_QueryLength = 0;
- char* m_CurrentHeaderName = nullptr; // Used while parsing headers
- size_t m_CurrentHeaderNameLength = 0;
- char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
- size_t m_CurrentHeaderValueLength = 0;
- std::vector<HeaderEntry> m_Headers;
- int8_t m_ContentLengthHeaderIndex;
- int8_t m_AcceptHeaderIndex;
- int8_t m_ContentTypeHeaderIndex;
- int8_t m_RangeHeaderIndex;
- HttpVerb m_RequestVerb;
- bool m_KeepAlive = false;
- bool m_Expect100Continue = false;
- int m_RequestId = -1;
- Oid m_SessionId{};
- IoBuffer m_BodyBuffer;
- uint64_t m_BodyPosition = 0;
- http_parser m_Parser;
- char m_HeaderBuffer[1024];
- std::string m_NormalizedUrl;
+ HttpRequestParserCallbacks& m_Connection;
+ char* m_HeaderCursor = m_HeaderBuffer;
+ char* m_Url = nullptr;
+ size_t m_UrlLength = 0;
+ char* m_QueryString = nullptr;
+ size_t m_QueryLength = 0;
+ char* m_CurrentHeaderName = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderNameLength = 0;
+ char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderValueLength = 0;
+ std::vector<HeaderEntry> m_Headers;
+ int8_t m_ContentLengthHeaderIndex;
+ int8_t m_AcceptHeaderIndex;
+ int8_t m_ContentTypeHeaderIndex;
+ int8_t m_RangeHeaderIndex;
+ HttpVerb m_RequestVerb;
+ bool m_KeepAlive = false;
+ bool m_Expect100Continue = false;
+ int m_RequestId = -1;
+ Oid m_SessionId{};
+ IoBuffer m_BodyBuffer;
+ uint64_t m_BodyPosition = 0;
+ http_parser m_Parser;
+ char m_HeaderBuffer[1024];
+ std::string m_NormalizedUrl;
void AppendCurrentHeader();
diff --git a/src/zenhttp/httpplugin.cpp b/src/zenhttp/httpplugin.cpp
new file mode 100644
index 000000000..45a6c23bd
--- /dev/null
+++ b/src/zenhttp/httpplugin.cpp
@@ -0,0 +1,781 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpplugin.h>
+
+#if ZEN_WITH_PLUGINS
+
+# include "httpparser.h"
+
+# include <zencore/except.h>
+# include <zencore/logging.h>
+# include <zencore/trace.h>
+# include <zencore/workthreadpool.h>
+# include <zenhttp/httpserver.h>
+
+# include <memory>
+# include <string_view>
+
+# if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+# endif
+
+# define PLUGIN_VERBOSE_TRACE 1
+
+# if PLUGIN_VERBOSE_TRACE
+# define ZEN_TRACE_VERBOSE ZEN_TRACE
+# else
+# define ZEN_TRACE_VERBOSE(fmtstr, ...)
+# endif
+
+namespace zen {
+
+struct HttpPluginServerImpl;
+struct HttpPluginResponse;
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpPluginConnectionHandler : public TransportServerConnectionHandler, public HttpRequestParserCallbacks, RefCounted
+{
+ virtual uint32_t AddRef() const override;
+ virtual uint32_t Release() const override;
+
+ virtual void OnBytesRead(const void* Buffer, size_t DataSize) override;
+
+ // HttpRequestParserCallbacks
+
+ virtual void HandleRequest() override;
+ virtual void TerminateConnection() override;
+
+ void Initialize(TransportConnectionInterface* Transport, HttpPluginServerImpl& Server);
+
+private:
+ enum class RequestState
+ {
+ kInitialState,
+ kInitialRead,
+ kReadingMore,
+ kWriting, // Currently writing response, connection will be re-used
+ kWritingFinal, // Writing response, connection will be closed
+ kDone,
+ kTerminated
+ };
+
+ RequestState m_RequestState = RequestState::kInitialState;
+ HttpRequestParser m_RequestParser{*this};
+
+ uint32_t m_ConnectionId = 0;
+ Ref<IHttpPackageHandler> m_PackageHandler;
+
+ TransportConnectionInterface* m_TransportConnection = nullptr;
+ HttpPluginServerImpl* m_Server = nullptr;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpPluginServerImpl : public TransportServerInterface
+{
+ HttpPluginServerImpl();
+ ~HttpPluginServerImpl();
+
+ void AddPlugin(Ref<TransportPluginInterface> Plugin);
+ void RemovePlugin(Ref<TransportPluginInterface> Plugin);
+
+ void Start();
+ void Stop();
+
+ void RegisterService(const char* InUrlPath, HttpService& Service);
+ HttpService* RouteRequest(std::string_view Url);
+
+ struct ServiceEntry
+ {
+ std::string ServiceUrlPath;
+ HttpService* Service;
+ };
+
+ RwLock m_Lock;
+ std::vector<ServiceEntry> m_UriHandlers;
+ std::vector<Ref<TransportPluginInterface>> m_Plugins;
+
+ // TransportServerInterface
+
+ virtual TransportServerConnectionHandler* CreateConnectionHandler(TransportConnectionInterface* Connection) override;
+};
+
+/** This is the class which request handlers interface with when
+ generating responses
+ */
+
+class HttpPluginServerRequest : public HttpServerRequest
+{
+public:
+ HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer);
+ ~HttpPluginServerRequest();
+
+ HttpPluginServerRequest(const HttpPluginServerRequest&) = delete;
+ HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
+
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
+
+ virtual IoBuffer ReadPayload() override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
+ virtual bool TryGetRanges(HttpRanges& Ranges) override;
+
+ using HttpServerRequest::WriteResponse;
+
+ HttpRequestParser& m_Request;
+ IoBuffer m_PayloadBuffer;
+ std::unique_ptr<HttpPluginResponse> m_Response;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpPluginResponse
+{
+public:
+ HttpPluginResponse() = default;
+ explicit HttpPluginResponse(HttpContentType ContentType) : m_ContentType(ContentType) {}
+
+ void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList);
+
+ inline uint16_t ResponseCode() const { return m_ResponseCode; }
+ inline uint64_t ContentLength() const { return m_ContentLength; }
+
+ const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; }
+ void SuppressPayload() { m_ResponseBuffers.resize(1); }
+
+private:
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
+ std::vector<IoBuffer> m_ResponseBuffers;
+ ExtendableStringBuilder<160> m_Headers;
+
+ std::string_view GetHeaders();
+};
+
+void
+HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
+{
+ ZEN_TRACE_CPU("http_plugin::InitializeForPayload");
+
+ m_ResponseCode = ResponseCode;
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
+
+ m_ResponseBuffers.reserve(ChunkCount + 1);
+ m_ResponseBuffers.push_back({}); // Placeholder for header
+
+ uint64_t TotalDataSize = 0;
+
+ for (IoBuffer& Buffer : BlobList)
+ {
+ uint64_t BufferDataSize = Buffer.Size();
+
+ ZEN_ASSERT(BufferDataSize);
+
+ TotalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
+ {
+ // TODO: Use direct file transfer, via TransmitFile/sendfile
+
+ m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+ }
+ else
+ {
+ // Send from memory
+
+ m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+ }
+ }
+ m_ContentLength = TotalDataSize;
+
+ auto Headers = GetHeaders();
+ m_ResponseBuffers[0] = IoBufferBuilder::MakeCloneFromMemory(Headers.data(), Headers.size());
+}
+
+std::string_view
+HttpPluginResponse::GetHeaders()
+{
+ m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
+ << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Length: " << ContentLength() << "\r\n"sv;
+
+ if (!m_IsKeepAlive)
+ {
+ m_Headers << "Connection: close\r\n"sv;
+ }
+
+ m_Headers << "\r\n"sv;
+
+ return m_Headers;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+HttpPluginConnectionHandler::Initialize(TransportConnectionInterface* Transport, HttpPluginServerImpl& Server)
+{
+ m_TransportConnection = Transport;
+ m_Server = &Server;
+}
+
+uint32_t
+HttpPluginConnectionHandler::AddRef() const
+{
+ return RefCounted::AddRef();
+}
+
+uint32_t
+HttpPluginConnectionHandler::Release() const
+{
+ return RefCounted::Release();
+}
+
+void
+HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes)
+{
+ while (AvailableBytes)
+ {
+ const size_t ConsumedBytes = m_RequestParser.ConsumeData((const char*)Buffer, AvailableBytes);
+
+ if (ConsumedBytes == ~0ull)
+ {
+ // terminate connection
+
+ return TerminateConnection();
+ }
+
+ Buffer = reinterpret_cast<const uint8_t*>(Buffer) + ConsumedBytes;
+ AvailableBytes -= ConsumedBytes;
+ }
+}
+
+// HttpRequestParserCallbacks
+
+void
+HttpPluginConnectionHandler::HandleRequest()
+{
+ if (!m_RequestParser.IsKeepAlive())
+ {
+ // Once response has been written, connection is done
+ m_RequestState = RequestState::kWritingFinal;
+
+ // We're not going to read any more data from this socket
+
+ const bool Receive = true;
+ const bool Transmit = false;
+ m_TransportConnection->Shutdown(Receive, Transmit);
+ }
+ else
+ {
+ m_RequestState = RequestState::kWriting;
+ }
+
+ auto SendBuffer = [&](const IoBuffer& InBuffer) -> int64_t {
+ const char* Buffer = reinterpret_cast<const char*>(InBuffer.GetData());
+ size_t Bytes = InBuffer.GetSize();
+
+ return m_TransportConnection->WriteBytes(Buffer, Bytes);
+ };
+
+ // Generate response
+
+ if (HttpService* Service = m_Server->RouteRequest(m_RequestParser.Url()))
+ {
+ ZEN_TRACE_CPU("http_plugin::HandleRequest");
+
+ HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body());
+
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ {
+ try
+ {
+ Service->HandleRequest(Request);
+ }
+ catch (std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
+ }
+
+ if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response))
+ {
+ // Transmit the response
+
+ if (m_RequestParser.RequestVerb() == HttpVerb::kHead)
+ {
+ Response->SuppressPayload();
+ }
+
+ const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers();
+
+ //// TODO: should cork/uncork for Linux?
+
+ for (const IoBuffer& Buffer : ResponseBuffers)
+ {
+ int64_t SentBytes = SendBuffer(Buffer);
+
+ if (SentBytes < 0)
+ {
+ TerminateConnection();
+
+ return;
+ }
+ }
+
+ return;
+ }
+ }
+
+ // No route found for request
+
+ std::string_view Response;
+
+ if (m_RequestParser.RequestVerb() == HttpVerb::kHead)
+ {
+ if (m_RequestParser.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "\r\n"sv;
+ }
+ else
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Connection: close\r\n"
+ "\r\n"sv;
+ }
+ }
+ else
+ {
+ if (m_RequestParser.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+ }
+ else
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "Connection: close\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+ }
+ }
+
+ const int64_t SentBytes = SendBuffer(IoBufferBuilder::MakeFromMemory(MakeMemoryView(Response)));
+
+ if (SentBytes < 0)
+ {
+ TerminateConnection();
+
+ return;
+ }
+}
+
+void
+HttpPluginConnectionHandler::TerminateConnection()
+{
+ ZEN_ASSERT(m_TransportConnection);
+ m_TransportConnection->CloseConnection();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer)
+: m_Request(Request)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const int PrefixLength = Service.UriPrefixLength();
+
+ std::string_view Uri = Request.Url();
+ Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size())));
+ m_Uri = Uri;
+ m_UriWithExtension = Uri;
+ m_QueryString = Request.QueryString();
+
+ m_Verb = Request.RequestVerb();
+ m_ContentLength = Request.Body().Size();
+ m_ContentType = Request.ContentType();
+
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
+ // Parse any extension, to allow requesting a particular response encoding via the URL
+
+ {
+ std::string_view UriSuffix8{m_Uri};
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1));
+ }
+ }
+ }
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = Request.AcceptType();
+ }
+}
+
+HttpPluginServerRequest::~HttpPluginServerRequest()
+{
+}
+
+Oid
+HttpPluginServerRequest::ParseSessionId() const
+{
+ return m_Request.SessionId();
+}
+
+uint32_t
+HttpPluginServerRequest::ParseRequestId() const
+{
+ return m_Request.RequestId();
+}
+
+IoBuffer
+HttpPluginServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary));
+ std::array<IoBuffer, 0> Empty;
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
+}
+
+void
+HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpPluginResponse(ContentType));
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
+}
+
+void
+HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(!m_Response);
+ m_Response.reset(new HttpPluginResponse(ContentType));
+
+ IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList);
+}
+
+void
+HttpPluginServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ ZEN_ASSERT(!m_Response);
+
+ // Not one bit async, innit
+ ContinuationHandler(*this);
+}
+
+bool
+HttpPluginServerRequest::TryGetRanges(HttpRanges& Ranges)
+{
+ return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpPluginServerImpl::HttpPluginServerImpl()
+{
+}
+
+HttpPluginServerImpl::~HttpPluginServerImpl()
+{
+}
+
+TransportServerConnectionHandler*
+HttpPluginServerImpl::CreateConnectionHandler(TransportConnectionInterface* Connection)
+{
+ HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()};
+ Handler->Initialize(Connection, *this);
+ return Handler;
+}
+
+void
+HttpPluginServerImpl::Start()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (auto& Plugin : m_Plugins)
+ {
+ try
+ {
+ Plugin->Initialize(this);
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception caught during plugin initialization: {}", Ex.what());
+ }
+ }
+}
+
+void
+HttpPluginServerImpl::Stop()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (auto& Plugin : m_Plugins)
+ {
+ try
+ {
+ Plugin->Shutdown();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what());
+ }
+
+ Plugin = nullptr;
+ }
+
+ m_Plugins.clear();
+}
+
+void
+HttpPluginServerImpl::AddPlugin(Ref<TransportPluginInterface> Plugin)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Plugins.emplace_back(std::move(Plugin));
+}
+
+void
+HttpPluginServerImpl::RemovePlugin(Ref<TransportPluginInterface> Plugin)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin);
+ if (It != m_Plugins.end())
+ {
+ m_Plugins.erase(It);
+ }
+}
+
+void
+HttpPluginServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
+{
+ std::string_view UrlPath(InUrlPath);
+ Service.SetUriPrefixLength(UrlPath.size());
+ if (!UrlPath.empty() && UrlPath.back() == '/')
+ {
+ UrlPath.remove_suffix(1);
+ }
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_UriHandlers.push_back({std::string(UrlPath), &Service});
+}
+
+HttpService*
+HttpPluginServerImpl::RouteRequest(std::string_view Url)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ HttpService* CandidateService = nullptr;
+ std::string::size_type CandidateMatchSize = 0;
+ for (const ServiceEntry& SvcEntry : m_UriHandlers)
+ {
+ const std::string& SvcUrl = SvcEntry.ServiceUrlPath;
+ const std::string::size_type SvcUrlSize = SvcUrl.size();
+ if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 &&
+ ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/')))
+ {
+ CandidateMatchSize = SvcUrl.size();
+ CandidateService = SvcEntry.Service;
+ }
+ }
+
+ return CandidateService;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpPluginServer::HttpPluginServer(unsigned int ThreadCount)
+: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u))
+, m_Impl(new HttpPluginServerImpl)
+{
+}
+
+HttpPluginServer::~HttpPluginServer()
+{
+ if (m_Impl)
+ {
+ ZEN_ERROR("~HttpPluginServer() called without calling Close() first");
+ }
+}
+
+int
+HttpPluginServer::Initialize(int BasePort)
+{
+ try
+ {
+ m_Impl->Start();
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("Caught exception starting http plugin server: {}", ex.what());
+ }
+
+ return BasePort;
+}
+
+void
+HttpPluginServer::Close()
+{
+ try
+ {
+ m_Impl->Stop();
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what());
+ }
+
+ delete m_Impl;
+ m_Impl = nullptr;
+}
+
+void
+HttpPluginServer::Run(bool IsInteractive)
+{
+ const bool TestMode = !IsInteractive;
+
+ int WaitTimeout = -1;
+ if (!TestMode)
+ {
+ WaitTimeout = 1000;
+ }
+
+# if ZEN_PLATFORM_WINDOWS
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Press ESC or Q to quit");
+ }
+
+ do
+ {
+ if (!TestMode && _kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+# else
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Ctrl-C to quit");
+ }
+
+ do
+ {
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+# endif
+}
+
+void
+HttpPluginServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+
+void
+HttpPluginServer::RegisterService(HttpService& Service)
+{
+ m_Impl->RegisterService(Service.BaseUri(), Service);
+}
+
+void
+HttpPluginServer::AddPlugin(Ref<TransportPluginInterface> Plugin)
+{
+ m_Impl->AddPlugin(Plugin);
+}
+
+void
+HttpPluginServer::RemovePlugin(Ref<TransportPluginInterface> Plugin)
+{
+ m_Impl->RemovePlugin(Plugin);
+}
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index a98a3c9bb..523befa72 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -6,6 +6,13 @@
#include "httpnull.h"
#include "httpsys.h"
+#include "zenhttp/httpplugin.h"
+
+#if ZEN_WITH_PLUGINS
+# include "dlltransport.h"
+# include "winsocktransport.h"
+#endif
+
#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
#include <zencore/compactbinarypackage.h>
@@ -711,6 +718,7 @@ enum class HttpServerClass
{
kHttpAsio,
kHttpSys,
+ kHttpPlugin,
kHttpNull
};
@@ -723,7 +731,7 @@ CreateHttpServer(const HttpServerConfig& Config)
#if ZEN_WITH_HTTPSYS
Class = HttpServerClass::kHttpSys;
-#elif 1
+#else
Class = HttpServerClass::kHttpAsio;
#endif
@@ -735,6 +743,10 @@ CreateHttpServer(const HttpServerConfig& Config)
{
Class = HttpServerClass::kHttpSys;
}
+ else if (Config.ServerClass == "plugin"sv)
+ {
+ Class = HttpServerClass::kHttpPlugin;
+ }
else if (Config.ServerClass == "null"sv)
{
Class = HttpServerClass::kHttpNull;
@@ -747,6 +759,27 @@ CreateHttpServer(const HttpServerConfig& Config)
ZEN_INFO("using asio HTTP server implementation");
return Ref<HttpServer>(new HttpAsioServer(Config.ThreadCount));
+#if ZEN_WITH_PLUGINS
+ case HttpServerClass::kHttpPlugin:
+ {
+ ZEN_INFO("using plugin HTTP server implementation");
+ Ref<HttpPluginServer> Server{new HttpPluginServer(Config.ThreadCount)};
+
+# if 1
+ Ref<TransportPluginInterface> WinsockPlugin{CreateSocketTransportPlugin(1337, Config.ThreadCount)};
+ Server->AddPlugin(WinsockPlugin);
+# endif
+
+# if 0
+ Ref<DllTransportPlugin> DllPlugin{new DllTransportPlugin(1337, Config.ThreadCount)};
+ DllPlugin->LoadDll("winsock");
+ Server->AddPlugin(DllPlugin);
+# endif
+
+ return Server;
+ }
+#endif
+
#if ZEN_WITH_HTTPSYS
case HttpServerClass::kHttpSys:
ZEN_INFO("using http.sys server implementation");
diff --git a/src/zenhttp/include/zenhttp/httpplugin.h b/src/zenhttp/include/zenhttp/httpplugin.h
new file mode 100644
index 000000000..409eb1b61
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpplugin.h
@@ -0,0 +1,49 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/refcount.h>
+#include <zencore/thread.h>
+
+#if !defined(ZEN_WITH_PLUGINS)
+# if ZEN_PLATFORM_WINDOWS
+# define ZEN_WITH_PLUGINS 1
+# else
+# define ZEN_WITH_PLUGINS 0
+# endif
+#endif
+
+#if ZEN_WITH_PLUGINS
+# include "transportplugin.h"
+# include <zenhttp/httpserver.h>
+
+namespace zen {
+
+struct HttpPluginServerImpl;
+
+class HttpPluginServer : public HttpServer
+{
+public:
+ HttpPluginServer(unsigned int ThreadCount);
+ ~HttpPluginServer();
+
+ virtual void RegisterService(HttpService& Service) override;
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool IsInteractiveSession) override;
+ virtual void RequestExit() override;
+ virtual void Close() override;
+
+ void AddPlugin(Ref<TransportPluginInterface> Plugin);
+ void RemovePlugin(Ref<TransportPluginInterface> Plugin);
+
+private:
+ Event m_ShutdownEvent;
+ int m_BasePort = 0;
+ unsigned int m_ThreadCount = 0;
+
+ HttpPluginServerImpl* m_Impl = nullptr;
+};
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/include/zenhttp/transportplugin.h b/src/zenhttp/include/zenhttp/transportplugin.h
new file mode 100644
index 000000000..38b07d471
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/transportplugin.h
@@ -0,0 +1,97 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <stdint.h>
+
+// Important note: this header is meant to compile standalone
+// and should therefore not depend on anything from the Zen codebase
+
+class TransportConnectionInterface;
+class TransportPluginInterface;
+class TransportServerConnectionHandler;
+class TransportServerInterface;
+
+/*************************************************************************
+
+ The following interfaces are implemented on the server side, and instances
+ are provided to the plugins.
+
+*************************************************************************/
+
+/** Plugin-server interface for connection
+ *
+ * This is how the transport feeds data to the connection handler
+ * which will parse the incoming messages and dispatch to
+ * appropriate request handlers and ultimately call into functions
+ * which write data back to the client.
+ */
+class TransportServerConnectionHandler
+{
+public:
+ virtual uint32_t AddRef() const = 0;
+ virtual uint32_t Release() const = 0;
+ virtual void OnBytesRead(const void* Buffer, size_t DataSize) = 0;
+};
+
+/** Plugin-server interface
+ *
+ * There will be one instance of this per plugin, and the plugin
+ * should use this to manage lifetimes of connections and any
+ * other resources.
+ */
+class TransportServerInterface
+{
+public:
+ virtual TransportServerConnectionHandler* CreateConnectionHandler(TransportConnectionInterface* Connection) = 0;
+};
+
+/*************************************************************************
+
+ The following interfaces are to be implemented by transport plugins.
+
+*************************************************************************/
+
+/** Interface which needs to be implemented by a transport plugin
+ *
+ * This is responsible for setting up and running the communication
+ * for a given transport.
+ */
+class TransportPluginInterface
+{
+public:
+ virtual uint32_t AddRef() const = 0;
+ virtual uint32_t Release() const = 0;
+ virtual void Initialize(TransportServerInterface* ServerInterface) = 0;
+ virtual void Shutdown() = 0;
+
+ /** Check whether this transport is usable.
+ */
+ virtual bool IsAvailable() = 0;
+};
+
+/** A transport plugin provider needs to implement this interface
+ *
+ * There will be one instance of this per established connection and
+ * this interface is used to write response data back to the client.
+ */
+class TransportConnectionInterface
+{
+public:
+ virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) = 0;
+ virtual void Shutdown(bool Receive, bool Transmit) = 0;
+ virtual void CloseConnection() = 0;
+};
+
+#if defined(_MSC_VER)
+# define DLL_TRANSPORT_API __declspec(dllexport)
+#else
+# define DLL_TRANSPORT_API
+#endif
+
+extern "C"
+{
+ DLL_TRANSPORT_API TransportPluginInterface* CreateTransportPluginInterface();
+}
+
+typedef TransportPluginInterface* (*PfnCreateTransportPluginInterface)();
diff --git a/src/zenhttp/winsocktransport.cpp b/src/zenhttp/winsocktransport.cpp
new file mode 100644
index 000000000..e86e4822e
--- /dev/null
+++ b/src/zenhttp/winsocktransport.cpp
@@ -0,0 +1,367 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "winsocktransport.h"
+
+#if ZEN_WITH_PLUGINS
+
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/workthreadpool.h>
+
+# if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <WinSock2.h>
+# include <WS2tcpip.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+# endif
+
+# include <thread>
+
+namespace zen {
+
+class SocketTransportPluginImpl;
+
+class SocketTransportPlugin : public TransportPluginInterface, RefCounted
+{
+public:
+ SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount);
+ ~SocketTransportPlugin();
+
+ virtual uint32_t AddRef() const override;
+ virtual uint32_t Release() const override;
+ virtual void Initialize(TransportServerInterface* ServerInterface) override;
+ virtual void Shutdown() override;
+ virtual bool IsAvailable() override;
+
+private:
+ bool m_IsOk = true;
+ uint16_t m_BasePort = 0;
+ int m_ThreadCount = 0;
+ SocketTransportPluginImpl* m_Impl;
+};
+
+struct SocketTransportConnection : public TransportConnectionInterface
+{
+public:
+ SocketTransportConnection();
+ ~SocketTransportConnection();
+
+ void Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket);
+ void HandleConnection();
+
+ // TransportConnectionInterface
+
+ virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override;
+ virtual void Shutdown(bool Receive, bool Transmit) override;
+ virtual void CloseConnection() override;
+
+private:
+ Ref<TransportServerConnectionHandler> m_ConnectionHandler;
+ SOCKET m_ClientSocket{};
+ bool m_IsTerminated = false;
+};
+
+SocketTransportConnection::SocketTransportConnection()
+{
+}
+
+SocketTransportConnection::~SocketTransportConnection()
+{
+}
+
+void
+SocketTransportConnection::Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket)
+{
+ ZEN_ASSERT(!m_ConnectionHandler);
+
+ m_ConnectionHandler = ServerConnection;
+ m_ClientSocket = ClientSocket;
+}
+
+void
+SocketTransportConnection::HandleConnection()
+{
+ ZEN_ASSERT(m_ConnectionHandler);
+
+ const int InputBufferSize = 64 * 1024;
+ uint8_t* InputBuffer = new uint8_t[64 * 1024];
+ auto _ = MakeGuard([&] { delete[] InputBuffer; });
+
+ do
+ {
+ const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0);
+
+ if (RecvBytes == 0)
+ {
+ // Connection closed
+ return CloseConnection();
+ }
+ else if (RecvBytes < 0)
+ {
+ // Error
+ return CloseConnection();
+ }
+
+ m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes);
+ } while (m_ClientSocket);
+}
+
+void
+SocketTransportConnection::CloseConnection()
+{
+ if (m_IsTerminated)
+ {
+ return;
+ }
+
+ ZEN_ASSERT(m_ClientSocket);
+ m_IsTerminated = true;
+
+ shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data
+
+ closesocket(m_ClientSocket);
+ m_ClientSocket = 0;
+}
+
+int64_t
+SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize)
+{
+ const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer);
+ int64_t TotalBytesSent = 0;
+
+ while (DataSize)
+ {
+ const int MaxBlockSize = 128 * 1024;
+ const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize;
+ const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0);
+
+ if (SentBytes < 0)
+ {
+ // Error
+ return SentBytes;
+ }
+
+ BufferCursor += SentBytes;
+ DataSize -= SentBytes;
+ TotalBytesSent += SentBytes;
+ }
+
+ return TotalBytesSent;
+}
+
+void
+SocketTransportConnection::Shutdown(bool Receive, bool Transmit)
+{
+ if (Receive)
+ {
+ if (Transmit)
+ {
+ shutdown(m_ClientSocket, SD_BOTH);
+ }
+ else
+ {
+ shutdown(m_ClientSocket, SD_RECEIVE);
+ }
+ }
+ else if (Transmit)
+ {
+ shutdown(m_ClientSocket, SD_SEND);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+class SocketTransportPluginImpl
+{
+public:
+ SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount);
+ ~SocketTransportPluginImpl();
+
+ uint16_t Start(uint16_t Port, TransportServerInterface* ServerInterface);
+ void Stop();
+
+private:
+ TransportServerInterface* m_ServerInterface = nullptr;
+ uint16_t m_BasePort = 0;
+ int m_ThreadCount = 0;
+ bool m_IsOk = true;
+
+ SOCKET m_ListenSocket{};
+ std::thread m_AcceptThread;
+ std::atomic_flag m_KeepRunning;
+ std::unique_ptr<WorkerThreadPool> m_WorkerThreadpool;
+};
+
+SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount)
+: m_BasePort(BasePort)
+, m_ThreadCount(ThreadCount)
+{
+# if ZEN_PLATFORM_WINDOWS
+ WSADATA wsaData;
+ if (int Result = WSAStartup(0x202, &wsaData); Result != 0)
+ {
+ m_IsOk = false;
+ WSACleanup();
+ }
+# endif
+
+ m_WorkerThreadpool = std::make_unique<WorkerThreadPool>(m_ThreadCount, "http_conn");
+}
+
+SocketTransportPluginImpl::~SocketTransportPluginImpl()
+{
+ Stop();
+
+# if ZEN_PLATFORM_WINDOWS
+ if (m_IsOk)
+ {
+ WSACleanup();
+ }
+# endif
+}
+
+uint16_t
+SocketTransportPluginImpl::Start(uint16_t Port, TransportServerInterface* ServerInterface)
+{
+ m_ServerInterface = ServerInterface;
+ m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0);
+
+ if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET)
+ {
+ ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError());
+
+ return 0;
+ }
+
+ sockaddr_in6 Server{};
+ Server.sin6_family = AF_INET6;
+ Server.sin6_port = htons(Port);
+ Server.sin6_addr = in6addr_any;
+
+ if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR)
+ {
+ ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError());
+
+ return 0;
+ }
+
+ if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR)
+ {
+ ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError());
+
+ return 0;
+ }
+
+ m_KeepRunning.test_and_set();
+
+ m_AcceptThread = std::thread([&] {
+ SetCurrentThreadName("http_plugin_acceptor");
+
+ ZEN_INFO("HTTP plugin server waiting for connections");
+
+ do
+ {
+ if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR)
+ {
+ int Flag = 1;
+ setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag));
+
+ // Handle new connection
+ SocketTransportConnection* Connection = new SocketTransportConnection();
+ TransportServerConnectionHandler* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)};
+ Connection->Initialize(ConnectionInterface, ClientSocket);
+
+ m_WorkerThreadpool->ScheduleWork([Connection] {
+ try
+ {
+ Connection->HandleConnection();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception caught in connection loop: {}", Ex.what());
+ }
+
+ delete Connection;
+ });
+ }
+ else
+ {
+ }
+ } while (!IsApplicationExitRequested() && m_KeepRunning.test());
+
+ ZEN_INFO("HTTP plugin server accept thread exit");
+ });
+
+ return Port;
+}
+
+void
+SocketTransportPluginImpl::Stop()
+{
+ // TODO: all pending/ongoing work should be drained here as well
+
+ m_KeepRunning.clear();
+
+ closesocket(m_ListenSocket);
+ m_ListenSocket = 0;
+
+ if (m_AcceptThread.joinable())
+ {
+ m_AcceptThread.join();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount)
+: m_BasePort(BasePort)
+, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u))
+, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount))
+{
+}
+
+SocketTransportPlugin::~SocketTransportPlugin()
+{
+ delete m_Impl;
+}
+
+uint32_t
+SocketTransportPlugin::AddRef() const
+{
+ return RefCounted::AddRef();
+}
+
+uint32_t
+SocketTransportPlugin::Release() const
+{
+ return RefCounted::Release();
+}
+
+void
+SocketTransportPlugin::Initialize(TransportServerInterface* ServerInterface)
+{
+ m_Impl->Start(m_BasePort, ServerInterface);
+}
+
+void
+SocketTransportPlugin::Shutdown()
+{
+ m_Impl->Stop();
+}
+
+bool
+SocketTransportPlugin::IsAvailable()
+{
+ return true;
+}
+
+TransportPluginInterface*
+CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount)
+{
+ return new SocketTransportPlugin(BasePort, ThreadCount);
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/winsocktransport.h b/src/zenhttp/winsocktransport.h
new file mode 100644
index 000000000..34809cddc
--- /dev/null
+++ b/src/zenhttp/winsocktransport.h
@@ -0,0 +1,15 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpplugin.h>
+
+#if ZEN_WITH_PLUGINS
+
+namespace zen {
+
+TransportPluginInterface* CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount);
+
+} // namespace zen
+
+#endif